Cross-Self Pruning in Deep Learning
- Cross-Self Pruning (CSP) is a dual-method approach that combines cross-self key–value cache pruning and self-distilled weight pruning to improve efficiency in deep learning models.
- It utilizes cross-self KV cache pruning to balance token retention in vision-language models, mitigating biases between text and vision modalities.
- The method incorporates iterative self-distillation with cross-correlation loss to preserve feature fidelity while significantly reducing computational overhead.
Cross-Self Pruning (CSP) refers to two distinct but methodologically related approaches in deep learning for model compression and efficient inference: (1) cross-self key–value (KV) cache pruning for vision–language autoregressive models, which operates at the level of attention and token retention during inference (Pei et al., 2024); and (2) self-distilled pruning via cross-correlation loss, an iterative weight-pruning strategy for deep neural networks informed by internal representational similarity (Neill et al., 2021). Both paradigms improve compute and memory efficiency while achieving strong empirical performance, accomplished by integrating structural awareness—either at the modality or feature level—into the pruning decision process.
1. Motivation and Limitations of Unified Pruning Approaches
Vision-LLMs (VLMs) and LLMs under high compression demands often experience degraded performance when standard pruning methods are applied without considering structural heterogeneity among tokens or features. In VLMs, autoregressive decoding requires storing key–value caches for both text and visual modalities, resulting in quadratic or worse scaling of memory and computation with context length. Previous approaches, such as SnapKV and H2O, prune based on aggregate self-attention scores without distinguishing text from vision tokens, resulting in biased retention—typically favoring tokens with inherently higher attention distributions, often from text, leading to loss of critical visual information and compromised cross-modal reasoning (Pei et al., 2024).
In weight pruning, conventional magnitude-based pruning (MBP) and its variants ignore representational differences between pruned and original networks, often leading to slow recovery, loss of generalization, or sharp degradation at high sparsity. This motivates CSP-style methods that enforce architectural or representational fidelity during or after pruning (Neill et al., 2021).
2. Cross-Self KV Cache Pruning for Multimodal Models
The CSP method for VLMs decomposes the attention matrix into intra-modality (self) and inter-modality (cross) components. Given a post-softmax attention matrix with (text tokens , vision tokens ):
- : text→text (self)
- : vision→vision (self)
- : vision attends to text (cross)
- : text attends to vision (cross)
Token/key importance is scored for both views: By independently budgeting the number of keys to retain for self () and cross () attention, CSP ensures balanced preservation of both modalities. Only keys winning selection in both views are retained, supplemented by always keeping the most recent tokens. This selective masking eliminates systematic bias arising from differing source distributions of attention scores (Pei et al., 2024).
3. n-Softmax and Cache Smoothness Restoration
Pruning keys from the KV cache can artificially concentrate the softmax distribution, as the denominator in the attention calculation contracts. To counteract this sharpness, CSP applies the n-softmax function: where indexes retained keys and is a small constant (typically ). This technique restores smoothness to the distribution and stabilizes autoregressive generation without retraining, which prevents spurious degradation or instability in VLMs under aggressive pruning constraints (Pei et al., 2024).
4. Cross-Self Pruning with Self-Distillation and Cross-Correlation Loss
In deep neural compression, CSP (also known as SDP-CC) proceeds by joint iterative pruning and self-distillation. A student network (dynamically pruned via binary mask ) is trained to mimic a frozen teacher on both prediction and representational similarity. The composite loss function at batch level is: where is cross-entropy loss, is KL-divergence on softened logits, and is the cross-correlation loss between student and teacher penultimate activations. The cross-correlation objective maximizes diagonal alignment and minimizes off-diagonal covariance between channels: with the normalized feature cross-covariance matrix. This enforces both feature fidelity and decorrelation post-pruning, which improves downstream accuracy and the generalization of heavily pruned networks (Neill et al., 2021).
5. Algorithmic Procedures and Implementation
Pseudocode for Cross-Self KV Cache Pruning (VLMs):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
Input: Cached keys K, values V, queries Q, budgets T, K^s, K^c, window params R, O for each decoding iteration: if current_cache_size <= T: return K, V O_logits = (Q @ K.T) / sqrt(d) A = n_softmax(O_logits) A^s = concat(sum_rows(A^{st}), sum_rows(A^{sv})) A^c = concat(sum_rows(A^{cv}), sum_rows(A^{ct})) mask_s = topK_indices(A^s, K^s) mask_c = topK_indices(A^c, K^c) M = mask_s & mask_c # element-wise AND K = (K * M) + K[-R:] V = (V * M) + V[-R:] return pruned K, V |
6. Empirical Performance and Quantitative Evaluation
On MileBench, a comprehensive multimodal VLM benchmark spanning 29 datasets, CSP achieves substantial improvements in both efficiency and accuracy under stringent cache budgets:
| Task | Full KV | Best Prior | CSP | CSP Gain |
|---|---|---|---|---|
| T-3 Reasoning | 32.2 | 32.5 | 41.6 | +29% |
| S-5 Semantic QA | 60.5 | 60.5 | 61.0 | +0.5 absolute |
| NH Retrieval | 4.7 | 5.3 | 6.3 | highest, +18% rel. |
- CSP reduces the KV cache by 13.6% on average while matching or surpassing baselines in accuracy.
- Up to 41% relative accuracy improvement on conversational embodied dialogue benchmarks under aggressive pruning.
- Efficient n-softmax restores 1–2% accuracy versus regular softmax under strong pruning.
- On GLUE/XGLUE, cross-self pruning with cross-correlation self-distillation (SDP-CC) outperforms all magnitude-based pruning baselines, achieving up to 78.9 accuracy on BERT-base at 10% weights and 70.8 on XLM-R-base at 30% weights (Neill et al., 2021).
7. Limitations and Prospects
Both CSP variants (cache pruning and self-distilled pruning) have fixed budget parameters (, , mask schedules) that must be chosen a priori; adaptive strategies could offer further gains. Distributional overlap between modalities or representation spaces can reduce the benefit of cross/self decomposition. Extension to more than two modalities or feature partitions may require more complex decomposition and selection. Hybrid approaches, combining CSP with token-merging techniques or learned controllers for adaptive pruning, are potential future directions (Pei et al., 2024).
CSP in both inference-time and training-time formulations consistently demonstrates superior retention of critical information under strong constraints, balancing computational efficiency with robust performance across a wide range of tasks and architectures.