JaxARC: High-Throughput RL for ARC
- JaxARC is an open-source, high-throughput reinforcement learning environment for ARC that leverages stateless, functional JAX APIs.
- It replaces stateful Python designs with composable vectorized operations, achieving up to 5,439× speedup over legacy frameworks.
- The platform supports diverse ARC benchmarks with flexible action wrappers, accelerating research in program induction and reasoning.
JaxARC is an open-source, high-throughput reinforcement learning (RL) environment for the Abstraction and Reasoning Corpus (ARC), implemented in JAX, designed to facilitate large-scale AI research on human-like program induction and reasoning tasks. Unlike previous ARC environments that relied on stateful, object-oriented Python APIs with severe performance bottlenecks, JaxARC employs a functional, stateless architecture enabling unprecedented parallelism, flexibility, and reproducibility for RL workflows in abstraction and reasoning (Aadam et al., 24 Jan 2026).
1. Motivation and Historical Context
The ARC benchmark targets program synthesis via inductive reasoning from demonstration pairs, serving as a standard for evaluating AI systems' capacity for novel transformation discovery. Earlier ARC RL platforms (e.g., ARCLE built on OpenAI Gymnasium) suffered from the following limitations:
- Pure-Python, stateful, object-oriented design, with environment state mutated on step/reset.
- Parallelism achieved via Python-level multiprocessing, incurring prohibitive IPC and interpreter overhead.
- Throughput saturating at – steps/second; experiments spanning billions of steps required multi-day runtimes, obstructing algorithmic innovation.
Modern RL frameworks (PPO, meta-learning, population-based training) require massive concurrent rollouts, with scale targets in the – steps/second range. JaxARC was designed to address these computational barriers through JAX/XLA acceleration, a fully functional API, and composable wrappers (Aadam et al., 24 Jan 2026).
2. Core Functional Architecture
JaxARC models the environment as two stateless pure functions:
Key architectural features:
- All environment state (active grid, clipboard, counters, PRNG key) propagates explicitly as a JAX pytree, eliminating hidden mutability and favoring pure functional code.
- Variable grid sizes (1×1–30×30) are padded to 30×30, with a validity mask for efficient JIT and vectorization.
- TaskBuffer loads all puzzles into JAX arrays at environment start; stochastic task selection via PRNG-indexed JAX ops.
- Actions are consistently represented as tuples , with external wrappers supporting point actions (row, col, op) and bounding-box actions (r₁, c₁, r₂, c₂, op). Wrappers map any parameterization to the unified mask format.
- The entire environment logic, parameterized as a pytree, is compatible with JAX's transformation stack:
jax.jit(single-kernel compilation),jax.vmap(vectorization across batches), andjax.pmap(multi-device/sharded parallelism).
3. Parallelism and Performance Benchmarks
JaxARC's stateless JAX core drives exceptional parallel performance:
| Device | Batch Size | Gymnasium ARCLE (steps/s) | JaxARC (steps/s) | Speedup |
|---|---|---|---|---|
| Apple M2 Pro | 65,536 | 21,000 | 789,000 | 38× |
| RTX 3090 | 131,072 | 36,000 | 32,600,000 | 903× |
| Nvidia H100 | 131,072 | 26,000 | 142,000,000 | 5,439× |
At maximal scale (2,097,152 envs on H100), JaxARC reaches 790 million steps/second—enabling steps in 1.3 seconds. Gymnasium environments crash above 131K concurrent envs. This throughput enables full ARC-scale RL experiments and exhaustive hyperparameter sweeps that were previously infeasible (Aadam et al., 24 Jan 2026).
4. Supported Datasets and Environment Features
JaxARC supports multiple ARC benchmarks:
- Full ARC: 400 training + 600 test tasks
- MiniARC: 149 tasks
- User-defined ARC formats (JSON, YAML) via parser interfaces
- Task splits, named subsets, and lazy parsing
Flexible action space designs include:
- Core:
- Wrappers for:
- Point actions: sparse selections
- Bounding-box actions: regional selections
- All wrappers unify to binary mask format for the core step logic
Compositional wrappers facilitate:
- Observations: concatenated channels of grid(s), I/O demonstrations, targets
- Actions: custom parameterizations mapped via mask logic
- Rewards: user-definable via sparse bonuses, penalties, and optional pixel similarity shaping
Configuration-driven reproducibility:
- Hydra-based YAML config files enumerate datasets, wrappers, reward shaping, curriculum schedules, and PRNG seeds
- Deterministic construction: identical seed + config yields identical rollouts on CPU, GPU, or TPU
- Curriculum support, auto-reset, metric logging (Aadam et al., 24 Jan 2026)
5. Usage Example and API Design
A minimal workflow leverages pure JAX functions for environment interaction, batch vectorization, and JIT compilation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import jax import jax.numpy as jnp from jaxarc import make env, env_params = make("Mini-Most_Common_color_l6ab0lf3xztbyxsu3p", auto_download=True) key = jax.random.PRNGKey(0) state, timestep = env.reset(key, env_params=env_params) action_space = env.action_space(env_params) key, subkey = jax.random.split(key) action = action_space.sample(subkey) next_state, next_timestep = env.step(state, action, env_params=env_params) print("Reward:", next_timestep.reward) num_envs = 1024 keys = jax.random.split(key, num_envs) vm_reset = jax.vmap(env.reset, in_axes=(0, None)) states, timesteps = vm_reset(keys, env_params) @jax.jit def rollout_step(state, key): act = action_space.sample(key) return env.step(state, act, env_params) keys2 = jax.random.split(key, num_envs) vm_step = jax.vmap(rollout_step) next_states, next_timesteps = vm_step(states, keys2) |
Multidevice support via jax.pmap requires no API changes—batch computation seamlessly scales from single-node to distributed hardware (Aadam et al., 24 Jan 2026).
6. Impact on Reinforcement Learning Research
JaxARC enables experimental regimes previously out of reach:
- PPO training with 1,000 envs at 30M steps/sec per GPU
- Population-based training (evoRL), architecture search, and meta-learning (MAML, RL) on ARC datasets at billions of steps per experiment, with walltimes reduced from weeks to hours
- Deterministic rollouts for reproducible ablation studies and sweeps across observation/reward/curriculum wrappers
- Efficient transformer and attention-model policy exploration over demonstration examples
- Neuro-symbolic hybrids integrating RL with program synthesis
- Curriculum schedules unlocking progressively harder ARC tasks
- Massive-scale RL discovery of transformation primitives
JaxARC's performance (up to 5,439× speedup vs. Gymnasium ARCLE), functional design, and reproducibility infrastructure reposition ARC as a first-class benchmark for deep RL, meta-learning, neuro-symbolic research, and program induction at experimental scale (Aadam et al., 24 Jan 2026).
7. Relation to JAX-based Scientific Computing and Future Directions
JaxARC leverages core JAX design lessons—stateless pytrees, XLA-fused ops, batched computation, and function transforms—as seen in high-performance frameworks such as Carbox (astrochemistry simulation in JAX), which emphasizes end-to-end differentiability, accelerator throughput, and composability (Vermariën et al., 13 Nov 2025). The abstraction strategies and infrastructure in JaxARC (padded shapes, PRNG-index sampling, wrapper architecture) parallel JAX applications in scientific ML and interval reachability analysis (Harapanahalli et al., 2024).
A plausible extension is further integration with scientific ML stacks, e.g., replacing environment wrappers/step logic with learned surrogates or embedding RL environments within differentiable simulation pipelines. The underlying architectural principles of JaxARC suggest adaptability to other program induction, reasoning, and RL benchmarks where throughput, reproducibility, and flexible action-observation interfacing are critical.
References
- "JaxARC: A High-Performance JAX-based Environment for Abstraction and Reasoning Research" (Aadam et al., 24 Jan 2026)
- "Carbox: an end-to-end differentiable astrochemical simulation framework" (Vermariën et al., 13 Nov 2025)
- "immrax: A Parallelizable and Differentiable Toolbox for Interval Analysis and Mixed Monotone Reachability in JAX" (Harapanahalli et al., 2024)