Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
157 tokens/sec
GPT-4o
8 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

RAST: Reasoning Activation in LLMs via Small-model Transfer (2506.15710v1)

Published 30 May 2025 in cs.LG and cs.AI

Abstract: Reinforcement learning (RL) has become a powerful approach for improving the reasoning capabilities of LLMs, as evidenced by recent successes such as OpenAI's o1 and Deepseek-R1. However, applying RL at scale remains intimidatingly resource-intensive, requiring multiple model copies and extensive GPU workloads. On the other hand, while being powerful, recent studies suggest that RL does not fundamentally endow models with new knowledge; rather, it primarily reshapes the model's output distribution to activate reasoning capabilities latent in the base model. Building on this insight, we hypothesize that the changes in output probabilities induced by RL are largely model-size invariant, opening the door to a more efficient paradigm: training a small model with RL and transferring its induced probability shifts to larger base models. To verify our hypothesis, we conduct a token-level analysis of decoding trajectories and find high alignment in RL-induced output distributions across model scales, validating our hypothesis. Motivated by this, we propose RAST, a simple yet effective method that transfers reasoning behaviors by injecting RL-induced probability adjustments from a small RL-trained model into larger models. Experiments across multiple mathematical reasoning benchmarks show that RAST substantially and consistently enhances the reasoning capabilities of base models while requiring significantly lower GPU memory than direct RL training, sometimes even yielding better performance than the RL-trained counterparts. Our findings offer new insights into the nature of RL-driven reasoning and practical strategies for scaling its benefits without incurring its full computational cost. The project page of RAST is available at https://ozyyshr.github.io/RAST/.

Summary

  • The paper presents RAST, a novel method that transfers reasoning corrections from a small RL-trained model to a large LLM without full RL fine-tuning.
  • It computes delta logits from the small model’s RL and base versions to adjust the large model's token probabilities and activate latent reasoning behaviors.
  • Experiments demonstrate significant gains on math and code tasks with improved recovery rates and reduced GPU memory needs compared to direct RL training.

The paper "RAST: Reasoning Activation in LLMs via Small-model Transfer" (2506.15710) introduces a method to enhance the reasoning capabilities of LLMs without the high computational cost typically associated with Reinforcement Learning (RL) fine-tuning. The core idea is that RL primarily activates latent reasoning abilities already present in a base model by selectively adjusting the probabilities of a few key tokens, rather than teaching new knowledge. The authors hypothesize that these RL-induced probability shifts are largely invariant to model size.

Building on this, they propose RAST (Reasoning Activation in LLMs via Small-model Transfer), a decoding-time method. RAST transfers the "reasoning patterns" learned by a small RL-trained model (SRL\mathcal{S}_{\text{RL}}) to a larger base model (Mbase\mathcal{M}_{\text{base}}). This is achieved by calculating the difference in logit outputs between the small RL-trained model and its own base version (Sbase\mathcal{S}_{\text{base}}). This difference, denoted as ΔR\Delta R, represents the reasoning-oriented adjustments. At each decoding step, this ΔR\Delta R is added to the logits of the large base model Mbase\mathcal{M}_{\text{base}} before applying softmax to get the final token probabilities. The formula for the enhanced model M~\tilde{\mathcal{M}}'s probability distribution is:

PM~(Xtx<t)=softmax[logitsMbase(Xtx<t)+λ(logitsSRL(Xtx<t)logitsSbase(Xtx<t))]P_{\tilde{\mathcal{M}}(X_t \mid x_{<t})} = \mathrm{softmax} \left[ \text{logits}_{\mathcal{M}_{\text{base}}}(X_t \mid x_{<t}) + \lambda(\text{logits}_{\mathcal{S}_{\text{RL}}}(X_t \mid x_{<t}) - \text{logits}_{\mathcal{S}_{\text{base}}}(X_t \mid x_{<t})) \right]

where λ\lambda is a hyperparameter controlling the strength of the adjustment.

A preliminary paper supports their hypothesis, showing a high Path Coverage Rate (PCR > 95%). This means that if given the token sequence generated by an RL-trained model, the base model would have predicted the same next token most of the time. The differing tokens often correspond to key reasoning behaviors like self-verification or backtracking.

Implementation and Experiments:

  • Models: The primary models are from the Qwen-2.5 family (1.5B, 7B, 14B, 32B) and their RL-trained versions from SimpleRL-Zoo. Llama-3.1 (8B, 70B) and Code-R1 models are also used for generalization experiments.
  • Tasks:
    • Mathematical Reasoning: MATH500, AIME24, AMC23, Minerva, OlympiadBench, GSM8K.
    • Code Reasoning: HumanEval+, MBPP+, LiveCodeBench.
  • Evaluation Metrics:
    • Pass@k: Whether at least one correct solution is found in kk samples.
    • Recovery Rate: How much of the performance gap between the base model and its RL-trained version is recovered by RAST.
  • Decoding: For math, temperature 1.0, top-p 0.95, max length 16,384 tokens. λ\lambda is set to 1.0. For code, greedy decoding is used.
  • Infrastructure: A revised vLLM version is implemented for RAST. Experiments run on 8 NVIDIA A6000 GPUs.

Key Findings:

  1. Consistent Performance Gains: RAST substantially improves the reasoning capabilities of base models across various mathematical benchmarks. For example, Qwen-2.5-32B + ΔR14B\Delta R_{14B} (logits from a 14B RL model) achieves 80.7% on MATH500, nearing the 81.3% of the fully RL-trained SimpleRL-32B, and significantly up from the 32B base model's 68.6%.
  2. Scalability of ΔR\Delta R: Using ΔR\Delta R from stronger (larger) small RL models generally leads to greater improvements in the large base model.
  3. Trade-offs:
    • Stronger base models (Mbase\mathcal{M}_{\text{base}}) tend to have higher recovery rates, indicating better receptiveness to transferred signals.
    • However, a very large capability gap between the base model and the source of ΔR\Delta R might hinder effective transfer.
  4. Increased Reasoning Diversity: RAST improves pass@k performance, often matching or even exceeding that of the fully RL-trained counterpart (MRL\mathcal{M}_{\text{RL}}). This suggests RAST helps explore a more diverse solution space.
  5. Generalization: RAST shows positive results when applied to the Llama-3.1 model family and also improves performance on code reasoning tasks.
  6. Efficiency: RAST significantly reduces GPU memory requirements compared to direct RL training. For instance, enhancing a 32B model with ΔR14B\Delta R_{14B} requires ~160GB GPU memory, while training a 32B model with RL (GRPO) needs ~350GB. It achieves high recovery rates (e.g., 84.8% for 32B + ΔR14B\Delta R_{14B}) with this reduced overhead.
  7. Signal Alignment: Higher cosine similarity between ΔR\Delta R signals from different model scales correlates with better transferability (higher recovery rates).
  8. Behavioral Shift: RAST guides the base model to exhibit reasoning behaviors similar to RL-trained models, such as increased self-verification. This is evidenced by higher KL divergence on specific reasoning-related tokens (e.g., "check") when comparing RAST-enhanced model outputs to base model outputs.
  9. Robustness: The method is shown to be robust to moderate variations in decoding hyperparameters like temperature (τ\tau) and the adjustment strength (λ\lambda).

Practical Implementation Steps for RAST:

  1. Select Models:
    • Choose a large base model (Mbase\mathcal{M}_{\text{base}}) you want to enhance (e.g., Qwen-2.5-32B).
    • Choose a smaller base model (Sbase\mathcal{S}_{\text{base}}) and its RL-fine-tuned version (SRL\mathcal{S}_{\text{RL}}) (e.g., Qwen-2.5-14B as Sbase\mathcal{S}_{\text{base}} and SimpleRL-14B as SRL\mathcal{S}_{\text{RL}}). Ensure Sbase\mathcal{S}_{\text{base}} and SRL\mathcal{S}_{\text{RL}} share the same tokenizer and architecture.
  2. Inference Setup:
    • Load all three models: Mbase\mathcal{M}_{\text{base}}, Sbase\mathcal{S}_{\text{base}}, and SRL\mathcal{S}_{\text{RL}}. This will increase memory requirements compared to inferencing with Mbase\mathcal{M}_{\text{base}} alone, but is significantly less than full RL training. The paper mentions using a modified vLLM for efficient inference.
  3. Decoding Process (per token generation):
    • For a given input prefix x<tx_{<t}: a. Obtain logits from Mbase\mathcal{M}_{\text{base}}: LMbase=logitsMbase(Xtx<t)L_{\mathcal{M}_{\text{base}}} = \text{logits}_{\mathcal{M}_{\text{base}}}(X_t \mid x_{<t}). b. Obtain logits from Sbase\mathcal{S}_{\text{base}}: LSbase=logitsSbase(Xtx<t)L_{\mathcal{S}_{\text{base}}} = \text{logits}_{\mathcal{S}_{\text{base}}}(X_t \mid x_{<t}). c. Obtain logits from SRL\mathcal{S}_{\text{RL}}: LSRL=logitsSRL(Xtx<t)L_{\mathcal{S}_{\text{RL}}} = \text{logits}_{\mathcal{S}_{\text{RL}}}(X_t \mid x_{<t}). d. Calculate the delta logits: ΔR=LSRLLSbase\Delta R = L_{\mathcal{S}_{\text{RL}}} - L_{\mathcal{S}_{\text{base}}}. e. Calculate the modified logits for the enhanced model M~\tilde{\mathcal{M}}: LM~=LMbase+λΔRL_{\tilde{\mathcal{M}}} = L_{\mathcal{M}_{\text{base}}} + \lambda \cdot \Delta R. f. Apply softmax to LM~L_{\tilde{\mathcal{M}}} to get the probability distribution PM~(Xtx<t)P_{\tilde{\mathcal{M}}}(X_t \mid x_{<t}). g. Sample the next token XtX_t based on this distribution using desired decoding strategies (e.g., nucleus sampling, greedy).
  4. Hyperparameter Tuning:
    • The paper primarily uses λ=1.0\lambda=1.0. However, Figure 9 shows robust performance for λ[0.3,1.5]\lambda \in [0.3, 1.5]. This might need tuning for different model combinations or tasks.
    • Standard decoding parameters like temperature and top-p should also be considered.

Pseudocode for RAST Decoding:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
function RAST_generate_token(M_base, S_base, S_RL, current_prefix, lambda_val):
  # Get logits from the large base model
  logits_M_base = M_base.get_logits(current_prefix)

  # Get logits from the small base model
  logits_S_base = S_base.get_logits(current_prefix)

  # Get logits from the small RL-tuned model
  logits_S_RL = S_RL.get_logits(current_prefix)

  # Calculate the delta R (reasoning correction signal)
  delta_R = logits_S_RL - logits_S_base

  # Apply the correction to the large base model's logits
  final_logits = logits_M_base + lambda_val * delta_R

  # Convert logits to probabilities
  probabilities = softmax(final_logits)

  # Sample the next token based on probabilities
  next_token = sample_from_distribution(probabilities)

  return next_token

Considerations:

  • Computational Cost: While more efficient than full RL, RAST requires running three models simultaneously during inference, which increases VRAM usage and latency compared to using only the large base model. The size of Sbase\mathcal{S}_{\text{base}} and SRL\mathcal{S}_{\text{RL}} will impact this overhead.
  • Availability of RL-tuned Small Models: The method relies on having access to a capable smaller RL-tuned model (SRL\mathcal{S}_{\text{RL}}) and its base version (Sbase\mathcal{S}_{\text{base}}).
  • Prompt Consistency: As noted in Table 1 (footnote \dagger), if the prompt template used for RAST inference (for Mbase\mathcal{M}_{\text{base}}) doesn't match the one used to train SRL\mathcal{S}_{\text{RL}}, the gains might be smaller. This highlights the importance of aligning the context in which ΔR\Delta R is generated and applied.

In conclusion, RAST offers a practical and significantly more resource-efficient alternative to direct RL training for enhancing the reasoning abilities of large LLMs by transferring learned reasoning adjustments from smaller, RL-tuned models at inference time. It leverages the insight that RL often refines existing capabilities rather than imparting entirely new ones.