Distributional Attention Distillation (CAB)
- Distributional Attention Distillation is a framework that transfers fine-grained attention biases from Transformer models to state-space models using an attention bridge.
- CAB employs lightweight MLPs to project and align teacher and student representations, avoiding the quadratic cost of full attention matrix matching.
- Empirical results demonstrate that CAB improves accuracy in low-data regimes and reduces perplexity in language tasks, while delivering significant computational and memory efficiency.
Distributional Attention Distillation is a methodological framework for transferring attention-based inductive biases from pretrained Transformer models to state-space models (SSMs), specifically under structural heterogeneity between teacher and student architectures. The approach addresses the challenge of distilling fine-grained sequence processing knowledge—encoded via self-attention in Transformers—into recurrent SSMs such as Mamba, which possess fundamentally different token interaction mechanisms. Distributional Attention Distillation as instantiated in Cross-architecture distillation via Attention Bridge (CAB) enables token-level supervision by aligning representation distributions at multiple layers, efficiently bridging the gap between these model types while maintaining computational and memory efficiency (Wang et al., 22 Oct 2025).
1. Architectural Components and Data Flow
CAB employs a three-part structure comprising a Transformer teacher, a Mamba student, and a lightweight “attention bridge” formed by small multi-layer perceptrons (MLPs). At each Transformer teacher layer (), the model extracts query and key representations and , where is the sequence length and is the teacher’s hidden dimension. Optionally, full attention matrices per head can be extracted, but this incurs computational and memory cost, which CAB avoids.
The student SSM—exemplified by Mamba—produces token-wise projections at each student layer . Here, is the student’s hidden dimension; and drive the implicit recurrent dynamics:
The attention bridge consists of two independent 2-layer MLPs with SiLU nonlinearities:
- , projecting student into the teacher key space,
- , projecting student into the teacher query space.
The data flow comprises: (1) forward pass through the teacher for , (2) forward pass through the student for , (3) mapping student projections to teacher space via the bridge and minimizing their discrepancy.
2. Loss Formulation and Distributional Alignment
CAB eschews explicit attention matrix matching to avoid quadratic costs, instead aligning keys and queries using an loss. For each student layer and mapped teacher layer , the loss is:
where is a layer-mapping function allowing for cross-depth alignment.
A conceptual variant considers aligning full attention distributions with a KL-divergence objective:
However, CAB does not use this variant in practice due to the associated memory cost.
3. Layer-wise Alignment Strategies
Cross-depth heterogeneity between Transformer and SSM—i.e., for student and teacher layers—is managed through a flexible mapping:
This effectively “stretches” or “compresses” the mapping so that each student layer is aligned with an appropriate teacher layer. Alternate schemes include:
- 1:1 alignment when (),
- skip alignment (align only every -th layer),
- weighted combination with learnable or preset weights such that (CAB uses ).
4. Training Procedure and Optimization
CAB optimizes a joint loss:
with default weights and , and tunable in the range for low-data regimes. The task loss is cross-entropy (classification) or a soft KL-divergence (language modeling) between teacher and student outputs. In vision, a bidirectional variant is used: both forward and backward student projections are aligned per layer. In language modeling, CAB adopts a two-stage recipe: Stage 1 aligns on 200M tokens with only; Stage 2 freezes these MLPs and optimizes soft KL on logits.
The following table summarizes typical hyperparameters and alignment mapping:
| Regime | Layer Mapping | ||
|---|---|---|---|
| All (robust) | 1.0 | 1.0 |
5. Empirical Evaluation
CAB demonstrates effectiveness across vision (ImageNet-1k) and language (OpenWebText) distillation tasks. In vision, experiments span 1%, 5%, 10%, and 20% of dataset, with teacher models DeiT-Tiny and DeiT-Small (12 layers, 192/384 dim) and student Vision Mamba variants (24 layers, matched dimensions).
Reported top-1 ImageNet-1k accuracy at 10% data:
| Method | Accuracy (%) |
|---|---|
| Vanilla | 32.9 |
| Soft KD | 42.0 |
| MOHAWK | 45.1 |
| CAB (Ours) | 49.2 |
In language modeling, CAB is evaluated on OpenWebText, C4, and WikiText (perplexity metric, 4B tokens, DistilGPT2 teacher to Phi-Mamba-123M student):
| Method | OpenWebText | C4 | WikiText |
|---|---|---|---|
| Attention-Weight-Reuse | 37.2 | 62.9 | 99.8 |
| MOHAWK | 31.1 | 51.2 | 77.9 |
| CAB (Ours) | 30.1 | 50.1 | 74.7 |
CAB delivers percentage point improvement in low-data ImageNet over MOHAWK and lowest perplexity on benchmark language datasets.
Efficiency is a central objective: CAB avoids memory use, operating at memory and speed versus MOHAWK for 200M tokens.
6. Computational and Methodological Implications
CAB’s design—particularly the avoidance of explicit attention matrix alignment—enables efficient cross-modal distillation under resource constraints. Bridging via MLPs, rather than matching attention weights or distributions, sidesteps the quadratic penalty typically associated with attention mechanisms, enhancing practical scalability. Flexible layer-wise alignment accommodates varied architectural depth between teacher and student. The two-stage training in language modeling demonstrates adaptability to large-scale, incremental tasks.
A plausible implication is that similar bridge-based approaches may generalize to other forms of cross-architecture knowledge transfer, especially where representational heterogeneity or scale induces architectural incompatibilities.
7. Significance and Outlook
Distributional Attention Distillation as realized in CAB substantiates the claim that fine-grained attention-based inductive biases from Transformers are transferable into recurrent state-space models. This is evidenced by strong empirical results in low-data regimes and both vision and language settings, with sustained efficiency gains in memory and computational cost. The method’s generic token-level supervision, lightweight bridge design, and adaptable depth alignment collectively address longstanding challenges in cross-architecture distillation. These outcomes suggest accelerated adoption and improved ecosystem maturity for SSMs, leveraging Transformer expertise for rapid model improvement in emerging architectures (Wang et al., 22 Oct 2025).