PyBird-JAX: Accelerated inference in large-scale structure with model-independent emulation of one-loop galaxy power spectra (2507.20990v1)
Abstract: We present $\texttt{PyBird-JAX}$, a differentiable, $\texttt{JAX}$-based implementation of $\texttt{PyBird}$, using internal neural network emulators to accelerate computationally costly operations for rapid large-scale structure (LSS) analysis. $\texttt{PyBird-JAX}$ computes one-loop EFTofLSS predictions for redshift-space galaxy power spectrum multipoles in 1.2 ms on a CPU and 0.2 ms on a GPU, achieving 3-4 orders of magnitude speed-up over $\texttt{PyBird}$. The emulators take a compact spline-based representation of the input linear power spectrum $P(k)$ as feature vectors, making the approach applicable to a wide range of cosmological models. We rigorously validate its accuracy against large-volume simulations and on BOSS data, including cosmologies not explicitly represented in the training set. Leveraging automatic differentiation, $\texttt{PyBird-JAX}$ supports Fisher forecasting, Taylor expansion of model predictions, gradient-based searches, and vectorised ensemble sampling. Interfaced with a variety of samplers and Boltzmann solvers, $\texttt{PyBird-JAX}$ provides a high-performance, end-to-end inference pipeline. Combined with a symbolic-$P(k)$ generator, a typical Stage-4 LSS MCMC converges in minutes on a GPU. Our results demonstrate that $\texttt{PyBird-JAX}$ delivers the precision and speed required for upcoming LSS surveys, opening the door to accelerated cosmological inference with minimal accuracy loss and no pretraining. In a companion paper [1], we put $\texttt{PyBird-JAX}$ to use in achieving LSS marginalised constraints free from volume projection effects through non-flat measures.
Collections
Sign up for free to add this paper to one or more collections.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.