Contextual Decomposition for Transformers
- The paper’s main contribution is CD-T’s principled framework for additive decomposition of Transformer activations, providing precise attributions.
- CD-T defines module-specific rules to partition hidden states into relevant and irrelevant parts, leveraging linearity and averaging-based nonlinear rules.
- CD-T facilitates efficient circuit discovery by extracting sparse attention-head subgraphs that preserve a significant fraction of model performance.
Contextual Decomposition for Transformers (CD-T) is a principled family of attribution and circuit-discovery methodologies that partition hidden activations and outputs in Transformer models into additive contributions from interpretable sources. CD-T generalizes contextual decomposition approaches from recurrent and feedforward networks to the parallel architecture and attention mechanisms of Transformers, delivering fine-grained, mathematically consistent attributions and enabling efficient, interpretable circuit extraction for mechanistic interpretability.
1. Formal Definition and Mathematical Framework
CD-T views the computation within a Transformer as a directed acyclic graph (DAG), whose nodes are activations and whose edges correspond to elementary operations (linear transformations, nonlinearities, attention, normalization, etc.) (Hsu et al., 2024). At each node, the activation vector is conceptually split into “relevant” and “irrelevant” summands corresponding to different sources:
For any module , CD-T propagates this decomposition (using module-specific rules) such that
where is the output originating from the relevant part, and from the irrelevant. In linear modules, this propagation is trivial: splits linearly (, ). For nonlinearities such as ReLU, the decomposition uses averaging-based rules introduced by Murdoch et al. (2018):
Layer normalization, skip connections, and other common modules are handled analogously via summation and linearity.
The central contribution in CD-T is the handling of multi-head self-attention. Given token representations , each split as , an attention head computes queries, keys, values , attention weights
and outputs
CD-T propagates only the decomposition of the value vectors, not the weights. Write , then
By consistently applying these propagation rules throughout the computation graph, CD-T achieves an additive decomposition of final outputs (e.g., logits) into precise per-source contributions, supporting arbitrary granularity: tokens, heads, or intermediate features (Hsu et al., 2024, Oh et al., 2023, Modarressi et al., 2023).
2. Algorithmic Protocols and Pseudocode
Contextual Decomposition for Transformers can be implemented efficiently as a single augmented forward pass through the Transformer, with the following core algorithmic protocol (Hsu et al., 2024):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
for node in nodes(M): a[node] = compute_activation(node, inputs) for node in nodes(M): if node in S: beta[node] = a[node]; gamma[node] = 0 else: beta[node] = 0; gamma[node] = a[node] for node in topological_sort(nodes(M)): (beta[node], gamma[node]) = apply_cd_rule(node, beta[inputs(node)], gamma[inputs(node)]) return (beta[t], gamma[t]) |
Each module applies its specific CD rule to its input decomposition. For attention, only the value vectors' decomposition is propagated with fixed ; for nonlinearities, piecewise rules or linearization are applied per the activation and input.
Alternative but functionally similar token-wise and per-parcel decompositions for encoder or decoder-only architectures have also been formalized and implemented in works on linear decompositions and DecompX/Decomposed eXplanations (Oh et al., 2023, Modarressi et al., 2023), operating on all layers and supporting per-token, per-class attributions. These frameworks guarantee the decomposed attribution sums exactly to the original hidden states or logits by construction.
3. Application to Mechanistic Circuit Discovery
In CD-T-based circuit discovery, the objective is to extract a sparse subgraph (“circuit”) comprising a small subset of heads or modules, such that its ablated output closely matches the full model’s behavior . The typical workflow proceeds as follows (Hsu et al., 2024):
- Reverse iterative selection: Starting from the final output node (e.g., logits), treat it as the “receiver set” .
- Direct-effect scoring: For each upstream attention head (feeding into ), compute its direct effect:
Here, is the decomposed contribution propagated from to under mean ablation.
- Selection and recursion: Retain top- heads with the highest , add them to the candidate circuit, update to these heads, and iterate until reaching the input embedding layer.
This backwards pruning drastically reduces candidate circuit size at each stage, supporting the discovery of highly interpretable, fine-grained attention-head circuits.
4. Empirical Faithfulness Evaluation and Runtime Characteristics
In experimental analyses, CD-T achieves high-fidelity circuit discovery with exceptional efficiency compared to prior path-patching algorithms (Hsu et al., 2024). For instance, empirical runtime per-level is approximately 1:52:20 (hh:mm:ss) for CD-T versus 3:37:26 for path-patching, yielding a near speedup due to single-pass decomposition rather than repeated ablation and tracing.
Faithfulness is quantified as the expected fraction of true-label logits preserved by the circuit. Notably, on pathology classification models, CD-T circuits using only of total heads recover of total model performance, outperforming path-patching (41.9\% with 0.03\% of heads), while random head selection achieves negligible test recovery (). Thus, CD-T identifies salient circuits with extremely sparse subgraphs and minimal approximation error.
Standard mechanistic interpretability benchmarks (indirect-object identification, greater-than, docstring completion), as well as comparisons to ACDC, EAP, or reporting of formal ROC-AUC/statistical metrics, are not present in the current published evaluations (Hsu et al., 2024).
5. Relationship to Other Vector and Tokenwise Decomposition Methods
Several recent frameworks—albeit under different nomenclature—provide mathematically analogous decompositions and attributions in Transformers:
- Token-wise Linear Decomposition: In decoder-only models, context token contributions to each next-token logit are decomposed as , precisely tracking their influence through all layers (Oh et al., 2023). Positionwise ablation of parcels quantifies each token's effect on the output probability, supporting analyses of collocation, syntax, and coreference. The method is mathematically exact aside from local linearization at activation functions.
- DecompX: Provides a holistic per-token decomposition of intermediate and final states, propagating elemental attributions through self-attention, normalization, FFN, and classification head. The method handles all residual connections, preserves vector structure, and directly attributes per-class logit contributions (Modarressi et al., 2023). Comprehensive perturbation-based faithfulness metrics (AOPC, accuracy after masking, predictive performance with rationales) show DecompX outperforming both scalar-rollout and gradient-based baselines on standard datasets.
A summary comparison of salient decomposition approaches:
| Method | Decomposition Granularity | Handles Nonlinearities | Evaluation Scope |
|---|---|---|---|
| CD-T (Hsu et al., 2024) | Per-head, arbitrary node | Yes | Circuits, ablation |
| Token-wise (Oh et al., 2023) | Per-token, sequence pos. | Yes (locally) | Token ablation |
| DecompX (Modarressi et al., 2023) | Per-token, per-class | Yes | Perturbation, MNLI, SST-2 |
These approaches are compatible in their core principles—propagating additive decompositions through all network components—differentiating primarily in level of attribution granularity, handling of bias terms, and experimental focus.
6. Limitations, Assumptions, and Extensions
All current CD-T methodologies are exact in the sense that the sum of decomposed parcel contributions reconstructs the original model outputs, provided that activation functions are locally differentiable (e.g., GELU, ReLU, tanh), and that bias apportionment through LayerNorm and FFN is handled carefully. Non-differentiable points do not present a practical obstacle due to their measure-zero frequency (Oh et al., 2023, Modarressi et al., 2023).
Key assumptions and computational constraints include:
- The decomposition requires storing and updating parcels for tokens per layer, yielding quadratic memory and computational overhead.
- Current protocols are natively suited to causal or encoder-only architectures, but can be generalized to encoder-decoder models by decomposing both encoder and cross-attention components.
- Real-world faithfulness is evaluated by masking or ablating top-attribution features. For task-specific interpretability, decomposing the classification head or output softmax is essential for actionable attributions (Modarressi et al., 2023).
Potential future research avenues include:
- Extension to mixture-of-experts and sparsely activated Transformer variants.
- Analysis of decomposition behavior in very large (multi-billion parameter) and multilingual models.
- Span/group-based ablation mechanistically combines under linearity (chain rule).
- Use of CD-T attributions for targeted intervention, control, and fine-grained probing in LLMs.
7. Significance for Model Interpretability Research
Contextual Decomposition for Transformers unifies a class of algebraic, faithful attribution schemes that overcome approximation and runtime bounds of earlier patching or gradient-based methods. CD-T enables direct, fine-grained, and theoretically justified tracing of causal influence—at per-token, per-head, or arbitrary node granularity—facilitating principled circuit discovery in large networks.
Empirical evidence (Hsu et al., 2024, Oh et al., 2023, Modarressi et al., 2023) indicates that CD-T-based approaches make mechanistic explanations of model behavior more scalable and transparent while imposing tractable computational costs. The ability to extract minimal, highly faithful circuits with precise attribution clarifies which architectural features are causally responsible for specific predictions, supporting robust progress in mechanistic interpretability and model diagnostics within deep learning research.