MatMamba: Nested SSM Architecture
- MatMamba is a nested state space model architecture that integrates Matryoshka representation learning with a Mamba2 backbone.
- It enables dynamic extraction of multiple submodels from a single trained network, ensuring competitive performance on vision and language tasks.
- The design supports elastic and adaptive inference, allowing efficient scaling across diverse computational budgets without retraining.
MatMamba is a state space model (SSM) architecture that integrates Matryoshka representation learning principles into the Mamba2 SSM backbone. The model is designed to achieve “nested” or adaptive model capacities within a single, jointly trained network, enabling efficient elastic inference and scaling across a wide range of computational budgets. MatMamba provides a universal model from which multiple smaller submodels—each defined by a parameter “slice”—can be extracted post-training, without sacrificing performance relative to independently trained models of the same size. This approach is inspired by prior work on Matryoshka Transformer models (MatFormer) and exhibits competitive performance on vision (ImageNet-1K) and language modeling (FineWeb) tasks compared to Transformers, while offering faster inference for long context lengths (Shukla et al., 2024).
1. Mathematical Foundation and Model Structure
A MatMamba layer is fundamentally a Mamba2 SSM block parameterized to support multiple explicit granularities. For a given granularity , the recurrence equations are: where:
- is the hidden state at time ,
- is an input projection,
- , , , are nested parameter matrices.
These matrices are realized by slicing larger parameter tensors (e.g., ) to yield all necessary submodel weights. Each submodel corresponds to a prefix of the full model’s parameters, guaranteeing the nested property: where is the number of slices and recovers the full model.
MatMamba layers compute the following sequence of operations for each slice: where is the SiLU nonlinearity and denotes concatenation.
2. Nested Parameterization and Matryoshka Learning
Nested parameterization in MatMamba enforces that all learnable tensors (input projections, head-specific weights, SSM recurrence matrices, convolution filters, output projections) are sliced in index order to define multiple submodels. For chosen granularities and corresponding dimensions : During training, every forward pass produces outputs, one for each slice, and joint optimization is performed over all slices.
3. Training Paradigm
MatMamba utilizes joint optimization for all slices in a single training regime: where is the output of the th slice. Training runs all slices forward with gradient accumulation, then one backward update, and retains memory usage comparable to that of the largest slice (i.e., a single Mamba2 base model).
Optimization regimes adhere to established recipes:
- Vision (ImageNet-1K): AdamW (, cosine decay), weight decay $0.1$, warmup $10$k steps, batch size $4096/8192$, dropout $0.1$, stochastic depth $0.1$.
- Language (FineWeb): AdamW (, , warmup $2$k), weight decay $0.1$.
No curriculum or progressive shrinking is used; all slices are trained jointly from scratch.
4. Elastic and Adaptive Inference
MatMamba enables adaptive deployment by runtime selection of parameter “slices,” providing compute-accuracy elasticity:
- Fixed-slice inference: selecting a global dimension , the submodel runs identically to a Mamba2 with the same width.
- “Mix’n’Match” inference: per-layer selection of slice dimensions or intermediate values (within head-divisibility constraints), supporting layer-wise adaptation.
For a given sequence length and selected inner dimension : contrasted with Transformer self-attention: For , MatMamba offers substantial speed and memory advantages.
5. Empirical Performance and Comparative Evaluation
Performance assessments span vision and language domains:
- ImageNet-1K (Patch-16, 20 Layer Model): Slices at , , , retain classification accuracy closely matching independent Mamba2 baselines at each size, with top-1 difference. “Mix’n’Match” slices interpolate or slightly exceed the trained anchors.
- Image Retrieval (1-NN on CLS embeddings): Submodels, derived from a common model (e.g., 135M parameters), preserve the learned metric space (0.5% drop in 1-NN accuracy for 55% FLOP reduction), outperforming independently trained small Mamba2 baselines.
- Language Modeling (FineWeb, Decoder LM): Final validation losses for each slice show uniform, predictable scaling (–$0.4$ nats per halving of width). Each full-size slice matches a separately trained Mamba2.
Table: Example Results (ImageNet-1K)
| Full Dim | Params | Top-1 | Half-slice Dim | Params | Top-1 |
|---|---|---|---|---|---|
| 1024 | 132.7M | 81.9% | 512 | 69.0M | 76.8% |
| 256 | 37.1M | 72.3% | 128 | 21.2M | 67.4% |
A similar table is reported for language modeling losses. Each explicit slice consistently approximates the performance of an independent model, with layer-wise “Mix’n’Match” enabling resource-adaptive trade-offs (Shukla et al., 2024).
6. Practical Deployment and Guidance
MatMamba’s “nested” property supports a spectrum of deployment scenarios:
- Edge devices: select smallest available slice ( or ) for minimal latency.
- Backend servers: utilize the full model for maximal accuracy.
- Speculative decoding: perform rapid draft predictions at intermediate widths, then verify with higher capacity.
Temporal complexity is directly governed by the selected slice. It is recommended to pick the minimal adequate dimension such that satisfies target latency, ensuring divisibility of heads (i.e., ). “Mix’n’Match” enables any intermediate granularity (not only anchors), albeit for maximal consistency anchor slices are recommended.
This suggests MatMamba is particularly suited for environments demanding dynamic or heterogeneous compute allocation, such as multi-device inference, federated learning settings, or scalable service backends.
7. Comparison with Related Approaches
MatMamba submodels precisely interpolate between anchor model sizes, in contrast with retrained Mamba2 or Transformer baselines, which would require separate full training runs per configuration. Empirically, submodels extracted from MatMamba match, to within (top-1) or nats (LM validation loss), the performance of separately trained SSMs, which is not matched by pure width-pruning of Transformers.
This Matryoshka-style elasticity was not previously realized in SSM frameworks and augments the deployment flexibility of the Mamba2 architecture. For high-performance long-sequence modeling, MatMamba provides a scalable universal model that obviates the need for retraining for each deployment scale, supporting dozens or potentially hundreds of slicing points within a single consistent metric space (Shukla et al., 2024).