Papers
Topics
Authors
Recent
AI Research Assistant
AI Research Assistant
Well-researched responses based on relevant abstracts and paper content.
Custom Instructions Pro
Preferences or requirements that you'd like Emergent Mind to consider when generating responses.
Gemini 2.5 Flash
Gemini 2.5 Flash 86 tok/s
Gemini 2.5 Pro 56 tok/s Pro
GPT-5 Medium 31 tok/s Pro
GPT-5 High 33 tok/s Pro
GPT-4o 102 tok/s Pro
Kimi K2 202 tok/s Pro
GPT OSS 120B 467 tok/s Pro
Claude Sonnet 4 37 tok/s Pro
2000 character limit reached

Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy (2505.24473v2)

Published 30 May 2025 in cs.LG and cs.AI

Abstract: Sparse Autoencoders (SAEs) have proven to be powerful tools for interpreting neural networks by decomposing hidden representations into disentangled, interpretable features via sparsity constraints. However, conventional SAEs are constrained by the fixed sparsity level chosen during training; meeting different sparsity requirements therefore demands separate models and increases the computational footprint during both training and evaluation. We introduce a novel training objective, \emph{HierarchicalTopK}, which trains a single SAE to optimise reconstructions across multiple sparsity levels simultaneously. Experiments with Gemma-2 2B demonstrate that our approach achieves Pareto-optimal trade-offs between sparsity and explained variance, outperforming traditional SAEs trained at individual sparsity levels. Further analysis shows that HierarchicalTopK preserves high interpretability scores even at higher sparsity. The proposed objective thus closes an important gap between flexibility and interpretability in SAE design.

Summary

  • The paper introduces HierarchicalTopK, a novel objective that trains one sparse autoencoder across various sparsity levels to achieve competitive reconstruction and interpretability.
  • It demonstrates that a single model can replace multiple fixed-sparsity SAEs, reducing computational cost and minimizing inactive ('dead') features.
  • Experimental results show that adjusting ℓ0 at inference maintains high performance and interpretability, offering flexible and efficient feature hierarchies.

This paper introduces HierarchicalTopK, a novel training objective for Sparse Autoencoders (SAEs) that enables a single SAE model to perform effectively across multiple sparsity levels, or 0\ell_0 norms. Traditional SAEs are trained for a fixed sparsity level, meaning different models must be trained and stored if varying sparsity is required, increasing computational costs. HierarchicalTopK addresses this by optimizing reconstruction quality simultaneously for all sparsity levels up to a predefined maximum KK.

The core idea is to modify the standard SAE reconstruction loss. A typical SAE reconstructs the input xˇ\v{x} as xˇ^=Wdeclˇ+bˇdec\hat{\v{x}} = W_{\text{dec}}\v{l} + \v{b}_{\text{dec}}, where lˇ\v{l} are the sparse latent activations obtained by lˇ=σ(Wencxˇ+bˇenc)\v{l} = \sigma(W_{\text{enc}}\v{x} + \v{b}_{\text{enc}}). With TopK-based activations, only the top kk latents are non-zero. The HierarchicalTopK approach defines a reconstruction xˇ^j\hat{\v{x}}_j for each sparsity level jj from $1$ to KK (or a subset J\mathcal{J} of these values):

xˇ^j=itopj ⁣lˇi(xˇ)eˇi+bˇdec\hat{\v{x}}_j = \sum_{i\in\operatorname{top}_j}\! \v{l}_i(\v{x})\,\v{e}_i + \v{b}_{\text{dec}}

where eˇi\v{e}_i are the decoder embeddings (columns of WdecW_{\text{dec}}) and lˇi(xˇ)\v{l}_i(\v{x}) is the ii-th component of the latent vector lˇ\v{l} corresponding to the ii-th largest activation. The total loss is then the average of the reconstruction errors for each jj:

Lhierarchical=1JjJxˇxˇ^j2\mathcal{L}_{\text{hierarchical}} = \frac{1}{|\mathcal{J}|}\sum_{j\in\mathcal{J}} \|\v{x} - \hat{\v{x}}_j\|^2

This formulation encourages the SAE to learn features that are progressively useful, ensuring that reconstructions improve monotonically as more features (larger jj) are included.

Implementation and Computational Cost:

The hierarchical loss can be computed efficiently. The paper notes that the cumulative sum of feature contributions needed for xˇ^j\hat{\v{x}}_j for all jj can be calculated in a single forward pass. A naive PyTorch implementation is provided:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def hierarchical_loss(sparse_idx, sparse_val, decoder, b_dec, target):
    """
    sparse_idx: LongTensor of shape (B, K) with indices of active embeddings
    sparse_val: FloatTensor of shape (B, K) with corresponding activation values
    decoder:     FloatTensor of shape (D, h) containing the dictionary embeddings
    b_dec:       FloatTensor of shape (h) containing decoder bias
    target:      FloatTensor of shape (B, h) with the original inputs
    """
    B, K = sparse_idx.shape
    flatten_idx = sparse_idx.view(-1)
    emb = decoder[flatten_idx].view(B, K, -1) # Shape (B, K, h)
    emb = emb * sparse_val.unsqueeze(-1) # Scale embeddings by activations

    # Compute cumulative reconstructions for j = 1 to K
    # recon_cum[b, j-1, :] is the reconstruction using top j features for batch element b
    recon_cum = emb.cumsum(dim=1) + b_dec.unsqueeze(0).unsqueeze(0) # Add bias to each cumulative sum

    # Target needs to be broadcastable for subtraction: (B, 1, h)
    diff = recon_cum - target.unsqueeze(1)
    total_err = diff.pow(2).mean() # Mean over B, K, and h dimensions
    return total_err
The authors also developed optimized Triton kernels for their fused HierarchicalTopK (referred to as FlexSAE in the appendix) which reportedly run faster than baseline TopK implementations while using similar peak memory. Specifically, for a batch size of 64, model dimension h=2304h=2304, dictionary size D=216D=2^{16}, and 0=128\ell_0=128, the "Fused Hierarchical" kernel had a time per step of 10.0531±0.042010.0531 \pm 0.0420 ms, a 4.07% speedup over the baseline TopK kernel from Gao et al. (2025).

The paper also explores reducing computational load by subsampling the terms in the hierarchical loss. Instead of summing over all j{1,,K}j \in \{1, \dots, K\}, they test using Jx={1}{iN:imodx=01<iK}J_x = \{1\} \cup \{i \in \mathbb{N} : i \bmod x = 0 \land 1 < i \le K \}. Computing the loss on every 8th term (J8J_8) showed performance nearly identical to using the full set of terms, offering a potential 8x reduction in FLOPs for this part of the calculation, though the practical speedup in training time per step was minimal due to efficient kernel implementations.

Experimental Validation:

Experiments were conducted using activations from the 12th layer of the Gemma-2 2B model, with SAEs having a dictionary size of D=65,536D = 65,536. Key findings include:

  1. Pareto Optimality: A single HierarchicalTopK SAE (trained with K=128K=128) achieves a Pareto-optimal trade-off between Fraction of Unexplained Variance (FVU) and 0\ell_0 sparsity across various sparsity levels. It matches or outperforms baseline TopK and BatchTopK SAEs that were individually trained for specific sparsity levels (e.g., k{32,64,128}k \in \{32, 64, 128\}). This means one HierarchicalTopK model can replace multiple standard SAEs.
  2. Performance when changing 0\ell_0 at inference: When 0\ell_0 is varied at inference time (interpolating within the training range of K=128K=128), the HierarchicalTopK model consistently performs as well as or better than baselines.
  3. Interpretability: Using an automated interpretability score (AutoInterp score), the HierarchicalTopK SAE maintained high interpretability even at higher sparsity levels (e.g., 0=128\ell_0=128 score was similar to 0=32\ell_0=32). In contrast, standard TopK and BatchTopK models tended to show decreased interpretability at lower sparsity (higher 0\ell_0).
  4. Reduced Dead Features: HierarchicalTopK models exhibit significantly fewer "almost dead" features (features with activation frequency < 10510^{-5}) when 0\ell_0 is reduced at inference time, compared to TopK and BatchTopK models trained at a fixed higher kk and then evaluated at lower kk. This suggests features learned by HierarchicalTopK are more robustly useful across different sparsity budgets.
  5. Latent Structure: The cosine similarity between feature embeddings in the reconstruction sum showed that HierarchicalTopK maintains a desirable monotonic decrease in similarity as less important features are added, unlike vanilla TopK SAEs. Hierarchical training also leads to more latents with higher mean squared activation values.

Applications:

The HierarchicalTopK approach is particularly useful for:

  • Flexible Interpretability Analysis: Researchers can use a single model to explore features at various levels of detail (sparsity) without retraining.
  • Resource-Constrained Deployment: A single SAE can be adapted to different computational budgets at inference time by simply choosing a different 0\ell_0.
  • Understanding Feature Hierarchies: The progressive nature of the reconstruction encourages learning features that build upon each other, potentially revealing hierarchical structures in the learned representations.

Limitations:

The authors acknowledge two main limitations:

  1. Evaluation Scope: Experiments were limited to the Gemma-2 2B model and the FineWeb dataset. Further testing is needed on other model architectures and data.
  2. Interpretability Metrics: The paper relied on automated metrics for interpretability. Human studies would be beneficial to validate the semantic alignment of the learned features.

In conclusion, HierarchicalTopK offers a more flexible, efficient, and interpretable method for training SAEs. By optimizing for a range of sparsity levels simultaneously, it allows a single model to achieve strong performance and interpretability across these levels, reducing the need for multiple specialized models.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

Don't miss out on important new AI/ML research

See which papers are being discussed right now on X, Reddit, and more:

“Emergent Mind helps me see which AI papers have caught fire online.”

Philip

Philip

Creator, AI Explained on YouTube