Papers
Topics
Authors
Recent
2000 character limit reached

JaxPP System: Deep Learning & Plasma Simulation

Updated 17 December 2025
  • JaxPP System is a dual-framework built on JAX that supports scalable pipeline-parallel deep learning for transformer models and differentiable PIC plasma simulations.
  • The deep learning component utilizes an MPMD architecture with custom gradient accumulation and asynchronous communication to achieve efficient resource utilization and performance gains.
  • The PIC framework, known as JAX-in-Cell, employs explicit Boris and implicit Crank–Nicolson solvers to ensure energy conservation and full differentiability for plasma simulations.

The JaxPP System encompasses two distinct state-of-the-art frameworks in computational science, both built on top of JAX but targeting fundamentally different domains: scalable deep learning via pipeline parallelism for transformer-scale models (Xhebraj et al., 18 Dec 2024), and fully differentiable particle-in-cell (PIC) plasma simulations (Ma et al., 13 Dec 2025). The term “JaxPP” thus denotes (1) a multi-program, multi-device (MPMD) pipeline-parallel training system for large neural networks, and (2) a 1D3V electromagnetic PIC code known as JAX-in-Cell, both leveraging JAX’s just-in-time (JIT) compilation, vectorization, and functional programming paradigm for performance and composability.

1. JaxPP for MPMD Pipeline-Parallel Deep Learning

Overview and Motivation

JaxPP addresses the challenge of training neural networks with hundreds of billions of parameters, where device-local memory is inadequate for model replication and communication bandwidth limits preclude scaling via conventional data or tensor parallelism alone. Traditional SPMD (single-program multiple-data) infrastructures such as JAX’s GSPMD and XLA facilitate intra-operator partitioning but lack efficient support for pipeline parallelism across heterogeneous stages. The JaxPP system lifts pipeline parallelism into a first-class MPMD abstraction, decoupling per-stage execution and enabling asynchrony and overlap of compute and communication (Xhebraj et al., 18 Dec 2024).

Programming Model and Schedules

JaxPP extends the JAX training API with two primitives: pipeline_yield(x), which partitions the model into logical pipeline stages, and accumulate_grads(f, schedule)(batch), which applies user-defined gradient accumulation schedules. Schedules encode the mapping of forward and backward tasks on microbatches and stages to independent SPMD actors (sets of GPUs). Standard schedules such as GPipe, 1F1B (one-forward-one-backward), and interleaved variants are supported, alongside custom task orderings. Gradient accumulation is optimized via a loop-commuting transformation to reduce inter-stage communication:

G=j(i=0M1gj(i))G = \sum_{j} \left( \sum_{i=0}^{M-1} g_j^{(i)} \right)

where gj(i)g_j^{(i)} is the gradient for parameter jj on microbatch ii.

Task Distribution and Communication Inference

Post-autodiff, JaxPP constructs a JAXPR where each pipeline_yield demarcates a new stage. Tasks are distributed such that all forward and backward operations belonging to a given stage are assigned to a dedicated actor. Communication dependencies are automatically determined: for each cross-actor data dependency, matched asynchronous NCCL.Send/NCCL.Recv calls are inserted, ordered topologically to guarantee deadlock avoidance.

MPMD Runtime and Execution

The runtime comprises a single Python controller process (with Ray for RPC orchestration) and actors—each a GPU pool executing XLA-compiled SPMD kernels. Local tasks within each actor (forward/backward execution, send/recv operations) are fused and dispatched as a single RPC. Critically, all P2P transfers are non-blocking, allowing compute-communication overlap; the makespan is the sum of per-actor compute times plus the critical communication path, less any overlap. Actor-local object stores and buffer liveness analysis minimize memory usage and synchronize buffer disposal post-send.

Performance Evaluation

Benchmarks on NVIDIA DGX H100 clusters show that, for GPT-3 175B on 128 GPUs, JaxPP achieves a step time of 9.64 s (457 TFLOPS/GPU), representing a 1.11×1.11\times speedup over JAX FSDP (10.70 s, 412 TFLOPS/GPU) and outperforming JAX SPMD-PP’s 13.96 s (316 TFLOPS/GPU). Weak scaling efficiency remains above 92.9% up to 1024 GPUs. JaxPP’s interleaved pipeline schedules yield up to ~20% savings on rematerialization and ~10% savings by overlapping P2P transfers, with dispatch overheads as the main residual bottleneck. The system allows fine-tuning of microbatch and interleaving parameters to trade off between pipeline bubbles and device utilization (Xhebraj et al., 18 Dec 2024).

Implementation and Integration

JaxPP integrates without forking JAX or XLA; transformed JAXPRs are compiled into separate XLA executables per stage. All data/tensor/expert sharding semantics are preserved. Recommendations include modest circular repeat values and microbatch sizes tailored to saturate hardware throughput, exploiting pipeline parallelism for bandwidth-constrained deployments, and flexible device assignment per stage for load balancing.

2. JaxPP as JAX-in-Cell: Differentiable PIC Framework

System Architecture

JAX-in-Cell, referenced as “JaxPP System,” implements an electromagnetic multispecies 1D3V PIC algorithm encompassing both explicit (Boris) and implicit (Crank–Nicolson) integrators (Ma et al., 13 Dec 2025). The entire simulation advances an immutable carry tuple—(E,B,xp,vp,q,m)(E, B, x_p, v_p, q, m)—via a single JIT-compiled JAX function utilizing lax.scan for time-stepping. State is stored as monolithic arrays (fields, particles, species charges/masses) to ensure vectorization, efficient device utilization, and purity of functional state transitions.

Governing Equations and Discretization

The core dynamical system is the Vlasov–Maxwell system:

  • tfs+vxfs+(qs/ms)(E+v×B)vfs=0\partial_t f_s + v \cdot \nabla_x f_s + (q_s/m_s)(E + v \times B) \cdot \nabla_v f_s = 0
  • tE=c2×BJ/ε0\partial_t E = c^2 \nabla \times B - J/\varepsilon_0
  • tB=×E\partial_t B = -\nabla \times E
  • dp/dt=q(E+v×B)dp/dt = q (E + v\times B)

Particles are deposited onto a Yee staggered mesh using quadratic spline shape functions for charge and current density, while fields are updated using second-order central differences. Field-to-particle interpolation also applies the same spline kernel to minimize noise and preserve momentum.

Particle-Pushing Algorithms

Two primary integrators are implemented:

  1. Explicit Boris Solver: Second-order, volume-preserving, but can exhibit energy drift under strong electric fields. Steps include a half-acceleration by EE, rotation in BB, another half-acceleration, and position update.
  2. Implicit Crank–Nicolson with Picard Iteration: Time-centering enables discrete energy conservation. Fixed-point iterations solve nonlinear velocity updates until convergence. Total system energy remains constant up to machine precision in implicit runs.

Boundary Conditions and Multi-species Support

JaxPP supports periodic, reflective, and absorbing boundary conditions for both particles and fields. Multi-species handling is achieved by initialization-stage concatenation of species attribute arrays, with all subsequent operations vectorized over the global phase-space representation, optimizing for GPU/TPU execution efficiency.

Differentiable Programming

Every step in the simulation pipeline, from particle push to field update, is composed of JAX primitives. The entire timestepping operation is traceable for forward/reverse-mode autodifferentiation via JAX’s AD system. This enables Jacobian and gradient computation through many timesteps with no need for explicit adjoints. Sample use cases include inverse design problems and ML-in-the-loop plasma optimization.

Performance and Validation

On standard benchmark problems in 1D plasma physics—including Landau damping, two-stream instability, Weibel instability, and bump-on-tail—JaxPP matches or reproduces analytic/theoretical rates, with bounded numerical errors (e.g., energy conservation to <1011<10^{-11} in implicit mode). On NERSC Perlmutter, an NVIDIA A100 GPU achieves up to 100× speedup relative to an AMD EPYC CPU for particle scans of >60,000>60,000 particles, with similar algorithmic performance between 32- and 64-bit precision except at the strictest accuracy thresholds (Ma et al., 13 Dec 2025).

3. Extensibility and Use Cases

The modular, functionally pure, and vectorized design of both JaxPP systems supports straightforward extensibility:

  • For deep learning, users can customize pipeline decomposition, scheduling, and device assignment, integrating with heterogeneous hardware or exploring atypical pipeline configurations.
  • In plasma simulation, extensions include higher spatial dimensions (2D3V, 3D3V), alternative deposit/interpolation kernels, new integrators, collisional terms, or hybrid fluid-kinetic schemes.
  • The differentiable nature of JAX-in-Cell allows embedding into Physics-Informed Neural Networks or adjoint-based design optimization pipelines, unlocking joint physics and machine learning applications.

In deep learning, JaxPP’s pipeline abstraction and explicit asynchronous MPMD execution contrast with GPipe-style SPMD loops and NeMo’s pipeline/tensor/data hybrid. JaxPP matches or exceeds FSDP and SPMD-PP in scaling efficiency (up to 1.11× hardware utilization improvement), with remaining performance gaps attributable to aggregate kernel sizes and dispatch overhead. The system operates as a drop-in JAX extension, without forking or diverging from primary upstream libraries (Xhebraj et al., 18 Dec 2024).

JAX-in-Cell (JaxPP) distinguishes itself from traditional PIC codes (e.g., Fortran, C++ FDTD codes) by offering conciseness, full autodiff, and seamless execution across Python, JIT-compiled CPU, GPU, and TPU targets. The architecture bridges educational and production use, favoring maintainability, extensibility, and coupling to differentiable programming (Ma et al., 13 Dec 2025).

5. Practical and Implementation Considerations

In pipeline-parallel training, practitioners are advised to select microbatch size and interleaving factor to optimize throughput, balancing pipeline bubbles (idle device time) and kernel launch efficiency. Buffer liveness and memory management are automated via actor-local object stores and explicit delete directives post-communication. Heterogeneous workloads are accommodated by manual device assignment per stage.

For JAX-in-Cell, users can implement new physical models by extending or composing the modular routines (e.g., new field solvers, boundary types, deposition kernels) and can exploit the test suite and continuous integration (GitHub Actions) for workflow robustness. A plausible implication is that the structure supports rapid research prototyping as well as production-scale multi-GPU execution.

6. Summary of Key Features

JaxPP Variant Domain Parallelism/Mode Core Innovations
Deep Learning Transformer-scale models MPMD pipeline parallel Asynchronous P2P, custom schedules, fine-grained stage/executable placement (Xhebraj et al., 18 Dec 2024)
JAX-in-Cell PIC Plasma physics JAX vectorized, JIT Fully differentiable, explicit/implicit solvers, modular composition, GPU/TPU-targeted (Ma et al., 13 Dec 2025)

Each system demonstrates the performance, composability, and extensibility of pure-JAX scientific computing under demanding real-world workloads.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (2)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to JaxPP System.