DISCO-DJ: Differentiable Cosmology in Jax
- DISCO-DJ is a fully differentiable simulation framework for cosmology that integrates a linear Einstein–Boltzmann solver with a non-linear particle–mesh module.
- It leverages Jax’s automatic differentiation, JIT compilation, and adjoint methods to deliver high-fidelity, efficient computations for robust parameter estimation.
- The framework supports hybrid simulation–machine learning workflows and scalable Bayesian inference, enhancing precision in extracting cosmological observables.
DISCO-DJ (DIfferentiable Simulations for COsmology – Done with Jax) is a software ecosystem for cosmological simulation and inference, providing a comprehensive, fully differentiable framework for the forward modeling of both linear and non-linear cosmic structure formation. Written in Python using the Jax library, DISCO-DJ leverages automatic differentiation (autodiff), just-in-time (JIT) GPU acceleration, and adjoint methods to deliver memory-efficient, high-fidelity, and fast computations for cosmological observables, ranging from the linear transfer function to the non-linear matter density field. The system is designed for full-field level inference, robust parameter estimation, and seamless integration into modern Bayesian inference pipelines, enabling the extraction of maximal information from large-scale structure surveys and facilitating the development of hybrid simulation–machine learning workflows.
1. Architecture and Design Principles
DISCO-DJ comprises modular components centered on differentiable solutions of cosmological evolution equations: a linear Einstein–Boltzmann solver and a non-linear particle–mesh (PM) -body module. The code is implemented in Jax, a high-performance array library that enables:
- Full automatic differentiation for all numerical computations
- Efficient JIT compilation for GPU/TPU acceleration
- Forward- and reverse-mode differentiation (crucial for field-level inference)
- Custom adjoint evolution for memory efficiency in high-dimensional simulations
This design marks a departure from legacy cosmological codes (e.g., CAMB, CLASS, Gadget), which are non-differentiable “black boxes” with limited applicability to advanced inference schemes requiring derivatives of output observables with respect to high-dimensional parameter spaces (e.g., initial conditions, cosmological parameters).
2. Differentiable Einstein–Boltzmann Solver (DISCO-DJ I)
The Einstein–Boltzmann module in DISCO-DJ numerically integrates the linearized Einstein–Boltzmann equations for cosmological perturbations in a fully differentiable manner (Hahn et al., 2023). In synchronous gauge, for example, the evolution of scalar metric perturbations , baryon and cold dark matter perturbations, and massive neutrinos is given by:
and similar hierarchies for photon and neutrino perturbations.
- Automatic Differentiation: By using the Diffrax package (within Jax), both discretize-then-optimize (forward-mode) and adjoint (reverse-mode) derivatives of the solution with respect to all input cosmological parameters are accessible.
- Validation: The module produces matter power spectra, transfer functions, and other observables in per-mille agreement with CAMB and CLASS—including for massive neutrinos and general dark energy parameterizations.
- Fisher Forecasting: Exact Jacobians of power spectra with respect to parameters are used to construct Fisher matrices for rigorous survey forecasts.
- Modularity: The structure readily allows for extensibility to non-standard physics (e.g., modified gravity, additional neutrino species).
3. Differentiable Particle–Mesh N-body Simulations (DISCO-DJ II)
The PM module provides a fully differentiable simulation of mildly non-linear cosmic structure formation (List et al., 6 Oct 2025). Key attributes include:
- Theory-informed time integrators—notably the BullFrog method—which reproduce 2LPT trajectories in the pre-shell-crossing regime:
with drift and kick coefficients , designed to match analytic perturbative growth.
- Non-uniform FFT (NUFFT) for direct Fourier transforms from particle positions, suppressing aliasing and controlling discreteness and spectral leakage.
- Custom autodiff routines for particle–mesh gridding and force assignment, supporting forward-mode, reverse-mode, and full adjoint time-stepping.
- Adjoint formulation: Reverse-mode differentiation is performed with memory cost independent of the number of time steps.
- Scalability: Simulations with particles reach percent-level accuracy for the power spectrum at using only 6 BullFrog steps (runtime: a few seconds on modern GPUs).
- Field-level inference: Enables direct optimization or sampling of high-dimensional initial condition fields and cosmological parameters by comparing full simulated density fields with data, e.g., via Hamiltonian Monte Carlo.
4. Numerical and Algorithmic Innovations
DISCO-DJ achieves high performance and accuracy by integrating several advanced numerical strategies:
- Custom kernels for mass assignment and grid interpolation—supporting higher-order schemes (e.g., TSC, PCS), de-aliasing via interlacing, and Lagrangian sheet-based resampling.
- Adjoint time integration for memory-efficient backpropagation through hundreds of time steps—critical for reverse-mode autodiff in field-level inference.
- GPU-accelerated computation throughout, including FFTs, gridding, and vectorized numerical operations, yielding order-of-magnitude speed gains over comparable codes (List et al., 6 Oct 2025).
- Detailed accuracy studies: Extensive benchmarks quantify the impact of time-stepping, grid density, force resolution, and discreteness corrections, ensuring reliability for cosmological analyses.
5. Applications: Field-level Bayesian Inference and Self-consistent Cosmological Pipelines
A primary motivation for DISCO-DJ is field-level Bayesian inference, an approach that utilizes the full spatial information in density fields and galaxy distribution for cosmological parameter estimation. The differentiable nature of the DISCO-DJ pipeline is essential for:
- Efficient gradient-based sampling (e.g., HMC) in high-dimensional parameter spaces, where typical likelihoods depend on hundreds of thousands of parameters (initial condition amplitudes plus cosmological variables).
- Loss functions that can involve arbitrary field-level summary statistics or the full three-dimensional field, with gradients propagated through all simulation components.
- Seamless coupling with differentiable Einstein–Boltzmann solvers, yielding a fully self-consistent mapping from primordial parameters to late-time observables.
- Training and emulation: The framework generates large, self-consistent simulation suites for neural network–based emulators, or directly supports hybrid pipelines (e.g., with PyBird-JAX (Reeves et al., 28 Jul 2025), CosmoPower-JAX (Piras et al., 2023)).
6. Comparison with Other Differentiable and Non-differentiable Codes
- Legacy codes (CAMB, CLASS, standard -body): Cannot supply exact derivatives of output observables with respect to input parameters, limiting their use in modern gradient-based inference and making survey optimization or parameter dependence expensive and numerically fragile.
- Related Jax-based efforts (pmwd, CosmoPower-JAX, JAX-cosmo, GODMAX, SwiftC, halox): DISCO-DJ distinguishes itself by providing both a differentiable linear solver and a high-fidelity, memory-efficient, and differentiable PM code with end-to-end pipeline integration, adjoint differentiation, and scale-appropriate numerical accuracy.
- Innovative features: The BullFrog integrator, NUFFT force computation, and robust handling of forward- and reverse-mode autodiff set DISCO-DJ apart in terms of both speed and theoretical soundness.
7. Open-source Availability and Future Directions
DISCO-DJ is open-source and available at https://github.com/cosmo-sims/DISCO-DJ, facilitating transparent cross-comparison, extension, and integration into broader cosmological data analysis infrastructures (List et al., 6 Oct 2025). Its modular Python/Jax codebase simplifies the development of new physical modules (e.g., baryonic effects, alternative gravity, biased tracer models) and enables the rapid prototyping of differentiable scientific simulators. Planned extensions include:
- Incorporation of higher-order -body solvers and baryonic subgrid physics with differentiable surrogates
- Efficient hybridization with machine learning approaches for improved inference in high-dimensional observational data spaces
- End-to-end pipelines covering Einstein–Boltzmann, mildly non-linear, and strongly non-linear regimes within a consistent autodiff environment
DISCO-DJ thus provides a technically rigorous and extensible foundation for next-generation, gradient-based cosmological inference, emphasizing physical fidelity, scalability, and full differentiability at every stage in the modeling process.