JaxPP: Accelerated JAX-Based Computation
- JaxPP is a suite of JAX-based frameworks that accelerates scientific and engineering computations using techniques like pipeline parallelism for deep learning.
- It features a differentiable population balance equation solver that enables end-to-end gradient computation for rapid parameter estimation in process engineering.
- JaxPP includes a batched AC power flow solver that delivers significant GPU speed-ups, seamless integration with AI pipelines, and scalable power network analysis.
JaxPP is a designation for multiple specialized software frameworks leveraging the JAX numerical computation library to accelerate and automate scientific and engineering workflows at scale. Notably, JaxPP refers to: (1) a high-performance library for distributed deep learning pipeline parallelism; (2) a differentiable solver for population balance equations (PBE); and (3) a batched AC power flow solver for power grid applications. All variants exploit JAX primitives—JIT compilation, auto-vectorization, automatic differentiation (AD), and GPU/TPU support—to deliver domain-specific acceleration, robust differentiability, and seamless integration with the broader JAX-based AI stack.
1. High-Performance Pipeline Parallelism for Deep Learning
JaxPP, as introduced in "Scaling Deep Learning Training with MPMD Pipeline Parallelism" (Xhebraj et al., 2024), addresses the scaling of deep learning model training through flexible pipeline parallelism. It enables users to annotate ordinary JAX model code with minimal intrusion, specifying only pipeline stage boundaries via pipeline_yield and pipeline schedule objects. Pipeline schedules—including GPipe, 1F1B, and interleaved variants—are defined explicitly, while all intra-stage parallelism (data, tensor, expert) remains managed by JAX's SPMD compiler stack.
The JaxPP system introduces a Multiple-Program Multiple-Data (MPMD) runtime for asynchronous execution of statically defined Single-Program Multiple-Data (SPMD) tasks. It uses a single controller process to trace the user’s training loop into JAX’s IR, applying pipeline transformations to produce a global task graph. Each actor process is assigned a fused local SPMD executable incorporating forward/backward passes, inter-stage send/recv operations employing NCCL P2P communication, and buffer management routines. Communication patterns and stage placements are automatically inferred through dependency analysis of the JAX computation graph.
Gradient accumulation is handled by the primitive accumulate_grads, abstracting the microbatch loop from the user. For shared parameter usage (e.g., tied embeddings), JaxPP implements a loop commuting rewrite, formally rewriting
as
to optimize communication of gradient sums across stages.
Performance benchmarking on GPT-3 175B and Llama-2 70B demonstrates that JaxPP provides a 44.6% speedup over JAX SPMD pipeline GPipe (9.64s vs. 13.96s/step, 457 vs 316 TFLOPS/device), and achieves up to a 1.11× increase in hardware utilization compared with JAX FSDP, with weak-scaling efficiencies above 92% across 1,024 GPUs. JaxPP attains over 91% of NeMo’s throughput for comparable configurations but requires only cuDNN-backed attention kernels—no custom CUDA (Xhebraj et al., 2024).
Key trade-offs include the dispatch versus bubble reduction frontier: highly interleaved schedules reduce idle pipeline stages (bubbles) but may incur overhead from increased kernel launches and microbatch granularity. Activation memory is reduced from O(microbatches) to O(stages) for 1F1B/interleaved schedules. Current limitations include reliance on explicit schedule annotation and a single controller, with possible latency constraints at exascale. Planned extensions target schedule autotuning, MLIR dialect integration, and support for dynamic or conditional pipelines.
2. Differentiable Population Balance Equation Solver
In process engineering and scientific ML contexts, JaxPP denotes a differentiable, JAX-native PBE solver for multidimensional transport equations, chiefly in crystallization modeling (Alsubeihi et al., 2024). The framework encodes the 2D population balance (for e.g. particle size/shape distribution ) coupled to a liquid-phase mass balance using a finite volume method (FVM):
Discretization employs high-resolution FVM with Godunov directional splitting and van Leer flux limiters, with mass balance updates and numerical fluxes computed at each substep. All routines are jax.jit-compiled for accelerator efficiency and employ jax.vmap for batch processing. Reverse-mode AD enables end-to-end differentiable simulation of the PDE solver, supporting parameter identification and hybrid physical/ML modeling, such as embedding neural networks as submodels for growth rates.
Benchmarks demonstrate speedups of up to 300× over NumPy baselines (JAX-GPU), with JAX-CPU delivering ∼3× acceleration. Automatic differentiation for gradient computation enables model parameter estimation more than 40× faster than finite-difference approaches for large models. This AD property is critical for scalable scientific machine learning, including "discovery" of physical models from data and in-the-loop correction via neural surrogates.
Limitations include scalability to higher-dimensional PBEs and constraints set by GPU memory consumption during AD. Research is ongoing into adaptive mesh refinement, operator-splitting for acceleration, and physics-regularized hybridization to prevent unphysical ML submodel behaviors (Alsubeihi et al., 2024).
3. Batched AC Power Flow with JAX Acceleration
JaxPP in power systems denotes a JAX-based library for high-throughput batched AC power flow computations (Zhou et al., 13 May 2026). The transmission network solver implements the Newton–Raphson method for the nonlinear fixed-point equations governing bus power injections:
Solvers utilize sparse BCOO representations for , with Jacobian-vector products performed by jax.linearize and linear solves handled through batched GMRES, leveraging block-triangular sparse preconditioners. For unbalanced distribution networks, a Z-Bus fixed-point iteration scheme is provided, supporting three-phase loads.
The solver interface exposes direct compatibility with standard power engineering tools (e.g., Network.from_pandapower, Network.from_dss), accepting batched scenario inputs for voltage, power injection, and load profiles. All core loops are batched by jax.vmap and jax.jit.
On practical cases ranging from the IEEE 118-bus to multi-thousand bus networks, reported GPU speed-ups are 2.5–66.8× (RTX5000Pro) and 2.2–1046× (NVIDIA H200) compared to single-threaded CPU. GPU occupancy is maximized at batch sizes above 1,000, with runtime essentially independent of batch size. The framework supports embedding within AI and RL pipelines, enabling end-to-end differentiable power system analytics and backpropagation through power flow layers for optimal setpoint learning (Zhou et al., 13 May 2026).
4. Architecture, JAX Integration, and Programming Model
Across all domains, the unifying paradigm of JaxPP is deep integration with the JAX computational graph and XLA compilation pipeline. Code bases are written in pure Python, with all heavy computation executed via JAX primitives. Domains utilize jit-compilation for per-solver acceleration, vmap and pmap for batched or distributed operation, and grad/linearize for machine-precision AD.
For MPMD pipeline parallelism, the architecture comprises a global controller tracing and transforming program IR, actor runtimes managing device-local executable pools, and automatic insertion of point-to-point NCCL communication primitives. Population balance and power flow solvers similarly exploit JAX’s immutable array semantics and static computation graphs for efficient, parallelizable numerical kernels.
Programming models emphasize minimal code modification—e.g., pipeline stage marking via lightweight annotations, batch-agnostic interfaces for solvers, and transparent use of AD and device placement. Where explicit scheduling or configuration is necessary (as in pipelined training), user annotations remain the only necessary extensions to vanilla JAX code.
5. Quantitative Performance, Benchmarks, and Use Cases
Performance results consistently demonstrate orders-of-magnitude improvements over NumPy, CPU-bound implementations, or naive parallelization for all JaxPP frameworks. Key reported results include:
| Application Domain | Hardware | Speedup vs Baseline | Reference |
|---|---|---|---|
| Deep learning pipelines | 64–1024 H100 GPUs | 1.11× Util vs FSDP<br\>44.6% faster/step vs GPipe<br\>92.87% weak scaling | (Xhebraj et al., 2024) |
| Population balance eqns. | RTX 4090 GPU, Ryzen 3900 CPU | 300× (JAX-GPU), 3× (JAX-CPU) | (Alsubeihi et al., 2024) |
| Batched AC power flow | RTX5000Pro/H200 | 2.5–4,768× (CPU→GPU batched) | (Zhou et al., 13 May 2026) |
Differentiable end-to-end computation enables seamless embedding into optimization pipelines, scientific ML, and control/RL routines. Example usages span parameter estimation, hybrid physics-ML discovery, batch scenario simulation, and large-scale distributed training.
6. Limitations, Scaling, and Future Directions
Scaling challenges include balancing kernel launch granularity versus pipeline stall reduction (bubble reduction), memory/throughput trade-offs in pipeline schedules, and scaling controller latency for massive cluster sizes. Some frameworks require explicit annotation for schedules or batch processing. Power flow and PBE solvers encounter GPU memory or AD overheads for extremely large batch sizes or high parameter counts.
Proposed research directions involve schedule autotuning, deeper integration with upstream compiler frameworks (MLIR), dynamic or conditional execution across distributed pipelines, adaptive mesh methods for PBEs, and further AI/physics hybridization. All JaxPP tools remain model-agnostic and hardware-agnostic, inheriting JAX’s extensibility and constant evolution.
7. Code Repositories and Open-Source Availability
The batched AC power flow solver is available at [https://github.com/oxfordcontrol/jaxpp], with full documentation and APIs for transmission and distribution network analysis, prepared for use with standard Python (3.8–3.11) and JAX (0.9+) installations (Zhou et al., 13 May 2026).
In summary, JaxPP frameworks deliver deeply optimized, highly parallel, and differentiable pipelines for large-scale scientific and engineering computation, exploiting the flexibility and composability of the JAX stack while abstracting away parallelization, communication, and hardware management for domain specialists.