LLMs are widely used in various applications, but their inference comes with significant computational costs, largely dominated by the memory required to store the Key-Value (KV) cache during autoregressive decoding. The KV-cache grows linearly with the sequence length, batch size, and the number of attention layers, often becoming much larger than the model weights themselves, leading to memory bottlenecks and reduced throughput.
Existing methods to optimize the KV-cache primarily focus on the sequence dimension. They identify and discard less important tokens in the input sequence based on various heuristics like recency (Sliding Window Attention (Beltagy et al., 2020 ), StreamingLLM (Xiao et al., 2023 )) or attention scores (Heavy-Hitter Oracle (H2O) [2024], Scissorhands [2024]). These methods typically apply the same cache budget or policy uniformly across all attention layers.
The paper "SqueezeAttention: 2D Management of KV-Cache in LLM Inference via Layer-wise Optimal Budget" (Wang et al., 7 Apr 2024 ) proposes optimizing the KV-cache from a new perspective: the layer dimension. It observes that not all attention layers contribute equally to the final output embedding. By quantifying the "importance" of each layer and adaptively allocating KV-cache budgets layer-wise, SqueezeAttention aims to achieve better accuracy for a given memory budget or reduce memory usage while maintaining accuracy.
Observation on Layer Importance:
The core observation is that the transformation applied by the self-attention mechanism varies across layers. To quantify this, the authors measure the cosine similarity between the hidden states of a token before and after passing through the self-attention block in each layer. A higher cosine similarity indicates less change (or less "information inserted") by that layer's attention computation, suggesting potentially lower importance for caching its full KV embeddings.
Experiments across various LLMs (Mistral-7B, Llama2-7B/70B, Falcon-7B) show common patterns:
- The first half of layers generally causes larger changes (lower cosine similarity).
- The very first and last layers might be particularly important, depending on the model.
- Some layers, particularly in the later half, show very high cosine similarity (close to 1.0), suggesting they have minimal impact on the embedding transformation via attention and their KV-cache might be less critical.
This observation suggests that a uniform KV-cache budget across all layers is sub-optimal and that KV-cache optimization can be performed along both the sequence dimension (which tokens to keep) and the layer dimension (which layers need more cache).
SqueezeAttention Algorithm:
SqueezeAttention works by combining an existing sequence-wise KV-cache compression algorithm () with a layer-wise budget allocation strategy. The process involves:
- Layer Importance Measurement: During the prompt prefilling phase, for each token, calculate the cosine similarity between the hidden state before and after the self-attention block in each layer. Average this similarity over all prompt tokens to get a single importance score for each layer.
- Layer Grouping: Cluster the layers into groups based on their averaged cosine similarities using KMeans (empirically, 3 groups were found effective). Groups with lower average similarity are considered more important, and groups with higher average similarity less important.
- Budget Reallocation: Given a total KV-cache budget (often expressed as a percentage of the full cache size or maximum sequence length), reallocate this budget among layers. Layers in "unimportant" groups (those with highest similarity) are assigned a reduced budget (e.g., a fraction of the initial uniform budget ). The budget saved from these layers is then redistributed among the more "important" layer groups.
- Layer-wise Compression: During the decoding phase, for each layer, apply the chosen sequence-wise compression algorithm () but constrained by the specific budget assigned to that layer, , which may differ from other layers' budgets.
This process is detailed in Algorithm 1 in the paper. The hyperparameter controls how much budget is reduced from the least important layers.
Practical Implementation Details:
Implementing SqueezeAttention requires modifications to the LLM's forward pass and KV-cache management:
- Forward Pass Modification: During prefilling, the forward pass needs to be augmented to compute and store the hidden states before and after the self-attention block for each layer. This adds a small computational overhead and temporary memory usage during prefilling.
- Cosine Similarity Calculation: Implement the cosine similarity calculation efficiently, likely leveraging optimized linear algebra libraries.
- KMeans Clustering: Run KMeans clustering on the layer importance scores after prefilling. This is a standard clustering algorithm and the overhead is minimal for the typical number of layers (e.g., 32-80).
- Layer-wise Budget Management: The KV-cache management system needs to track and enforce different maximum capacities () for the KV-cache of each individual layer.
- Integration with : The logic for the chosen sequence-wise compression algorithm (, e.g., Sliding Window, H2O) must be adapted to operate independently on each layer's KV-cache within its allocated budget . This requires modifying the attention mechanism's cache handling logic. For example, a layer receiving a smaller budget will retain fewer tokens via than a layer with a larger budget.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
def squeeze_attention_decode_step(layer_idx, hidden_state, kv_cache, layer_budget_bytes, sequence_compressor): # Assuming hidden_state is input to self-attention # Compute query, key, value (Q, K, V) for the new token query = layer.wq(hidden_state) new_key = layer.wk(hidden_state) new_value = layer.wv(hidden_state) # Update KV-cache for this layer # kv_cache[layer_idx] stores the cache for this layer current_layer_cache = kv_cache[layer_idx] updated_layer_cache = concat(current_layer_cache, (new_key, new_value)) # Apply sequence-wise compression based on layer-specific budget # sequence_compressor.apply(cache, budget) prunes tokens from cache pruned_layer_cache = sequence_compressor.apply(updated_layer_cache, layer_budget_bytes[layer_idx]) kv_cache[layer_idx] = pruned_layer_cache # Compute attention using the pruned cache attention_output = layer.self_attention(query, pruned_layer_cache) # Continue with the rest of the layer's forward pass output_hidden_state = layer.mlp(attention_output) return output_hidden_state |
Performance and Resource Implications:
The paper demonstrates significant practical benefits:
- Memory Savings: SqueezeAttention consistently achieves comparable accuracy to baseline sequence-only methods with 30% to 70% less total KV-cache memory. Compared to Full Cache, it saves 70% to 80% of per-token decoding memory usage.
- Throughput Improvement: Reduced memory footprint allows for larger batch sizes, directly increasing token generation throughput. Experiments show up to 2.2 throughput improvement for Mistral-7B and 1.4 for Llama2-70B compared to Full Cache, enabling larger batches that would OOM (out-of-memory) otherwise.
- Computational Overhead: The layer importance calculation and clustering add a small overhead during prefilling (a single pass). The per-layer budget application during decoding doesn't add significant per-token computation compared to the base method, it just applies the same pruning logic with a different threshold per layer. The main computational benefit comes from operating on smaller KV-caches.
Limitations and Considerations:
- Hyperparameter Tuning: The hyperparameter and the number of clusters (3 in the paper) might require tuning depending on the specific model and task to balance accuracy and memory savings.
- Importance Metric Sensitivity: The cosine similarity metric is a heuristic for importance. While effective in the paper's experiments, its robustness across all possible models and tasks might need further investigation.
- Integration Complexity: Integrating SqueezeAttention requires modifying the core attention and KV-cache management logic within the inference framework (e.g., Hugging Face Transformers, vLLM). This can be complex depending on the framework's architecture.
Conclusion:
SqueezeAttention (Wang et al., 7 Apr 2024 ) introduces a practical approach to optimize LLM inference by recognizing and exploiting the varying importance of attention layers. By dynamically allocating KV-cache budgets layer-wise based on simple cosine similarity measurements, it can be combined with existing sequence-wise compression techniques to achieve substantial memory savings and throughput improvements without significant accuracy degradation. This 2D optimization offers a valuable tool for practitioners aiming to deploy LLMs more efficiently, especially for long contexts and large batch sizes. The orthogonality of SqueezeAttention means it can potentially enhance the efficiency of many existing and future sequence-wise KV-cache optimization methods.