Papers
Topics
Authors
Recent
2000 character limit reached

Prisma: An Open Source Toolkit for Mechanistic Interpretability in Vision and Video (2504.19475v3)

Published 28 Apr 2025 in cs.CV, cs.AI, and cs.LG

Abstract: Robust tooling and publicly available pre-trained models have helped drive recent advances in mechanistic interpretability for LLMs. However, similar progress in vision mechanistic interpretability has been hindered by the lack of accessible frameworks and pre-trained weights. We present Prisma (Access the codebase here: https://github.com/Prisma-Multimodal/ViT-Prisma), an open-source framework designed to accelerate vision mechanistic interpretability research, providing a unified toolkit for accessing 75+ vision and video transformers; support for sparse autoencoder (SAE), transcoder, and crosscoder training; a suite of 80+ pre-trained SAE weights; activation caching, circuit analysis tools, and visualization tools; and educational resources. Our analysis reveals surprising findings, including that effective vision SAEs can exhibit substantially lower sparsity patterns than language SAEs, and that in some instances, SAE reconstructions can decrease model loss. Prisma enables new research directions for understanding vision model internals while lowering barriers to entry in this emerging field.

Summary

  • The paper introduces Prisma, an open-source toolkit that unifies tools and pre-trained weights to advance mechanistic interpretability in vision and video models.
  • The paper employs a hooked architecture via HookedViT to support over 75 transformer models, enabling efficient activation caching and intervention analyses.
  • The paper demonstrates that SAE reconstructions can reduce cross-entropy loss in deeper layers, revealing modality-specific differences in sparse coding.

Prisma is an open-source toolkit designed to accelerate mechanistic interpretability research specifically for vision and video models. It addresses the current gap in accessible frameworks and pre-trained weights that has hindered progress compared to the field of LLM interpretability. Prisma provides a unified platform integrating various tools and resources needed for scalable vision interpretability studies.

The toolkit centers around a HookedViT class, which adapts popular vision model repositories like timm, OpenCLIP, and Hugging Face to provide a unified interface for accessing and modifying activations and weights. This is similar to the functionality offered by TransformerLens for LLMs. This hooked architecture supports over 75 vision and video transformers, enabling researchers to easily perform downstream interpretability tasks such as activation caching and interventions. For video models, it includes support for 3D tubelet embeddings found in architectures like ViViT and V-JEPA.

Prisma offers comprehensive support for training and evaluating Sparse Autoencoders (SAEs) and their variants, Transcoders and Crosscoders, which are key techniques for decomposing dense model representations into sparser, potentially more interpretable features. The toolkit provides implementations for various sparse coder architectures, including ReLU, Top-K, JumpReLU, and Gated variants. The training process is highly configurable, allowing control over parameters like encoder/decoder initialization, activation normalization, ghost gradients (to prevent dead features), dead feature resampling, and early stopping. It supports both activation caching for speed and on-the-fly computation for VRAM optimization. Evaluation tools are included to measure metrics essential for assessing sparse coder quality, such as L0 sparsity, explained variance, reconstruction/substitution losses, and identifying maximally activating inputs for specific features.

A significant contribution of Prisma is the release of over 80 pre-trained SAE weights for CLIP-B and DINO-B Vision Transformers, covering all layers and various token configurations (all patches, CLS-only, spatial patches), as well as transcoders for CLIP. These pre-trained models provide a starting point for researchers without the need for extensive training computation. Training details for these released models include:

  • Expansion Factor: Typically x64 (e.g., mapping a 768-dim space to 49,152 features).
  • Weight Initialization: Encoder weights initialized as the transpose of decoder weights.
  • Training Data: ImageNet1k for one epoch.
  • Optimizer & Schedule: Adam optimizer with a learning rate sweep and cosine annealing with warmup.
  • L1 Coefficients: Swept over a wide range to find optimal sparsity/reconstruction trade-offs.
  • Auxiliary Loss: Ghost gradients utilized.

In addition to sparse coding, Prisma includes a suite of tools for circuit analysis, attention head visualization, and logit lens techniques. These tools facilitate understanding interactions between different model components. The toolkit supports ablations via forward pass interventions using the hooked model architecture. The HookedSAEViT module enables analysis by replacing layers with sparse coders, and the HookedViT module integrates with automatic circuit discovery algorithms like ACDC.

To lower the barrier to entry and enable research in compute-constrained environments, Prisma provides educational resources such as tutorial notebooks and documentation. It also includes pre-trained toy Vision Transformers on ImageNet with 1-4 layers, including attention-only variants, mirroring a successful strategy in language interpretability research.

The paper presents preliminary analyses conducted using Prisma, revealing surprising findings about vision SAEs compared to their language counterparts. One observation is that vision SAEs trained on CLIP-B/32 exhibit substantially higher L0 sparsity (around 500+ active features per patch) while maintaining comparable explained variance, in contrast to language SAEs which typically show L0 values below 100. Potential explanations explored include inherent differences in visual and linguistic information density, the granularity of image patches versus language tokens, and the specialization of different token types (CLS vs. spatial patches) which show varying feature utilization across layers (Figure \ref{fig:alive_features}). This suggests that optimal sparse coding techniques might differ significantly between modalities and even within token types in vision models.

Another unexpected finding is that injecting SAE reconstructions back into the model's forward pass can sometimes decrease the original model's cross-entropy loss. This effect was observed particularly in deeper layers for CLS tokens in certain CLIP and DINO SAEs (Figure \ref{fig:sae_performance}). This suggests a potential denoising effect of the SAE reconstruction, which aligns with prior observations but warrants further investigation into the underlying mechanism. This performance improvement was not consistently observed across all SAE types or token configurations.

The release includes detailed evaluation tables for the trained SAEs, showing metrics like explained variance, average L0 (overall and per CLS token), cosine similarity, reconstruction cosine similarity, cross-entropy loss, reconstruction cross-entropy loss, zero ablation cross-entropy loss, percentage of CE recovered, and percentage of alive features. These metrics provide a quantitative basis for evaluating the quality and behavior of the released sparse coders.

Here are the evaluation tables provided in the paper:

CLIP-ViT-B-32 SAEs

Vanilla SAEs (All Patches)

Layer Sublayer l1 coeff. \% Explained var. Avg L0 Avg CLS L0 Cos sim Recon cos sim CE Recon CE Zero abl CE \% CE recovered \% Alive features Model Link
0 mlp_out 1e-5 98.7 604.44 36.92 0.994 0.998 6.762 6.762 6.779 99.51 100 link
0 resid_post 1e-5 98.6 1110.9 40.46 0.993 0.988 6.762 6.763 6.908 99.23 100 link
1 mlp_out 1e-5 98.4 1476.8 97.82 0.992 0.994 6.762 6.762 6.889 99.40 100 link
1 resid_post 1e-5 98.3 1508.4 27.39 0.991 0.989 6.762 6.763 6.908 99.02 100 link
2 mlp_out 1e-5 98.0 1799.7 376.0 0.992 0.998 6.762 6.762 6.803 99.44 100 link
2 resid_post 5e-5 90.6 717.84 10.11 0.944 0.960 6.762 6.767 6.908 96.34 100 link
3 mlp_out 1e-5 98.1 1893.4 648.2 0.992 0.999 6.762 6.762 6.784 99.54 100 link
3 resid_post 1e-5 98.1 2053.9 77.90 0.989 0.996 6.762 6.762 6.908 99.79 100 link
4 mlp_out 1e-5 98.1 1901.2 1115.0 0.993 0.999 6.762 6.762 6.786 99.55 100 link
4 resid_post 1e-5 98.0 2068.3 156.7 0.989 0.997 6.762 6.762 6.908 99.74 100 link
5 mlp_out 1e-5 98.2 1761.5 1259.0 0.993 0.999 6.762 6.762 6.797 99.76 100 link
5 resid_post 1e-5 98.1 1953.8 228.5 0.990 0.997 6.762 6.762 6.908 99.80 100 link
6 mlp_out 1e-5 98.3 1598.0 1337.0 0.993 0.999 6.762 6.762 6.789 99.83 100 link
6 resid_post 1e-5 98.2 1717.5 321.3 0.991 0.996 6.762 6.762 6.908 99.93 100 link
7 mlp_out 1e-5 98.2 1535.3 1300.0 0.992 0.999 6.762 6.762 6.796 100.17 100 link
7 resid_post 1e-5 98.2 1688.4 494.3 0.991 0.995 6.762 6.761 6.908 100.24 100 link
8 mlp_out 1e-5 97.8 1074.5 1167.0 0.990 0.998 6.762 6.761 6.793 100.57 100 link
8 resid_post 1e-5 98.2 1570.8 791.3 0.991 0.992 6.762 6.761 6.908 100.41 100 link
9 mlp_out 1e-5 97.6 856.68 1076.0 0.989 0.998 6.762 6.762 6.792 100.28 100 link
9 resid_post 1e-5 98.2 1533.5 1053.0 0.991 0.989 6.762 6.761 6.908 100.32 100 link
10 mlp_out 1e-5 98.1 788.49 965.5 0.991 0.998 6.762 6.762 6.772 101.50 99.80 link
10 resid_post 1e-5 98.4 1292.6 1010.0 0.992 0.987 6.762 6.760 6.908 100.83 99.99 link
11 mlp_out 5e-5 89.7 748.14 745.5 0.972 0.993 6.762 6.759 6.768 135.77 100 link
11 resid_post 1e-5 98.4 1405.0 1189.0 0.993 0.987 6.762 6.765 6.908 98.03 99.99 link

Vanilla SAEs (CLS only)

Layer Sublayer l1 coeff. \% Explained var. Avg CLS L0 Cos sim Recon cos sim CE Recon CE Zero abl CE \% CE recovered \% Alive features Model Link
0 resid_post 2e-8 82 934.83 0.98008 0.99995 6.7622 6.7622 6.9084 99.9984 4.33 link
1 resid_post 8e-6 85 314.13 0.97211 0.99994 6.7622 6.7622 6.9083 100.00 2.82 link
2 resid_post 9e-8 96 711.84 0.98831 0.99997 6.7622 6.7622 6.9083 99.9977 2.54 link
3 resid_post 1e-8 95 687.41 0.98397 0.99994 6.7622 6.7622 6.9085 99.9998 4.49 link
4 resid_post 9e-8 95 681.08 0.98092 0.99988 6.7622 6.7622 6.9082 100.00 15.75 link
5 resid_post 1e-7 94 506.77 0.97404 0.99966 6.7622 6.7622 6.9081 99.9911 16.80 link
6 resid_post 1e-8 92 423.70 0.96474 0.99913 6.7622 6.7622 6.9083 99.9971 29.46 link
7 resid_post 2e-6 88 492.68 0.93899 0.99737 6.7622 6.7622 6.9082 99.9583 51.68 link
8 resid_post 4e-8 76 623.01 0.89168 0.99110 6.7622 6.7625 6.9087 99.7631 82.07 link
9 resid_post 1e-12 74 521.90 0.87076 0.98191 6.7622 6.7628 6.9083 99.5425 93.68 link
10 resid_post 3e-7 74 533.94 0.87646 0.96514 6.7622 6.7635 6.9082 99.1070 99.98 link
11 resid_post 1e-8 65 386.09 0.81890 0.89607 6.7622 6.7853 6.9086 84.1918 99.996 link

Top K SAEs (CLS only, k=64k=64)

Layer Sublayer \% Explained var. Avg CLS L0 Cos sim Recon cos sim CE Recon CE Zero abl CE \% CE recovered \% Alive features Model Link
0 resid_post 90 64 0.98764 0.99998 6.7622 6.7622 6.9084 99.995 46.80 link
1 resid_post 96 64 0.99429 0.99999 6.7622 6.7622 6.9083 100.00 4.86 link
2 resid_post 96 64 0.99000 0.99998 6.7622 6.7622 6.9083 100.00 5.50 link
3 resid_post 95 64 0.98403 0.99995 6.7622 6.7622 6.9085 100.00 5.21 link
4 resid_post 94 64 0.97485 0.99986 6.7621 6.7622 6.9082 99.998 6.81 link
5 resid_post 93 64 0.96985 0.99962 6.7622 6.7622 6.9081 99.997 21.89 link
6 resid_post 92 64 0.96401 0.99912 6.7622 6.7622 6.9083 100.00 28.81 link
7 resid_post 90 64 0.95057 0.99797 6.7622 6.7621 6.9082 100.03 65.84 link
8 resid_post 87 64 0.93029 0.99475 6.7622 6.7620 6.9087 100.11 93.75 link
9 resid_post 85 64 0.91814 0.98865 6.7622 6.7616 6.9083 100.43 98.90 link
10 resid_post 86 64 0.93072 0.97929 6.7622 6.7604 6.9082 101.19 94.55 link
11 resid_post 84 64 0.91880 0.94856 6.7622 6.7578 6.9086 102.97 97.99 link

Vanilla SAEs (Spatial Patches)

Layer Sublayer l1 coeff. \% Explained var. Avg L0 Cos sim Recon cos sim CE Recon CE Zero abl CE \% CE recovered \% Alive features Model Link
0 resid_post 1e-12 99 989.19 0.99 0.99 6.7621 6.7621 6.9084 99.9981 100.00 link
1 resid_post 3e-11 99 757.83 0.99 0.99 6.7622 6.7622 6.9083 99.9969 45.39 link
2 resid_post 4e-12 99 1007.89 0.99 0.99 6.7622 6.7622 6.9083 100.00 97.93 link
3 resid_post 2e-8 99 935.06 0.99 0.99 6.7622 6.7622 6.9085 99.9882 100.00 link
4 resid_post 3e-8 99 965.15 0.99 0.99 6.7622 6.7622 6.9082 99.9842 100.00 link
5 resid_post 1e-8 99 966.38 0.99 0.99 6.7622 6.7622 6.9081 99.9961 100.00 link
6 resid_post 1e-8 99 1006.62 0.99 0.99 6.7622 6.7622 6.9083 100.00 99.97 link
7 resid_post 1e-8 99 984.19 0.99 0.99 6.7622 6.7622 6.9082 100.00 100.00 link
8 resid_post 3e-8 99 965.12 0.99 1.00 6.7622 6.7622 6.9087 100.00 92.37 link
9 resid_post 9e-8 99 854.92 0.99 1.00 6.7622 6.7622 6.9083 99.9991 85.43 link
10 resid_post 1e-4 72 88.80 0.84 0.97 6.7621 6.7638 6.9082 98.85 100.00 link
11 resid_post 3e-7 99 829.09 0.99 1.00 6.7622 6.7622 6.9086 100.00 55.71 link

Top K Transcoders (All Patches)

Layer Block \% Explained var. k Avg CLS L0 Cos sim CE Recon CE Zero abl CE \% CE recovered Model Link
0 MLP 96 768 767 0.9655 6.7621 6.7684 6.8804 94.68 link
1 MLP 94 256 255 0.9406 6.7621 6.7767 6.8816 87.78 link
2 MLP 93 1024 475 0.9758 6.7621 6.7681 6.7993 83.92 link
3 MLP 90 1024 825 0.9805 6.7621 6.7642 6.7999 94.42 link
4 MLP 76 512 29 0.9830 6.7621 6.7636 6.8080 96.76 link
5 MLP 91 1024 1017 0.9784 6.7621 6.7643 6.8296 96.82 link
6 MLP 94 1024 924 0.9756 6.7621 6.7630 6.8201 98.40 link
7 MLP 97 1024 1010 0.9629 6.7621 6.7631 6.8056 97.68 link
8 MLP 98 1024 1023 0.9460 6.7621 6.7630 6.8017 97.70 link
9 MLP 98 1024 1023 0.9221 6.7621 6.7630 6.7875 96.50 link
10 MLP 97 1024 1019 0.9334 6.7621 6.7636 6.7860 93.95 link

DINO-B SAEs

Vanilla (All patches)

Layer Sublayer Avg L0. \% Explained var. Avg CLS L0 Cos sim CE Recon CE Zero abl CE \% CE Recovered Model Link
0 resid_post 507 98 347 0.95009 1.885033 1.936518 7.2714 99.04 link
1 resid_post 549 95 959 0.93071 1.885100 1.998274 7.2154 97.88 link
2 resid_post 812 95 696 0.9560 1.885134 2.006115 7.201461 97.72 link
3 resid_post 989 95 616 0.96315 1.885131 1.961913 7.2068 98.56 link
4 resid_post 876 99 845 0.99856 1.885224 1.883169 7.1636 100.04 link
5 resid_post 1001 98 889 0.99129 1.885353 1.875520 7.1412 100.19 link
6 resid_post 962 99 950 0.99945 1.885239 1.872594 7.1480 100.24 link
7 resid_post 1086 98 1041 0.99341 1.885371 1.869443 7.1694 100.30 link
8 resid_post 530 90 529 0.9475 1.885511 1.978638 7.1315 98.22 link
9 resid_post 1105 99 1090 0.99541 1.885341 1.894026 7.0781 99.83 link
10 resid_post 835 99 839 0.99987 1.885371 1.884487 7.3606 100.02 link
11 resid_post 1085 99 1084 0.99673 1.885370 1.911608 6.9078 99.48 link

Prisma aims to be a foundational tool for the emerging field of vision and video mechanistic interpretability, providing the necessary components and pre-trained resources to accelerate understanding of these complex models.

Whiteboard

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 3 tweets with 13 likes about this paper.