Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models (2408.10189v1)

Published 19 Aug 2024 in cs.LG and cs.AI

Abstract: Transformer architectures have become a dominant paradigm for domains like LLMing but suffer in many inference settings due to their quadratic-time self-attention. Recently proposed subquadratic architectures, such as Mamba, have shown promise, but have been pretrained with substantially less computational resources than the strongest Transformer models. In this work, we present a method that is able to distill a pretrained Transformer architecture into alternative architectures such as state space models (SSMs). The key idea to our approach is that we can view both Transformers and SSMs as applying different forms of mixing matrices over the token sequences. We can thus progressively distill the Transformer architecture by matching different degrees of granularity in the SSM: first matching the mixing matrices themselves, then the hidden units at each block, and finally the end-to-end predictions. Our method, called MOHAWK, is able to distill a Mamba-2 variant based on the Phi-1.5 architecture (Phi-Mamba) using only 3B tokens and a hybrid version (Hybrid Phi-Mamba) using 5B tokens. Despite using less than 1% of the training data typically used to train models from scratch, Phi-Mamba boasts substantially stronger performance compared to all past open-source non-Transformer models. MOHAWK allows models like SSMs to leverage computational resources invested in training Transformer-based architectures, highlighting a new avenue for building such models.

Transforming Quadratic Time Transformers to Subquadratic Models via MOHAWK

In the paper "Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models," the authors introduce a novel approach called MOHAWK that distills the quadratic knowledge embedded within Transformer architectures into efficient subquadratic models such as state space models (SSMs). The overall goal is to benefit from the inherent advantages of Transformer models while mitigating their computational inefficiency due to the quadratic complexity of self-attention mechanisms.

Main Contributions

  1. MOHAWK Distillation Method: The proposed method, MOHAWK, facilitates the transfer of a pretrained Transformer's capabilities to a subquadratic architecture such as SSMs. MOHAWK unfolds in three stages: Matrix Orientation, Hidden-State Alignment, and Weight-Transfer and Knowledge Distillation. This structured approach enables robust initialization and convergence of the alternative architecture, effectively aligning the SSM layers to match the Transformer's performance.
  2. Phi-Mamba Architecture: The authors present the Phi-Mamba model, a variant of the Mamba architecture using the pretrained Transformer Phi-1.5. The model utilizes the MOHAWK procedure, yielding exceptional performance on benchmarks despite consuming considerably fewer computational resources and training data compared to traditional methods.
  3. Performance Evaluation: The duty of Phi-Mamba is validated through rigorous experimentation on well-defined benchmarks including WinoGrande, HellaSwag, PIQA, and others, where it consistently achieves superior results over prior open-source, non-Transformer models of comparable size.

Technical Depth and Results

The paper explores the technical intricacies of aligning SSM matrices with self-attention matrices, highlighting how the SSM matrices can approximate the expressivity of self-attention via structured approaches like state-space duality. Notably, the Phi-Mamba model achieves competitively high accuracy using a fraction of the training tokens required by other state-of-the-art models. For instance, Phi-Mamba marked a notable improvement, achieving 71.7% accuracy on the WinoGrande dataset.

Furthermore, ablation studies underscore the efficacy of MOHAWK stages. Each stage uniquely enhances the fidelity of the distilled model to the pre-existing Transformer, proving crucial in achieving the impressive final results. The approach also demonstrates viability through hybrid architectures, where both self-attention and subquadratic layers are interleaved effectively.

Implications and Future Directions

The methodological advance presented in this paper lays a significant groundwork for building next-generation subquadratic models that can scale efficiently without a substantial compromise on performance. The demonstrated synergy between Transformers and subquadratic models opens a rich avenue for exploring more diverse applications and architectures where computational resources are constrained.

Practical implications are manifold, with potential applications in domains where latency and computational cost are critical factors. The research also sheds light on the broader questions of model distillation across paradigms, emphasizing that transformer capabilities can be preserved and extended beyond their native structural confines.

Finally, several future developments are envisioned. Extending the framework to explore other model architectures beyond SSMs, refining further the alignment and distillation techniques, and investigating the implications on other types of sequence data offers exciting prospects. More broadly, methodologies like MOHAWK that push the boundary of efficient model design are crucial as machine learning models continue to proliferate and scale. This work stands as a testament to the potential for innovation in optimizing model performance within the constraints of real-world applications.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (5)
  1. Aviv Bick (4 papers)
  2. Kevin Y. Li (3 papers)
  3. Eric P. Xing (192 papers)
  4. J. Zico Kolter (151 papers)
  5. Albert Gu (40 papers)
Citations (3)
Youtube Logo Streamline Icon: https://streamlinehq.com