Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
140 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

M1: Towards Scalable Test-Time Compute with Mamba Reasoning Models (2504.10449v1)

Published 14 Apr 2025 in cs.LG

Abstract: Effective reasoning is crucial to solving complex mathematical problems. Recent LLMs have boosted performance by scaling test-time computation through long chain-of-thought reasoning. However, transformer-based models are inherently limited in extending context length due to their quadratic computational complexity and linear memory requirements. In this paper, we introduce a novel hybrid linear RNN reasoning model, M1, built on the Mamba architecture, which allows memory-efficient inference. Our approach leverages a distillation process from existing reasoning models and is further enhanced through RL training. Experimental results on the AIME and MATH benchmarks show that M1 not only outperforms previous linear RNN models but also matches the performance of state-of-the-art Deepseek R1 distilled reasoning models at a similar scale. We also compare our generation speed with a highly performant general purpose inference engine, vLLM, and observe more than a 3x speedup compared to a same size transformer. With throughput speedup, we are able to achieve higher accuracy compared to DeepSeek R1 distilled transformer reasoning models under a fixed generation time budget using self-consistency voting. Overall, we introduce a hybrid Mamba reasoning model and provide a more effective approach to scaling test-time generation using self-consistency or long chain of thought reasoning.

Summary

  • The paper introduces M1, a hybrid model that overcomes Transformer scalability issues through a multi-stage training pipeline of distillation, fine-tuning, and reinforcement learning.
  • The paper demonstrates that M1-3B achieves competitive reasoning performance while delivering over 3x faster token generation for long sequences and large batch sizes.
  • The paper provides a practical training recipe with integrated framework optimizations, including effective handling of grouped query attention and accelerated RL training.

This paper introduces M1, a hybrid reasoning model based on the Mamba architecture, designed to overcome the scalability limitations of Transformer models for tasks requiring long chain-of-thought reasoning, such as complex mathematical problem-solving (2504.10449). Transformers suffer from quadratic computational complexity and linear memory requirements with increasing sequence length, making long context generation and large batch inference inefficient. M1 aims to provide comparable reasoning performance to state-of-the-art models but with significantly improved inference efficiency.

M1 Training Pipeline

The M1 model is developed through a three-stage process:

  1. Distillation:
    • The process starts by distilling knowledge from a pre-trained Transformer model (Llama3.2-3B-Instruct) into a hybrid Mamba-Transformer architecture.
    • It adapts the MambaInLlama framework [wang2025mamballamadistillingaccelerating], initializing Mamba layer projections (A,B,C,D\mathbf{A}, \mathbf{B}, \mathbf{C}, \mathbf{D}) from the corresponding Transformer attention projections (K,Q,V,O\mathbf{K}, \mathbf{Q}, \mathbf{V}, \mathbf{O}).
    • Implementation Detail: To handle Grouped Query Attention (GQA) present in the source Transformer, two additional linear layers are introduced in the Mamba blocks to project head_dim * kv_head to head_dim * n_head. This expansion enhances the expressiveness of Mamba's B\mathbf{B} (input projection) and X\mathbf{X} (input sequence) parameters, as Mamba doesn't use a KV cache.
    • The distillation uses reverse KL divergence (DKL(pstudentpteacher)D_{KL}(p_{student} || p_{teacher})) as the loss function, optimized using AdamW.
    • Practical Implementation: Training utilizes the Axolotl framework, employing data packing (merging sequences up to a max length of 8192 tokens) for efficiency, chat templates, and masking user prompts so loss is only computed on the assistant's response.
  2. Supervised Fine-Tuning (SFT):
    • Phase 1 (General Math): The distilled model is fine-tuned on the OpenMathInstruct-2 dataset [toshniwal2024openmathinstruct2acceleratingaimath] for 2 epochs to enhance general mathematical capabilities. Training setup mirrors the distillation stage.
    • Phase 2 (Reasoning Data): Further fine-tuning is performed on a mixed dataset (8 billion tokens total) comprising reasoning traces generated by DeepSeek R1 series models (e.g., OpenR1-Math-220k, OpenThoughts-114k-math).
    • Implementation Detail: The maximum sequence length for this stage is increased to 24,576 tokens to accommodate long reasoning chains (covering 99% of the data). Training runs for 5 epochs with a lower learning rate (6×1066 \times 10^{-6}).
  3. Reinforcement Learning (RL) for Reasoning:
    • The model's reasoning ability is further boosted using RL, specifically the GRPO (Generalized Reward Policy Optimization) algorithm within the VeRL framework [sheng2024hybridflow].
    • Implementation Detail: The authors integrated Mamba generation into VeRL, resolving CUDA graph incompatibility issues with FSDP, leading to a 5x speedup in Mamba generation during RL training compared to running without CUDA graph optimizations.
    • The GRPO loss function is modified by removing the KL penalty term (found to destabilize training) and adding an entropy bonus (ηH(πθ)\eta \, H(\pi_{\theta})) to encourage policy diversity.
    • RL training uses a batch size of 128, PPO batch size of 64, generates 8 trajectories per sequence, and allows a maximum generation length of 32k tokens. Training runs for 50 steps, selecting the checkpoint with the highest critic reward. A prompt "Let's think step by step..." is appended during training and evaluation.

Performance and Evaluation

  • Model Architecture: M1-3B uses 6 interleaved attention layers within 28 total layers, with a Mamba SSM state size of 16.
  • Reasoning Benchmarks: Evaluated on MATH500, AIME25, AIME24, AMC23, and OlympiadBench.
  • Results: M1-3B achieves performance comparable to the DeepSeek-R1-Distill-Qwen-1.5B model on most benchmarks (e.g., 81.7 vs 83.9 Pass@1 on MATH500). This is noteworthy as the baseline Qwen model was trained on significantly more math-specific data (>1T tokens).
  • Inference Speed:
    • Benchmarked against Llama-3.2-3B (same size Transformer) and DeepSeek-R1-Distill-Qwen-1.5B using vLLM on an H100 GPU.
    • Finding: M1-3B shows a >3x speedup in token generation throughput compared to the Llama-3.2-3B model at a large batch size (512) and long decoding length (4096 tokens).
    • The speed advantage increases with batch size and sequence length, attributed to Mamba's memory efficiency (smaller state, no KV cache) making decoding less memory-bound.
  • Test-Time Compute Scaling:
    • The inference speedup allows M1 to generate more samples (for self-consistency/majority voting) or longer reasoning chains within a fixed time budget.
    • Experiments show that when normalizing performance by generation time (seconds), M1 can achieve higher accuracy than the baseline Transformer (DeepSeek-R1-Distill-Qwen-1.5B) by either generating more samples for majority voting or generating longer sequences (Figure 4).

Analysis and Implementation Insights

  • Impact of RL Training Length: Increasing the maximum sequence length during RL training (up to 24k tokens) significantly improves reasoning performance (Figure 5). M1's efficiency makes training with such long sequences feasible.
  • Training Stage Ablation: Each stage (Distillation, SFT-Math, SFT-Reasoning, RL) contributes progressively to the final performance, with SFT on reasoning data providing a substantial boost (Table 2).
  • Distillation Strategy: Directly distilling from a reasoning model (Deepseek-R1-Qwen-1.5B) yielded poor results with the limited reasoning dataset (8B tokens). The staged approach (distilling a general math model first, then fine-tuning on reasoning data) proved more effective for transferring reasoning capabilities across architectures with less specialized data.

Practical Takeaways

  • M1 demonstrates that hybrid Mamba-Transformer models can achieve strong reasoning performance comparable to specialized Transformer-based reasoning models.
  • The primary advantage of M1 is its significantly higher inference throughput (>3x) for long sequences and large batches, enabled by Mamba's linear time complexity and memory efficiency.
  • This speedup directly translates into improved performance under fixed computational budgets for test-time scaling techniques like self-consistency (majority voting) or generating longer chains of thought.
  • The paper provides a practical multi-stage training recipe (Distillation -> SFT (General Math) -> SFT (Reasoning Data) -> RL) for building high-performing hybrid reasoning models, particularly when specialized reasoning data is limited.
  • The successful integration with frameworks like Axolotl and VeRL (with modifications for Mamba/CUDA graph) highlights the feasibility of incorporating Mamba-based models into existing LLM training pipelines.
  • The efficiency gains are particularly relevant for RL training, where generating long rollouts is often a bottleneck. M1's architecture can alleviate this.
Reddit Logo Streamline Icon: https://streamlinehq.com