Papers
Topics
Authors
Recent
Assistant
AI Research Assistant
Well-researched responses based on relevant abstracts and paper content.
Custom Instructions Pro
Preferences or requirements that you'd like Emergent Mind to consider when generating responses.
Gemini 2.5 Flash
Gemini 2.5 Flash 82 tok/s
Gemini 2.5 Pro 48 tok/s Pro
GPT-5 Medium 36 tok/s Pro
GPT-5 High 32 tok/s Pro
GPT-4o 110 tok/s Pro
Kimi K2 185 tok/s Pro
GPT OSS 120B 456 tok/s Pro
Claude Sonnet 4.5 34 tok/s Pro
2000 character limit reached

Incorporating Hierarchical Semantics in Sparse Autoencoder Architectures (2506.01197v1)

Published 1 Jun 2025 in cs.CL, cs.AI, and cs.LG

Abstract: Sparse dictionary learning (and, in particular, sparse autoencoders) attempts to learn a set of human-understandable concepts that can explain variation on an abstract space. A basic limitation of this approach is that it neither exploits nor represents the semantic relationships between the learned concepts. In this paper, we introduce a modified SAE architecture that explicitly models a semantic hierarchy of concepts. Application of this architecture to the internal representations of LLMs shows both that semantic hierarchy can be learned, and that doing so improves both reconstruction and interpretability. Additionally, the architecture leads to significant improvements in computational efficiency.

Summary

  • The paper introduces a hierarchical SAE that combines a top-level encoder with expert-specific low-level autoencoders to enhance feature interpretability.
  • It employs a Mixture-of-Experts activation with projection matrices to reduce feature splitting and achieve computational efficiency.
  • Empirical results demonstrate improved reconstruction performance and reduced feature absorption compared to standard sparse autoencoders.

This paper introduces a Hierarchical Sparse Autoencoder (H-SAE) architecture designed to improve the interpretability and reconstruction performance of Sparse Autoencoders (SAEs) by explicitly modeling the hierarchical structure inherent in semantic concepts. Standard SAEs learn a flat set of features, which can lead to "feature splitting" (a single concept represented by multiple specialized features) and a trade-off between reconstruction accuracy and feature interpretability. The H-SAE aims to mitigate these issues.

The core idea is inspired by findings that LLMs represent categorical concepts with a parent feature (indicating concept activation) and a low-rank subspace containing child features (specific instances of the concept) (Shafayat et al., 16 Mar 2024). The H-SAE architecture mirrors this structure:

  1. Top-Level SAE: A standard SAE with a relatively small number of features, designed to capture high-level concepts.
  2. Projection Matrices: For each feature (expert) in the top-level SAE, there are learnable down-projection (Πjdown\mathbf{\Pi}^{\text{down}}_j) and up-projection (Πjup\mathbf{\Pi}^{\text{up}}_j) matrices. These map the input to a lower-dimensional subspace associated with the high-level concept and then back to the original space.
  3. Low-Level SAEs (Experts): Each high-level feature has an associated low-level SAE that operates on the projected low-dimensional subspace. These low-level SAEs learn finer-grained features (sub-latents) specific to the activated high-level concept.

A key aspect is the Mixture-of-Experts (MoE) style activation: a low-level SAE is only activated if its corresponding high-level feature is among the top-kk activated features. This respects the conceptual hierarchy (e.g., "corgi" can only be active if "dog" is active) and significantly improves computational efficiency. The low-level SAEs use a TopK1\text{TopK}_1 operation, meaning only a single sub-latent is chosen per activated expert.

The forward pass for the H-SAE is given by:

H-SAE(x)=jTopKIndicesk(zjdj+ΠjupSAE1j(Πjdownx))\text{H-SAE}(\mathbf{x}) = \sum_{j \in \text{TopKIndices}_k} \left( z_j \mathbf{d_j} + \mathbf{\Pi_j^{\text{up}}} \text{SAE}^j_1(\mathbf{\Pi_j^{\text{down}}}\mathbf{x}) \right)

where x\mathbf{x} is the input, TopKIndicesk\text{TopKIndices}_k are the indices of the top kk activated high-level features, zjz_j is the activation of the jj-th high-level feature, dj\mathbf{d_j} is the jj-th high-level decoder vector (feature), and SAE1j\text{SAE}^j_1 is the expert-specific low-level autoencoder for the jj-th high-level feature.

The training objective is:

L=Lrecon+λ1Lortho+λ2Lsparse\mathcal{L} = \mathcal{L}_\text{recon} + \lambda_1 \mathcal{L}_\text{ortho} + \lambda_2 \mathcal{L}_\text{sparse}

where:

  • Lrecon=xH-SAE(x)22+βxx^high22\mathcal{L}_\text{recon} = \|\mathbf{x} - \text{H-SAE}(\mathbf{x})\|_2^2 + \beta\|\mathbf{x} - \mathbf{\hat{x}}^{\text{high}}\|_2^2: The reconstruction loss includes a term for the overall reconstruction and a term (β=0.1\beta=0.1) specifically for the reconstruction from only the top-level SAE. This encourages the top-level features to be meaningful on their own. x^high=DTopKk(LeakyReLUα(E(xb)))\mathbf{\hat{x}}^{\text{high}} = \mathbf{D} \text{TopK}_k(\text{LeakyReLU}_\alpha(\mathbf{E} (\mathbf{x}-\mathbf{b}))).
  • Lortho=EDdiag(ED)Fmtop2mtop\mathcal{L}_\text{ortho} = \frac{\|\mathbf{E}\mathbf{D} - \text{diag}(\mathbf{E}\mathbf{D})\|_F}{m_{\text{top}}^2 - m_{\text{top}}}: A bi-orthogonality penalty on the top-level encoder (E\mathbf{E}) and decoder (D\mathbf{D}) matrices to discourage semantic redundancy and reduce dead features. mtopm_{\text{top}} is the number of top-level features.
  • Lsparse\mathcal{L}_\text{sparse}: An 1\ell_1 penalty on latent activations (both top and low-level) outside the top-k to encourage further specialization.

Algorithm 1 details the forward pass and loss computation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def forward_pass(x, E_top, D_top, top_k, E_experts, D_experts, Pi_down, Pi_up, leaky_relu_alpha):
    # High-level encoding
    encoded_top_raw = leaky_relu_alpha(E_top @ x) # Assuming x is already (x-b)
    top_k_activations, top_k_indices = get_top_k(encoded_top_raw, top_k) # TopK operation

    # High-level reconstruction
    x_hat_high = D_top @ top_k_activations # Or D_top[:, top_k_indices] @ top_k_activations if activations are sparse

    # Initialize low-level reconstruction
    x_hat_low = np.zeros_like(x)
    
    # Store low-level activations for sparsity loss
    z_low_levels = []

    for j in top_k_indices:
        # Project to expert subspace
        x_sub_j = Pi_down[j] @ x
        
        # Low-level encoding (using SAE_1, so k_low = 1)
        # Note: The paper uses SAE^j_1, implying a TopK_1 operation within the expert SAE.
        # This simplifies to selecting the max activated feature in the expert.
        encoded_low_raw_j = leaky_relu_alpha(E_experts[j] @ x_sub_j)
        
        # For SAE_1 (TopK_1):
        # max_low_activation_idx = np.argmax(encoded_low_raw_j)
        # z_j_low_sparse = np.zeros_like(encoded_low_raw_j)
        # z_j_low_sparse[max_low_activation_idx] = encoded_low_raw_j[max_low_activation_idx]
        # For simplicity, if the expert SAE itself is a TopK_1 SAE:
        z_j_low_sparse = get_top_k(encoded_low_raw_j, 1)[0] # Assuming get_top_k returns sparse activations
        z_low_levels.append(z_j_low_sparse) # Store for sparsity loss
        
        # Reconstruct in subspace
        x_hat_sub_j = D_experts[j] @ z_j_low_sparse
        
        # Project back and accumulate
        x_hat_low += Pi_up[j] @ x_hat_sub_j
        
    # Combined reconstruction
    x_hat = x_hat_high + x_hat_low
    
    return x_hat, top_k_activations, z_low_levels, top_k_indices

def compute_loss(x, x_hat, x_hat_high, z_top, z_low_levels, E_top, D_top, beta, lambda1, lambda2):
    # Reconstruction loss
    l_recon_total = np.sum((x - x_hat)**2)
    l_recon_top = np.sum((x - x_hat_high)**2)
    l_recon = l_recon_total + beta * l_recon_top
    
    # Sparsity loss
    l_sparse = np.sum(np.abs(z_top)) # L1 on all top activations before TopK
    for z_j_low in z_low_levels:
        l_sparse += np.sum(np.abs(z_j_low)) # L1 on all low-level expert activations before TopK_1
        
    # Orthogonality loss
    ED_prod = E_top @ D_top
    diag_ED = np.diag(np.diag(ED_prod))
    m_top = D_top.shape[1] # Number of top-level features
    l_ortho = np.linalg.norm(ED_prod - diag_ED, 'fro')**2 / (m_top**2 - m_top)
    
    total_loss = l_recon + lambda1 * l_ortho + lambda2 * l_sparse
    return total_loss

def get_top_k(activations, k):
    indices = np.argsort(activations)[-k:]
    values = activations[indices]
    sparse_activations = np.zeros_like(activations)
    sparse_activations[indices] = values
    return sparse_activations, indices # Or just values and indices depending on subsequent use

Experimental Setup and Results:

  • Data: 1 billion residual stream vectors from layer 20 of Gemma 2-2B, extracted from Wikipedia articles. Vectors are normalized to unit norm.
  • Baseline: TopK SAE.
  • Reconstruction: H-SAE shows significantly better reconstruction performance (lower 1 - explained variance and lower LLM CrossEntropy loss when reconstructed activations are used) compared to standard SAEs. For example, an H-SAE with 8k top-level features and 64 sub-latents per expert performs comparably to a standard SAE with 32k features, but with 1/4th the compute cost for the top-level.
  • Interpretability:
    • Qualitative: Visualizations (Figure 1, 2, 3, 7, 8) show H-SAE learns meaningful hierarchical features (e.g., "marriage" high-level, "divorce," "engagement" low-level; "airports" high-level, "US airport," "airport size" low-level).
    • Feature Absorption: H-SAE shows less feature absorption (undesirable merging of distinct concepts into one feature or splitting of one concept) on the SAEBench first-letter classification task. The H-SAE architecture had a lower "Mean Absorption Fraction Score" (Figure 5a).
    • Cross-Lingual Redundancy: H-SAE activates more similar sets of features for the same token in different languages (English, French, Spanish, German), indicating less redundancy and better composability (Figure 5b). It achieved lower mean set differences.
  • Computational Efficiency: Due to the sparse activation of experts, the H-SAE adds negligible computational overhead compared to a standard SAE with the same number of top-level features, while offering a much larger effective dictionary size (mtop×mlowm_{\text{top}} \times m_{\text{low}}).

Implementation Considerations:

  • Implemented in JAX and Equinox.
  • Trained with a batch size of 32,512 and top-k of 32 for high-level features.
  • Subspace dimension (ss) was 4 for 16 sub-latents per expert, and 8 otherwise.
  • λ1\lambda_1 (orthogonality) = 0.1, β\beta (top-level recon) = 0.1, λ2\lambda_2 (L1 sparsity) = 0.001.
  • Adam optimizer, learning rate 51045 \cdot 10^{-4} with warmup and cosine decay.
  • Ablation studies on token unembeddings (Appendix B) suggest that whitening the input (multiplying by the inverse square root of the covariance matrix) is crucial for learning meaningful features in that context, aligning with the concept of a "causal inner product." The orthogonality and 1\ell_1 regularizers were not strictly necessary for interpretability in these ablations but were kept for practical benefits like reducing dead latents.

Limitations and Future Work:

  • While improved, results are not perfect; some hard-to-interpret features remain, and reconstruction is not perfect.
  • The paper suggests exploring non-Euclidean reconstruction objectives or more sophisticated objectives from causal representation learning.

In summary, the H-SAE architecture provides a practical method to improve SAEs by incorporating semantic hierarchy. This leads to better reconstruction, improved interpretability (less feature splitting/absorption, more composable features), and significant computational efficiency, allowing for effectively larger and more fine-grained dictionaries.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets

This paper has been mentioned in 3 posts and received 17 likes.