Papers
Topics
Authors
Recent
2000 character limit reached

Distilled Matryoshka Sparse Autoencoders

Updated 2 January 2026
  • The paper introduces DMSAEs that iteratively distill and freeze high-value encoder features to form a stable, reusable core for sparse autoencoders.
  • It employs an attribution-guided selection method using gradient×activation metrics to retain the minimal set covering 90% cumulative attribution.
  • Empirical evaluation on Gemma-2-2B shows that the distilled 197-feature core enhances reconstruction, feature transferability, and downstream metric consistency.

Distilled Matryoshka Sparse Autoencoders (DMSAEs) constitute a training pipeline for sparse autoencoders that extracts a compact, transferable core of robust, human-interpretable features. By iteratively distilling and freezing the directions most consistently useful for a base model’s next-token loss, DMSAEs address the instability and redundancy of standard sparse feature learning. The methodology revolves around an attribution-guided selection process, transferring only the distilled core encoder weight vectors across cycles, while reinitializing the decoder and non-core latents. Empirical evaluation on the Gemma-2-2B model demonstrates that DMSAEs yield a strongly reusable core—comprising 197 stabilized features after seven cycles—which improves consistency, interpretability, and downstream SAEBench metrics relative to conventional Matryoshka Sparse Autoencoders (Martin-Linares et al., 31 Dec 2025).

1. Matryoshka Sparse Autoencoders: Hierarchical Feature Learning

A typical sparse autoencoder represents an overcomplete dictionary through an encoder–decoder architecture, subject to a sparsity constraint: f(x)=ReLU(Wencx+benc)Rm,x^=Wdecf(x)+bdecf(x) = \mathrm{ReLU}(W_{\text{enc}} x + b_{\text{enc}}) \in \mathbb{R}^m,\quad \hat{x} = W_{\text{dec}} f(x) + b_{\text{dec}} where f(x)f(x) enforces sparsity (commonly through TopK thresholding).

Matryoshka Sparse Autoencoders (MSAEs) extend this framework by introducing a hierarchy of prefix sizes M={m1<m2<...<mL}M = \{ m_1 < m_2 < ... < m_L \}, with reconstruction objectives over all prefixes: LMSAE(x)=mMxx^m22+αLaux\mathcal{L}_{\text{MSAE}}(x) = \sum_{m\in M} \| x - \hat{x}_m \|_2^2 + \alpha \mathcal{L}_{\text{aux}} where x^m\hat{x}_m reconstructs with only the first mm latents. Early latents must encode high-frequency, generalizable content, while later ones specialize. In vanilla MSAEs, no explicit “core” features are preserved across runs, resulting in significant variability and challenging feature reuse (Martin-Linares et al., 31 Dec 2025).

2. Attribution Metric: Gradient × Activation

DMSAEs introduce an attribution-driven basis selection methodology for identifying high-value features. For a given token position uu, let xuRdx_u \in \mathbb{R}^d denote the residual stream activation, and gu=LNT/xug_u = \partial \mathcal{L}_{\text{NT}} / \partial x_u the gradient of next-token loss. Encoding and masking to the smallest prefix, for each latent jj:

  • Compute activation au,ja_{u, j}
  • Normalize the decoder vector wˉjdec=wjdec/wjdec2\bar{w}^{\text{dec}}_j = w^{\text{dec}}_j / \| w^{\text{dec}}_j \|_2
  • Compute su,j=guwˉjdecs_{u, j} = g_u^\top \bar{w}^{\text{dec}}_j
  • Attribution score: GxAu,j=au,jsu,jGxA_{u,j} = | a_{u,j} \cdot s_{u,j} |

Due to the heavy-tailed attribution distribution, DMSAEs aggregate GxAu,jGxA_{u, j} over positions via a high quantile (e.g., q=0.99q = 0.99): Aj=Quantileu(GxAu,j;q)A_j = \text{Quantile}_u ( GxA_{u, j}; q ) Latents are sorted by AjA_j; the smallest subset whose cumulative attribution exceeds a threshold τ\tau is retained as the distilled core: C=argminSPSs.t.jSAjτjPAjC = \operatorname*{argmin}_{S \subseteq P} |S| \quad \text{s.t.} \quad \sum_{j \in S} A_j \geq \tau \sum_{j \in P} A_j

3. Iterative Distillation and Transfer Protocol

The DMSAE pipeline operates as a multi-cycle, iterative distillation process. The key steps are:

  • Initialization: Score the released SAEBench model to select the initial core C(0)C^{(0)}.
  • Per-Cycle Training:
    • Freeze encoder rows corresponding to C(t1)C^{(t-1)}; reinitialize all other parameters.
    • Train a two-group MSAE, letting the core latents remain dense while constraining non-core sparsity.
    • Upon convergence, compute quantile-based GxA attributions for core and prefix-0 latents.
    • Select the smallest core C(t)C^{(t)} achieving τ\tau-coverage as above.
  • Core Stabilization: After TT cycles, the final distilled core is C=C(T)C(T1)C^* = C^{(T)} \cap C^{(T-1)}, containing only latents persisting through the last two cycles.

This procedure is formalized in the DMSAE high-level pseudocode provided in (Martin-Linares et al., 31 Dec 2025).

4. Distilled Core Convergence and Empirical Evaluation

Applied to Gemma-2-2B, DMSAEs were trained at layer 12 residual streams with dictionary size K=65,000K = 65{,}000 latents, T=7T = 7 distillation cycles, and a non-core sparsity of k=320k = 320 on 500M tokens. Prefix sizes MM and coverage threshold τ=0.9\tau = 0.9 were used throughout (Martin-Linares et al., 31 Dec 2025).

Across cycles, the selected core stabilized between 200–400 latents. The intersection of the last two cycles, C(6)C(7)C^{(6)} \cap C^{(7)}, yielded a distilled core of 197 features persisting across restarts. Empirical evidence indicates that a randomly chosen core of equal size becomes inactive (core 00\ell_0 \rightarrow 0), whereas the distilled core remains active and contributes significantly to loss reduction throughout training. This demonstrates that the distilled directions are systematically high-value.

5. Performance on SAEBench and Downstream Metrics

DMSAEs were benchmarked by transferring the distilled core CC^* to new SAEs at multiple sparsity regimes (k{20,40,80,160,320,640}k \in \{20, 40, 80, 160, 320, 640\}), freezing only CC^* encoder rows and reinitializing all else. SAEBench evaluations included:

  • Reconstruction loss
  • Fraction of variance explained
  • Feature absorption
  • RAVEL
  • Targeted concept removal
  • Spurious correlation removal
  • AutoInterp

Results show that across sparsity levels, DMSAEs match or exceed the vanilla MSAE baseline on reconstruction, absorption, and RAVEL, maintaining stable downstream metrics except for a decrease in AutoInterp at the lowest kk. A sparse-core ablation (global TopK mask optionally including the core) reveals qualitatively similar trends (Martin-Linares et al., 31 Dec 2025).

6. Implications: Interpretability, Transfer, and Model Compression

Distilled Matryoshka Sparse Autoencoders yield several key advantages:

  • Feature Transferability: Freezing only the most consistently useful encoder directions enables reliable reuse of core features across sparsity budgets, restarts, and tasks.
  • Interpretability: The distilled core produces a compact, monosemantic, and non-redundant basis, stabilizing feature semantics and simplifying manual or automated feature annotation. It also reduces feature splitting and absorption.
  • Compression: Only the C197|C^*| \approx 197 encoder weight rows require preservation for transfer, facilitating a two-stage compression approach: a stable core “backbone” and a lightweight, re-trainable residual dictionary.

A plausible implication is that DMSAEs enable more modular and interpretable model analysis pipelines by decoupling the stable, attribution-maximizing core from non-core features adaptively specialized for different downstream requirements (Martin-Linares et al., 31 Dec 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to Distilled Matryoshka Sparse Autoencoders (DMSAEs).