Rollout Routing Replay (R3)
- Rollout Routing Replay (R3) is a method that stabilizes reinforcement learning in Mixture-of-Experts models by replaying the inference routing mask during training.
- The technique halves the train–infer KL divergence by aligning the training and inference routing distributions, thereby reducing variance in importance weights.
- Empirical results on math reasoning tasks show that R3 prevents PPO policy collapse and improves performance by 1–6 points compared to standard methods.
Rollout Routing Replay (R) is a methodology for stabilizing reinforcement learning (RL) in Mixture-of-Experts (MoE) transformer models by eliminating discrepancies between routing decisions made during inference (rollout generation) and those during training. The approach addresses the fundamental instability induced by non-deterministic or diverging expert selection across these phases, which manifests acutely in RL settings, leading to inflated policy KL divergence, high-variance importance weights, and frequent policy collapse. R rectifies the mismatch by logging the expert selection mask during inference and deterministically replaying it in training, thus yielding a training policy closely aligned to inference-time decisions and dramatically improving RL stability and final task performance (Ma et al., 13 Oct 2025).
1. Training–Inference Discrepancy in MoE Routing
Reinforcement learning for LLMs commonly separates the rollout (inference) and gradient (training) computation engines. In dense LMs, these pipelines produce nearly identical token-level probability distributions. However, MoE models use a sparse expert gating mechanism: at each token position, the output depends on a sparse gate vector determined by a routing mask applied to the router logits. Minor numerical or implementation differences across inference and training can flip which experts are selected, causing dramatically different token outputs.
Empirical observations quantify this effect:
- On Qwen3-30B-A3B, for MoE (nearly 2 that of dense baselines).
- 10% of router calls select different experts between training and inference; at the token level, 94% of tokens differ in at least one layer.
This expert selection mismatch inflates the variance of the PPO importance weights ; high variance in leads to unstable policy optimization, clipped ratios, and frequent RL collapse, where policy updates drift to trivial or degenerate behaviors (Ma et al., 13 Oct 2025).
2. Formal Specification of Rollout Routing Replay
At each token step in MoE layer , denote:
- 0: activation in the inference engine
- 1: router logits
- 2: top-3 mask of selected experts
The inference-time routing distribution:
4
During training, the system traditionally recomputes router logits 5 and applies its own mask 6. Rollout Routing Replay instead overrides the training mask with 7 and computes a replayed gate:
8
resulting in the training routing distribution:
9
This framework views R0 as minimizing:
1
Empirical measurements show R2 halves the MoE train–infer KL divergence (from 3 to 4), thereby matching dense model behavior (5).
3. Algorithmic Workflow and Implementation
R6 executes two phases per global RL step:
1. Rollout (Inference Engine):
- For each sample in a batch, sequence tokens are generated autoregressively.
- At each token and MoE layer, compute 7 and 8 (top-9 experts), cache mask per token per layer.
- Store 0 for training.
2. Training (Training Engine with R1):
- For each minibatch, at each token and layer:
- Compute 2, retrieve cached 3
- Form 4 as above
- Evaluate policy and PPO loss using 5
- Backpropagate gradients
Crucially, only the mask is overridden; gradients still flow into 6. R7 is applied in both "old" and "new" policy branches during on-policy PPO. Pseudocode and further details can be found in (Ma et al., 13 Oct 2025).
4. Theoretical Properties and Stability Effects
R8 enforces near-identity between the training and inference routing distributions under small parameter updates (9 as 0), causing 1. Empirically, the fraction of "extreme tokens," defined as 2, drops by an order of magnitude for 3 with R4.
As a result, importance weights 5 become tightly centered around 1, correlating with stable PPO policy updates and eliminating the policy collapse observed in non-replayed MoE RL (Ma et al., 13 Oct 2025).
5. Empirical Evaluation and Benchmarks
R6 has been evaluated on mathematical reasoning tasks over roughly 100,000 problems, with metrics reported on AIME24/25 (Avg@32), AMC23 (Avg@16), and MATH500 Lv5 (Avg@4). Using Qwen3-30B-A3B (Base and SFT), the following baselines were compared:
- GRPO: token-level clipped PPO
- GSPO: sequence-level importance sampling
- TIS: truncated importance sampling
Results indicate:
- Multi-step SFT: GSPO baseline achieves Avg 7, GSPO+R8: Avg 9 (+2.2), GRPO+R0: Avg 1 (+1.3)
- Single-step SFT: GRPO collapses at step 60 (Avg 62.2), GRPO+TIS: Avg 66.2, GRPO+R2: Avg 71.8 (+5.6 over TIS), with no collapse
- Similar gains observed with Base model
Stability curves demonstrate that D3 and F(2) escalate during collapsing runs but remain below 4 with R5. Generation dynamics show smoother gradient norms, steadier entropy, and faster reward gains with R6 (Ma et al., 13 Oct 2025).
6. Ablations and Sensitivity Analysis
Key ablation findings include:
- R7 in combination with TIS provides no additive benefit or may slightly degrade performance, suggesting R8 nearly closes the off-policy gap alone.
- R9 stabilizes both single- and multi-mini-step PPO, with especially pronounced advantages in the fragile single-step regime.
- Caching 0 masks (vs. recomputing) yields no fidelity loss but improves computational efficiency for long-context or agent scenarios.
- Default 1 is robust, with preliminary tests showing similar gains for 2.
7. Practical Implications and Limitations
Implementing R3 entails certain costs and constraints:
- Memory/storage overhead increases as 4 due to per-token, per-layer caching of router masks; prefix caching is recommended for long sequences or dialogs.
- Framework support must enable extraction and injection of routing masks between rollout and training engines; this is straightforward if router code is shared but may require engineering otherwise.
- R5 addresses routing-induced discrepancies alone; other nondeterministic elements (e.g., kernel-level or architectural mismatches) may still produce residual train–infer gaps.
- While demonstrated primarily for MoE Transformers on math reasoning tasks, the methodology generalizes to any sparse routing module (including conditional computation) and other RL applications (e.g., code generation, logical reasoning).
Future work may adapt R6 to diverse router architectures (e.g., noisy top-7, Gumbel-Softmax), low-precision kernels, or distributed multi-node inference, but the protocol—record routing decisions at inference, replay in training—remains conceptually constant.
In summary, Rollout Routing Replay (R8) directly targets MoE router-induced train–infer mismatch, delivering a two-fold reduction in train–infer KL divergence, a substantial decrease in outlier tokens, eliminating PPO collapse, and yielding consistent 1–6 point improvements in downstream RL task performance (Ma et al., 13 Oct 2025).