Deep JKO: Neural Methods for Gradient Flows
- Deep JKO is a variational framework that integrates the implicit Euler JKO scheme with deep neural networks to solve high-dimensional PDEs.
- It employs neural ODEs, residual networks, and particle methods to preserve mass conservation and energy dissipation while approximating gradient flows.
- Deep JKO offers scalable solutions for generative modeling, sampling, and inverse problems through structured network designs and efficient optimization.
Deep JKO refers to the integration of the Jordan–Kinderlehrer–Otto (JKO) time-implicit variational discretization for Wasserstein gradient flows with deep neural network approximations. The resulting methodologies—known as Deep JKO schemes or algorithms—replace classical PDE-based or optimization-theoretic solvers for JKO subproblems with network parameterizations, most commonly neural ordinary differential equations (neural ODEs), residual networks, or particle methods. This paradigm enables scalable and flexible numerical solutions of high-dimensional PDEs, generative modeling, sampling algorithms, and learning of dynamics or energies from data, while preserving fundamental structural properties such as mass conservation and energy dissipation.
1. The JKO Scheme and Deep Neural Parameterization
The JKO scheme provides an implicit Euler discretization for gradient flows in the space of probability measures equipped with the 2-Wasserstein metric. For a chosen energy functional , the JKO update with time step is
where denotes the Wasserstein-2 distance. In practice, as , these iterates converge to the Wasserstein gradient flow (Xu et al., 2022, Vidal et al., 2022).
Deep JKO algorithms solve each proximal step by parameterizing the transport map or the velocity field using a neural network (e.g., MLP, ResNet, or attention-based operator), and train these parameters to minimize the discrete JKO objective. This enables tackling high-dimensional, nonlinear, and nonlocal PDEs and generative modeling tasks previously inaccessible to traditional grid- or particle-based methods (Lee et al., 2023, Lee et al., 25 Mar 2026, Georgoulis et al., 2022).
2. Core Algorithmic Concepts: Variational Structure and Network Architecture
Deep JKO algorithms consistently combine three structural components:
- JKO variational step: For measure , parameterize a class of maps or velocity fields and define pushforwards or flows via a neural ODE or residual block discretization. The optimal parameters minimize
Constraints such as 0 or their Hamiltonian/KFP generalizations are imposed (Vidal et al., 2022, Lee et al., 2023, Lee et al., 25 Mar 2026).
- Neural representation: The velocity 1 or map 2 may be
- a time-dependent MLP or ResNet (for general Euclidean datasets) (Xu et al., 2022, Hertrich et al., 2024),
- a residual block (for block-wise JKO iteration) (Xu et al., 2022),
- a scalar potential with 3 to enforce conservative flow (Lee et al., 2023),
- a Transformer or attention-based set operator for whole-population updates (Feng et al., 9 Jan 2026).
- Density computation and change-of-variables: Densities after each JKO step are computed via the neural ODE's change-of-variable formula, using the trace of the velocity Jacobian along the flow. Efficient estimators, including Hutchinson trace estimators, are employed for high dimension (Vidal et al., 2022, Hertrich et al., 2024). Pushforward densities are updated as 4.
- Training loop and optimization: For each JKO time step, network parameters are optimized via stochastic gradient methods (Adam, SGD), using either full-step or block-wise approaches. For population learning problems (e.g., iJKOnet), adversarial min–max or min–min optimization is used to jointly update potentials and transport maps (Persiianov et al., 2 Jun 2025).
3. Advanced Architectures and Learning Strategies
Several variants extend the Deep JKO paradigm:
- Blockwise JKO flows: JKO-iFlow introduces a residual block architecture, allowing per-block training with adaptability and scalable memory usage. An adaptive time reparameterization ensures each block performs approximately equal transport in 5, and progressive block refinement improves accuracy and invertibility (Xu et al., 2022).
- Self-supervised neural JKO operators: Rather than sequentially solving each JKO subproblem, one can fit a single operator 6 mapping input densities directly to the JKO minimizer, trained in a performative “learn-to-evolve” loop (Feng et al., 9 Jan 2026). This alternates trajectory generation (using the current operator) with operator fitting on generated trajectories, implicitly bootstrapping toward the true JKO flow.
- Particle- and kinetic-based Deep JKO: For kinetic PDEs such as Vlasov–Fokker–Planck, the JKO step is extended to phase space, and the velocity field parameterized by deep networks acts as a control in a kinetic ODE driven by both Hamiltonian (conservative) and dissipative (JKO) structure (Lee et al., 25 Mar 2026).
- Inverse optimization and adversarial learning: iJKOnet reconstructs unknown energies (potentials, interactions) from unpaired time snapshots of evolving densities by adversarially minimizing the JKO gap over energy candidates and transport maps. This yields strong empirical and theoretical guarantees for system identification and population dynamics learning (Persiianov et al., 2 Jun 2025).
4. Convergence Guarantees and Structural Properties
Deep JKO methods inherit several theoretical guarantees from the classical JKO scheme:
- Convergence: For convex energies (e.g., KL divergence), the sequence 7 generated by any fixed-step implicit JKO scheme converges weakly to the minimizer of 8, with explicit rates 9 for KL flows (Vidal et al., 2022).
- Consistency: Provided the network parameterization and ODE solver can solve each subproblem to global optimality and with vanishing error, Deep JKO converges to the true continuous Wasserstein gradient flow (Vidal et al., 2022, Lee et al., 2023, Lee et al., 25 Mar 2026).
- Mass conservation and dissipation: The variational structure ensures that solutions conserve total mass and dissipate the chosen energy functional at each step (Lee et al., 2023, Lee et al., 25 Mar 2026).
- Approximation error: Realistic networks yield nonzero residuals due to capacity or optimization limitations, but empirical error can often be reduced below statistical fluctuation thresholds (mode-weight MSE 0, log-normalizer error 1 in high-dim sampling (Hertrich et al., 2024)).
5. Practical Implementation and Empirical Performance
Choices in implementation affect scalability, accuracy, and efficiency:
- Network architectures: Shallow ResNets (width 8–16, depth 2–4), MLPs (widths up to 1024 for 2), and attention-based blocks for set- or particle-based data. Special convolutional (image) and graph (ChebNet) variants address structured datasets (Xu et al., 2022).
- ODE integrators: Dormand–Prince (RK45), classical RK4 over 3, or custom symplectic/pIC solvers for kinetic equations. Tolerances typically 4 (sampling) or 5 (density evolution).
- Regularization: ODE trace penalty and 6 to enforce convexity and control overfitting or variance.
- Hyperparameters: JKO step size (e.g., 7) acts as the time step, with little sensitivity except for trade-off between step difficulty and total iterations (Vidal et al., 2022).
- Computational cost: Memory usage per block 1–2GB, batch sizes up to 1024, per-block training time 8–9s, total runtime scales with network size and problem dimension (Xu et al., 2022, Hertrich et al., 2024).
- Parallelization: Blockwise or per-time-step architectures allow distributed training, while attention-based operators afford set-level parallel pushforward.
- Numerical performance: Achieves state-of-the-art error in high-dimensional density modeling and sampling (energy distance 0–1), low error in PDE solutions (relative 2 error 3–4 up to 5), and strong generalization in inverse population dynamics (Persiianov et al., 2 Jun 2025, Hertrich et al., 2024, Georgoulis et al., 2022).
6. Applications and Scope of Deep JKO
Deep JKO frameworks have been deployed in diverse scientific and machine learning contexts:
- Generative modeling and normalizing flows: Deep JKO yields scalable, hyperparameter-robust density models via neural ODEs, with improved fitting metrics and reduced tuning overhead compared to single-shot CNFs and diffusion models (Vidal et al., 2022, Xu et al., 2022).
- Unnormalized density sampling: Importance-corrected Deep JKO combines CNFs with rejection-resampling to guarantee i.i.d. samples and accurate normalization, outperforming MALA, HMC, and diffusion samplers in high dimension and multimodal regimes (Hertrich et al., 2024).
- Learning from evolutionary data: iJKOnet reconstructs unknown potentials/interactions governing observed population dynamics (e.g., cell population flows), recovering model parameters and trajectories from discrete sample snapshots (Persiianov et al., 2 Jun 2025).
- PDE solvers: Deep JKO approaches solve Fokker–Planck, porous-medium, kinetic Fokker–Planck, and multicomponent aggregation equations in high dimension with accuracy and efficiency unobtainable by traditional grid or particle methods (Georgoulis et al., 2022, Lee et al., 2023, Lee et al., 25 Mar 2026).
- Bayesian inverse problems: Kalman–Wasserstein gradient flows for Bayesian posterior exploration can be implemented via Deep JKO Lagrangian particle formulations, with empirical convergence to the posterior (Lee et al., 2023).
7. Limitations, Open Challenges, and Future Directions
While offering major practical and theoretical advances, Deep JKO methods present challenges:
- Optimization landscape: Each JKO subproblem involves nonconvex optimization for network parameters, potentially requiring careful tuning, initialization, and stabilization (e.g., gradient clipping, learning rate schedules) (Lee et al., 2023, Lee et al., 25 Mar 2026).
- Approximation and generalization: Finite network capacity, sampling noise, and highly nonconvex/degenerate energies may limit long-horizon accuracy or generalization. Statistical theory for learned JKO operators remains an open topic (Feng et al., 9 Jan 2026).
- Scalability: Extremely high dimensions (6), long time horizons, or stiff dynamics may strain memory and computational resources, especially for kinetic or blockwise schemes (Hertrich et al., 2024, Lee et al., 25 Mar 2026).
- Structural preservation: While JKO structure guarantees mass conservation and energy dissipation, full preservation of higher-order invariants (e.g., symplecticity in Hamiltonian systems) is only approximate due to discretization and optimization error (Lee et al., 25 Mar 2026).
- Extensions: Deep JKO methods can plausibly be extended to other proximal operators (e.g., Moreau–Yosida), composite splitting schemes, or time-adaptive steps. Applications to mean-field games, data-driven prior learning, and hybrid variational–score models are active research frontiers (Feng et al., 9 Jan 2026).
Deep JKO thus constitutes a foundational bridge between measure-theoretic variational principles and scalable data-driven modeling, with a rapidly expanding range of theoretical and applied impacts across machine learning, PDEs, Bayesian inference, and population dynamics (Vidal et al., 2022, Xu et al., 2022, Lee et al., 25 Mar 2026, Hertrich et al., 2024, Persiianov et al., 2 Jun 2025, Feng et al., 9 Jan 2026, Georgoulis et al., 2022, Lee et al., 2023).