- The paper introduces a hierarchical SAE that combines a top-level encoder with expert-specific low-level autoencoders to enhance feature interpretability.
- It employs a Mixture-of-Experts activation with projection matrices to reduce feature splitting and achieve computational efficiency.
- Empirical results demonstrate improved reconstruction performance and reduced feature absorption compared to standard sparse autoencoders.
This paper introduces a Hierarchical Sparse Autoencoder (H-SAE) architecture designed to improve the interpretability and reconstruction performance of Sparse Autoencoders (SAEs) by explicitly modeling the hierarchical structure inherent in semantic concepts. Standard SAEs learn a flat set of features, which can lead to "feature splitting" (a single concept represented by multiple specialized features) and a trade-off between reconstruction accuracy and feature interpretability. The H-SAE aims to mitigate these issues.
The core idea is inspired by findings that LLMs represent categorical concepts with a parent feature (indicating concept activation) and a low-rank subspace containing child features (specific instances of the concept) (Shafayat et al., 16 Mar 2024). The H-SAE architecture mirrors this structure:
- Top-Level SAE: A standard SAE with a relatively small number of features, designed to capture high-level concepts.
- Projection Matrices: For each feature (expert) in the top-level SAE, there are learnable down-projection (Πjdown) and up-projection (Πjup) matrices. These map the input to a lower-dimensional subspace associated with the high-level concept and then back to the original space.
- Low-Level SAEs (Experts): Each high-level feature has an associated low-level SAE that operates on the projected low-dimensional subspace. These low-level SAEs learn finer-grained features (sub-latents) specific to the activated high-level concept.
A key aspect is the Mixture-of-Experts (MoE) style activation: a low-level SAE is only activated if its corresponding high-level feature is among the top-k activated features. This respects the conceptual hierarchy (e.g., "corgi" can only be active if "dog" is active) and significantly improves computational efficiency. The low-level SAEs use a TopK1 operation, meaning only a single sub-latent is chosen per activated expert.
The forward pass for the H-SAE is given by:
H-SAE(x)=j∈TopKIndicesk∑(zjdj+ΠjupSAE1j(Πjdownx))
where x is the input, TopKIndicesk are the indices of the top k activated high-level features, zj is the activation of the j-th high-level feature, dj is the j-th high-level decoder vector (feature), and SAE1j is the expert-specific low-level autoencoder for the j-th high-level feature.
The training objective is:
L=Lrecon+λ1Lortho+λ2Lsparse
where:
- Lrecon=∥x−H-SAE(x)∥22+β∥x−x^high∥22: The reconstruction loss includes a term for the overall reconstruction and a term (β=0.1) specifically for the reconstruction from only the top-level SAE. This encourages the top-level features to be meaningful on their own. x^high=DTopKk(LeakyReLUα(E(x−b))).
- Lortho=mtop2−mtop∥ED−diag(ED)∥F: A bi-orthogonality penalty on the top-level encoder (E) and decoder (D) matrices to discourage semantic redundancy and reduce dead features. mtop is the number of top-level features.
- Lsparse: An ℓ1 penalty on latent activations (both top and low-level) outside the top-k to encourage further specialization.
Algorithm 1 details the forward pass and loss computation:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
def forward_pass(x, E_top, D_top, top_k, E_experts, D_experts, Pi_down, Pi_up, leaky_relu_alpha):
# High-level encoding
encoded_top_raw = leaky_relu_alpha(E_top @ x) # Assuming x is already (x-b)
top_k_activations, top_k_indices = get_top_k(encoded_top_raw, top_k) # TopK operation
# High-level reconstruction
x_hat_high = D_top @ top_k_activations # Or D_top[:, top_k_indices] @ top_k_activations if activations are sparse
# Initialize low-level reconstruction
x_hat_low = np.zeros_like(x)
# Store low-level activations for sparsity loss
z_low_levels = []
for j in top_k_indices:
# Project to expert subspace
x_sub_j = Pi_down[j] @ x
# Low-level encoding (using SAE_1, so k_low = 1)
# Note: The paper uses SAE^j_1, implying a TopK_1 operation within the expert SAE.
# This simplifies to selecting the max activated feature in the expert.
encoded_low_raw_j = leaky_relu_alpha(E_experts[j] @ x_sub_j)
# For SAE_1 (TopK_1):
# max_low_activation_idx = np.argmax(encoded_low_raw_j)
# z_j_low_sparse = np.zeros_like(encoded_low_raw_j)
# z_j_low_sparse[max_low_activation_idx] = encoded_low_raw_j[max_low_activation_idx]
# For simplicity, if the expert SAE itself is a TopK_1 SAE:
z_j_low_sparse = get_top_k(encoded_low_raw_j, 1)[0] # Assuming get_top_k returns sparse activations
z_low_levels.append(z_j_low_sparse) # Store for sparsity loss
# Reconstruct in subspace
x_hat_sub_j = D_experts[j] @ z_j_low_sparse
# Project back and accumulate
x_hat_low += Pi_up[j] @ x_hat_sub_j
# Combined reconstruction
x_hat = x_hat_high + x_hat_low
return x_hat, top_k_activations, z_low_levels, top_k_indices
def compute_loss(x, x_hat, x_hat_high, z_top, z_low_levels, E_top, D_top, beta, lambda1, lambda2):
# Reconstruction loss
l_recon_total = np.sum((x - x_hat)**2)
l_recon_top = np.sum((x - x_hat_high)**2)
l_recon = l_recon_total + beta * l_recon_top
# Sparsity loss
l_sparse = np.sum(np.abs(z_top)) # L1 on all top activations before TopK
for z_j_low in z_low_levels:
l_sparse += np.sum(np.abs(z_j_low)) # L1 on all low-level expert activations before TopK_1
# Orthogonality loss
ED_prod = E_top @ D_top
diag_ED = np.diag(np.diag(ED_prod))
m_top = D_top.shape[1] # Number of top-level features
l_ortho = np.linalg.norm(ED_prod - diag_ED, 'fro')**2 / (m_top**2 - m_top)
total_loss = l_recon + lambda1 * l_ortho + lambda2 * l_sparse
return total_loss
def get_top_k(activations, k):
indices = np.argsort(activations)[-k:]
values = activations[indices]
sparse_activations = np.zeros_like(activations)
sparse_activations[indices] = values
return sparse_activations, indices # Or just values and indices depending on subsequent use |
Experimental Setup and Results:
- Data: 1 billion residual stream vectors from layer 20 of Gemma 2-2B, extracted from Wikipedia articles. Vectors are normalized to unit norm.
- Baseline: TopK SAE.
- Reconstruction: H-SAE shows significantly better reconstruction performance (lower 1 - explained variance and lower LLM CrossEntropy loss when reconstructed activations are used) compared to standard SAEs. For example, an H-SAE with 8k top-level features and 64 sub-latents per expert performs comparably to a standard SAE with 32k features, but with 1/4th the compute cost for the top-level.
- Interpretability:
- Qualitative: Visualizations (Figure 1, 2, 3, 7, 8) show H-SAE learns meaningful hierarchical features (e.g., "marriage" high-level, "divorce," "engagement" low-level; "airports" high-level, "US airport," "airport size" low-level).
- Feature Absorption: H-SAE shows less feature absorption (undesirable merging of distinct concepts into one feature or splitting of one concept) on the SAEBench first-letter classification task. The H-SAE architecture had a lower "Mean Absorption Fraction Score" (Figure 5a).
- Cross-Lingual Redundancy: H-SAE activates more similar sets of features for the same token in different languages (English, French, Spanish, German), indicating less redundancy and better composability (Figure 5b). It achieved lower mean set differences.
- Computational Efficiency: Due to the sparse activation of experts, the H-SAE adds negligible computational overhead compared to a standard SAE with the same number of top-level features, while offering a much larger effective dictionary size (mtop×mlow).
Implementation Considerations:
- Implemented in JAX and Equinox.
- Trained with a batch size of 32,512 and top-k of 32 for high-level features.
- Subspace dimension (s) was 4 for 16 sub-latents per expert, and 8 otherwise.
- λ1 (orthogonality) = 0.1, β (top-level recon) = 0.1, λ2 (L1 sparsity) = 0.001.
- Adam optimizer, learning rate 5⋅10−4 with warmup and cosine decay.
- Ablation studies on token unembeddings (Appendix B) suggest that whitening the input (multiplying by the inverse square root of the covariance matrix) is crucial for learning meaningful features in that context, aligning with the concept of a "causal inner product." The orthogonality and ℓ1 regularizers were not strictly necessary for interpretability in these ablations but were kept for practical benefits like reducing dead latents.
Limitations and Future Work:
- While improved, results are not perfect; some hard-to-interpret features remain, and reconstruction is not perfect.
- The paper suggests exploring non-Euclidean reconstruction objectives or more sophisticated objectives from causal representation learning.
In summary, the H-SAE architecture provides a practical method to improve SAEs by incorporating semantic hierarchy. This leads to better reconstruction, improved interpretability (less feature splitting/absorption, more composable features), and significant computational efficiency, allowing for effectively larger and more fine-grained dictionaries.