- The paper presents RAST, a novel method that transfers reasoning corrections from a small RL-trained model to a large LLM without full RL fine-tuning.
- It computes delta logits from the small model’s RL and base versions to adjust the large model's token probabilities and activate latent reasoning behaviors.
- Experiments demonstrate significant gains on math and code tasks with improved recovery rates and reduced GPU memory needs compared to direct RL training.
The paper "RAST: Reasoning Activation in LLMs via Small-model Transfer" (2506.15710) introduces a method to enhance the reasoning capabilities of LLMs without the high computational cost typically associated with Reinforcement Learning (RL) fine-tuning. The core idea is that RL primarily activates latent reasoning abilities already present in a base model by selectively adjusting the probabilities of a few key tokens, rather than teaching new knowledge. The authors hypothesize that these RL-induced probability shifts are largely invariant to model size.
Building on this, they propose RAST (Reasoning Activation in LLMs via Small-model Transfer), a decoding-time method. RAST transfers the "reasoning patterns" learned by a small RL-trained model (SRL) to a larger base model (Mbase). This is achieved by calculating the difference in logit outputs between the small RL-trained model and its own base version (Sbase). This difference, denoted as ΔR, represents the reasoning-oriented adjustments. At each decoding step, this ΔR is added to the logits of the large base model Mbase before applying softmax to get the final token probabilities. The formula for the enhanced model M~'s probability distribution is:
PM~(Xt∣x<t)=softmax[logitsMbase(Xt∣x<t)+λ(logitsSRL(Xt∣x<t)−logitsSbase(Xt∣x<t))]
where λ is a hyperparameter controlling the strength of the adjustment.
A preliminary paper supports their hypothesis, showing a high Path Coverage Rate (PCR > 95%). This means that if given the token sequence generated by an RL-trained model, the base model would have predicted the same next token most of the time. The differing tokens often correspond to key reasoning behaviors like self-verification or backtracking.
Implementation and Experiments:
- Models: The primary models are from the Qwen-2.5 family (1.5B, 7B, 14B, 32B) and their RL-trained versions from SimpleRL-Zoo. Llama-3.1 (8B, 70B) and Code-R1 models are also used for generalization experiments.
- Tasks:
- Mathematical Reasoning: MATH500, AIME24, AMC23, Minerva, OlympiadBench, GSM8K.
- Code Reasoning: HumanEval+, MBPP+, LiveCodeBench.
- Evaluation Metrics:
- Pass@k: Whether at least one correct solution is found in k samples.
- Recovery Rate: How much of the performance gap between the base model and its RL-trained version is recovered by RAST.
- Decoding: For math, temperature 1.0, top-p 0.95, max length 16,384 tokens. λ is set to 1.0. For code, greedy decoding is used.
- Infrastructure: A revised vLLM version is implemented for RAST. Experiments run on 8 NVIDIA A6000 GPUs.
Key Findings:
- Consistent Performance Gains: RAST substantially improves the reasoning capabilities of base models across various mathematical benchmarks. For example, Qwen-2.5-32B + ΔR14B (logits from a 14B RL model) achieves 80.7% on MATH500, nearing the 81.3% of the fully RL-trained SimpleRL-32B, and significantly up from the 32B base model's 68.6%.
- Scalability of ΔR: Using ΔR from stronger (larger) small RL models generally leads to greater improvements in the large base model.
- Trade-offs:
- Stronger base models (Mbase) tend to have higher recovery rates, indicating better receptiveness to transferred signals.
- However, a very large capability gap between the base model and the source of ΔR might hinder effective transfer.
- Increased Reasoning Diversity: RAST improves pass@k performance, often matching or even exceeding that of the fully RL-trained counterpart (MRL). This suggests RAST helps explore a more diverse solution space.
- Generalization: RAST shows positive results when applied to the Llama-3.1 model family and also improves performance on code reasoning tasks.
- Efficiency: RAST significantly reduces GPU memory requirements compared to direct RL training. For instance, enhancing a 32B model with ΔR14B requires ~160GB GPU memory, while training a 32B model with RL (GRPO) needs ~350GB. It achieves high recovery rates (e.g., 84.8% for 32B + ΔR14B) with this reduced overhead.
- Signal Alignment: Higher cosine similarity between ΔR signals from different model scales correlates with better transferability (higher recovery rates).
- Behavioral Shift: RAST guides the base model to exhibit reasoning behaviors similar to RL-trained models, such as increased self-verification. This is evidenced by higher KL divergence on specific reasoning-related tokens (e.g., "check") when comparing RAST-enhanced model outputs to base model outputs.
- Robustness: The method is shown to be robust to moderate variations in decoding hyperparameters like temperature (τ) and the adjustment strength (λ).
Practical Implementation Steps for RAST:
- Select Models:
- Choose a large base model (Mbase) you want to enhance (e.g., Qwen-2.5-32B).
- Choose a smaller base model (Sbase) and its RL-fine-tuned version (SRL) (e.g., Qwen-2.5-14B as Sbase and SimpleRL-14B as SRL). Ensure Sbase and SRL share the same tokenizer and architecture.
- Inference Setup:
- Load all three models: Mbase, Sbase, and SRL. This will increase memory requirements compared to inferencing with Mbase alone, but is significantly less than full RL training. The paper mentions using a modified vLLM for efficient inference.
- Decoding Process (per token generation):
- For a given input prefix x<t:
a. Obtain logits from Mbase: LMbase=logitsMbase(Xt∣x<t).
b. Obtain logits from Sbase: LSbase=logitsSbase(Xt∣x<t).
c. Obtain logits from SRL: LSRL=logitsSRL(Xt∣x<t).
d. Calculate the delta logits: ΔR=LSRL−LSbase.
e. Calculate the modified logits for the enhanced model M~: LM~=LMbase+λ⋅ΔR.
f. Apply softmax to LM~ to get the probability distribution PM~(Xt∣x<t).
g. Sample the next token Xt based on this distribution using desired decoding strategies (e.g., nucleus sampling, greedy).
- Hyperparameter Tuning:
- The paper primarily uses λ=1.0. However, Figure 9 shows robust performance for λ∈[0.3,1.5]. This might need tuning for different model combinations or tasks.
- Standard decoding parameters like temperature and top-p should also be considered.
Pseudocode for RAST Decoding:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
function RAST_generate_token(M_base, S_base, S_RL, current_prefix, lambda_val):
# Get logits from the large base model
logits_M_base = M_base.get_logits(current_prefix)
# Get logits from the small base model
logits_S_base = S_base.get_logits(current_prefix)
# Get logits from the small RL-tuned model
logits_S_RL = S_RL.get_logits(current_prefix)
# Calculate the delta R (reasoning correction signal)
delta_R = logits_S_RL - logits_S_base
# Apply the correction to the large base model's logits
final_logits = logits_M_base + lambda_val * delta_R
# Convert logits to probabilities
probabilities = softmax(final_logits)
# Sample the next token based on probabilities
next_token = sample_from_distribution(probabilities)
return next_token |
Considerations:
- Computational Cost: While more efficient than full RL, RAST requires running three models simultaneously during inference, which increases VRAM usage and latency compared to using only the large base model. The size of Sbase and SRL will impact this overhead.
- Availability of RL-tuned Small Models: The method relies on having access to a capable smaller RL-tuned model (SRL) and its base version (Sbase).
- Prompt Consistency: As noted in Table 1 (footnote †), if the prompt template used for RAST inference (for Mbase) doesn't match the one used to train SRL, the gains might be smaller. This highlights the importance of aligning the context in which ΔR is generated and applied.
In conclusion, RAST offers a practical and significantly more resource-efficient alternative to direct RL training for enhancing the reasoning abilities of large LLMs by transferring learned reasoning adjustments from smaller, RL-tuned models at inference time. It leverages the insight that RL often refines existing capabilities rather than imparting entirely new ones.