Papers
Topics
Authors
Recent
Search
2000 character limit reached

JaxARC: High-Throughput RL for ARC

Updated 31 January 2026
  • JaxARC is an open-source, high-throughput reinforcement learning environment for ARC that leverages stateless, functional JAX APIs.
  • It replaces stateful Python designs with composable vectorized operations, achieving up to 5,439× speedup over legacy frameworks.
  • The platform supports diverse ARC benchmarks with flexible action wrappers, accelerating research in program induction and reasoning.

JaxARC is an open-source, high-throughput reinforcement learning (RL) environment for the Abstraction and Reasoning Corpus (ARC), implemented in JAX, designed to facilitate large-scale AI research on human-like program induction and reasoning tasks. Unlike previous ARC environments that relied on stateful, object-oriented Python APIs with severe performance bottlenecks, JaxARC employs a functional, stateless architecture enabling unprecedented parallelism, flexibility, and reproducibility for RL workflows in abstraction and reasoning (Aadam et al., 24 Jan 2026).

1. Motivation and Historical Context

The ARC benchmark targets program synthesis via inductive reasoning from demonstration pairs, serving as a standard for evaluating AI systems' capacity for novel transformation discovery. Earlier ARC RL platforms (e.g., ARCLE built on OpenAI Gymnasium) suffered from the following limitations:

  • Pure-Python, stateful, object-oriented design, with environment state mutated on step/reset.
  • Parallelism achieved via Python-level multiprocessing, incurring prohibitive IPC and interpreter overhead.
  • Throughput saturating at O(104O(10^4105)10^5) steps/second; experiments spanning billions of steps required multi-day runtimes, obstructing algorithmic innovation.

Modern RL frameworks (PPO, meta-learning, population-based training) require massive concurrent rollouts, with scale targets in the 10810^810910^9 steps/second range. JaxARC was designed to address these computational barriers through JAX/XLA acceleration, a fully functional API, and composable wrappers (Aadam et al., 24 Jan 2026).

2. Core Functional Architecture

JaxARC models the environment as two stateless pure functions:

(statenew,timestep)=env.reset(key,env_params)(\text{state}_{\text{new}},\, \text{timestep}) = \text{env.reset}(\text{key},\, \text{env\_params})

(statenext,timestep)=env.step(state,action,env_params)(\text{state}_{\text{next}},\, \text{timestep}) = \text{env.step}(\text{state},\, \text{action},\, \text{env\_params})

Key architectural features:

  • All environment state (active grid, clipboard, counters, PRNG key) propagates explicitly as a JAX pytree, eliminating hidden mutability and favoring pure functional code.
  • Variable grid sizes (1×1–30×30) are padded to 30×30, with a validity mask for efficient JIT and vectorization.
  • TaskBuffer loads all puzzles into JAX arrays at environment start; stochastic task selection via PRNG-indexed JAX ops.
  • Actions are consistently represented as tuples (operation_id,binary_mask)(\text{operation\_id}, \text{binary\_mask}), with external wrappers supporting point actions (row, col, op) and bounding-box actions (r₁, c₁, r₂, c₂, op). Wrappers map any parameterization to the unified mask format.
  • The entire environment logic, parameterized as a pytree, is compatible with JAX's transformation stack: jax.jit (single-kernel compilation), jax.vmap (vectorization across batches), and jax.pmap (multi-device/sharded parallelism).

3. Parallelism and Performance Benchmarks

JaxARC's stateless JAX core drives exceptional parallel performance:

Device Batch Size Gymnasium ARCLE (steps/s) JaxARC (steps/s) Speedup
Apple M2 Pro 65,536 21,000 789,000 38×
RTX 3090 131,072 36,000 32,600,000 903×
Nvidia H100 131,072 26,000 142,000,000 5,439×

At maximal scale (2,097,152 envs on H100), JaxARC reaches 790 million steps/second—enabling 10910^9 steps in <<1.3 seconds. Gymnasium environments crash above \sim131K concurrent envs. This throughput enables full ARC-scale RL experiments and exhaustive hyperparameter sweeps that were previously infeasible (Aadam et al., 24 Jan 2026).

4. Supported Datasets and Environment Features

JaxARC supports multiple ARC benchmarks:

  • Full ARC: 400 training + 600 test tasks
  • MiniARC: 149 tasks
  • User-defined ARC formats (JSON, YAML) via parser interfaces
  • Task splits, named subsets, and lazy parsing

Flexible action space designs include:

  • Core: (operation_id[0..34],30×30mask)(\text{operation\_id} \in [0..34], 30\times30\, \text{mask})
  • Wrappers for:
    • Point actions: sparse selections (r,c,op)(r, c, op)
    • Bounding-box actions: regional selections (r1,c1,r2,c2,op)(r_1, c_1, r_2, c_2, op)
    • All wrappers unify to binary mask format for the core step logic

Compositional wrappers facilitate:

  • Observations: concatenated channels of grid(s), I/O demonstrations, targets
  • Actions: custom parameterizations mapped via mask logic
  • Rewards: user-definable via sparse bonuses, penalties, and optional pixel similarity shaping

Configuration-driven reproducibility:

  • Hydra-based YAML config files enumerate datasets, wrappers, reward shaping, curriculum schedules, and PRNG seeds
  • Deterministic construction: identical seed + config yields identical rollouts on CPU, GPU, or TPU
  • Curriculum support, auto-reset, metric logging (Aadam et al., 24 Jan 2026)

5. Usage Example and API Design

A minimal workflow leverages pure JAX functions for environment interaction, batch vectorization, and JIT compilation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import jax
import jax.numpy as jnp
from jaxarc import make

env, env_params = make("Mini-Most_Common_color_l6ab0lf3xztbyxsu3p", auto_download=True)

key = jax.random.PRNGKey(0)
state, timestep = env.reset(key, env_params=env_params)
action_space = env.action_space(env_params)
key, subkey = jax.random.split(key)
action = action_space.sample(subkey)
next_state, next_timestep = env.step(state, action, env_params=env_params)

print("Reward:", next_timestep.reward)

num_envs = 1024
keys = jax.random.split(key, num_envs)
vm_reset = jax.vmap(env.reset, in_axes=(0, None))
states, timesteps = vm_reset(keys, env_params)

@jax.jit
def rollout_step(state, key):
    act = action_space.sample(key)
    return env.step(state, act, env_params)

keys2 = jax.random.split(key, num_envs)
vm_step = jax.vmap(rollout_step)
next_states, next_timesteps = vm_step(states, keys2)

Multidevice support via jax.pmap requires no API changes—batch computation seamlessly scales from single-node to distributed hardware (Aadam et al., 24 Jan 2026).

6. Impact on Reinforcement Learning Research

JaxARC enables experimental regimes previously out of reach:

  • PPO training with >>1,000 envs at >>30M steps/sec per GPU
  • Population-based training (evoRL), architecture search, and meta-learning (MAML, RL2^2) on ARC datasets at billions of steps per experiment, with walltimes reduced from weeks to hours
  • Deterministic rollouts for reproducible ablation studies and sweeps across observation/reward/curriculum wrappers
  • Efficient transformer and attention-model policy exploration over demonstration examples
  • Neuro-symbolic hybrids integrating RL with program synthesis
  • Curriculum schedules unlocking progressively harder ARC tasks
  • Massive-scale RL discovery of transformation primitives

JaxARC's performance (up to 5,439× speedup vs. Gymnasium ARCLE), functional design, and reproducibility infrastructure reposition ARC as a first-class benchmark for deep RL, meta-learning, neuro-symbolic research, and program induction at experimental scale (Aadam et al., 24 Jan 2026).

7. Relation to JAX-based Scientific Computing and Future Directions

JaxARC leverages core JAX design lessons—stateless pytrees, XLA-fused ops, batched computation, and function transforms—as seen in high-performance frameworks such as Carbox (astrochemistry simulation in JAX), which emphasizes end-to-end differentiability, accelerator throughput, and composability (Vermariën et al., 13 Nov 2025). The abstraction strategies and infrastructure in JaxARC (padded shapes, PRNG-index sampling, wrapper architecture) parallel JAX applications in scientific ML and interval reachability analysis (Harapanahalli et al., 2024).

A plausible extension is further integration with scientific ML stacks, e.g., replacing environment wrappers/step logic with learned surrogates or embedding RL environments within differentiable simulation pipelines. The underlying architectural principles of JaxARC suggest adaptability to other program induction, reasoning, and RL benchmarks where throughput, reproducibility, and flexible action-observation interfacing are critical.


References

Topic to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to JaxARC.