Papers
Topics
Authors
Recent
Search
2000 character limit reached

MatMamba: Nested SSM Architecture

Updated 21 February 2026
  • MatMamba is a nested state space model architecture that integrates Matryoshka representation learning with a Mamba2 backbone.
  • It enables dynamic extraction of multiple submodels from a single trained network, ensuring competitive performance on vision and language tasks.
  • The design supports elastic and adaptive inference, allowing efficient scaling across diverse computational budgets without retraining.

MatMamba is a state space model (SSM) architecture that integrates Matryoshka representation learning principles into the Mamba2 SSM backbone. The model is designed to achieve “nested” or adaptive model capacities within a single, jointly trained network, enabling efficient elastic inference and scaling across a wide range of computational budgets. MatMamba provides a universal model from which multiple smaller submodels—each defined by a parameter “slice”—can be extracted post-training, without sacrificing performance relative to independently trained models of the same size. This approach is inspired by prior work on Matryoshka Transformer models (MatFormer) and exhibits competitive performance on vision (ImageNet-1K) and language modeling (FineWeb) tasks compared to Transformers, while offering faster inference for long context lengths (Shukla et al., 2024).

1. Mathematical Foundation and Model Structure

A MatMamba layer is fundamentally a Mamba2 SSM block parameterized to support multiple explicit granularities. For a given granularity kk, the recurrence equations are: ht(k)=A(k)ht1(k)+B(k)xt,yt(k)=C(k)ht(k)+D(k)xt,h_t^{(k)} = A^{(k)} h_{t-1}^{(k)} + B^{(k)} x_t, \quad y_t^{(k)} = C^{(k)} h_t^{(k)} + D^{(k)} x_t, where:

  • ht(k)Rdstate(k)h_t^{(k)} \in \mathbb{R}^{d_{\text{state}}^{(k)}} is the hidden state at time tt,
  • xtRdinner(k)x_t \in \mathbb{R}^{d_{\text{inner}}^{(k)}} is an input projection,
  • A(k)A^{(k)}, B(k)B^{(k)}, C(k)C^{(k)}, D(k)D^{(k)} are nested parameter matrices.

These matrices are realized by slicing larger parameter tensors (e.g., Wx,WB,WC,Wdt,A,DW_x, W_B, W_C, W_{dt}, A, D) to yield all necessary submodel weights. Each submodel M(k)M^{(k)} corresponds to a prefix of the full model’s parameters, guaranteeing the nested property: M(1)M(2)M(g),M^{(1)} \subset M^{(2)} \subset \cdots \subset M^{(g)}, where gg is the number of slices and M(g)M^{(g)} recovers the full model.

MatMamba layers compute the following sequence of operations for each slice: XBC(k)(u)=σ(Conv(Wconv(k),[Wx(k)uWBuWCu])), Y(k)(u)=SSM(XBC(k)(u),Wdt(k)u,A(k),D(k)), M(k)(u)=RMSNorm(Y(k)(u)σ(Wz(k)u))Wout(k)T,\begin{aligned} XBC^{(k)}(u) &= \sigma\left(\mathrm{Conv}(W_{\mathrm{conv}}^{(k)}, [W_x^{(k)}u \Vert W_B u \Vert W_C u])\right),\ Y^{(k)}(u) &= \mathrm{SSM}(XBC^{(k)}(u), W_{dt}^{(k)} u, A^{(k)}, D^{(k)}),\ M^{(k)}(u) &= \mathrm{RMSNorm}(Y^{(k)}(u) \odot \sigma(W_z^{(k)}u)) W_{\mathrm{out}}^{(k)T}, \end{aligned} where σ\sigma is the SiLU nonlinearity and \Vert denotes concatenation.

2. Nested Parameterization and Matryoshka Learning

Nested parameterization in MatMamba enforces that all learnable tensors (input projections, head-specific weights, SSM recurrence matrices, convolution filters, output projections) are sliced in index order to define multiple submodels. For gg chosen granularities and corresponding dimensions {m1,,mg}\{m_1, \dots, m_g\}: Wx(i)=Wx[0 ⁣: ⁣di,],A(i)=A[0 ⁣: ⁣hi,:],Wout(i)=Wout[:,0 ⁣: ⁣di].W_x^{(i)} = W_x[0\!:\!d_i, \cdot], \quad A^{(i)} = A[0\!:\!h_i, :], \quad W_{\mathrm{out}}^{(i)} = W_{\mathrm{out}}[:, 0\!:\!d_i]. During training, every forward pass produces gg outputs, one for each slice, and joint optimization is performed over all slices.

3. Training Paradigm

MatMamba utilizes joint optimization for all gg slices in a single training regime: Ljoint=i=1gλiL(fi(x),y),λi=1g,\mathcal{L}_{\text{joint}} = \sum_{i=1}^g \lambda_i \mathcal{L}(f_i(x), y), \qquad \lambda_i = \frac{1}{g}, where fi(x)f_i(x) is the output of the iith slice. Training runs all slices forward with gradient accumulation, then one backward update, and retains memory usage comparable to that of the largest slice (i.e., a single Mamba2 base model).

Optimization regimes adhere to established recipes:

  • Vision (ImageNet-1K): AdamW (LR=0.005\text{LR}=0.005, cosine decay), weight decay $0.1$, warmup $10$k steps, batch size $4096/8192$, dropout $0.1$, stochastic depth $0.1$.
  • Language (FineWeb): AdamW (β=(0.9,0.95)\beta=(0.9,0.95), LR=3×104\text{LR}=3 \times 10^{-4}, warmup $2$k), weight decay $0.1$.

No curriculum or progressive shrinking is used; all slices are trained jointly from scratch.

4. Elastic and Adaptive Inference

MatMamba enables adaptive deployment by runtime selection of parameter “slices,” providing compute-accuracy elasticity:

  • Fixed-slice inference: selecting a global dimension dkd_k, the submodel runs identically to a Mamba2 with the same width.
  • “Mix’n’Match” inference: per-layer selection of slice dimensions m{m1,,mg}m_{\ell} \in \{m_1,\dots,m_g\} or intermediate values (within head-divisibility constraints), supporting layer-wise adaptation.

For a given sequence length LL and selected inner dimension dkd_k: TMatMamba(L,dk)=O(Ldk+dk2),T_{\text{MatMamba}}(L, d_k) = O(L d_k + d_k^2), contrasted with Transformer self-attention: TTransformer(L,d)=O(L2d+d2).T_{\text{Transformer}}(L, d) = O(L^2 d + d^2). For LdL \gg d, MatMamba offers substantial speed and memory advantages.

5. Empirical Performance and Comparative Evaluation

Performance assessments span vision and language domains:

  • ImageNet-1K (Patch-16, 20 Layer Model): Slices at dd, d/2d/2, d/4d/4, d/8d/8 retain classification accuracy closely matching independent Mamba2 baselines at each size, with ±0.2%\pm0.2\% top-1 difference. “Mix’n’Match” slices interpolate or slightly exceed the trained anchors.
  • Image Retrieval (1-NN on CLS embeddings): Submodels, derived from a common model (e.g., 135M parameters), preserve the learned metric space (0.5% drop in 1-NN accuracy for 55% FLOP reduction), outperforming independently trained small Mamba2 baselines.
  • Language Modeling (FineWeb, Decoder LM): Final validation losses for each slice show uniform, predictable scaling (0.3\sim0.3–$0.4$ nats per halving of width). Each full-size slice matches a separately trained Mamba2.

Table: Example Results (ImageNet-1K)

Full Dim Params Top-1 Half-slice Dim Params Top-1
1024 132.7M 81.9% 512 69.0M 76.8%
256 37.1M 72.3% 128 21.2M 67.4%

A similar table is reported for language modeling losses. Each explicit slice consistently approximates the performance of an independent model, with layer-wise “Mix’n’Match” enabling resource-adaptive trade-offs (Shukla et al., 2024).

6. Practical Deployment and Guidance

MatMamba’s “nested” property supports a spectrum of deployment scenarios:

  • Edge devices: select smallest available slice (d/8d/8 or d/4d/4) for minimal latency.
  • Backend servers: utilize the full model for maximal accuracy.
  • Speculative decoding: perform rapid draft predictions at intermediate widths, then verify with higher capacity.

Temporal complexity is directly governed by the selected slice. It is recommended to pick the minimal adequate dimension such that O(Ldk)O(L d_k) satisfies target latency, ensuring divisibility of heads (i.e., emimoddhead=0e m_i \bmod d_{\text{head}}=0). “Mix’n’Match” enables any intermediate granularity (not only anchors), albeit for maximal consistency anchor slices are recommended.

This suggests MatMamba is particularly suited for environments demanding dynamic or heterogeneous compute allocation, such as multi-device inference, federated learning settings, or scalable service backends.

MatMamba submodels precisely interpolate between anchor model sizes, in contrast with retrained Mamba2 or Transformer baselines, which would require separate full training runs per configuration. Empirically, submodels extracted from MatMamba match, to within ±0.2%\pm0.2\% (top-1) or ±0.1\pm0.1 nats (LM validation loss), the performance of separately trained SSMs, which is not matched by pure width-pruning of Transformers.

This Matryoshka-style elasticity was not previously realized in SSM frameworks and augments the deployment flexibility of the Mamba2 architecture. For high-performance long-sequence modeling, MatMamba provides a scalable universal model that obviates the need for retraining for each deployment scale, supporting dozens or potentially hundreds of slicing points within a single consistent metric space (Shukla et al., 2024).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to MatMamba.