MuJoCo-XLA Simulator (MJX)
- MuJoCo-XLA (MJX) is a differentiable simulator that re-implements MuJoCo using JAX-compatible XLA primitives for scalable, parallel physics simulation.
- It employs XLA optimizations such as operation fusion and JIT compilation to achieve 10–50× faster throughput than traditional CPU/GPU simulators, enhancing RL training.
- MJX integrates gradient-based system identification with adaptive integrators and contact models, enabling precise control and rapid prototyping for robotics and learning tasks.
MuJoCo-XLA (MJX) is a JAX-native, differentiable physics simulator inspired by MuJoCo, re-implemented to express all simulation elements—rigid-body dynamics, collision detection, constraint solving, and discrete-time integration—as composable graphs of XLA (Accelerated Linear Algebra) primitives. This architecture enables highly scalable, parallelizable, and differentiable physical simulation suitable for learning and control tasks in robotics, with performance gains of one to two orders of magnitude over traditional CPU- or even multi-threaded GPU simulation. By embedding the simulator into the JAX/XLA ecosystem, MJX supports seamless GPU/TPU execution, reverse-mode autodifferentiation, and joint optimization over policy and environment parameters (Thibault et al., 6 Jul 2024, Tunçay et al., 15 Dec 2025, Paulus et al., 17 Jun 2025).
1. Architecture and XLA Integration
MJX is fundamentally a rewrite of the classical MuJoCo simulation engine such that every computational block is encoded in JAX-compatible primitives (not C/C++ or hand-written CUDA). The simulation API and modelling practices—including XML-based model descriptions, body/joint/actuator specification, and the per-step “step” interface—remain consistent with MuJoCo, which allows direct translation of workflows.
Internally, each MJX simulation step is traced by JAX, lowered to HLO (High Level Optimizer) operations, and JIT-compiled by XLA into one or a few highly fused GPU kernels. The core computational flow for a physics step—such as the transition —becomes an XLA graph, where represents the parameterized simulator dynamics. For a batch of environments, XLA produces , fully supporting parallel rollouts.
Key XLA optimizations in MJX include: operation fusion (eliminating redundant memory traffic), constant propagation and dead-code elimination (pruning fixed elements), batched memory layout for efficient coalesced loads/stores, and JIT/ahead-of-time kernel compilation. The forward and reverse-mode differentiation rules for each primitive are explicitly registered, allowing policy gradients and system gradients to flow through both control and physical interactions (Thibault et al., 6 Jul 2024, Tunçay et al., 15 Dec 2025).
2. Differentiable Physics and Contact Modeling
MJX implements rigid-body dynamics, tree-based inertia solving, and friction/drag forces using JAX’s computational graph. Collision detection and constraint solving adopt smooth, penalty-based methods compatible with reverse-mode differentiation. The contact solver reformulates collision resolution via a soft-constrained version of Gauss's principle, including Baumgarte stabilization, such that contact impulses are computed as
where is the inertia matrix, the constraint Jacobian, a reference acceleration depending on penetration depth and velocity, and a smooth penalty function controlled by the impedance spline (“solimp”). All collision geometric tests are smoothed (e.g., soft-clamp on signed distances), producing gradients everywhere (Paulus et al., 17 Jun 2025).
3. Large-Scale Parallelization and RL Integration
MJX is designed for extreme parallelism on GPU/TPU, with all state, velocity, and force tensors represented as contiguous batched arrays in device memory. Batch-parallelism is realized through JAX primitives like vmap and lax.scan, which are JIT-compiled into a small number of fused XLA kernels. The practical effect is the ability to simulate thousands of environments in lock-step, vectorizing over both the agent population (for distributed RL) and time (for temporal unrolling).
MJX’s main ecosystem integration is via Brax, where RL environment wrappers provide high-level policy and value function management, while MJX handles the core physical simulation. API differences compared to MuJoCo are minor: step functions are JIT-compiled and pythonic (e.g., step_fn = jax.jit(model.step)), and all simulation/learning logic executes inside one device-bound function, eliminating host-device synchronization overhead.
A representative code structure follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import jax from brax.v1.envs import create model = create('mjx_reemc') step_fn = jax.jit(model.step) reset_fn = jax.jit(model.reset) key = jax.random.PRNGKey(0) keys = jax.random.split(key, batch_size) states = reset_fn(keys) for update in range(num_updates): actions = policy_apply(params, states.obs) next_states = step_fn(states, actions) grads = jax.grad(ppo_loss)(params, states, actions, next_states) params = update_params(params, grads) states = next_states |
4. Performance Benchmarks
Experimental results report dramatic acceleration in RL training. For a 30-DoF REEM-C humanoid, 8,192 parallel environments were advanced for 200 million PPO steps in 56 minutes on a single RTX 4090, corresponding to approximately 60,000 steps per second. Compared to single-threaded MuJoCo (O(1,000) steps/s), MJX offers 10–50× throughput improvement. For 6-DOF underwater vehicles, 4,096 environments ran 100 PPO episodes (512 steps each) in less than 2 minutes on an RTX 4060, with total environment steps per second observed as: and relative speedup – (Tunçay et al., 15 Dec 2025, Thibault et al., 6 Jul 2024).
GPU utilization during main kernels is consistently above 90%. JIT tracing incurs a nominal overhead ($10$–$20$ s) amortized over the run. When compared against multi-threaded Gazebo or IsaacSim, MJX supports 20–40× faster policy training. A plausible implication is that MJX enables research previously bottlenecked by slow simulation, particularly in sim-to-real RL and system identification.
5. Gradient-Based System Identification and Trajectory Optimization
MJX supports differentiable simulation for end-to-end learning and system identification, as all physics and contact computations admit vjp and jvp rules traceable by JAX. This enables direct joint optimization of physical parameters and policy weights via standard first-order optimizers (e.g., SGD, Adam, or the Optax library). In system identification applications, MJX enables the recovery of key physical properties such as mass and inertia, as well as complex nonlinearities (e.g., friction via neural approximations), from purely trajectory-level (state and action) data, without requiring explicit torque ground truth (Kovalev et al., 6 Aug 2025).
Accurate system identification in bipedal locomotion and underwater AUV tasks demonstrably improves trajectory tracking, reduces sim-to-real drift, and increases robustness to disturbances. The transparent interplay between gradient-based identification and RL is a hallmark of the MJX framework, and has been realized in on-policy RL, model-predictive control, and zero-shot sim-to-real transfer (Kovalev et al., 6 Aug 2025, Tunçay et al., 15 Dec 2025).
6. Refinements: DiffMJX and Contacts-From-Distance (CFD)
DiffMJX addresses a core challenge in penalty-based simulators: hard (stiff) contacts induce gradient oscillations and even sign errors in automatic differentiation, due to large discretization errors at coarse integration steps. DiffMJX augments MJX with an adaptive ODE integrator (e.g., Tsit5 with PID control), using error estimation and step rejection to maintain gradient quality. Backpropagation is supported by discretize-then-optimize (checkpointed tape for memory) or optimize-then-discretize (adjoint sensitivity) schemes. Empirical results demonstrate that adaptive integration Pareto-dominates fixed-step methods for gradient quality and runtime (Paulus et al., 17 Jun 2025).
Contacts-From-Distance (CFD) solves the “zero-gradient barrier” in contact-rich learning: ordinary contact forces are identically zero when two bodies do not touch, resulting in vanishing gradients for policy improvement. CFD introduces artificial, informativeness-preserving contact forces at small positive separations—spline-extended impedance and reference acceleration—active only in the backward pass via the straight-through estimator: ensuring that real rollouts remain physically accurate while gradients inform action “at a distance.” CFD (in combination with DiffMJX) improves convergence in system identification and policy optimization for dexterous manipulation tasks, notably reducing the sim-to-real gap and enabling model-predictive control that sampling-based planners cannot achieve in tractable time (Paulus et al., 17 Jun 2025).
7. Limitations and Future Directions
Despite pronounced performance gains, MJX has several acknowledged limitations. The default actuator model is simplified (typically position-control); realistic tendon or hydraulic dynamics are not differentiated or JIT-compiled. Contact solver fidelity is traded for performance—harder solvers (e.g., iterative TREP) may offer higher accuracy at significant computational cost. Hyperparameter tuning for contact and integration settings is manual; future releases may introduce automated gradient-based tuning. Multi-rate integration (to support sub-millisecond joint controllers) is under development (Thibault et al., 6 Jul 2024).
DiffMJX introduces computational overhead (2–5× slower than vanilla MJX) and requires integrator tolerance tuning. CFD incurs increased per-step costs with dense contact scenes or large near-contact regions. Gradient-based system ID remains vulnerable to local minima and nonphysical optima (especially with CFD enabled). Combining CFD with sampling-based or hybrid planners is a prospective direction.
MJX, DiffMJX, and CFD jointly establish a GPU-native, autodifferentiable, and semantically transparent foundation for large-scale simulation-based robotics learning, rapidly closing the practical gap between fast GPU simulators and high-fidelity real-world control (Thibault et al., 6 Jul 2024, Tunçay et al., 15 Dec 2025, Paulus et al., 17 Jun 2025).