Papers
Topics
Authors
Recent
2000 character limit reached

Differentiable OT Layer via Neural ODEs

Updated 6 November 2025
  • Differentiable optimal transport layer is a module that parameterizes and solves OT problems with neural ODE-based rectified flow while preserving exact marginal constraints.
  • It uses unconstrained SGD-trained neural ODE steps that minimize convex transport costs, ensuring an interior, stable coupling between distributions.
  • This design overcomes instability in traditional dual or penalty methods and is applicable to tasks such as generative modeling, domain adaptation, and structured alignment.

A differentiable optimal transport (OT) layer is a module within a computational graph that parameterizes, solves, and backpropagates through an optimal transport problem, enabling integration into deep learning architectures. This approach exploits recent advancements in OT theory, computational optimal transport, neural parameterizations, and scientific computing to achieve end-to-end training of models that require distributional alignment, transport, or coupling. Differentiable OT layers are employed for tasks such as generative modeling, domain adaptation, structured alignment, and neural architecture regularization.

1. Fundamental Principles and Problem Setting

The central object computed by a differentiable OT layer is a coupling γ\gamma that minimizes a given transport cost while matching prescribed marginals. In the canonical Monge–Kantorovich problem between distributions π0\pi_0 and π1\pi_1 on Rd\mathbb{R}^d and a convex cost cc,

minγΠ(π0,π1)E(X0,X1)γ[c(X1X0)]\min_{\gamma \in \Pi(\pi_0, \pi_1)} \mathbb{E}_{(X_0, X_1) \sim \gamma} [ c(X_1 - X_0) ]

where Π(π0,π1)\Pi(\pi_0, \pi_1) denotes the set of couplings with marginals π0\pi_0 and π1\pi_1. The challenge arises from the requirement that the OT solution be fully differentiable with respect to inputs, neural network weights, or any other computational graph parameters, and that it preserves marginal constraints exactly at every step.

The differentiable OT layer, as realized in the rectified flow framework (Liu, 2022), constructs the coupling not by explicit projection or dual variable optimization, but rather as the solution of a flow parameterized by neural ordinary differential equations (ODEs), ensuring each step is strictly within the set of valid couplings and supports efficient and stable automatic differentiation.

2. Rectified Flow: Marginal-Preserving Neural ODE Construction

Rectified flow defines a sequence of neural ODEs, each trained via unconstrained regression, that progressively reduces the transport cost while automatically preserving marginal constraints throughout the flow. Given an initial valid coupling (X0,X1)(X_0, X_1), a time-dependent neural ODE is constructed: dZtdt=vtX(Zt),Z0=X0\frac{dZ_t}{dt} = v^X_t(Z_t), \quad Z_0 = X_0 where the time index t[0,1]t \in [0,1] interpolates between source and target (Xt=tX1+(1t)X0X_t = t X_1 + (1-t) X_0).

The velocity field vtX(z)v_t^X(z) is learned by regressing against the displacement: infv01E[X1X0v(Xt,t)2]dt\underset{v}{\inf} \int_{0}^{1} \mathbb{E}\left[\|X_1 - X_0 - v(X_t, t)\|^2\right] dt This regression is parameterized by a neural network and trained with stochastic gradient descent (SGD) using samples (X0,X1)(X_0, X_1), requiring no inner optimization loop, adversarial game, or projection step.

Each ODE step, when solved, transforms the current coupling into a new coupling with the same marginals; the marginal-preserving property holds by construction. This enables composition of multiple ODEs (multiple flow steps), each reducing transport cost, yielding a monotonic, interior trajectory within the space of valid couplings.

3. User-Specified Cost Functions and Bregman Loss Extension

While the original rectified flow approach decreases the entire family of transport costs induced by convex functions, the c-rectified flow variant solves the OT problem for a specific convex cost cc. The key mechanism is:

  • Constraining the velocity field to a form vt(x)=c(ft(x))v_t(x) = \nabla c^*(\nabla f_t(x)), where cc^* is the convex conjugate of cc and ftf_t is parameterized by a (possibly deep) neural network.
  • Minimizing: inff01E[c(f(Xt))(X1X0)f(Xt)+c(X1X0)]dt\inf_{f} \int_0^1 \mathbb{E}\left[ c^*(\nabla f(X_t)) - (X_1-X_0)^\top \nabla f(X_t) + c(X_1-X_0) \right] dt This loss constitutes a Bregman divergence-like term, ensuring the learned flow aligns increments (X1X0)(X_1 - X_0) with the structure imposed by the convex cost.

Recursive application of c-rectified flow steps, each parameterized and optimized in this way, produces a sequence of couplings with strictly decreasing cc-transport cost, converging to the OT solution.

4. Differentiability and Backpropagation

The full computation pipeline—from sample input pairs, through neural ODE evaluation, to cost-based regression—is differentiable:

  • Gradients with respect to all parameters (neural ODE weights) are propagated via the neural ODE adjoint method (as in [Chen et al., 2018]), which solves a dual ODE backward in time to obtain vector-Jacobian products efficiently.
  • Differentiation backpropagates through the regression losses, neural network architectures, and any downstream objective.
  • Because marginal matching is enforced at every step internally (rather than via projection or penalty), the gradients are stable and not subject to marginals "drifting" off-manifold.

This endows the rectified flow OT layer with exact, sample-level, and efficient gradient flow, making it directly usable as a differentiable component in deep learning.

5. Design Comparison with Prior Differentiable OT Layers

Previous approaches to differentiable OT layers fall into several categories:

  • Lagrangian or dual methods (e.g., Sinkhorn, entropic regularization): Require inner minimax/dual optimization or iterative projection, which can be unstable in high-dimensional or neural settings, and generally only approximately satisfy marginal constraints.
  • Penalty-based or projection methods: Rely on soft constraints or penalizing marginal violations, which can introduce biases or inefficiencies, and can result in drift away from exact matching.
  • Neural input-convex networks (ICNN) approaches: Focus on OT maps (Monge maps) directly, not couplings, and may require stronger supervision or additional structure.

In contrast, the rectified flow differentiable OT layer:

  • Requires only unconstrained SGD and regression at each step (no inner game).
  • Preserves marginals exactly by construction at every stage.
  • Is robustly differentiable due to the absence of hard constraints or projection steps.
  • Admits any convex cost function cc, rather than being restricted to quadratic or Euclidean costs.

A summary comparison is given below.

Aspect Rectified Flow OT Layer Dual/Sinkhorn/Projection Methods
Constraint Exact marginal matching Approximate via penalty/projection
Training Unconstrained regression Minimax or iterative dual/projection
Cost flexibility Any convex Often quadratic only
Differentiability Full (ODE + NN) Sometimes unstable/backprop heavy
Efficiency High (no inner loop) Slower (inner optimization)

6. Architectural Integration and Applications

A differentiable OT layer based on rectified flow can be used within deep learning pipelines as follows:

  • Forward pass:
  1. Obtain batched samples (or features) distributed according to π0\pi_0 and π1\pi_1 (e.g., outputs of encoders or generative models).
  2. Form an initial coupling; optionally, use product couplings for independence.
  3. Apply $1$–KK steps of the rectified flow neural ODE, each with separately parameterized and trained drift fields.
  4. Output the final (transformed) coupling, now closer to the minimal transport cost dictated by the user’s objective.
  • Backward pass:
    • Automatic differentiation flows through all neural ODEs and their loss functions, enabling end-to-end training.

Practical applications include:

  • Generative models: Transport from prior noise or latent representations to the data distribution, enforcing user-specified costs (beyond standard normalizing flows or diffusion models).
  • Domain adaptation: Aligning feature distributions or embeddings exactly at the distributional level, rather than via approximate penalties.
  • Variational inference: Using OT layers for reparameterization or constructing couplings between prior and posterior distributions.
  • GANs: Structure regularization or alignment between generator and target distributions via OT, with stable, exact marginal matching.
  • Transfer learning and matching tasks: Providing a mechanism to guarantee exact distributional matching across modalities or domains.

7. Computational and Scaling Properties

The key computational advantages are:

  • No inner-loop optimization—all steps are SGD-based regression losses.
  • Neural ODE integration—modern adjoint methods offer scalable and memory-efficient backpropagation.
  • Sample and batch-level parallelizability—the method is directly applicable in distributed or GPU-accelerated environments.
  • Exact constraint enforcement—no post-hoc correction or projection, so no trade-off between approximation quality and efficiency.
  • Any convex cost—parameterization of cc enables adaptation to complex problem requirements.

In high-dimensional settings, the method compares favorably to dual or penalty-based methods, particularly in stability, differentiability, and sample efficiency, as shown in the empirical evaluations and convergence properties in (Liu, 2022).


Key Takeaway

The differentiable optimal transport layer enabled by rectified flow achieves a combination of efficiency, flexibility, and theoretical rigor not available in previous approaches. It constructs a neural ODE-based interior path in the space of couplings between distributions, strictly preserving marginals, minimizing user-specified convex transport costs, and supporting stable, end-to-end differentiation—all critical capabilities for modern machine learning systems requiring principled distributional alignment or matching.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)
Forward Email Streamline Icon: https://streamlinehq.com

Follow Topic

Get notified by email when new papers are published related to Differentiable Optimal Transport Layer.