Papers
Topics
Authors
Recent
Search
2000 character limit reached

Rollout Routing Replay (R3)

Updated 9 May 2026
  • Rollout Routing Replay (R3) is a method that stabilizes reinforcement learning in Mixture-of-Experts models by replaying the inference routing mask during training.
  • The technique halves the train–infer KL divergence by aligning the training and inference routing distributions, thereby reducing variance in importance weights.
  • Empirical results on math reasoning tasks show that R3 prevents PPO policy collapse and improves performance by 1–6 points compared to standard methods.

Rollout Routing Replay (R3^3) is a methodology for stabilizing reinforcement learning (RL) in Mixture-of-Experts (MoE) transformer models by eliminating discrepancies between routing decisions made during inference (rollout generation) and those during training. The approach addresses the fundamental instability induced by non-deterministic or diverging expert selection across these phases, which manifests acutely in RL settings, leading to inflated policy KL divergence, high-variance importance weights, and frequent policy collapse. R3^3 rectifies the mismatch by logging the expert selection mask during inference and deterministically replaying it in training, thus yielding a training policy closely aligned to inference-time decisions and dramatically improving RL stability and final task performance (Ma et al., 13 Oct 2025).

1. Training–Inference Discrepancy in MoE Routing

Reinforcement learning for LLMs commonly separates the rollout (inference) and gradient (training) computation engines. In dense LMs, these pipelines produce nearly identical token-level probability distributions. However, MoE models use a sparse expert gating mechanism: at each token position, the output y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x) depends on a sparse gate vector gRMg \in \mathbb{R}^M determined by a routing mask applied to the router logits. Minor numerical or implementation differences across inference and training can flip which experts are selected, causing dramatically different token outputs.

Empirical observations quantify this effect:

  • On Qwen3-30B-A3B, DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3} for MoE (nearly 2×\times that of dense baselines).
  • 10% of router calls select different experts between training and inference; at the token level, 94% of tokens differ in at least one layer.

This expert selection mismatch inflates the variance of the PPO importance weights wt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t}); high variance in wtw_t leads to unstable policy optimization, clipped ratios, and frequent RL collapse, where policy updates drift to trivial or degenerate behaviors (Ma et al., 13 Oct 2025).

2. Formal Specification of Rollout Routing Replay

At each token step tt in MoE layer \ell, denote:

  • 3^30: activation in the inference engine
  • 3^31: router logits
  • 3^32: top-3^33 mask of selected experts

The inference-time routing distribution:

3^34

During training, the system traditionally recomputes router logits 3^35 and applies its own mask 3^36. Rollout Routing Replay instead overrides the training mask with 3^37 and computes a replayed gate:

3^38

resulting in the training routing distribution:

3^39

This framework views Ry=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)0 as minimizing:

y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)1

Empirical measurements show Ry=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)2 halves the MoE train–infer KL divergence (from y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)3 to y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)4), thereby matching dense model behavior (y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)5).

3. Algorithmic Workflow and Implementation

Ry=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)6 executes two phases per global RL step:

1. Rollout (Inference Engine):

  • For each sample in a batch, sequence tokens are generated autoregressively.
  • At each token and MoE layer, compute y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)7 and y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)8 (top-y=i=1MgiEi(x)y = \sum_{i=1}^M g_i \cdot E_i(x)9 experts), cache mask per token per layer.
  • Store gRMg \in \mathbb{R}^M0 for training.

2. Training (Training Engine with RgRMg \in \mathbb{R}^M1):

  • For each minibatch, at each token and layer:
    • Compute gRMg \in \mathbb{R}^M2, retrieve cached gRMg \in \mathbb{R}^M3
    • Form gRMg \in \mathbb{R}^M4 as above
    • Evaluate policy and PPO loss using gRMg \in \mathbb{R}^M5
    • Backpropagate gradients

Crucially, only the mask is overridden; gradients still flow into gRMg \in \mathbb{R}^M6. RgRMg \in \mathbb{R}^M7 is applied in both "old" and "new" policy branches during on-policy PPO. Pseudocode and further details can be found in (Ma et al., 13 Oct 2025).

4. Theoretical Properties and Stability Effects

RgRMg \in \mathbb{R}^M8 enforces near-identity between the training and inference routing distributions under small parameter updates (gRMg \in \mathbb{R}^M9 as DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}0), causing DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}1. Empirically, the fraction of "extreme tokens," defined as DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}2, drops by an order of magnitude for DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}3 with RDKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}4.

As a result, importance weights DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}5 become tightly centered around 1, correlating with stable PPO policy updates and eliminating the policy collapse observed in non-replayed MoE RL (Ma et al., 13 Oct 2025).

5. Empirical Evaluation and Benchmarks

RDKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}6 has been evaluated on mathematical reasoning tasks over roughly 100,000 problems, with metrics reported on AIME24/25 (Avg@32), AMC23 (Avg@16), and MATH500 Lv5 (Avg@4). Using Qwen3-30B-A3B (Base and SFT), the following baselines were compared:

Results indicate:

  • Multi-step SFT: GSPO baseline achieves Avg DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}7, GSPO+RDKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}8: Avg DKL(πtrainπinfer)1.5×103\mathrm{D_{KL}}(\pi_\mathrm{train} \| \pi_\mathrm{infer}) \approx 1.5 \times 10^{-3}9 (+2.2), GRPO+R×\times0: Avg ×\times1 (+1.3)
  • Single-step SFT: GRPO collapses at step 60 (Avg 62.2), GRPO+TIS: Avg 66.2, GRPO+R×\times2: Avg 71.8 (+5.6 over TIS), with no collapse
  • Similar gains observed with Base model

Stability curves demonstrate that D×\times3 and F(2) escalate during collapsing runs but remain below ×\times4 with R×\times5. Generation dynamics show smoother gradient norms, steadier entropy, and faster reward gains with R×\times6 (Ma et al., 13 Oct 2025).

6. Ablations and Sensitivity Analysis

Key ablation findings include:

  • R×\times7 in combination with TIS provides no additive benefit or may slightly degrade performance, suggesting R×\times8 nearly closes the off-policy gap alone.
  • R×\times9 stabilizes both single- and multi-mini-step PPO, with especially pronounced advantages in the fragile single-step regime.
  • Caching wt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})0 masks (vs. recomputing) yields no fidelity loss but improves computational efficiency for long-context or agent scenarios.
  • Default wt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})1 is robust, with preliminary tests showing similar gains for wt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})2.

7. Practical Implications and Limitations

Implementing Rwt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})3 entails certain costs and constraints:

  • Memory/storage overhead increases as wt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})4 due to per-token, per-layer caching of router masks; prefix caching is recommended for long sequences or dialogs.
  • Framework support must enable extraction and injection of routing masks between rollout and training engines; this is straightforward if router code is shared but may require engineering otherwise.
  • Rwt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})5 addresses routing-induced discrepancies alone; other nondeterministic elements (e.g., kernel-level or architectural mismatches) may still produce residual train–infer gaps.
  • While demonstrated primarily for MoE Transformers on math reasoning tasks, the methodology generalizes to any sparse routing module (including conditional computation) and other RL applications (e.g., code generation, logical reasoning).

Future work may adapt Rwt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})6 to diverse router architectures (e.g., noisy top-wt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})7, Gumbel-Softmax), low-precision kernels, or distributed multi-node inference, but the protocol—record routing decisions at inference, replay in training—remains conceptually constant.

In summary, Rollout Routing Replay (Rwt=πtrain(ytx,y<t)/πtrain,old(ytx,y<t)w_t = \pi_\mathrm{train}(y_t|x, y_{<t}) / \pi_\mathrm{train,old}(y_t|x, y_{<t})8) directly targets MoE router-induced train–infer mismatch, delivering a two-fold reduction in train–infer KL divergence, a substantial decrease in outlier tokens, eliminating PPO collapse, and yielding consistent 1–6 point improvements in downstream RL task performance (Ma et al., 13 Oct 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Rollout Routing Replay (R$^3$).