- The paper introduces a novel dual-SAE approach that sparsifies the Jacobian mapping between input and output activations in MLPs.
- It leverages efficient k×k sub-matrix computations to reduce the cost of Jacobian evaluations in GPT-2 style architectures.
- Empirical results show that induced computational sparsity maintains reconstruction quality while enhancing interpretability of LLMs.
This paper introduces Jacobian Sparse Autoencoders (JSAEs), a novel method designed to understand not just the representations within LLMs but also the computations performed by them, specifically within Multi-Layer Perceptrons (MLPs). Traditional Sparse Autoencoders (SAEs) excel at finding sparse, interpretable features in LLM activations, but they don't inherently sparsify the computational graph connecting these features across layers. JSAEs address this by training a pair of SAEs—one on the input and one on the output of an MLP—and adding a crucial term to the loss function that encourages sparsity in the Jacobian of the function mapping the input SAE's latent activations to the output SAE's latent activations.
Core Idea and Implementation
The setup involves two k-sparse SAEs:
- Input SAE (SAEx): Encodes the MLP input x into a sparse latent vector sx, and decodes sx back to x^.
- sx=encoderx(x)=ϕ(Wexx+bex)
- x^=decoderx(sx)=Wdxsx+bdx
- Output SAE (SAEy): Encodes the MLP output y (where y=f(x) for an MLP f) into a sparse latent vector sy, and decodes sy back to y^.
- sy=encodery(y)=ϕ(Weyy+bey)
- y^=decodery(sy)=Wdysy+bdy
The JSAE focuses on the function fs:SX→SY which describes the MLP's operation in the sparse bases learned by the SAEs:
fs=encodery∘f∘decoderx∘TopK
The TopK activation function is applied to the input SAE's latents sx. For k-sparse SAEs, where sx already has only k non-zero elements, TopK(sx)=sx. However, its inclusion is vital for efficiently computing the Jacobian.
The loss function for JSAEs is:
L=MSE(x,x^)+MSE(y,y^)+k2λi,j∑∣Jfs,i,j∣
where Jfs,i,j=∂sx,j∂fs(sx)i is an element of the Jacobian matrix of fs, and λ is a hyperparameter controlling the strength of the Jacobian sparsity penalty. The term k2 normalizes the penalty, as there are at most k2 non-zero elements in the relevant part of the Jacobian due to the TopK activation.
A key technical contribution is making the Jacobian calculation tractable. A naive computation would be excessively large ($BatchSize \times \text{num_output_features} \times \text{num_input_features}$). The authors leverage two insights:
- Effective Jacobian Size Reduction: Since sx (after TopK) has only k active (non-zero) features, and we are interested in its effect on the k active features of sy, we only need to consider a k×k sub-matrix of the Jacobian per token. The TopK function ensures that derivatives with respect to non-active input features are zero.
- Efficient Formula: For GPT-2 style MLPs, the Jacobian of fs (specifically, the k×k active part) can be computed efficiently using a derived formula involving three matrix multiplications and pointwise operations, avoiding computationally expensive auto-differentiation for each of the k input features.
Jfs(active)=(We,y(active)W2)⋅diag(ϕMLP′(z))⋅(W1Wd,x(active))
where W1,W2 are MLP weights, Wd,x(active) and We,y(active) are the active columns/rows of the SAE decoder/encoder, and z is the MLP hidden activation.
This optimized calculation makes training JSAEs only about twice as computationally expensive as training a single standard SAE.
Key Findings
- Induced Computational Sparsity: JSAEs significantly increase the sparsity of the Jacobian matrix (i.e., fewer strong connections between input and output SAE features) compared to standard SAEs, while maintaining comparable reconstruction quality and model performance.
- Hyperparameter Trade-off: There's a "sweet spot" for the Jacobian penalty coefficient λ (e.g., λ≈0.5 for Pythia-410m, λ≈1 for Pythia-70m) where substantial Jacobian sparsity is achieved with minimal degradation in SAE performance metrics.
- Interpretability: The interpretability of individual JSAE features, measured by automatic interpretability scores, remains similar to that of traditional SAEs.
- Learned Structure: JSAEs achieve significantly greater Jacobian sparsity when applied to pre-trained LLMs compared to randomly initialized LLMs. This suggests that the computational sparsity JSAEs find is a property learned during model training.
- Jacobian as a Valid Proxy: The function fs is found to be mostly linear when mapping individual input SAE features to output SAE features. Most of these scalar mappings are linear or well-approximated by JumpReLU functions. This linearity means the local Jacobian values are good indicators of the global relationship, making Jacobian sparsity a strong proxy for actual computational sparsity (i.e., a near-zero Jacobian element implies a near-zero causal effect even for larger input changes). JSAEs also tend to increase the proportion of linear functions in fs.
Practical Implications and Applications
- Understanding LLM Computation: JSAEs offer a new tool for mechanistic interpretability, moving beyond understanding static representations to understanding information flow and transformation within model components like MLPs. This can help identify sparse "circuits" – how a few input concepts (SAE features) combine to produce output concepts.
- Unsupervised Circuit Discovery: Unlike some circuit discovery methods that require task-specific supervision, JSAEs work in an unsupervised manner, potentially revealing fundamental computational pathways learned by the LLM.
- Model Editing and Safety: By identifying sparse computational pathways, JSAEs could inform more targeted model editing techniques or help identify and mitigate undesirable computations.
- Efficiency in Analysis: The tractable computation of Jacobians makes this approach scalable for analyzing large models.
Implementation Considerations:
- Computational Cost: Training a JSAE pair is roughly twice the cost of training a single standard SAE.
- Hyperparameter Tuning: The Jacobian coefficient λ needs to be tuned. The paper suggests optimal values depend on the model and layer. The k for TopK sparsity is also a key parameter.
- MLP Architecture: The efficient Jacobian derivation provided is for GPT-2 style MLPs. Extension to other architectures (e.g., GLU-based MLPs) would require deriving a new efficient Jacobian formula.
- TopK SAEs: The current implementation relies on TopK SAEs for Jacobian efficiency. If other SAE activation schemes (e.g., L1 penalized ReLU) are desired, the Jacobian computation strategy might need to be adapted.
- Software: The implementation is based on the SAELens library. The code is available in supplementary material / GitHub.
Limitations and Future Work
- Scope: Currently demonstrated on single MLP layers. Extending to multiple layers or entire models is a future step.
- MLP Types: The specific efficient Jacobian formula is for GPT-2 style MLPs. Adaptation for GLU variants (common in modern LLMs) is needed.
- Activation Functions: The reliance on TopK for efficiency might be a limitation if other SAE activation functions (which might avoid issues like high-density features) are preferred.
- Interpreting Feature Pairs: While individual features are interpretable, interpreting the connections (large Jacobian elements) between input and output features requires further methodological development.
- Linearity Assumption: While fs is found to be mostly linear, the implications of non-linear interactions for the Jacobian as a proxy for computation need continued investigation.
In conclusion, Jacobian Sparse Autoencoders provide a promising and computationally feasible approach to uncover sparse computational structures within LLM MLPs. By directly optimizing for a sparse Jacobian, JSAEs go beyond interpreting static activations to shed light on how these activations are processed, paving the way for a deeper understanding of learned computations in neural networks.