JAX-Based Neural Emulators
- JAX-based neural emulators are differentiable surrogate models that integrate neural networks with JAX's automatic differentiation, JIT, and vectorization for efficient simulation.
- They utilize symbolic graph abstractions and hybrid physical-neural approaches to deliver significant acceleration and enhanced model optimization across diverse applications.
- Their applications span fields like cosmological inference, structural mechanics, and spiking neural systems, enabling robust, scalable scientific computations.
JAX-based neural emulators are differentiable, high-performance surrogate models built on the JAX library and its associated accelerated computing stack (XLA, GPU/TPU support). These emulators combine machine learning architectures—predominantly neural networks—with JAX’s automatic differentiation (AD), just-in-time (JIT) compilation, and vectorization to achieve rapid, scalable, and hardware-portable solutions for tasks that traditionally require computationally intensive simulations or non-differentiable models. JAX-based neural emulators are found across domains including deep learning training abstractions, scientific machine learning, simulation surrogates for cosmology, structural mechanics, brain modeling, economic simulation, and non-standard neural network optimization paradigms.
1. Foundations: JAX Infrastructure and Symbolic Extensions
JAX provides the computational substrate for neural emulators through a functional programming paradigm emphasizing immutable state, explicit argument-passing, and support for AD and JIT compilation, thus ensuring both differentiability and hardware efficiency. Symbolic graph abstractions such as those in SymJAX (Balestriero, 2020) bring Theano-style model building on top of JAX, allowing users to manipulate variables, placeholders, and computational graphs directly.
Compared to idiomatic JAX—which requires explicit parameter threading and functional model definitions—SymJAX enables symbolic graph construction, automatic state updates, and higher-level module APIs (e.g., with sj.function), streamlining model building and update management. This approach is compiled down to JAX and XLA-native representations, fusing kernel operations and transparently supporting CPU, GPU, or TPU backends without sacrificing performance.
Feature | JAX | SymJAX |
---|---|---|
State handling | Explicit, immutable | Symbolic, abstracted |
Model updates | User-defined in code | Via updates dict |
Compilation | JIT (explicit) | JIT (graph-compiled) |
Programming style | Functional | Symbolic graph |
2. Hardware-Accelerated and Differentiable Surrogate Modeling
JAX’s AD and JIT lower the barrier for constructing differentiable surrogate models (neural emulators) that can both match the accuracy of traditional solvers and facilitate end-to-end optimization. Applications span:
- Particle and transport equation modeling, e.g., Population Balance Equations (PBEs), where the entire simulation code—including FVM updates and potential neural or hybrid corrections—is differentiated via reverse-mode AD (Alsubeihi et al., 1 Nov 2024). JAX-based solvers exhibit up to 300× acceleration over NumPy and support hybrid physical/neural models by embedding NNs into solver routines.
- Inelastic constitutive modeling, where neural architectures parameterize dual potentials constrained by thermodynamic theory; explicit convexity and non-negativity constraints are imposed at the architectural level, with all gradients computed via JAX AD (Holthusen et al., 19 Feb 2025).
- Structural equilibrium and inverse form-finding, e.g., JAX FDM, where the entire sparse linear system solution chain remains differentiable, enabling backpropagation through physics solvers or their coupling with data-driven neural networks (Pastrana et al., 2023).
These frameworks leverage JAX’s compositionality, hardware acceleration, and differentiability to support both traditional gradient-based training and hybrid approaches that discover physics-informed surrogate models directly from data.
3. Scientific and Cosmological Inference at Scale
JAX-based neural emulators are fundamental accelerators for Bayesian parameter estimation and model comparison in high-dimensional scientific inference. The paradigm is seen in:
- CosmoPower-JAX and related pipelines (Piras et al., 2023, Reeves et al., 28 Jul 2025, Carrion, 14 Aug 2025, Jin et al., 9 Oct 2024, Lovick et al., 16 Sep 2025), where neural networks are trained on precomputed grids of simulation outputs (e.g., cosmological power spectra, Ly-α flux correlations, baryonic feedback boosts). These emulators:
- Map high-dimensional parameter vectors to observable fields rapidly—e.g., a 37/157-parameter cosmic shear or “3x2pt” inference completed in days on a GPU cluster versus years on CPU (Piras et al., 2023).
- Replace the most computationally expensive operations (e.g., loop integrals, IR resummation in LSS, nonlinear spectra) with emulated forward passes that are typically orders of magnitude faster while retaining sub-percent accuracy.
- Support advanced inference engines—automatic gradients enable Hamiltonian Monte Carlo (NumPyro/NUTS (Piras et al., 2023)), vectorized ensemble sampling, and GPU-accelerated nested sampling (Lovick et al., 16 Sep 2025).
- Propagate emulator uncertainties directly into likelihoods by incorporating emulator error covariances into Bayesian fits (Jin et al., 9 Oct 2024).
- The general architecture comprises a feed-forward MLP with task-specific nonlinearity (e.g., “hard_tanh,” adaptive surrogates (Jin et al., 9 Oct 2024, Carrion, 14 Aug 2025)), regularization (weight decay, gradient clipping), and explicit support for gradient propagation end-to-end through the emulator and the scientific likelihood function.
This approach enables effective exploration and model discrimination for next-generation surveys, providing orders of magnitude improvements in computational cost and the ability to tackle parameter spaces beyond the reach of traditional numerical solvers.
4. Neural Emulators for Spiking and Biophysical Neural Systems
Biologically plausible neural emulation is supported by libraries such as BrainPy (Wang et al., 2023), Slax (Summe et al., 8 Apr 2024), and SNNAX (Lohoff et al., 4 Sep 2024), which extend JAX’s capabilities for spiking neural networks (SNNs) and dynamical brain systems. Key features include:
- Event-driven and sparse operators (e.g., event-driven matrix-vector multiplication, synaptic projection types) for efficiency and memory coalescing in large-scale network simulations.
- Modular neuron and synapse construction, supporting both detailed conductance-based dynamics (e.g., exponential synapse update equations) and high-level abstraction over network assembly.
- Differentiable simulation of event-driven dynamics (e.g., conductance updates via discretized for JAX AD), unlocking the possibility for end-to-end training of spiking models.
- A rich ecosystem of training algorithms—in Slax, BPTT with surrogate gradients, online learning (OSTL, OTTT), forward-mode learning (FPTT), and RTRL are all implemented, with benchmarking tools based on loss landscapes and gradient cosine similarities.
- Integration with Flax and JAX-native optimizers allows seamless model design and deployment for SNNs.
- Scalability benchmarks (e.g., simulation of >4 million neurons on a single GPU (Wang et al., 2023)) and domain interoperability (e.g., conversion between JAX/Flax models and spiking modules).
These frameworks provide a toolkit for large-scale spiking neural modeling, as well as a research platform for studying the tradeoff between bio-plausibility, efficiency, and expressivity in neural emulators.
5. Non-standard Optimization and Gradient-free Architectures
JAX-based neural emulation extends beyond standard gradient-based optimization. Notably:
- PJAX (Bergmeister et al., 6 Jun 2025) reframes neural network training as a feasibility problem, eschewing loss minimization for iterative projections onto local constraint sets defined by the computation graph of network primitives (e.g., dot, sum, ReLU). Training involves composing projection operators (e.g., ) and iteratively finding a state lying in the intersection of all constraints via alternating projection or Douglas–Rachford splittings.
- This approach allows parallel updates over bipartite computation graphs and accommodates non-differentiable, non-smooth primitives (e.g., hard quantization, max pooling) by direct projection, bypassing the need for smooth surrogates or gradient approximations. It is natively vectorized via JAX, supports arbitrary model architectures (MLPs, CNNs, RNNs), and is benchmarked against gradient-based methods.
The feasibility-seeking projection method thus represents a rigorously parallel, local-update alternative to gradient-based optimization for neural emulation in JAX.
6. Benchmarks, Model Coverage, and Application Domains
Benchmarks such as APEBench (Koehler et al., 31 Oct 2024) provide comprehensive suites for evaluating JAX-based emulators in scientific contexts—in this case, time-dependent PDEs. Features include:
- Fully differentiable pseudo-spectral solvers embedded in JAX enable comparison across 46 distinct PDEs in 1D–3D, with a unified unrolling taxonomy for neural training (supervised, unrolled, hybrid corrector–predictor classes).
- Normalized PDE identifiers (e.g., dynamics coefficient and difficulty factor ) relate simulation settings to classical stability criteria (CFL, etc.), enabling explicit comparison of emulator performance versus finite volume or spectral numerical methods.
- Rollout metrics such as aggregated geometric mean nRMSE quantify temporal generalization, with rigorous error tracking over long-horizon unrolls. Experiments demonstrate that architectures trained with longer unrolling steps can recover standard upwind numerical stencils, providing cross-validation of neural versus numerical solvers.
JAX-based neural emulators are thus benchmarked not merely for pointwise prediction, but for dynamical stability and long-term autoregressive rollouts, bridging data-driven modeling and numerical simulation.
7. Future Implications and Trajectories
The integration of JAX-based neural emulators into scientific computation, simulation-based design, and large-scale inference has several clear future implications, all grounded in reported data:
- Acceleration by orders of magnitude (10³–10⁵) allows for exploration of high-dimensional Bayesian spaces (up to 157 nuisance parameters (Piras et al., 2023)) and rapid nested sampling (Lovick et al., 16 Sep 2025).
- Fully differentiable pipelines will enable robust uncertainty propagation, end-to-end optimization, and routine application of gradient-based samplers (e.g., HMC, NUTS).
- The modularity and open-source release of several libraries (SymJAX, EvoJAX, JAX FDM, BrainPy, jinns, PJAX) will drive community adoption and further innovation in domain-specific neural emulator architectures.
- Hybridization of traditional physical modeling with machine-learned components—facilitated by JAX’s AD and JIT—enables “in-the-loop” correction, physics discovery, and meta-modeling across engineering and applied science (Alsubeihi et al., 1 Nov 2024, Gangloff et al., 18 Dec 2024).
- Neural emulators with convexity and thermodynamic constraints (Holthusen et al., 19 Feb 2025) represent a movement toward interpretable, physically consistent data-driven models, with implications for automated discovery in materials science and continuum mechanics.
A plausible implication is the increasing convergence between traditional scientific computing, machine-learned surrogate modeling, and automated discovery—enabled by JAX’s unique differentiable, compositional, and hardware–portable model of computation.
JAX-based neural emulators thus encompass a methodological and software ecosystem permitting the construction, training, acceleration, and deployment of differentiable surrogate models across scientific, engineering, and AI domains. Their rapid adoption and documented performance metrics signal their growing foundational role in modern computational science.