TRIM-KV: Efficient KV Pruning for LVLMs
- The paper introduces TRIM-KV, an inference-phase pruning method that retains the most salient visual tokens to reduce the KV cache in large vision-language models.
- It leverages the sparsity of early cross-attention maps by aggregating head-wise attention scores to identify and select key visual tokens.
- Empirical results demonstrate latency reductions up to 26% and maintained accuracy with 50% token retention across standard vision-language benchmarks.
Cross-attention-based TRIM-KV is an inference-phase pruning method for reducing the key-value (KV) cache footprint of visual features in cross-attention-based large vision-LLMs (LVLMs). Developed in the context of models such as LLaMA-3.2-Vision, which employ dedicated cross-attention layers to aggregate image and text information, TRIM-KV exploits the empirical sparsity of early cross-attention maps to retain only the most salient visual tokens, thereby significantly reducing memory and computational costs while maintaining competitive performance across standardized benchmarks (Lee et al., 1 Apr 2025).
1. KV Cache Bottleneck in Cross-Attended LVLMs
In cross-attention-based LVLMs, images are encoded into a set of visual tokens, with each token mapped via linear projections to key () and value () representations of dimension . In contrast to self-attention layers on textual data, where the number of tokens is modest, the number of image tokens can be significantly larger, especially for high-resolution images (e.g., for 384384 images, versus text tokens). The KV cache required to store these representations is therefore dominated by visual features. The memory footprint is
where the factor of 2 corresponds to storing both keys and values. For LLaMA-3.2-Vision-11B, the cross-attention KV cache () is approximately 12.5 times larger than the self-attention KV cache () on text at these settings. Thus, unmitigated visual KV caching poses a principal bottleneck during inference (Lee et al., 1 Apr 2025).
2. Mathematical Formulation of TRIM-KV Token Importance
Cross-attention operates by computing attention weights
where denotes queries from text tokens, and the visual keys. The output is a weighted sum , with the visual values. In TRIM-KV, token importance is measured head-wise in the first cross-attention layer, aggregating over all query tokens:
for each attention head and visual token , producing a set of token importance scores for each head. This approach exploits the sparsity of attention maps, where a small subset of visual tokens typically dominate the aggregate attention distribution.
3. Trimming Algorithm and Pseudocode
The TRIM-KV pruning process operates as follows:
- For each attention head (), select the top visual tokens by their aggregate importance, where for a chosen pruning ratio (e.g., 0.5).
- The overall pruned token set is the union across heads: .
- Only the and vectors belonging to are kept; the remainder are discarded from the KV cache for all subsequent cross-attention layers.
Pseudocode:
1 2 3 4 5 6 7 8 9 10 |
Input: α ∈ ℝ^{H×n×n_k}, K_ratio
Output: index_set T
T ← ∅
for h in 1…H:
p ← sum over queries of α[h,:,:] # shape (n_k)
k ← ⌈K_ratio × n_k⌉
T_h ← indices of top k entries in p
T ← T ∪ T_h
return T |
4. Integration with Standard LVLM Inference Pipelines
The TRIM-KV method is integrated purely at inference, requiring no re-training or model fine-tuning. The sequence is:
- Encode the image to obtain visual features of shape and project to and .
- At the first generation step (), compute attention weights and run TRIM-KV to obtain .
- Create pruned key and value caches: .
- For all subsequence steps and cross-attention layers, use as the cached memory.
This process, termed "plug-and-play," hinges on the empirically observed stability of cross-attention patterns after the first block, such that a one-time token selection remains stable across remaining layers and generation steps (Lee et al., 1 Apr 2025).
5. Empirical Impact and Benchmark Results
The introduction of TRIM-KV into LLaMA-3.2-Vision-11B results in significant memory and latency reductions during inference. On a batch of 32, total latency is reduced from approximately 4,000 ms (full tokens) to 3,165 ms (50% tokens, -19.7%) and 2,917 ms (40% tokens, -26%).
Accuracy is preserved across six vision-language benchmarks at up to 50.9% visual token retention (K_ratio=0.5), with <0.5% deviation from full-cache baselines. A summary of performance is as follows:
| K_ratio | SEED-Image | MME | MMVP | LLaVA-Bench |
|---|---|---|---|---|
| 1.00 | 72.6 | 1685.9 | 46.7 | 88.3 |
| 0.50 | 72.3 | 1687.3 | 47.3 | 88.1 |
| 0.40 | 72.1 | 1682.8 | 47.3 | 87.3 |
Ablation studies at K_ratio=0.5 demonstrate substantial superiority of TRIM-KV over random (e.g., SEED=67.0, LLaVA=83.2) and spatial (SEED=71.8, LLaVA=85.9) baselines for the same token budget. This underscores the relevance of attention-derived token selection for minimizing performance loss (Lee et al., 1 Apr 2025).
6. Scalability, Limitations, and Theoretical Considerations
The efficacy of TRIM-KV is strengthened as image resolution and the number of visual tokens increase, due to a greater absolute reduction in memory and computation. The assumption underlying the method is a structured stability in early cross-attention maps, such that one-time token selection remains valid across all subsequent operations. Scenarios with dynamically shifting attention patterns may degrade the method's effectiveness.
TRIM-KV is designed specifically for cross-attention-based architectures ("Flamingo-style" LVLMs). Self-attention-only models require alternative approaches for efficient token pruning. Additionally, the optimal choice of may require empirical validation for new tasks or domains, as excessive trimming could affect particular downstream behaviors (Lee et al., 1 Apr 2025).
7. Summary and Outlook
Cross-attention-based TRIM-KV achieves up to 50% reduction in visual KV cache, corresponding to 12–26% latency improvements, while maintaining state-of-the-art performance on multiple vision-language benchmarks without any model retraining or fine-tuning. By identifying and retaining only those visual tokens most salient to the initial cross-attention computation, the method addresses a central bottleneck in cross-attention-based LVLMs. This suggests that attention-based token pruning, particularly at the early layers, is a viable pathway for efficient multimodal inference at scale, especially as image and batch sizes increase (Lee et al., 1 Apr 2025).