Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash 90 tok/s
Gemini 2.5 Pro 53 tok/s Pro
GPT-5 Medium 41 tok/s
GPT-5 High 42 tok/s Pro
GPT-4o 109 tok/s
GPT OSS 120B 477 tok/s Pro
Kimi K2 222 tok/s Pro
2000 character limit reached

Thinking Slow, Fast: Scaling Inference Compute with Distilled Reasoners (2502.20339v1)

Published 27 Feb 2025 in cs.CL and cs.AI

Abstract: Recent advancements have demonstrated that the performance of LLMs can be significantly enhanced by scaling computational resources at test time. A common strategy involves generating multiple Chain-of-Thought (CoT) trajectories and aggregating their outputs through various selection mechanisms. This raises a fundamental question: can models with lower complexity leverage their superior generation throughput to outperform similarly sized Transformers for a fixed computational budget? To address this question and overcome the lack of strong subquadratic reasoners, we distill pure and hybrid Mamba models from pretrained Transformers. Trained on only 8 billion tokens, our distilled models show strong performance and scaling on mathematical reasoning datasets while being much faster at inference for large batches and long sequences. Despite the zero-shot performance hit due to distillation, both pure and hybrid Mamba models can scale their coverage and accuracy performance past their Transformer teacher models under fixed time budgets, opening a new direction for scaling inference compute.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

Summary

  • The paper presents a method to scale LLM inference compute by distilling knowledge from Transformer models into faster subquadratic architectures like Mamba.
  • Distilled models show significant speedups (up to 4.2x), allowing them to achieve better accuracy and coverage on reasoning tasks within fixed computation budgets than baseline Transformers.
  • The results indicate that subquadratic architectures like Mamba are effective and scalable alternatives to Transformers for tasks that benefit from high-throughput inference.

The paper introduces a method for scaling inference compute in LLMs by distilling knowledge from Transformer models into subquadratic architectures, specifically Mamba models. The central idea is to leverage the faster generation throughput of Mamba to compensate for any potential loss in per-sample accuracy compared to Transformers, thus achieving better performance under fixed computational budgets.

The paper addresses the challenge that while scaling inference compute via techniques like Chain-of-Thought (CoT) and majority voting has proven effective for improving reasoning in LLMs, the memory-bound nature of Transformers during generation limits their scalability. Subquadratic architectures like Mamba offer linear time complexity during training or prefill and constant memory requirements during inference, which enables higher inference throughput. However, the lack of large-scale pretrained subquadratic models hinders their application in reasoning tasks.

To overcome this, the authors propose distilling knowledge from pretrained Transformer models (Llama 3.2-1B-Instruct and Llama 3.2-3B-Instruct) into pure Mamba (dubbed Llamba) and hybrid Mamba-Transformer architectures (dubbed MambaInLlama). The distillation process involves:

  • Matrix Orientation: Aligning the Mamba model's SSM matrix mixer with the teacher's self-attention matrix by minimizing the distance between the two matrices.
  • Hidden State Alignment: Matching the student and teacher's layers' hidden state outputs.
  • Weight Transfer and Knowledge Distillation: Transferring the remaining, unoptimized parameters (e.g., MLPs, embeddings, and norms), and finetuning the complete end-to-end student model using a distillation loss on the student and teacher logits.

For the hybrid models, the authors modify the protocol proposed by a previous paper, to distill specific capabilities. In this approach, the linear projections for QQ, KK, VV, and OO are initialized using the corresponding linear projections for WqW_q, WkW_k, WvW_v, and WoW_o, respectively. The only additional learned parameters in the new layers are the sampling rate Δ\Delta and the dynamic AA. The reverse Kullback–Leibler (KL) divergence is used as the loss function. The Mamba-1 architecture is used for the hybrid models.

The distilled models are then evaluated on mathematical reasoning tasks (MATH and GSM8K) using multiple CoT completions, and their performance is analyzed under fixed compute and memory constraints. The evaluation metrics include coverage (pass@k) and accuracy, with aggregation strategies like majority voting and weighted Best-of-N (using a reward model trained with process supervision to select the best response).

The results demonstrate that the distilled models exhibit significantly faster inference speeds compared to their Transformer counterparts. For example, the distilled models are shown to be up to ×3.7\times3.7 and ×4.2\times4.2 faster than their respective Llama 1B and 3B baselines. This speedup enables the distilled models to generate more completions within a given time budget, which translates to better coverage and accuracy on the reasoning tasks. The hybrid Mamba-Transformer models (MambaInLlama) are found to be slightly faster than the pure Mamba models (Llamba), potentially due to the smaller SSM state size.

The paper also highlights the lack of correlation between common multiple choice-based benchmarks and mathematical reasoning, as well as the significant impact of the distillation dataset on the final capabilities of the distilled models. Furthermore, the authors show that supervised fine-tuning (SFT) after distillation can further improve the accuracy and coverage of the distilled models.

Specifically, the paper finds that the distilled models can achieve the same degree of coverage in nearly half the time of their respective teachers. Also, the lighter and faster batch inference of the distilled models results in a better accuracy/time Pareto front at several completion scales. The authors conjecture that it would be possible to distill a larger baseline model that would outperform Llama-3B, providing better accuracy for a given time budget.

In conclusion, the paper provides evidence that subquadratic architectures like Mamba can be effectively used for scaling inference compute in LLMs, achieving better performance than Transformers under fixed computational budgets. The findings suggest that Mamba and other attention alternatives are strong substitutes to Transformers for tasks that benefit from scalable inference compute.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.

Don't miss out on important new AI/ML research

See which papers are being discussed right now on X, Reddit, and more:

“Emergent Mind helps me see which AI papers have caught fire online.”

Philip

Philip

Creator, AI Explained on YouTube