Papers
Topics
Authors
Recent
Search
2000 character limit reached

Distributional Attention Distillation (CAB)

Updated 27 January 2026
  • Distributional Attention Distillation is a framework that transfers fine-grained attention biases from Transformer models to state-space models using an attention bridge.
  • CAB employs lightweight MLPs to project and align teacher and student representations, avoiding the quadratic cost of full attention matrix matching.
  • Empirical results demonstrate that CAB improves accuracy in low-data regimes and reduces perplexity in language tasks, while delivering significant computational and memory efficiency.

Distributional Attention Distillation is a methodological framework for transferring attention-based inductive biases from pretrained Transformer models to state-space models (SSMs), specifically under structural heterogeneity between teacher and student architectures. The approach addresses the challenge of distilling fine-grained sequence processing knowledge—encoded via self-attention in Transformers—into recurrent SSMs such as Mamba, which possess fundamentally different token interaction mechanisms. Distributional Attention Distillation as instantiated in Cross-architecture distillation via Attention Bridge (CAB) enables token-level supervision by aligning representation distributions at multiple layers, efficiently bridging the gap between these model types while maintaining computational and memory efficiency (Wang et al., 22 Oct 2025).

1. Architectural Components and Data Flow

CAB employs a three-part structure comprising a Transformer teacher, a Mamba student, and a lightweight “attention bridge” formed by small multi-layer perceptrons (MLPs). At each Transformer teacher layer LlTL^{\text{T}}_l (l=1Tl = 1 \ldots T), the model extracts query and key representations QT(l)RL×dtQ^{\text{T}}(l) \in \mathbb{R}^{L \times d_t} and KT(l)RL×dtK^{\text{T}}(l) \in \mathbb{R}^{L \times d_t}, where LL is the sequence length and dtd_t is the teacher’s hidden dimension. Optionally, full attention matrices Al,hTA^{\text{T}}_{l,h} per head hh can be extracted, but this incurs O(L2)O(L^2) computational and memory cost, which CAB avoids.

The student SSM—exemplified by Mamba—produces token-wise projections B(l),C(l)RL×dsB^{(l)}, C^{(l)} \in \mathbb{R}^{L \times d_s} at each student layer ll. Here, dsd_s is the student’s hidden dimension; BB and CC drive the implicit recurrent dynamics:

ht=Aˉtht1+Btxt,yt=Cthth_t = \bar{A}_t h_{t-1} + B_t x_t,\quad y_t = C_t h_t

The attention bridge consists of two independent 2-layer MLPs with SiLU nonlinearities:

  • ϕB:RdsRdt\phi_B: \mathbb{R}^{d_s} \rightarrow \mathbb{R}^{d_t}, projecting student BB into the teacher key space,
  • ϕC:RdsRdt\phi_C: \mathbb{R}^{d_s} \rightarrow \mathbb{R}^{d_t}, projecting student CC into the teacher query space.

The data flow comprises: (1) forward pass through the teacher for {KT(g(l)),QT(g(l))}\{K^{\text{T}}(g(l)), Q^{\text{T}}(g(l))\}, (2) forward pass through the student for {B(l),C(l)}\{B^{(l)}, C^{(l)}\}, (3) mapping student projections to teacher space via the bridge and minimizing their discrepancy.

2. Loss Formulation and Distributional Alignment

CAB eschews explicit attention matrix matching to avoid quadratic costs, instead aligning keys and queries using an 2\ell_2 loss. For each student layer l=1Sl=1\dots S and mapped teacher layer g(l)g(l), the loss is:

Lattn=1Sl=1S(ϕB(B(l))KT(g(l))22+ϕC(C(l))QT(g(l))22)\mathcal{L}_\text{attn} = \frac{1}{S} \sum_{l=1}^S \Big( \|\phi_B(B^{(l)}) - K^{\text{T}}(g(l))\|_2^2 + \|\phi_C(C^{(l)}) - Q^{\text{T}}(g(l))\|_2^2 \Big)

where g(l)g(l) is a layer-mapping function allowing for cross-depth alignment.

A conceptual variant considers aligning full attention distributions with a KL-divergence objective:

LattKL=l=1Th=1Ht=1LKL(Al,h,tTBll(A^l,h,tS))\mathcal{L}_\text{att}^{KL} = \sum_{l=1}^T \sum_{h=1}^H \sum_{t=1}^L \mathrm{KL}\left( A^{\text{T}}_{l,h,t} \,\|\, B_{l\to l'}(\hat{A}^{\text{S}}_{l',h,t}) \right)

However, CAB does not use this variant in practice due to the associated memory cost.

3. Layer-wise Alignment Strategies

Cross-depth heterogeneity between Transformer and SSM—i.e., STS \neq T for SS student and TT teacher layers—is managed through a flexible mapping:

g(l)=lSTg(l) = \left\lfloor \frac{l}{S} \cdot T \right\rfloor

This effectively “stretches” or “compresses” the mapping so that each student layer is aligned with an appropriate teacher layer. Alternate schemes include:

  • 1:1 alignment when S=TS=T (g(l)=lg(l)=l),
  • skip alignment (align only every kk-th layer),
  • weighted combination with learnable or preset weights wlw_l such that lwl=1\sum_l w_l = 1 (CAB uses wl=1/Sw_l = 1/S).

4. Training Procedure and Optimization

CAB optimizes a joint loss:

Ltotal=αLtask+βLattn\mathcal{L}_\text{total} = \alpha \cdot \mathcal{L}_\text{task} + \beta \cdot \mathcal{L}_\text{attn}

with default weights α=1.0\alpha = 1.0 and β=1.0\beta = 1.0, and β\beta tunable in the range [0.1,1.0][0.1, 1.0] for low-data regimes. The task loss is cross-entropy (classification) or a soft KL-divergence (language modeling) between teacher and student outputs. In vision, a bidirectional variant is used: both forward and backward student projections are aligned per layer. In language modeling, CAB adopts a two-stage recipe: Stage 1 aligns ϕB,ϕC\phi_B,\phi_C on 200M tokens with Lattn\mathcal{L}_\text{attn} only; Stage 2 freezes these MLPs and optimizes soft KL on logits.

The following table summarizes typical hyperparameters and alignment mapping:

Regime α\alpha β\beta Layer Mapping g(l)g(l)
All (robust) 1.0 1.0 lST\left\lfloor \frac{l}{S} T \right\rfloor

5. Empirical Evaluation

CAB demonstrates effectiveness across vision (ImageNet-1k) and language (OpenWebText) distillation tasks. In vision, experiments span 1%, 5%, 10%, and 20% of dataset, with teacher models DeiT-Tiny and DeiT-Small (12 layers, 192/384 dim) and student Vision Mamba variants (24 layers, matched dimensions).

Reported top-1 ImageNet-1k accuracy at 10% data:

Method Accuracy (%)
Vanilla 32.9
Soft KD 42.0
MOHAWK 45.1
CAB (Ours) 49.2

In language modeling, CAB is evaluated on OpenWebText, C4, and WikiText (perplexity metric, 4B tokens, DistilGPT2 teacher to Phi-Mamba-123M student):

Method OpenWebText C4 WikiText
Attention-Weight-Reuse 37.2 62.9 99.8
MOHAWK 31.1 51.2 77.9
CAB (Ours) 30.1 50.1 74.7

CAB delivers 7.2\sim 7.2 percentage point improvement in low-data ImageNet over MOHAWK and lowest perplexity on benchmark language datasets.

Efficiency is a central objective: CAB avoids O(L2)O(L^2) memory use, operating at 1/10\sim 1/10 memory and 4×4\times speed versus MOHAWK for 200M tokens.

6. Computational and Methodological Implications

CAB’s design—particularly the avoidance of explicit attention matrix alignment—enables efficient cross-modal distillation under resource constraints. Bridging via MLPs, rather than matching attention weights or distributions, sidesteps the quadratic penalty typically associated with attention mechanisms, enhancing practical scalability. Flexible layer-wise alignment accommodates varied architectural depth between teacher and student. The two-stage training in language modeling demonstrates adaptability to large-scale, incremental tasks.

A plausible implication is that similar bridge-based approaches may generalize to other forms of cross-architecture knowledge transfer, especially where representational heterogeneity or scale induces architectural incompatibilities.

7. Significance and Outlook

Distributional Attention Distillation as realized in CAB substantiates the claim that fine-grained attention-based inductive biases from Transformers are transferable into recurrent state-space models. This is evidenced by strong empirical results in low-data regimes and both vision and language settings, with sustained efficiency gains in memory and computational cost. The method’s generic token-level supervision, lightweight bridge design, and adaptable depth alignment collectively address longstanding challenges in cross-architecture distillation. These outcomes suggest accelerated adoption and improved ecosystem maturity for SSMs, leveraging Transformer expertise for rapid model improvement in emerging architectures (Wang et al., 22 Oct 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Distributional Attention Distillation.