Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
119 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders (2407.14435v3)

Published 19 Jul 2024 in cs.LG

Abstract: Sparse autoencoders (SAEs) are a promising unsupervised approach for identifying causally relevant and interpretable linear features in a LLM's (LM) activations. To be useful for downstream tasks, SAEs need to decompose LM activations faithfully; yet to be interpretable the decomposition must be sparse -- two objectives that are in tension. In this paper, we introduce JumpReLU SAEs, which achieve state-of-the-art reconstruction fidelity at a given sparsity level on Gemma 2 9B activations, compared to other recent advances such as Gated and TopK SAEs. We also show that this improvement does not come at the cost of interpretability through manual and automated interpretability studies. JumpReLU SAEs are a simple modification of vanilla (ReLU) SAEs -- where we replace the ReLU with a discontinuous JumpReLU activation function -- and are similarly efficient to train and run. By utilising straight-through-estimators (STEs) in a principled manner, we show how it is possible to train JumpReLU SAEs effectively despite the discontinuous JumpReLU function introduced in the SAE's forward pass. Similarly, we use STEs to directly train L0 to be sparse, instead of training on proxies such as L1, avoiding problems like shrinkage.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (7)
  1. Senthooran Rajamanoharan (11 papers)
  2. Tom Lieberum (8 papers)
  3. Nicolas Sonnerat (10 papers)
  4. Arthur Conmy (22 papers)
  5. Vikrant Varma (10 papers)
  6. János Kramár (19 papers)
  7. Neel Nanda (50 papers)
Citations (28)

Summary

Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders

The paper "Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders" addresses significant issues in the domain of sparse autoencoders (SAEs), particularly focusing on improving the trade-off between sparsity and reconstruction fidelity. This topic is critical for interpreting linear features in LLM (LM) activations, which have numerous downstream applications in artificial intelligence.

Introduction

The fundamental aim of SAEs is to provide a sparse and interpretable decomposition of LM activations. SAEs achieve this by approximating sparse linear decompositions of these activations using a large dictionary of basic “feature” directions. However, traditional ReLU-based SAEs face a trade-off between sparseness and the fidelity of reconstructions. Increasing sparsity often results in a loss of reconstruction quality, posing a challenge to effectively utilize these decompositions.

JumpReLU Activation Function

The paper introduces a novel activation function, JumpReLU, which aims to alleviate this long-standing issue in SAEs. The JumpReLU function is a modification of the standard ReLU, defined by zeroing out pre-activations below a positive threshold, thereby reducing false positives while maintaining high fidelity in the reconstruction process. This innovation leverages straight-through-estimators (STEs) to handle the discontinuities in the JumpReLU function, allowing the model to be trained effectively using gradient-based methods.

Methodology

Sparse Autoencoder Training

The proposed SAEs, termed as JumpReLU SAEs, were trained on activations from the Gemma 2 9B model. The training utilized a loss function combining L2 reconstruction error and L0 sparsity penalty, thus directly optimizing for sparsity without needing proxies such as L1. This represents a significant methodological shift, as L1 penalties often lead to feature underestimations due to their shrinkage effect on feature magnitudes.

Kernel Density Estimation and STEs

SAEs using the JumpReLU activation were trained by estimating gradients of the expected loss. The application of STEs allowed the encoding of the gradients' density estimations, providing more accurate and efficient gradient signals. This STE approach connecting to Kernel Density Estimation (KDE) enabled the introduction of pseudo-derivatives, which maintained computational feasibility while ensuring training accuracy despite the discontinuities in the JumpReLU function.

Evaluation

Reconstruction Fidelity and Sparsity

The paper provides comprehensive evaluations comparing JumpReLU SAEs with Gated and TopK SAEs across multiple layers and sites within the Gemma 2 9B model. JumpReLU SAEs exhibited superior fidelity in reconstructions at matched levels of sparsity, proving their efficacy. Specifically, JumpReLU achieved better performing Pareto frontiers, offering strong numerical results that evidence the effectiveness of this innovation.

Interpretability

Interpretability was another critical aspect addressed. JumpReLU SAEs maintained comparable levels of interpretability to TopK and Gated SAEs, as assessed through both manual ratings and automated methods. The paper described a detailed interpretability paper, confirming that most features derived using JumpReLU SAEs were similarly or more interpretable.

Practical Implications

The practical implications of these advancements are multifaceted. JumpReLU SAEs provide a better balance between sparsity and reconstruction fidelity, crucial for practical applications such as circuit analysis and model steering. Additionally, the efficiency gains due to the simpler training process (compared to TopK's partial sort complexity) can result in more scalable and computationally efficient models.

Conclusion

JumpReLU SAEs represent an incremental advancement in sparse autoencoder architectures, demonstrating improved sparsity versus fidelity trade-offs without sacrificing interpretability. These improvements position JumpReLU SAEs as a promising tool for more interpretable and computationally efficient LLM decompositions. Future research may explore additional enhancements in loss functions or further optimize hyperparameters like the threshold initialization and bandwidth parameters to refine the training process and performance of JumpReLU SAEs further.

Acknowledgments and Contributions

The paper acknowledges the collective efforts of the research team, with contributions spanning from conceptualizing the new activation function, implementing the SAE training, conducting extensive evaluations, and developing interpretability assessment pipelines. This collaborative effort underscores the importance of interdisciplinary contributions to advancing AI research.

By presenting a rigorous and detailed exploration into a novel SAE architecture, this work lays the groundwork for future advancements in the field of AI interpretability and efficiency.