- The paper presents AtMan, a method that explains transformer predictions by perturbing attention scores during the forward pass to assess token impact without backpropagation.
- It introduces correlated token suppression using cosine similarity to aggregate semantically related tokens, enhancing explanation quality for complex inputs.
- Empirical results on text and image-text tasks show that AtMan achieves competitive accuracy with significantly lower memory usage than traditional gradient-based approaches.
Explaining predictions of large generative transformer models is challenging due to their complexity and computational requirements. Traditional explanation methods, particularly gradient-based ones like Integrated Gradients (IG), Input x Gradient (IxG), and methods specifically for transformers like Chefer's method, are often memory-intensive because they rely on backpropagation. This severely limits their applicability in production environments, especially for large models that already push hardware limits during inference.
The paper "AtMan: Understanding Transformer Predictions Through Memory Efficient Attention Manipulation" (2301.08110) introduces AtMan, a novel explanation method designed to address this memory bottleneck while providing competitive explanation quality. AtMan operates by manipulating the attention mechanism during the forward pass, avoiding the memory overhead associated with backpropagation.
Core Concept: Attention Manipulation as Perturbation
AtMan reformulates the task of finding important input parts by studying how perturbing the input affects the model's output loss for a specific target. Instead of directly perturbing the raw input (which is high-dimensional) or relying on gradients computed during the backward pass, AtMan applies perturbations within the model's latent space by modifying attention scores.
The core idea is that the influence of an input token on the output prediction can be approximated by observing the change in the model's loss for the target when the attention weights related to that token are suppressed. This is based on approximating influence functions by leaving out parts of the input conceptually.
Specifically, for an input sequence of tokens, AtMan iterates through each input token. For a chosen token i, it modifies the pre-softmax attention scores (H) in all attention layers and heads. The modification involves multiplying the column of H corresponding to token i by a factor (1−f), where f is a suppression factor (e.g., f=0.9 as found empirically). This effectively reduces the influence of token i on subsequent token predictions within the sequence.
The explanation score for token i with respect to a target output (e.g., the next predicted token or a sequence of target tokens) is then calculated as the difference in the model's loss for the target between the unmodified forward pass and the forward pass with token i's attention suppressed:
Explanation(wi,target)=Ltarget(w,θ−i)−Ltarget(w,θ)
where w is the input sequence, θ are the model parameters, θ−i denotes the model with attention for token i suppressed, and Ltarget is the loss function (e.g., cross-entropy) for the target tokens. A positive difference indicates that suppressing token i increases the loss for the target, meaning token i is important for predicting the target.
Correlated Token Attention Manipulation
For modalities where information is spread across multiple input tokens (like image patches in Vision Transformers), suppressing a single token might not capture the full influence of a concept. To address this, AtMan introduces correlated token suppression. It uses cosine similarity in the initial token embedding space to identify tokens that are semantically related. When explaining a specific token i, instead of suppressing only token i's attention column, AtMan suppresses the attention columns of all tokens k whose embeddings have a cosine similarity with token i's embedding above a certain threshold (κ, empirically set to 0.7). The suppression factor for token k related to token i is scaled by their similarity si,k, interpolating between the base suppression (1−f) and no suppression (1) based on the similarity score si,k (for si,k≥κ). This allows AtMan to highlight concepts spread across multiple input tokens.
Implementation and Practical Advantages
Implementing AtMan requires access to the model's attention layers during the forward pass. The core modification is applied to the attention scores before the softmax operation and causal masking.
The key practical benefits of AtMan stem from its design:
- Memory Efficiency: By avoiding backpropagation, AtMan's memory footprint is similar to a standard forward pass. This allows it to be applied to very large models (tested up to 30B parameters) on hardware where gradient-based methods (which can require double the memory) would fail (e.g., exceeding 80GB GPU memory limits).
- Scalability: While the basic approach requires a separate forward pass for each input token to be explained (or each group of correlated tokens), these forward passes are independent and can be run in parallel across multiple GPUs or within larger batches. Token aggregation (e.g., explaining paragraphs instead of individual words) further reduces the number of required passes. This makes the effective runtime manageable even for long sequences in deployed systems.
- Modality and Architecture Agnosticism: AtMan's mechanism works by manipulating the standard transformer attention mechanism, making it applicable to various transformer architectures (decoder-only like GPT-J and MAGMA, encoder-decoder like BLIP) and modalities (text, image-text) as long as the input is represented as a sequence of tokens/embeddings processed by attention layers.
Empirical Evaluation
The paper evaluates AtMan on text (SQuAD QA) and image-text (OpenImages VQA) tasks using generative transformer models (GPT-J, MAGMA, BLIP). The evaluation compares AtMan to gradient-based methods (IG, IxG, GradCAM, Chefer's method) using metrics like mean average precision (mAP) and mean average recall (mAR), treating ground-truth explanations (QA spans, image segmentations) as binary labels.
Results show that AtMan consistently outperforms previous gradient-based methods on both text and image-text benchmarks in terms of mAP and often mAR. Qualitatively, AtMan's explanations are often less noisy than gradient-based methods, especially around image object boundaries.
The memory and runtime analysis demonstrates AtMan's superior memory efficiency, staying close to baseline forward pass memory usage across different model sizes and sequence lengths, while gradient methods quickly exceed memory limits. Although sequential runtime increases with sequence length, its parallelizability makes it suitable for deployment.
Limitations and Future Work
While effective, AtMan requires tuning hyperparameters like the suppression factor and cosine similarity threshold. The interpretation of "explanation" depends on the chosen target tokens. Future work includes exploring alternative attention manipulation strategies, identifying the influence of specific layers, investigating hierarchical token aggregation, and applying AtMan to paper the explanatory capabilities of larger, more complex generative models.
In summary, AtMan offers a practical, memory-efficient approach to explaining predictions of large generative transformers by manipulating attention scores during the forward pass. Its performance and scalability make it a viable tool for understanding these complex models in real-world applications and research.