PyBird-JAX: Fast Cosmological Inference
- PyBird-JAX is a differentiable, JAX-based cosmology tool that integrates neural network emulators for rapid and precise one-loop EFT predictions.
- Its emulator strategy uses spline representations of the linear matter power spectrum to decouple predictions from specific cosmological parameters.
- Performance tests reveal speed-ups of 3–4 orders of magnitude, achieving sub-millisecond evaluations on GPUs for real-time Bayesian inference.
PyBird-JAX is a differentiable, JAX-based implementation of the PyBird code, developed to enable rapid, precision cosmological inference from large-scale structure (LSS) data. By embedding neural network emulators within a fully JAX-compiled environment, PyBird-JAX achieves millisecond-scale evaluations of one-loop effective field theory (EFT) predictions for redshift-space galaxy power spectrum multipoles, with speed-ups of 3–4 orders of magnitude compared to the original PyBird architecture. This extensive acceleration, coupled with rigorous accuracy validation, enables advanced scientific workflows such as Fisher forecasting, gradient-based sampling, and real-time Bayesian inference for Stage-4 galaxy surveys (Reeves et al., 28 Jul 2025).
1. JAX-Based Architecture and Differentiable Pipeline
PyBird-JAX is a comprehensive recoding of the original PyBird pipeline, implemented using JAX, a numerical computing library optimized for automatic differentiation (AD), just-in-time (JIT) compilation, and transparent dispatch to CPU and GPU. Every computational component—from one-loop EFT integrals to likelihood evaluation—is either written directly in JAX or adapted as a custom JAX primitive, ensuring full differentiability across the computation graph.
To address performance bottlenecks in expensive nonlinear calculations, such as two-point integrals evaluated via FFTLog or IR-resummation steps, PyBird-JAX utilizes embedded neural network (NN) emulators. These NNs are trained on the coefficients from a compact spline-based encoding of the linear matter power spectrum , ensuring strong generalization across cosmological models. Models originally trained with TensorFlow are ported to the JAX ecosystem using flax, enabling seamless integration into vectorized and differentiated pipelines.
2. Emulator Strategy and Model-Independence
A central innovation is the cosmology-independent emulator design: neural emulators operate not on cosmological parameters directly, but on feature vectors comprising the weights of a spline decomposition of . This construction abstracts away specific cosmological assumptions, allowing the same NN to provide accurate surrogate predictions across a broad manifold of cosmologies—including those not present in the training set. By decoupling the emulator from parameter space, PyBird-JAX eliminates the need for continual retraining or emulator updates for new theoretical models, subject only to sufficient coverage in representation.
In the computation flow, the NN emulator replaces full direct evaluation of loop integrals and IR-resummation computation, with a forward pass through the network yielding the required EFT outputs. Differentiability is preserved since all operations, including the emulator’s internal logic, reside in JAX’s transformation system.
3. Performance and Computational Scaling
Empirical performance metrics demonstrate that the full prediction for all three redshift-space galaxy power spectrum multipoles (monopole, quadrupole, and hexadecapole) is achieved in approximately 1.2 milliseconds on an AMD EPYC CPU, and 0.19 milliseconds on an NVIDIA A100 GPU. When leveraging JAX’s vmap for batch-mode evaluation (e.g., batch size = 128), effective evaluation time drops to 0.023 ms per realization.
These improvements correspond to a computational speed-up of 3–4 orders of magnitude relative to the original PyBird, attributed primarily to (a) the dramatic reduction in floating-point operations (from for full loop integrals to for emulator evaluation), and (b) the elimination of Python overhead in favor of JIT-compiled, vectorized code. The scaling permits likelihood-based inference (MCMC, nested sampling, HMC) to reach convergence in a matter of minutes, including high-dimensional nuisance parameter spaces, when deployed on a GPU.
4. Accuracy Validation and Robustness
Validation is performed against two classes of data:
- Large-volume simulated data, such as PT challenge suites with simulation boxes exceeding hundreds of , provide a testbed for assessing accuracy of LSS observables. PyBird-JAX exhibits residual errors below 0.1 relative to Stage-4 observational uncertainties over the bulk of the parameter space.
- Real galaxy data, including BOSS and eBOSS samples, where analyses encompass both standard CDM and extensions (e.g., Early Dark Energy, self-interacting neutrino cosmologies). Posterior constraints and best-fit predictions from PyBird-JAX (with emulator substitution) are consistent with those from non-emulated PyBird at the subpercent level, regardless of whether the specific cosmology was present in the NN training set.
Reliability in emulation is insured by the spline-based representation; even for unencountered parameter regimes or extended physics, interpolation in the observable space is robust, and no accuracy loss is observed within the tested metric regime.
5. Advanced Functional Capabilities
Leveraging JAX’s automatic differentiation, PyBird-JAX exposes the following advanced features:
- Fisher matrix and Hessian analysis: The code supports exact, single-command computation of the Fisher information matrix or likelihood Hessian, via reverse-mode AD, in the full observable parameterization.
- Taylor expansion acceleration: Model predictions can be rapidly expanded about fiducial points for efficient parameter inference, particularly valuable when integrating less-differentiable downstream solvers (e.g., non-JAX Boltzmann codes).
- Gradient-based optimization and sampling: Owing to full differentiability, gradient-based minimizers (Adam, Newton, LBFGS) and advanced samplers (HMC with NUTS algorithms) can be employed end-to-end, increasing efficiency in high-dimensional parameter searches.
- Vectorization and ensemble inference: PyBird-JAX seamlessly integrates with vmap and pmap for parallelized sampling over large ensembles (e.g., batched emcee or zeus walkers), maximizing GPU/TPU throughput.
- Toolkit interoperability: The pipeline is compatible with external samplers (emcee, zeus, nautilus, BlackJAX), Boltzmann solvers (CLASS, emulator surrogates), and symbolic generators, establishing an end-to-end likelihood evaluation platform.
6. Cosmological Applications and Results
Applied to practical scenarios, PyBird-JAX is used to obtain marginalized LSS constraints—including those free from volume projection biases by implementing non-flat integration measures made tractable by AD. This pipeline is applied to both simulated and real survey data, providing rapid summary statistics, parameter estimates, and error propagation. For simulated datasets (e.g., PT challenge), and for real Stage-4 survey analyses (e.g., DESI, Euclid), the tool demonstrates capacity to marginalize over dozens of nuisance parameters in minutes, with compatibility for direct inclusion of extended cosmological or astrophysical models.
A table summarizing timing and accuracy (as reported in (Reeves et al., 28 Jul 2025)):
Pipeline | CPU time (ms) | GPU time (ms) | Relative Accuracy (σ, max) |
---|---|---|---|
PyBird Original | 3000 | n/a | Subpercent |
PyBird-JAX (NN) | 1.2 | 0.19 | <0.1 |
PyBird-JAX (Emu, vmap) | 0.16 | 0.023 | <0.1 |
7. Future Prospects and Generalization
PyBird-JAX’s design, centered on JAX-based differentiable programming and surrogate emulation via cosmology-agnostic spline representations, is particularly well suited for next-generation cosmological analyses characterized by large data volume and model diversity. The methodological framework is directly extensible to higher-order loop corrections (two-loop, bispectrum), additional LSS observables, and joint analyses (e.g., with CMB likelihoods or lensing statistics).
The minimization of computational overhead, coupled with broad applicability of NN emulation, suggests the approach will continue to scale with evolving survey demands. A plausible implication is that real-time or iterative updating of cosmological constraints—and even feedback into survey design or systematics mitigation—becomes feasible within standard computational resources. Moreover, open compatibility with symbolic or emulator-based generators positions PyBird-JAX as a core component in modular, reproducible inference pipelines for cosmological data analysis.