- The paper presents a method to scale LLM inference compute by distilling knowledge from Transformer models into faster subquadratic architectures like Mamba.
- Distilled models show significant speedups (up to 4.2x), allowing them to achieve better accuracy and coverage on reasoning tasks within fixed computation budgets than baseline Transformers.
- The results indicate that subquadratic architectures like Mamba are effective and scalable alternatives to Transformers for tasks that benefit from high-throughput inference.
The paper introduces a method for scaling inference compute in LLMs by distilling knowledge from Transformer models into subquadratic architectures, specifically Mamba models. The central idea is to leverage the faster generation throughput of Mamba to compensate for any potential loss in per-sample accuracy compared to Transformers, thus achieving better performance under fixed computational budgets.
The paper addresses the challenge that while scaling inference compute via techniques like Chain-of-Thought (CoT) and majority voting has proven effective for improving reasoning in LLMs, the memory-bound nature of Transformers during generation limits their scalability. Subquadratic architectures like Mamba offer linear time complexity during training or prefill and constant memory requirements during inference, which enables higher inference throughput. However, the lack of large-scale pretrained subquadratic models hinders their application in reasoning tasks.
To overcome this, the authors propose distilling knowledge from pretrained Transformer models (Llama 3.2-1B-Instruct and Llama 3.2-3B-Instruct) into pure Mamba (dubbed Llamba) and hybrid Mamba-Transformer architectures (dubbed MambaInLlama). The distillation process involves:
- Matrix Orientation: Aligning the Mamba model's SSM matrix mixer with the teacher's self-attention matrix by minimizing the distance between the two matrices.
- Hidden State Alignment: Matching the student and teacher's layers' hidden state outputs.
- Weight Transfer and Knowledge Distillation: Transferring the remaining, unoptimized parameters (e.g., MLPs, embeddings, and norms), and finetuning the complete end-to-end student model using a distillation loss on the student and teacher logits.
For the hybrid models, the authors modify the protocol proposed by a previous paper, to distill specific capabilities. In this approach, the linear projections for Q, K, V, and O are initialized using the corresponding linear projections for Wq, Wk, Wv, and Wo, respectively. The only additional learned parameters in the new layers are the sampling rate Δ and the dynamic A. The reverse Kullback–Leibler (KL) divergence is used as the loss function. The Mamba-1 architecture is used for the hybrid models.
The distilled models are then evaluated on mathematical reasoning tasks (MATH and GSM8K) using multiple CoT completions, and their performance is analyzed under fixed compute and memory constraints. The evaluation metrics include coverage (pass@k) and accuracy, with aggregation strategies like majority voting and weighted Best-of-N (using a reward model trained with process supervision to select the best response).
The results demonstrate that the distilled models exhibit significantly faster inference speeds compared to their Transformer counterparts. For example, the distilled models are shown to be up to ×3.7 and ×4.2 faster than their respective Llama 1B and 3B baselines. This speedup enables the distilled models to generate more completions within a given time budget, which translates to better coverage and accuracy on the reasoning tasks. The hybrid Mamba-Transformer models (MambaInLlama) are found to be slightly faster than the pure Mamba models (Llamba), potentially due to the smaller SSM state size.
The paper also highlights the lack of correlation between common multiple choice-based benchmarks and mathematical reasoning, as well as the significant impact of the distillation dataset on the final capabilities of the distilled models. Furthermore, the authors show that supervised fine-tuning (SFT) after distillation can further improve the accuracy and coverage of the distilled models.
Specifically, the paper finds that the distilled models can achieve the same degree of coverage in nearly half the time of their respective teachers. Also, the lighter and faster batch inference of the distilled models results in a better accuracy/time Pareto front at several completion scales. The authors conjecture that it would be possible to distill a larger baseline model that would outperform Llama-3B, providing better accuracy for a given time budget.
In conclusion, the paper provides evidence that subquadratic architectures like Mamba can be effectively used for scaling inference compute in LLMs, achieving better performance than Transformers under fixed computational budgets. The findings suggest that Mamba and other attention alternatives are strong substitutes to Transformers for tasks that benefit from scalable inference compute.