- The paper introduces Routing Mamba (RoM), which scales State Space Models (SSMs) by applying sparse Mixture-of-Experts to projection layers via a shared routing mechanism.
- RoM achieves equivalent or better performance than dense Mamba models with up to 2.3× more active parameters, demonstrating significant parameter efficiency.
- The RoM framework enables efficient inference by activating only selected experts per token, maintaining constant time per token and supporting both pure SSM and hybrid architectures.
Routing Mamba: Scaling State Space Models with Mixture-of-Experts Projection
The paper presents Routing Mamba (RoM), a framework for scaling State Space Models (SSMs)—specifically Mamba—using a sparse Mixture-of-Experts (MoE) approach applied to linear projection layers. The work addresses the challenge of efficiently increasing the expressive power of SSMs, which have recently emerged as competitive alternatives to Transformers for long-sequence modeling, but have not previously benefited from MoE scaling in a manner analogous to Transformer-based models.
Technical Contributions
RoM introduces several key innovations:
- Sparse Mixture of Linear Projection Experts: Instead of applying MoE to feed-forward layers (as is standard in Transformer MoE models), RoM applies MoE to the main projection layers of Mamba: Conv Proj, Gate Proj, and Out Proj. This is motivated by the absence of large FFN blocks in SSMs and the centrality of these projections to SSM computation.
- Shared Routing Mechanism: RoM employs a single router per block to select experts for all three projection layers, ensuring that a token is processed by a coherent set of experts across the block. This contrasts with naive approaches that use independent routers for each projection, which the authors show leads to degraded performance and increased latency.
- Selective Expertization: For smaller or specialized projections (e.g., x Proj, dt Proj, Conv1D), parameters are shared across experts, inspired by Multi-Query Attention, to avoid unnecessary parameter and compute overhead.
- No Load Balancing Loss Required: Empirical results show that RoM achieves balanced expert utilization without explicit load balancing loss, simplifying training and reducing auxiliary loss tuning.
Empirical Results
The paper provides extensive experimental validation:
- Parameter Efficiency: RoM achieves equivalent or better perplexity compared to dense Mamba models that require up to 2.3× more active parameters. For example, a RoM model with 1.3B active parameters (10B total) matches the performance of a dense Mamba model with over 3B active parameters.
- Computational Savings: When applied to hybrid SSM-attention models (e.g., Samba), RoM yields a 23% reduction in FLOPs for similar performance, demonstrating practical efficiency gains.
- Length Extrapolation: RoM maintains consistent perplexity across a range of context lengths, including those much longer than seen during training, indicating robust generalization for long-sequence tasks.
- Hybrid Architectures: RoM can be combined with FFN-MoE in hybrid models, matching or exceeding the performance of pure FFN-MoE baselines at similar or lower parameter counts.
- Throughput: Despite increased total parameter count, RoM models achieve ~80% of the training throughput of dense models with the same number of active parameters, without expert parallelism optimizations.
Implementation Considerations
Model Architecture
RoM modifies the standard Mamba block as follows:
1
2
3
4
5
6
7
8
9
10
11
|
def rom_block(x, router_weights, expert_weights):
# Shared routing: select top-K experts for each token
routing_scores = softmax(x @ router_weights) # [batch, seq, num_experts]
topk_indices = topk(routing_scores, K)
mask = one_hot(topk_indices, num_experts)
# Apply experts to Conv, Gate, Out projections
h = sum(mask * (x @ expert_weights['conv']), axis=-1)
g = SiLU(sum(mask * (x @ expert_weights['gate']), axis=-1))
y = SSM_forward(h) # SSM computation as in Mamba
o = sum(mask * ((y * g) @ expert_weights['out']), axis=-1)
return o |
- Router: A small MLP or linear layer projects the input to routing logits. The router is shared across the three main projections.
- Expert Weights: Each expert has its own set of projection matrices for Conv, Gate, and Out. For minor projections, weights are shared.
- Sparse Activation: Only the top-K experts are activated per token, reducing compute and memory.
Training
- Framework: PyTorch with Fully Sharded Data Parallel (FSDP) and Megablocks for efficient grouped GEMM operations.
- No Expert Parallelism: All experts reside on the same device, avoiding token dropping and capacity factors, which simplifies implementation and improves stability.
- Optimization: AdamW with standard hyperparameters; no explicit load balancing loss is required.
Deployment
- Inference: Only the selected experts are activated per token, maintaining constant-time inference per token and enabling efficient deployment on hardware with limited memory bandwidth.
- Scalability: RoM is compatible with both pure SSMs and hybrid SSM-attention models, and can be extended to other SSM variants (e.g., Mamba2, Gated DeltaNet).
Comparative Analysis
The paper provides a detailed ablation of naive MoE integration strategies, showing that independent routing for each projection layer leads to performance degradation. The shared routing mechanism is empirically superior, aligning with best practices in MoE-MLP integration in Transformers.
RoM also outperforms attention-based MoE variants (e.g., MoA, SwitchHead) and simple width expansion of SSM layers, both in perplexity and computational efficiency.
Implications and Future Directions
Practical Implications:
- Efficient Scaling: RoM enables SSMs to scale to tens of billions of parameters with sparse activation, making them viable for large-scale LLMing tasks previously dominated by Transformer MoEs.
- Hardware Efficiency: The approach is well-suited to modern accelerators, leveraging grouped GEMM and avoiding the communication overhead of expert parallelism.
- Generalization: The shared routing strategy may inform future MoE designs in other architectures, including attention and hybrid models.
Theoretical Implications:
- Blockwise Expertization: The results suggest that holistic expertization of functionally cohesive blocks (e.g., all major projections in an SSM block) is more effective than piecemeal expertization.
- Routing Synergy: Shared routing across interdependent projections fosters coherent specialization and improves training stability.
Future Work:
- Broader Applicability: Extending RoM to other SSM variants and to architectures with more complex interleaving of SSM and attention layers.
- Expert Parallelism: Investigating distributed expert placement and communication-efficient routing for even larger models.
- Dynamic Routing: Exploring more sophisticated routing mechanisms, including context-aware or hierarchical routers.
Conclusion
Routing Mamba provides a principled and empirically validated approach to scaling SSMs with sparse MoE, overcoming the limitations of naive MoE integration. Its shared routing mechanism and selective expertization yield strong performance and efficiency gains, positioning SSMs as a scalable and practical alternative to Transformer-based MoE models for long-sequence modeling. The framework opens new avenues for efficient, large-scale sequence modeling in both research and production settings.