Attention-weighted CKA for Knowledge Distillation
- The paper introduces ACCKA, extending classic CKA by incorporating token-level attention weights to focus on salient audio regions during knowledge distillation.
- ACCKA eliminates the need for explicit projection layers by naturally handling mismatched embedding dimensions between teacher and student models.
- Empirical results within the PL-Distill framework demonstrate that ACCKA enables efficient compression while achieving or exceeding teacher performance in speech emotion recognition.
Attention-weighted Centered Kernel Alignment (ACCKA) is a kernel similarity measure designed to enhance alignment between representations of teacher and student models, particularly in the context of knowledge distillation for large audio-LLMs (LALMs). ACCKA extends classic Centered Kernel Alignment (CKA) by incorporating attention-based weighting at the level of individual time steps (audio tokens), thereby emphasizing regions deemed important by the teacher model's attention mechanism. This framework both highlights salient local structure and naturally accommodates mismatched feature spaces between teacher and student, obviating the need for explicit projection layers. ACCKA is a cornerstone of the PL-Distill framework for knowledge distillation in @@@@1@@@@ (SER), enabling efficient model compression while retaining or even exceeding teacher-level performance (Yang et al., 2 Feb 2026).
1. Foundation: Centered Kernel Alignment (CKA)
Centered Kernel Alignment is a normalized similarity measure between two sets of features and . The linear kernel Gram matrices, , , are centered using
resulting in , and similarly for . The linear CKA is defined as
where denotes the Frobenius inner product. CKA is closely related to the Hilbert-Schmidt independence criterion (HSIC) and measures the similarity of covariance structure, remaining invariant to isotropic invertible linear transforms of or . Notably, CKA accommodates feature spaces of differing dimensionality () and scales to high dimensions (Cortes et al., 2012, Yang et al., 2 Feb 2026).
2. ACCKA: Attention-weighted Extension
Attention-weighted Centered Kernel Alignment generalizes CKA by injecting importance weights reflecting token-level attention from the teacher model. For audio inputs, let the teacher's last-layer self-attention from the final 'response' token to audio tokens be . Normalize these to a probability vector,
yielding weights .
Each embedding row (time step) in both teacher and student representations is scaled by :
Embeddings are then centered by subtracting their columnwise means. The attention-weighted CKA ("ACCKA") is
where the centering operator is applied after weighting. ACCKA directs the alignment measure toward acoustically or semantically salient regions, as defined by the teacher's attention, improving the focus of knowledge transfer (Yang et al., 2 Feb 2026).
3. Objective Function and Optimization
The distillation loss at the projector-level is defined by the negative of the ACCKA similarity:
The goal is to minimize this loss, thereby maximizing correspondence between the statistical geometry of teacher and student embeddings at attention-critical time steps. Unlike adversarial or regression-based distillation losses, ACCKA requires no additional regularization, as normalization ensures the score remains bounded.
4. Handling Mismatched Embedding Dimensions
A fundamental property of both CKA and ACCKA is that embedding dimensionalities for teacher () and student () need not match. The formulation only requires products of the form , avoiding any explicit projection between feature spaces. Thus, the projector-level MLPs for teacher and student are free to evolve independently. ACCKA aligns the empirical covariance structures of these spaces, facilitating knowledge transfer even when teacher and student operate with different representational capacities (Yang et al., 2 Feb 2026).
5. Computational Implementation
The main stages of ACCKA computation are as follows:
- Normalization of attention:
- Application of weights: Multiply with each row of the corresponding teacher and student embeddings.
- Centering: Subtract per-column means from weighted embeddings.
- Covariance computation: Form cross-covariance matrices , , .
- Frobenius norms: Compute and .
- Final ACCKA score and loss: , .
The entire process is batchable, numerically stable with standard floating-point precision, and robust to division-by-zero through -stabilization. The computational complexity per sample is , scaling linearly with sequence length and quadratically with embedding dimensions (dominated by the larger of or ) (Yang et al., 2 Feb 2026).
6. Statistical and Learning-Theoretic Properties
Classic centered alignment, as formalized by Cortes, Mohri, and Rostamizadeh (Cortes et al., 2012), admits concentration bounds, kernel learning guarantees via convex quadratic programming, and stability-based generalization theorems. The extension to attention-weighted alignment introduces new statistical considerations: concentration now depends on maximal weights , and stability must account for bi-level fitting if is optimized on the same data. Proper regularization of the attention vector is necessary to avoid overfitting, though in ACCKA is fixed by the teacher's attention and thus not subject to direct optimization. Theoretical tools such as algorithmic stability and Rademacher complexity can be adapted to accommodate weighted kernels, provided constraints on are observed (Cortes et al., 2012).
7. Applications and Significance
ACCKA is deployed within the PL-Distill framework to enable projector-level knowledge distillation for LALMs applied to speech emotion recognition (SER). By combining ACCKA-guided projector-level alignment with logits-level KL divergence minimization, PL-Distill achieves compression of an 8.4B-parameter teacher to a 1.1B-parameter student while consistently outperforming both the teacher and SOTA baselines across diverse SER benchmarks (IEMOCAP, RAVDESS, SAVEE). ACCKA's conceptual innovation is its use of teacher-driven attention to selectively transfer representation structure, and its formal kernel-theoretic underpinning ensures robust alignment without requiring ad hoc dimension matching or additional regularization (Yang et al., 2 Feb 2026).
A plausible implication is that the ACCKA formalism may generalize to other cross-modal or structured distillation settings where attention signals signal salience. Its computational efficiency and precise handling of embedding mismatch make it a compelling candidate for ongoing research in model compression and transfer learning.