- The paper introduces HierarchicalTopK, a novel objective that trains one sparse autoencoder across various sparsity levels to achieve competitive reconstruction and interpretability.
- It demonstrates that a single model can replace multiple fixed-sparsity SAEs, reducing computational cost and minimizing inactive ('dead') features.
- Experimental results show that adjusting ℓ0 at inference maintains high performance and interpretability, offering flexible and efficient feature hierarchies.
This paper introduces HierarchicalTopK, a novel training objective for Sparse Autoencoders (SAEs) that enables a single SAE model to perform effectively across multiple sparsity levels, or ℓ0 norms. Traditional SAEs are trained for a fixed sparsity level, meaning different models must be trained and stored if varying sparsity is required, increasing computational costs. HierarchicalTopK addresses this by optimizing reconstruction quality simultaneously for all sparsity levels up to a predefined maximum K.
The core idea is to modify the standard SAE reconstruction loss. A typical SAE reconstructs the input xˇ as xˇ^=Wdeclˇ+bˇdec, where lˇ are the sparse latent activations obtained by lˇ=σ(Wencxˇ+bˇenc). With TopK-based activations, only the top k latents are non-zero. The HierarchicalTopK approach defines a reconstruction xˇ^j for each sparsity level j from $1$ to K (or a subset J of these values):
xˇ^j=i∈topj∑lˇi(xˇ)eˇi+bˇdec
where eˇi are the decoder embeddings (columns of Wdec) and lˇi(xˇ) is the i-th component of the latent vector lˇ corresponding to the i-th largest activation. The total loss is then the average of the reconstruction errors for each j:
Lhierarchical=∣J∣1j∈J∑∥xˇ−xˇ^j∥2
This formulation encourages the SAE to learn features that are progressively useful, ensuring that reconstructions improve monotonically as more features (larger j) are included.
Implementation and Computational Cost:
The hierarchical loss can be computed efficiently. The paper notes that the cumulative sum of feature contributions needed for xˇ^j for all j can be calculated in a single forward pass.
A naive PyTorch implementation is provided:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
def hierarchical_loss(sparse_idx, sparse_val, decoder, b_dec, target):
"""
sparse_idx: LongTensor of shape (B, K) with indices of active embeddings
sparse_val: FloatTensor of shape (B, K) with corresponding activation values
decoder: FloatTensor of shape (D, h) containing the dictionary embeddings
b_dec: FloatTensor of shape (h) containing decoder bias
target: FloatTensor of shape (B, h) with the original inputs
"""
B, K = sparse_idx.shape
flatten_idx = sparse_idx.view(-1)
emb = decoder[flatten_idx].view(B, K, -1) # Shape (B, K, h)
emb = emb * sparse_val.unsqueeze(-1) # Scale embeddings by activations
# Compute cumulative reconstructions for j = 1 to K
# recon_cum[b, j-1, :] is the reconstruction using top j features for batch element b
recon_cum = emb.cumsum(dim=1) + b_dec.unsqueeze(0).unsqueeze(0) # Add bias to each cumulative sum
# Target needs to be broadcastable for subtraction: (B, 1, h)
diff = recon_cum - target.unsqueeze(1)
total_err = diff.pow(2).mean() # Mean over B, K, and h dimensions
return total_err |
The authors also developed optimized
Triton kernels for their fused HierarchicalTopK (referred to as FlexSAE in the appendix) which reportedly run faster than baseline TopK implementations while using similar peak memory. Specifically, for a batch size of 64, model dimension
h=2304, dictionary size
D=216, and
ℓ0=128, the "Fused Hierarchical" kernel had a time per step of
10.0531±0.0420 ms, a 4.07% speedup over the baseline TopK kernel from Gao et al. (2025).
The paper also explores reducing computational load by subsampling the terms in the hierarchical loss. Instead of summing over all j∈{1,…,K}, they test using Jx={1}∪{i∈N:imodx=0∧1<i≤K}. Computing the loss on every 8th term (J8) showed performance nearly identical to using the full set of terms, offering a potential 8x reduction in FLOPs for this part of the calculation, though the practical speedup in training time per step was minimal due to efficient kernel implementations.
Experimental Validation:
Experiments were conducted using activations from the 12th layer of the Gemma-2 2B model, with SAEs having a dictionary size of D=65,536. Key findings include:
- Pareto Optimality: A single HierarchicalTopK SAE (trained with K=128) achieves a Pareto-optimal trade-off between Fraction of Unexplained Variance (FVU) and ℓ0 sparsity across various sparsity levels. It matches or outperforms baseline TopK and BatchTopK SAEs that were individually trained for specific sparsity levels (e.g., k∈{32,64,128}). This means one HierarchicalTopK model can replace multiple standard SAEs.
- Performance when changing ℓ0 at inference: When ℓ0 is varied at inference time (interpolating within the training range of K=128), the HierarchicalTopK model consistently performs as well as or better than baselines.
- Interpretability: Using an automated interpretability score (AutoInterp score), the HierarchicalTopK SAE maintained high interpretability even at higher sparsity levels (e.g., ℓ0=128 score was similar to ℓ0=32). In contrast, standard TopK and BatchTopK models tended to show decreased interpretability at lower sparsity (higher ℓ0).
- Reduced Dead Features: HierarchicalTopK models exhibit significantly fewer "almost dead" features (features with activation frequency < 10−5) when ℓ0 is reduced at inference time, compared to TopK and BatchTopK models trained at a fixed higher k and then evaluated at lower k. This suggests features learned by HierarchicalTopK are more robustly useful across different sparsity budgets.
- Latent Structure: The cosine similarity between feature embeddings in the reconstruction sum showed that HierarchicalTopK maintains a desirable monotonic decrease in similarity as less important features are added, unlike vanilla TopK SAEs. Hierarchical training also leads to more latents with higher mean squared activation values.
Applications:
The HierarchicalTopK approach is particularly useful for:
- Flexible Interpretability Analysis: Researchers can use a single model to explore features at various levels of detail (sparsity) without retraining.
- Resource-Constrained Deployment: A single SAE can be adapted to different computational budgets at inference time by simply choosing a different ℓ0.
- Understanding Feature Hierarchies: The progressive nature of the reconstruction encourages learning features that build upon each other, potentially revealing hierarchical structures in the learned representations.
Limitations:
The authors acknowledge two main limitations:
- Evaluation Scope: Experiments were limited to the Gemma-2 2B model and the FineWeb dataset. Further testing is needed on other model architectures and data.
- Interpretability Metrics: The paper relied on automated metrics for interpretability. Human studies would be beneficial to validate the semantic alignment of the learned features.
In conclusion, HierarchicalTopK offers a more flexible, efficient, and interpretable method for training SAEs. By optimizing for a range of sparsity levels simultaneously, it allows a single model to achieve strong performance and interpretability across these levels, reducing the need for multiple specialized models.