- The paper introduces Prisma, an open-source toolkit that unifies tools and pre-trained weights to advance mechanistic interpretability in vision and video models.
- The paper employs a hooked architecture via HookedViT to support over 75 transformer models, enabling efficient activation caching and intervention analyses.
- The paper demonstrates that SAE reconstructions can reduce cross-entropy loss in deeper layers, revealing modality-specific differences in sparse coding.
Prisma is an open-source toolkit designed to accelerate mechanistic interpretability research specifically for vision and video models. It addresses the current gap in accessible frameworks and pre-trained weights that has hindered progress compared to the field of LLM interpretability. Prisma provides a unified platform integrating various tools and resources needed for scalable vision interpretability studies.
The toolkit centers around a HookedViT class, which adapts popular vision model repositories like timm, OpenCLIP, and Hugging Face to provide a unified interface for accessing and modifying activations and weights. This is similar to the functionality offered by TransformerLens for LLMs. This hooked architecture supports over 75 vision and video transformers, enabling researchers to easily perform downstream interpretability tasks such as activation caching and interventions. For video models, it includes support for 3D tubelet embeddings found in architectures like ViViT and V-JEPA.
Prisma offers comprehensive support for training and evaluating Sparse Autoencoders (SAEs) and their variants, Transcoders and Crosscoders, which are key techniques for decomposing dense model representations into sparser, potentially more interpretable features. The toolkit provides implementations for various sparse coder architectures, including ReLU, Top-K, JumpReLU, and Gated variants. The training process is highly configurable, allowing control over parameters like encoder/decoder initialization, activation normalization, ghost gradients (to prevent dead features), dead feature resampling, and early stopping. It supports both activation caching for speed and on-the-fly computation for VRAM optimization. Evaluation tools are included to measure metrics essential for assessing sparse coder quality, such as L0 sparsity, explained variance, reconstruction/substitution losses, and identifying maximally activating inputs for specific features.
A significant contribution of Prisma is the release of over 80 pre-trained SAE weights for CLIP-B and DINO-B Vision Transformers, covering all layers and various token configurations (all patches, CLS-only, spatial patches), as well as transcoders for CLIP. These pre-trained models provide a starting point for researchers without the need for extensive training computation. Training details for these released models include:
- Expansion Factor: Typically x64 (e.g., mapping a 768-dim space to 49,152 features).
- Weight Initialization: Encoder weights initialized as the transpose of decoder weights.
- Training Data: ImageNet1k for one epoch.
- Optimizer & Schedule: Adam optimizer with a learning rate sweep and cosine annealing with warmup.
- L1 Coefficients: Swept over a wide range to find optimal sparsity/reconstruction trade-offs.
- Auxiliary Loss: Ghost gradients utilized.
In addition to sparse coding, Prisma includes a suite of tools for circuit analysis, attention head visualization, and logit lens techniques. These tools facilitate understanding interactions between different model components. The toolkit supports ablations via forward pass interventions using the hooked model architecture. The HookedSAEViT module enables analysis by replacing layers with sparse coders, and the HookedViT module integrates with automatic circuit discovery algorithms like ACDC.
To lower the barrier to entry and enable research in compute-constrained environments, Prisma provides educational resources such as tutorial notebooks and documentation. It also includes pre-trained toy Vision Transformers on ImageNet with 1-4 layers, including attention-only variants, mirroring a successful strategy in language interpretability research.
The paper presents preliminary analyses conducted using Prisma, revealing surprising findings about vision SAEs compared to their language counterparts. One observation is that vision SAEs trained on CLIP-B/32 exhibit substantially higher L0 sparsity (around 500+ active features per patch) while maintaining comparable explained variance, in contrast to language SAEs which typically show L0 values below 100. Potential explanations explored include inherent differences in visual and linguistic information density, the granularity of image patches versus language tokens, and the specialization of different token types (CLS vs. spatial patches) which show varying feature utilization across layers (Figure \ref{fig:alive_features}). This suggests that optimal sparse coding techniques might differ significantly between modalities and even within token types in vision models.
Another unexpected finding is that injecting SAE reconstructions back into the model's forward pass can sometimes decrease the original model's cross-entropy loss. This effect was observed particularly in deeper layers for CLS tokens in certain CLIP and DINO SAEs (Figure \ref{fig:sae_performance}). This suggests a potential denoising effect of the SAE reconstruction, which aligns with prior observations but warrants further investigation into the underlying mechanism. This performance improvement was not consistently observed across all SAE types or token configurations.
The release includes detailed evaluation tables for the trained SAEs, showing metrics like explained variance, average L0 (overall and per CLS token), cosine similarity, reconstruction cosine similarity, cross-entropy loss, reconstruction cross-entropy loss, zero ablation cross-entropy loss, percentage of CE recovered, and percentage of alive features. These metrics provide a quantitative basis for evaluating the quality and behavior of the released sparse coders.
Here are the evaluation tables provided in the paper:
CLIP-ViT-B-32 SAEs
Vanilla SAEs (All Patches)
| Layer |
Sublayer |
l1 coeff. |
\% Explained var. |
Avg L0 |
Avg CLS L0 |
Cos sim |
Recon cos sim |
CE |
Recon CE |
Zero abl CE |
\% CE recovered |
\% Alive features |
Model Link |
| 0 |
mlp_out |
1e-5 |
98.7 |
604.44 |
36.92 |
0.994 |
0.998 |
6.762 |
6.762 |
6.779 |
99.51 |
100 |
link |
| 0 |
resid_post |
1e-5 |
98.6 |
1110.9 |
40.46 |
0.993 |
0.988 |
6.762 |
6.763 |
6.908 |
99.23 |
100 |
link |
| 1 |
mlp_out |
1e-5 |
98.4 |
1476.8 |
97.82 |
0.992 |
0.994 |
6.762 |
6.762 |
6.889 |
99.40 |
100 |
link |
| 1 |
resid_post |
1e-5 |
98.3 |
1508.4 |
27.39 |
0.991 |
0.989 |
6.762 |
6.763 |
6.908 |
99.02 |
100 |
link |
| 2 |
mlp_out |
1e-5 |
98.0 |
1799.7 |
376.0 |
0.992 |
0.998 |
6.762 |
6.762 |
6.803 |
99.44 |
100 |
link |
| 2 |
resid_post |
5e-5 |
90.6 |
717.84 |
10.11 |
0.944 |
0.960 |
6.762 |
6.767 |
6.908 |
96.34 |
100 |
link |
| 3 |
mlp_out |
1e-5 |
98.1 |
1893.4 |
648.2 |
0.992 |
0.999 |
6.762 |
6.762 |
6.784 |
99.54 |
100 |
link |
| 3 |
resid_post |
1e-5 |
98.1 |
2053.9 |
77.90 |
0.989 |
0.996 |
6.762 |
6.762 |
6.908 |
99.79 |
100 |
link |
| 4 |
mlp_out |
1e-5 |
98.1 |
1901.2 |
1115.0 |
0.993 |
0.999 |
6.762 |
6.762 |
6.786 |
99.55 |
100 |
link |
| 4 |
resid_post |
1e-5 |
98.0 |
2068.3 |
156.7 |
0.989 |
0.997 |
6.762 |
6.762 |
6.908 |
99.74 |
100 |
link |
| 5 |
mlp_out |
1e-5 |
98.2 |
1761.5 |
1259.0 |
0.993 |
0.999 |
6.762 |
6.762 |
6.797 |
99.76 |
100 |
link |
| 5 |
resid_post |
1e-5 |
98.1 |
1953.8 |
228.5 |
0.990 |
0.997 |
6.762 |
6.762 |
6.908 |
99.80 |
100 |
link |
| 6 |
mlp_out |
1e-5 |
98.3 |
1598.0 |
1337.0 |
0.993 |
0.999 |
6.762 |
6.762 |
6.789 |
99.83 |
100 |
link |
| 6 |
resid_post |
1e-5 |
98.2 |
1717.5 |
321.3 |
0.991 |
0.996 |
6.762 |
6.762 |
6.908 |
99.93 |
100 |
link |
| 7 |
mlp_out |
1e-5 |
98.2 |
1535.3 |
1300.0 |
0.992 |
0.999 |
6.762 |
6.762 |
6.796 |
100.17 |
100 |
link |
| 7 |
resid_post |
1e-5 |
98.2 |
1688.4 |
494.3 |
0.991 |
0.995 |
6.762 |
6.761 |
6.908 |
100.24 |
100 |
link |
| 8 |
mlp_out |
1e-5 |
97.8 |
1074.5 |
1167.0 |
0.990 |
0.998 |
6.762 |
6.761 |
6.793 |
100.57 |
100 |
link |
| 8 |
resid_post |
1e-5 |
98.2 |
1570.8 |
791.3 |
0.991 |
0.992 |
6.762 |
6.761 |
6.908 |
100.41 |
100 |
link |
| 9 |
mlp_out |
1e-5 |
97.6 |
856.68 |
1076.0 |
0.989 |
0.998 |
6.762 |
6.762 |
6.792 |
100.28 |
100 |
link |
| 9 |
resid_post |
1e-5 |
98.2 |
1533.5 |
1053.0 |
0.991 |
0.989 |
6.762 |
6.761 |
6.908 |
100.32 |
100 |
link |
| 10 |
mlp_out |
1e-5 |
98.1 |
788.49 |
965.5 |
0.991 |
0.998 |
6.762 |
6.762 |
6.772 |
101.50 |
99.80 |
link |
| 10 |
resid_post |
1e-5 |
98.4 |
1292.6 |
1010.0 |
0.992 |
0.987 |
6.762 |
6.760 |
6.908 |
100.83 |
99.99 |
link |
| 11 |
mlp_out |
5e-5 |
89.7 |
748.14 |
745.5 |
0.972 |
0.993 |
6.762 |
6.759 |
6.768 |
135.77 |
100 |
link |
| 11 |
resid_post |
1e-5 |
98.4 |
1405.0 |
1189.0 |
0.993 |
0.987 |
6.762 |
6.765 |
6.908 |
98.03 |
99.99 |
link |
Vanilla SAEs (CLS only)
| Layer |
Sublayer |
l1 coeff. |
\% Explained var. |
Avg CLS L0 |
Cos sim |
Recon cos sim |
CE |
Recon CE |
Zero abl CE |
\% CE recovered |
\% Alive features |
Model Link |
| 0 |
resid_post |
2e-8 |
82 |
934.83 |
0.98008 |
0.99995 |
6.7622 |
6.7622 |
6.9084 |
99.9984 |
4.33 |
link |
| 1 |
resid_post |
8e-6 |
85 |
314.13 |
0.97211 |
0.99994 |
6.7622 |
6.7622 |
6.9083 |
100.00 |
2.82 |
link |
| 2 |
resid_post |
9e-8 |
96 |
711.84 |
0.98831 |
0.99997 |
6.7622 |
6.7622 |
6.9083 |
99.9977 |
2.54 |
link |
| 3 |
resid_post |
1e-8 |
95 |
687.41 |
0.98397 |
0.99994 |
6.7622 |
6.7622 |
6.9085 |
99.9998 |
4.49 |
link |
| 4 |
resid_post |
9e-8 |
95 |
681.08 |
0.98092 |
0.99988 |
6.7622 |
6.7622 |
6.9082 |
100.00 |
15.75 |
link |
| 5 |
resid_post |
1e-7 |
94 |
506.77 |
0.97404 |
0.99966 |
6.7622 |
6.7622 |
6.9081 |
99.9911 |
16.80 |
link |
| 6 |
resid_post |
1e-8 |
92 |
423.70 |
0.96474 |
0.99913 |
6.7622 |
6.7622 |
6.9083 |
99.9971 |
29.46 |
link |
| 7 |
resid_post |
2e-6 |
88 |
492.68 |
0.93899 |
0.99737 |
6.7622 |
6.7622 |
6.9082 |
99.9583 |
51.68 |
link |
| 8 |
resid_post |
4e-8 |
76 |
623.01 |
0.89168 |
0.99110 |
6.7622 |
6.7625 |
6.9087 |
99.7631 |
82.07 |
link |
| 9 |
resid_post |
1e-12 |
74 |
521.90 |
0.87076 |
0.98191 |
6.7622 |
6.7628 |
6.9083 |
99.5425 |
93.68 |
link |
| 10 |
resid_post |
3e-7 |
74 |
533.94 |
0.87646 |
0.96514 |
6.7622 |
6.7635 |
6.9082 |
99.1070 |
99.98 |
link |
| 11 |
resid_post |
1e-8 |
65 |
386.09 |
0.81890 |
0.89607 |
6.7622 |
6.7853 |
6.9086 |
84.1918 |
99.996 |
link |
Top K SAEs (CLS only, k=64)
| Layer |
Sublayer |
\% Explained var. |
Avg CLS L0 |
Cos sim |
Recon cos sim |
CE |
Recon CE |
Zero abl CE |
\% CE recovered |
\% Alive features |
Model Link |
| 0 |
resid_post |
90 |
64 |
0.98764 |
0.99998 |
6.7622 |
6.7622 |
6.9084 |
99.995 |
46.80 |
link |
| 1 |
resid_post |
96 |
64 |
0.99429 |
0.99999 |
6.7622 |
6.7622 |
6.9083 |
100.00 |
4.86 |
link |
| 2 |
resid_post |
96 |
64 |
0.99000 |
0.99998 |
6.7622 |
6.7622 |
6.9083 |
100.00 |
5.50 |
link |
| 3 |
resid_post |
95 |
64 |
0.98403 |
0.99995 |
6.7622 |
6.7622 |
6.9085 |
100.00 |
5.21 |
link |
| 4 |
resid_post |
94 |
64 |
0.97485 |
0.99986 |
6.7621 |
6.7622 |
6.9082 |
99.998 |
6.81 |
link |
| 5 |
resid_post |
93 |
64 |
0.96985 |
0.99962 |
6.7622 |
6.7622 |
6.9081 |
99.997 |
21.89 |
link |
| 6 |
resid_post |
92 |
64 |
0.96401 |
0.99912 |
6.7622 |
6.7622 |
6.9083 |
100.00 |
28.81 |
link |
| 7 |
resid_post |
90 |
64 |
0.95057 |
0.99797 |
6.7622 |
6.7621 |
6.9082 |
100.03 |
65.84 |
link |
| 8 |
resid_post |
87 |
64 |
0.93029 |
0.99475 |
6.7622 |
6.7620 |
6.9087 |
100.11 |
93.75 |
link |
| 9 |
resid_post |
85 |
64 |
0.91814 |
0.98865 |
6.7622 |
6.7616 |
6.9083 |
100.43 |
98.90 |
link |
| 10 |
resid_post |
86 |
64 |
0.93072 |
0.97929 |
6.7622 |
6.7604 |
6.9082 |
101.19 |
94.55 |
link |
| 11 |
resid_post |
84 |
64 |
0.91880 |
0.94856 |
6.7622 |
6.7578 |
6.9086 |
102.97 |
97.99 |
link |
Vanilla SAEs (Spatial Patches)
| Layer |
Sublayer |
l1 coeff. |
\% Explained var. |
Avg L0 |
Cos sim |
Recon cos sim |
CE |
Recon CE |
Zero abl CE |
\% CE recovered |
\% Alive features |
Model Link |
| 0 |
resid_post |
1e-12 |
99 |
989.19 |
0.99 |
0.99 |
6.7621 |
6.7621 |
6.9084 |
99.9981 |
100.00 |
link |
| 1 |
resid_post |
3e-11 |
99 |
757.83 |
0.99 |
0.99 |
6.7622 |
6.7622 |
6.9083 |
99.9969 |
45.39 |
link |
| 2 |
resid_post |
4e-12 |
99 |
1007.89 |
0.99 |
0.99 |
6.7622 |
6.7622 |
6.9083 |
100.00 |
97.93 |
link |
| 3 |
resid_post |
2e-8 |
99 |
935.06 |
0.99 |
0.99 |
6.7622 |
6.7622 |
6.9085 |
99.9882 |
100.00 |
link |
| 4 |
resid_post |
3e-8 |
99 |
965.15 |
0.99 |
0.99 |
6.7622 |
6.7622 |
6.9082 |
99.9842 |
100.00 |
link |
| 5 |
resid_post |
1e-8 |
99 |
966.38 |
0.99 |
0.99 |
6.7622 |
6.7622 |
6.9081 |
99.9961 |
100.00 |
link |
| 6 |
resid_post |
1e-8 |
99 |
1006.62 |
0.99 |
0.99 |
6.7622 |
6.7622 |
6.9083 |
100.00 |
99.97 |
link |
| 7 |
resid_post |
1e-8 |
99 |
984.19 |
0.99 |
0.99 |
6.7622 |
6.7622 |
6.9082 |
100.00 |
100.00 |
link |
| 8 |
resid_post |
3e-8 |
99 |
965.12 |
0.99 |
1.00 |
6.7622 |
6.7622 |
6.9087 |
100.00 |
92.37 |
link |
| 9 |
resid_post |
9e-8 |
99 |
854.92 |
0.99 |
1.00 |
6.7622 |
6.7622 |
6.9083 |
99.9991 |
85.43 |
link |
| 10 |
resid_post |
1e-4 |
72 |
88.80 |
0.84 |
0.97 |
6.7621 |
6.7638 |
6.9082 |
98.85 |
100.00 |
link |
| 11 |
resid_post |
3e-7 |
99 |
829.09 |
0.99 |
1.00 |
6.7622 |
6.7622 |
6.9086 |
100.00 |
55.71 |
link |
Top K Transcoders (All Patches)
| Layer |
Block |
\% Explained var. |
k |
Avg CLS L0 |
Cos sim |
CE |
Recon CE |
Zero abl CE |
\% CE recovered |
Model Link |
| 0 |
MLP |
96 |
768 |
767 |
0.9655 |
6.7621 |
6.7684 |
6.8804 |
94.68 |
link |
| 1 |
MLP |
94 |
256 |
255 |
0.9406 |
6.7621 |
6.7767 |
6.8816 |
87.78 |
link |
| 2 |
MLP |
93 |
1024 |
475 |
0.9758 |
6.7621 |
6.7681 |
6.7993 |
83.92 |
link |
| 3 |
MLP |
90 |
1024 |
825 |
0.9805 |
6.7621 |
6.7642 |
6.7999 |
94.42 |
link |
| 4 |
MLP |
76 |
512 |
29 |
0.9830 |
6.7621 |
6.7636 |
6.8080 |
96.76 |
link |
| 5 |
MLP |
91 |
1024 |
1017 |
0.9784 |
6.7621 |
6.7643 |
6.8296 |
96.82 |
link |
| 6 |
MLP |
94 |
1024 |
924 |
0.9756 |
6.7621 |
6.7630 |
6.8201 |
98.40 |
link |
| 7 |
MLP |
97 |
1024 |
1010 |
0.9629 |
6.7621 |
6.7631 |
6.8056 |
97.68 |
link |
| 8 |
MLP |
98 |
1024 |
1023 |
0.9460 |
6.7621 |
6.7630 |
6.8017 |
97.70 |
link |
| 9 |
MLP |
98 |
1024 |
1023 |
0.9221 |
6.7621 |
6.7630 |
6.7875 |
96.50 |
link |
| 10 |
MLP |
97 |
1024 |
1019 |
0.9334 |
6.7621 |
6.7636 |
6.7860 |
93.95 |
link |
DINO-B SAEs
Vanilla (All patches)
| Layer |
Sublayer |
Avg L0. |
\% Explained var. |
Avg CLS L0 |
Cos sim |
CE |
Recon CE |
Zero abl CE |
\% CE Recovered |
Model Link |
| 0 |
resid_post |
507 |
98 |
347 |
0.95009 |
1.885033 |
1.936518 |
7.2714 |
99.04 |
link |
| 1 |
resid_post |
549 |
95 |
959 |
0.93071 |
1.885100 |
1.998274 |
7.2154 |
97.88 |
link |
| 2 |
resid_post |
812 |
95 |
696 |
0.9560 |
1.885134 |
2.006115 |
7.201461 |
97.72 |
link |
| 3 |
resid_post |
989 |
95 |
616 |
0.96315 |
1.885131 |
1.961913 |
7.2068 |
98.56 |
link |
| 4 |
resid_post |
876 |
99 |
845 |
0.99856 |
1.885224 |
1.883169 |
7.1636 |
100.04 |
link |
| 5 |
resid_post |
1001 |
98 |
889 |
0.99129 |
1.885353 |
1.875520 |
7.1412 |
100.19 |
link |
| 6 |
resid_post |
962 |
99 |
950 |
0.99945 |
1.885239 |
1.872594 |
7.1480 |
100.24 |
link |
| 7 |
resid_post |
1086 |
98 |
1041 |
0.99341 |
1.885371 |
1.869443 |
7.1694 |
100.30 |
link |
| 8 |
resid_post |
530 |
90 |
529 |
0.9475 |
1.885511 |
1.978638 |
7.1315 |
98.22 |
link |
| 9 |
resid_post |
1105 |
99 |
1090 |
0.99541 |
1.885341 |
1.894026 |
7.0781 |
99.83 |
link |
| 10 |
resid_post |
835 |
99 |
839 |
0.99987 |
1.885371 |
1.884487 |
7.3606 |
100.02 |
link |
| 11 |
resid_post |
1085 |
99 |
1084 |
0.99673 |
1.885370 |
1.911608 |
6.9078 |
99.48 |
link |
Prisma aims to be a foundational tool for the emerging field of vision and video mechanistic interpretability, providing the necessary components and pre-trained resources to accelerate understanding of these complex models.