Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
106 tokens/sec
Gemini 2.5 Pro Premium
53 tokens/sec
GPT-5 Medium
26 tokens/sec
GPT-5 High Premium
27 tokens/sec
GPT-4o
109 tokens/sec
DeepSeek R1 via Azure Premium
91 tokens/sec
GPT OSS 120B via Groq Premium
515 tokens/sec
Kimi K2 via Groq Premium
213 tokens/sec
2000 character limit reached

Jacobian Sparse Autoencoders: Sparsify Computations, Not Just Activations (2502.18147v2)

Published 25 Feb 2025 in cs.LG, cs.AI, and cs.CL

Abstract: Sparse autoencoders (SAEs) have been successfully used to discover sparse and human-interpretable representations of the latent activations of LLMs. However, we would ultimately like to understand the computations performed by LLMs and not just their representations. The extent to which SAEs can help us understand computations is unclear because they are not designed to "sparsify" computations in any sense, only latent activations. To solve this, we propose Jacobian SAEs (JSAEs), which yield not only sparsity in the input and output activations of a given model component but also sparsity in the computation (formally, the Jacobian) connecting them. With a na\"ive implementation, the Jacobians in LLMs would be computationally intractable due to their size. One key technical contribution is thus finding an efficient way of computing Jacobians in this setup. We find that JSAEs extract a relatively large degree of computational sparsity while preserving downstream LLM performance approximately as well as traditional SAEs. We also show that Jacobians are a reasonable proxy for computational sparsity because MLPs are approximately linear when rewritten in the JSAE basis. Lastly, we show that JSAEs achieve a greater degree of computational sparsity on pre-trained LLMs than on the equivalent randomized LLM. This shows that the sparsity of the computational graph appears to be a property that LLMs learn through training, and suggests that JSAEs might be more suitable for understanding learned transformer computations than standard SAEs.

Summary

  • The paper introduces a novel dual-SAE approach that sparsifies the Jacobian mapping between input and output activations in MLPs.
  • It leverages efficient k×k sub-matrix computations to reduce the cost of Jacobian evaluations in GPT-2 style architectures.
  • Empirical results show that induced computational sparsity maintains reconstruction quality while enhancing interpretability of LLMs.

This paper introduces Jacobian Sparse Autoencoders (JSAEs), a novel method designed to understand not just the representations within LLMs but also the computations performed by them, specifically within Multi-Layer Perceptrons (MLPs). Traditional Sparse Autoencoders (SAEs) excel at finding sparse, interpretable features in LLM activations, but they don't inherently sparsify the computational graph connecting these features across layers. JSAEs address this by training a pair of SAEs—one on the input and one on the output of an MLP—and adding a crucial term to the loss function that encourages sparsity in the Jacobian of the function mapping the input SAE's latent activations to the output SAE's latent activations.

Core Idea and Implementation

The setup involves two kk-sparse SAEs:

  1. Input SAE (SAExSAE_x): Encodes the MLP input xx into a sparse latent vector sxs_x, and decodes sxs_x back to x^\hat{x}.
    • sx=encoderx(x)=ϕ(Wexx+bex)s_x = \text{encoder}_x(x) = \phi(W_e^x x + b_e^x)
    • x^=decoderx(sx)=Wdxsx+bdx\hat{x} = \text{decoder}_x(s_x) = W_d^x s_x + b_d^x
  2. Output SAE (SAEySAE_y): Encodes the MLP output yy (where y=f(x)y=f(x) for an MLP ff) into a sparse latent vector sys_y, and decodes sys_y back to y^\hat{y}.
    • sy=encodery(y)=ϕ(Weyy+bey)s_y = \text{encoder}_y(y) = \phi(W_e^y y + b_e^y)
    • y^=decodery(sy)=Wdysy+bdy\hat{y} = \text{decoder}_y(s_y) = W_d^y s_y + b_d^y

The JSAE focuses on the function fs:SXSYf_s: S_X \to S_Y which describes the MLP's operation in the sparse bases learned by the SAEs:

fs=encoderyfdecoderxTopKf_s = \text{encoder}_y \circ f \circ \text{decoder}_x \circ \text{TopK}

The TopK\text{TopK} activation function is applied to the input SAE's latents sxs_x. For kk-sparse SAEs, where sxs_x already has only kk non-zero elements, TopK(sx)=sx\text{TopK}(s_x) = s_x. However, its inclusion is vital for efficiently computing the Jacobian.

The loss function for JSAEs is:

L=MSE(x,x^)+MSE(y,y^)+λk2i,jJfs,i,j\mathcal{L} = \text{MSE}(x, \hat{x}) + \text{MSE}(y, \hat{y}) + \frac{\lambda}{k^2} \sum_{i,j} |J_{f_s,i,j}|

where Jfs,i,j=fs(sx)isx,jJ_{f_s,i,j} = \frac{\partial f_s(s_x)_i}{\partial s_{x,j}} is an element of the Jacobian matrix of fsf_s, and λ\lambda is a hyperparameter controlling the strength of the Jacobian sparsity penalty. The term k2k^2 normalizes the penalty, as there are at most k2k^2 non-zero elements in the relevant part of the Jacobian due to the TopK\text{TopK} activation.

A key technical contribution is making the Jacobian calculation tractable. A naive computation would be excessively large ($BatchSize \times \text{num_output_features} \times \text{num_input_features}$). The authors leverage two insights:

  1. Effective Jacobian Size Reduction: Since sxs_x (after TopK\text{TopK}) has only kk active (non-zero) features, and we are interested in its effect on the kk active features of sys_y, we only need to consider a k×kk \times k sub-matrix of the Jacobian per token. The TopK\text{TopK} function ensures that derivatives with respect to non-active input features are zero.
  2. Efficient Formula: For GPT-2 style MLPs, the Jacobian of fsf_s (specifically, the k×kk \times k active part) can be computed efficiently using a derived formula involving three matrix multiplications and pointwise operations, avoiding computationally expensive auto-differentiation for each of the kk input features.

    Jfs(active)=(We,y(active)W2)diag(ϕMLP(z))(W1Wd,x(active))J_{f_s}^{(\text{active})} = (W_{e,y}^{(\text{active})} W_2) \cdot \text{diag}(\phi'_{\text{MLP}}(z)) \cdot (W_1 W_{d,x}^{(\text{active})})

    where W1,W2W_1, W_2 are MLP weights, Wd,x(active)W_{d,x}^{(\text{active})} and We,y(active)W_{e,y}^{(\text{active})} are the active columns/rows of the SAE decoder/encoder, and zz is the MLP hidden activation.

This optimized calculation makes training JSAEs only about twice as computationally expensive as training a single standard SAE.

Key Findings

  1. Induced Computational Sparsity: JSAEs significantly increase the sparsity of the Jacobian matrix (i.e., fewer strong connections between input and output SAE features) compared to standard SAEs, while maintaining comparable reconstruction quality and model performance.
  2. Hyperparameter Trade-off: There's a "sweet spot" for the Jacobian penalty coefficient λ\lambda (e.g., λ0.5\lambda \approx 0.5 for Pythia-410m, λ1\lambda \approx 1 for Pythia-70m) where substantial Jacobian sparsity is achieved with minimal degradation in SAE performance metrics.
  3. Interpretability: The interpretability of individual JSAE features, measured by automatic interpretability scores, remains similar to that of traditional SAEs.
  4. Learned Structure: JSAEs achieve significantly greater Jacobian sparsity when applied to pre-trained LLMs compared to randomly initialized LLMs. This suggests that the computational sparsity JSAEs find is a property learned during model training.
  5. Jacobian as a Valid Proxy: The function fsf_s is found to be mostly linear when mapping individual input SAE features to output SAE features. Most of these scalar mappings are linear or well-approximated by JumpReLU functions. This linearity means the local Jacobian values are good indicators of the global relationship, making Jacobian sparsity a strong proxy for actual computational sparsity (i.e., a near-zero Jacobian element implies a near-zero causal effect even for larger input changes). JSAEs also tend to increase the proportion of linear functions in fsf_s.

Practical Implications and Applications

  • Understanding LLM Computation: JSAEs offer a new tool for mechanistic interpretability, moving beyond understanding static representations to understanding information flow and transformation within model components like MLPs. This can help identify sparse "circuits" – how a few input concepts (SAE features) combine to produce output concepts.
  • Unsupervised Circuit Discovery: Unlike some circuit discovery methods that require task-specific supervision, JSAEs work in an unsupervised manner, potentially revealing fundamental computational pathways learned by the LLM.
  • Model Editing and Safety: By identifying sparse computational pathways, JSAEs could inform more targeted model editing techniques or help identify and mitigate undesirable computations.
  • Efficiency in Analysis: The tractable computation of Jacobians makes this approach scalable for analyzing large models.

Implementation Considerations:

  • Computational Cost: Training a JSAE pair is roughly twice the cost of training a single standard SAE.
  • Hyperparameter Tuning: The Jacobian coefficient λ\lambda needs to be tuned. The paper suggests optimal values depend on the model and layer. The kk for TopK sparsity is also a key parameter.
  • MLP Architecture: The efficient Jacobian derivation provided is for GPT-2 style MLPs. Extension to other architectures (e.g., GLU-based MLPs) would require deriving a new efficient Jacobian formula.
  • TopK SAEs: The current implementation relies on TopK SAEs for Jacobian efficiency. If other SAE activation schemes (e.g., L1L_1 penalized ReLU) are desired, the Jacobian computation strategy might need to be adapted.
  • Software: The implementation is based on the SAELens library. The code is available in supplementary material / GitHub.

Limitations and Future Work

  • Scope: Currently demonstrated on single MLP layers. Extending to multiple layers or entire models is a future step.
  • MLP Types: The specific efficient Jacobian formula is for GPT-2 style MLPs. Adaptation for GLU variants (common in modern LLMs) is needed.
  • Activation Functions: The reliance on TopK for efficiency might be a limitation if other SAE activation functions (which might avoid issues like high-density features) are preferred.
  • Interpreting Feature Pairs: While individual features are interpretable, interpreting the connections (large Jacobian elements) between input and output features requires further methodological development.
  • Linearity Assumption: While fsf_s is found to be mostly linear, the implications of non-linear interactions for the Jacobian as a proxy for computation need continued investigation.

In conclusion, Jacobian Sparse Autoencoders provide a promising and computationally feasible approach to uncover sparse computational structures within LLM MLPs. By directly optimizing for a sparse Jacobian, JSAEs go beyond interpreting static activations to shed light on how these activations are processed, paving the way for a deeper understanding of learned computations in neural networks.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.

Don't miss out on important new AI/ML research

See which papers are being discussed right now on X, Reddit, and more:

“Emergent Mind helps me see which AI papers have caught fire online.”

Philip

Philip

Creator, AI Explained on YouTube