Papers
Topics
Authors
Recent
Search
2000 character limit reached

Deep JKO: Neural Methods for Gradient Flows

Updated 2 April 2026
  • Deep JKO is a variational framework that integrates the implicit Euler JKO scheme with deep neural networks to solve high-dimensional PDEs.
  • It employs neural ODEs, residual networks, and particle methods to preserve mass conservation and energy dissipation while approximating gradient flows.
  • Deep JKO offers scalable solutions for generative modeling, sampling, and inverse problems through structured network designs and efficient optimization.

Deep JKO refers to the integration of the Jordan–Kinderlehrer–Otto (JKO) time-implicit variational discretization for Wasserstein gradient flows with deep neural network approximations. The resulting methodologies—known as Deep JKO schemes or algorithms—replace classical PDE-based or optimization-theoretic solvers for JKO subproblems with network parameterizations, most commonly neural ordinary differential equations (neural ODEs), residual networks, or particle methods. This paradigm enables scalable and flexible numerical solutions of high-dimensional PDEs, generative modeling, sampling algorithms, and learning of dynamics or energies from data, while preserving fundamental structural properties such as mass conservation and energy dissipation.

1. The JKO Scheme and Deep Neural Parameterization

The JKO scheme provides an implicit Euler discretization for gradient flows in the space of probability measures equipped with the 2-Wasserstein metric. For a chosen energy functional E[ρ]\mathcal{E}[\rho], the JKO update with time step τ\tau is

ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},

where W2W_2 denotes the Wasserstein-2 distance. In practice, as τ0\tau \to 0, these iterates converge to the Wasserstein gradient flow tρ=W2E(ρ)\partial_t \rho = -\nabla_{W_2} \mathcal{E}(\rho) (Xu et al., 2022, Vidal et al., 2022).

Deep JKO algorithms solve each proximal step by parameterizing the transport map or the velocity field using a neural network (e.g., MLP, ResNet, or attention-based operator), and train these parameters to minimize the discrete JKO objective. This enables tackling high-dimensional, nonlinear, and nonlocal PDEs and generative modeling tasks previously inaccessible to traditional grid- or particle-based methods (Lee et al., 2023, Lee et al., 25 Mar 2026, Georgoulis et al., 2022).

2. Core Algorithmic Concepts: Variational Structure and Network Architecture

Deep JKO algorithms consistently combine three structural components:

  1. JKO variational step: For measure ρk\rho^k, parameterize a class of maps or velocity fields vθv_\theta and define pushforwards or flows z(x,t)z(x,t) via a neural ODE or residual block discretization. The optimal parameters minimize

L(θ)=12τ01vθ(z(x,t),t)2ρk(x)dxdt+E[z(,1)ρk].\mathcal{L}(\theta) = \frac{1}{2\tau} \int_0^1 \int \|v_\theta(z(x,t), t)\|^2 \rho^k(x) \,dx\,dt + \mathcal{E}[z(\cdot,1)_\sharp \rho^k].

Constraints such as τ\tau0 or their Hamiltonian/KFP generalizations are imposed (Vidal et al., 2022, Lee et al., 2023, Lee et al., 25 Mar 2026).

  1. Neural representation: The velocity τ\tau1 or map τ\tau2 may be
  2. Density computation and change-of-variables: Densities after each JKO step are computed via the neural ODE's change-of-variable formula, using the trace of the velocity Jacobian along the flow. Efficient estimators, including Hutchinson trace estimators, are employed for high dimension (Vidal et al., 2022, Hertrich et al., 2024). Pushforward densities are updated as τ\tau4.
  3. Training loop and optimization: For each JKO time step, network parameters are optimized via stochastic gradient methods (Adam, SGD), using either full-step or block-wise approaches. For population learning problems (e.g., iJKOnet), adversarial min–max or min–min optimization is used to jointly update potentials and transport maps (Persiianov et al., 2 Jun 2025).

3. Advanced Architectures and Learning Strategies

Several variants extend the Deep JKO paradigm:

  • Blockwise JKO flows: JKO-iFlow introduces a residual block architecture, allowing per-block training with adaptability and scalable memory usage. An adaptive time reparameterization ensures each block performs approximately equal transport in τ\tau5, and progressive block refinement improves accuracy and invertibility (Xu et al., 2022).
  • Self-supervised neural JKO operators: Rather than sequentially solving each JKO subproblem, one can fit a single operator τ\tau6 mapping input densities directly to the JKO minimizer, trained in a performative “learn-to-evolve” loop (Feng et al., 9 Jan 2026). This alternates trajectory generation (using the current operator) with operator fitting on generated trajectories, implicitly bootstrapping toward the true JKO flow.
  • Particle- and kinetic-based Deep JKO: For kinetic PDEs such as Vlasov–Fokker–Planck, the JKO step is extended to phase space, and the velocity field parameterized by deep networks acts as a control in a kinetic ODE driven by both Hamiltonian (conservative) and dissipative (JKO) structure (Lee et al., 25 Mar 2026).
  • Inverse optimization and adversarial learning: iJKOnet reconstructs unknown energies (potentials, interactions) from unpaired time snapshots of evolving densities by adversarially minimizing the JKO gap over energy candidates and transport maps. This yields strong empirical and theoretical guarantees for system identification and population dynamics learning (Persiianov et al., 2 Jun 2025).

4. Convergence Guarantees and Structural Properties

Deep JKO methods inherit several theoretical guarantees from the classical JKO scheme:

  • Convergence: For convex energies (e.g., KL divergence), the sequence τ\tau7 generated by any fixed-step implicit JKO scheme converges weakly to the minimizer of τ\tau8, with explicit rates τ\tau9 for KL flows (Vidal et al., 2022).
  • Consistency: Provided the network parameterization and ODE solver can solve each subproblem to global optimality and with vanishing error, Deep JKO converges to the true continuous Wasserstein gradient flow (Vidal et al., 2022, Lee et al., 2023, Lee et al., 25 Mar 2026).
  • Mass conservation and dissipation: The variational structure ensures that solutions conserve total mass and dissipate the chosen energy functional at each step (Lee et al., 2023, Lee et al., 25 Mar 2026).
  • Approximation error: Realistic networks yield nonzero residuals due to capacity or optimization limitations, but empirical error can often be reduced below statistical fluctuation thresholds (mode-weight MSE ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},0, log-normalizer error ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},1 in high-dim sampling (Hertrich et al., 2024)).

5. Practical Implementation and Empirical Performance

Choices in implementation affect scalability, accuracy, and efficiency:

  • Network architectures: Shallow ResNets (width 8–16, depth 2–4), MLPs (widths up to 1024 for ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},2), and attention-based blocks for set- or particle-based data. Special convolutional (image) and graph (ChebNet) variants address structured datasets (Xu et al., 2022).
  • ODE integrators: Dormand–Prince (RK45), classical RK4 over ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},3, or custom symplectic/pIC solvers for kinetic equations. Tolerances typically ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},4 (sampling) or ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},5 (density evolution).
  • Regularization: ODE trace penalty and ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},6 to enforce convexity and control overfitting or variance.
  • Hyperparameters: JKO step size (e.g., ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},7) acts as the time step, with little sensitivity except for trade-off between step difficulty and total iterations (Vidal et al., 2022).
  • Computational cost: Memory usage per block 1–2GB, batch sizes up to 1024, per-block training time ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},8–ρk+1=argminρ{12τW22(ρ,ρk)+E(ρ)},\rho_{k+1} = \arg\min_{\rho} \left\{ \frac{1}{2\tau} W_2^2(\rho, \rho_k) + \mathcal{E}(\rho) \right\},9s, total runtime scales with network size and problem dimension (Xu et al., 2022, Hertrich et al., 2024).
  • Parallelization: Blockwise or per-time-step architectures allow distributed training, while attention-based operators afford set-level parallel pushforward.
  • Numerical performance: Achieves state-of-the-art error in high-dimensional density modeling and sampling (energy distance W2W_20–W2W_21), low error in PDE solutions (relative W2W_22 error W2W_23–W2W_24 up to W2W_25), and strong generalization in inverse population dynamics (Persiianov et al., 2 Jun 2025, Hertrich et al., 2024, Georgoulis et al., 2022).

6. Applications and Scope of Deep JKO

Deep JKO frameworks have been deployed in diverse scientific and machine learning contexts:

  • Generative modeling and normalizing flows: Deep JKO yields scalable, hyperparameter-robust density models via neural ODEs, with improved fitting metrics and reduced tuning overhead compared to single-shot CNFs and diffusion models (Vidal et al., 2022, Xu et al., 2022).
  • Unnormalized density sampling: Importance-corrected Deep JKO combines CNFs with rejection-resampling to guarantee i.i.d. samples and accurate normalization, outperforming MALA, HMC, and diffusion samplers in high dimension and multimodal regimes (Hertrich et al., 2024).
  • Learning from evolutionary data: iJKOnet reconstructs unknown potentials/interactions governing observed population dynamics (e.g., cell population flows), recovering model parameters and trajectories from discrete sample snapshots (Persiianov et al., 2 Jun 2025).
  • PDE solvers: Deep JKO approaches solve Fokker–Planck, porous-medium, kinetic Fokker–Planck, and multicomponent aggregation equations in high dimension with accuracy and efficiency unobtainable by traditional grid or particle methods (Georgoulis et al., 2022, Lee et al., 2023, Lee et al., 25 Mar 2026).
  • Bayesian inverse problems: Kalman–Wasserstein gradient flows for Bayesian posterior exploration can be implemented via Deep JKO Lagrangian particle formulations, with empirical convergence to the posterior (Lee et al., 2023).

7. Limitations, Open Challenges, and Future Directions

While offering major practical and theoretical advances, Deep JKO methods present challenges:

  • Optimization landscape: Each JKO subproblem involves nonconvex optimization for network parameters, potentially requiring careful tuning, initialization, and stabilization (e.g., gradient clipping, learning rate schedules) (Lee et al., 2023, Lee et al., 25 Mar 2026).
  • Approximation and generalization: Finite network capacity, sampling noise, and highly nonconvex/degenerate energies may limit long-horizon accuracy or generalization. Statistical theory for learned JKO operators remains an open topic (Feng et al., 9 Jan 2026).
  • Scalability: Extremely high dimensions (W2W_26), long time horizons, or stiff dynamics may strain memory and computational resources, especially for kinetic or blockwise schemes (Hertrich et al., 2024, Lee et al., 25 Mar 2026).
  • Structural preservation: While JKO structure guarantees mass conservation and energy dissipation, full preservation of higher-order invariants (e.g., symplecticity in Hamiltonian systems) is only approximate due to discretization and optimization error (Lee et al., 25 Mar 2026).
  • Extensions: Deep JKO methods can plausibly be extended to other proximal operators (e.g., Moreau–Yosida), composite splitting schemes, or time-adaptive steps. Applications to mean-field games, data-driven prior learning, and hybrid variational–score models are active research frontiers (Feng et al., 9 Jan 2026).

Deep JKO thus constitutes a foundational bridge between measure-theoretic variational principles and scalable data-driven modeling, with a rapidly expanding range of theoretical and applied impacts across machine learning, PDEs, Bayesian inference, and population dynamics (Vidal et al., 2022, Xu et al., 2022, Lee et al., 25 Mar 2026, Hertrich et al., 2024, Persiianov et al., 2 Jun 2025, Feng et al., 9 Jan 2026, Georgoulis et al., 2022, Lee et al., 2023).

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Deep JKO.