Papers
Topics
Authors
Recent
2000 character limit reached

Get More with LESS: Synthesizing Recurrence with KV Cache Compression for Efficient LLM Inference

Published 14 Feb 2024 in cs.LG and cs.AI | (2402.09398v2)

Abstract: Many computational factors limit broader deployment of LLMs. In this paper, we focus on a memory bottleneck imposed by the key-value (KV) cache, a computational shortcut that requires storing previous KV pairs during decoding. While existing KV cache methods approach this problem by pruning or evicting large swaths of relatively less important KV pairs to dramatically reduce the memory footprint of the cache, they can have limited success in tasks that require recollecting a majority of previous tokens. To alleviate this issue, we propose LESS, a simple integration of a (nearly free) constant sized cache with eviction-based cache methods, such that all tokens can be queried at later decoding steps. Its ability to retain information throughout time shows merit on a variety of tasks where we demonstrate LESS can help reduce the performance gap from caching everything, sometimes even matching it, all while being efficient. Relevant code can be found at https://github.com/hdong920/LESS.

Citations (32)

Summary

  • The paper addresses the KV cache memory bottleneck by proposing LESS, a method that synthesizes low-rank embeddings with sparse policies to retain essential information.
  • LESS improves performance on tasks such as language modeling and summarization, reducing perplexity by over 20% and recovering more than 40% of Rouge-1 degradation.
  • The method enhances inference speed by achieving up to 1.3x lower latency and 1.7x higher throughput compared to full KV cache implementations.

LESS: Efficient LLM Inference via KV Cache Compression

The paper "Get More with LESS: Synthesizing Recurrence with KV Cache Compression for Efficient LLM Inference" (2402.09398) addresses the memory bottleneck in LLM inference caused by the key-value (KV) cache. The authors propose LESS (Low-rank Embedding Sidekick with Sparse policy), a method that integrates a constant-sized low-rank cache with eviction-based sparse KV cache policies to retain information from discarded KV pairs. This approach aims to improve performance on tasks requiring recollection of previous tokens while maintaining efficiency.

Background and Motivation

The KV cache, which stores previous keys and values at each layer during decoding, significantly increases memory consumption. Sparse policies reduce the KV cache size by pruning less important KV pairs. However, these methods can lead to performance degradation in tasks that require recalling a majority of previous tokens. The paper identifies this issue and motivates the need for a more effective KV cache policy that minimizes performance degradation, scales slowly, and is easy to integrate into existing LLMs. Figure 1

Figure 1: Toy (top row) and Llama 2 7B (bottom row) example decoder attention maps with $\hho$ as the underlying sparse policy.

The authors observe that the residual between full and sparse attention outputs is low-rank, motivating the use of low-rank methods to approximate these residuals for efficient caching. LESS leverages this observation by learning the residual between the original attention output and the output approximated by a sparse policy, thereby allowing queries to access previously omitted regions in attention maps (Figure 1).

LESS Method

LESS synthesizes sparse KV policies with low-rank states to improve performance. The method adds a constant-sized cache that does not scale with the sequence length. The kernels are defined as:

ϕ(q)=σϕ(σϕ(qWϕ,1)Wϕ,2)\phi(q) = \left| \sigma_\phi(\sigma_\phi(q W_{\phi, 1}) W_{\phi, 2}) \right|

ψ(k)=σψ(σψ(kWψ,1)Wψ,2)Wψ,3\psi(k) = \left| \sigma_\psi(\sigma_\psi(k W_{\psi, 1}) W_{\psi, 2}) W_{\psi, 3} \right|

where $\sigma_{\bcdot}$ are activation functions and $W_{\bcdot, i}$ are weight matrices. Figure 2

Figure 2: LESS algorithm during inference.

The attention calculation procedure involves computing an approximation of the original attention using cached keys and values from the sparse policy, as well as the low-rank cache (Figure 2). The cache is updated by embedding information from discarded KV pairs into the low-rank cache using recursive updates.

Implementation Details

The kernel functions ϕ\phi and ψ\psi are trained independently at each layer to minimize the 2\ell_2 distance to the output projection of the original attention layer. All weights except those in ϕ\phi and ψ\psi are frozen. The paper emphasizes that kernel initialization is critical, using learnable scalars initialized to a small value to allow the sparse policy to act as a warm start. An efficient implementation is developed to enhance throughput and reduce latency, using fused linear kernels and adapting an existing implementation for sparse caching.

Experimental Results

The paper evaluates LESS on Llama 2 and Falcon models with different sparse policies and sparsity levels. The experiments demonstrate that LESS achieves performance improvements in language modeling, classification, and summarization tasks. Figure 3

Figure 3: Layer-wise Llama 2 7B mean Hellinger distance from original attention probabilities, aggregated across WikiText evaluation samples.

Specifically, LESS reduces word perplexities on WikiText and PG-19 by over 20% from the sparse policy alone, and it recovers more than 40% of the Rouge-1 degradation caused by a sparse policy on the CNN/DailyMail dataset. The paper also reports latency reductions of up to 1.3x and throughput increases of up to 1.7x compared to full caching. Further analysis shows that LESS accurately replicates original attention probability distributions (Figure 3) and performs well with longer sequences.

Conclusion

The paper presents LESS as an effective approach to mitigate the KV cache bottleneck by synthesizing sparse KV cache algorithms with low-rank states. The method demonstrates significant performance gains while maintaining efficiency and being easy to train and deploy. Future work includes improving kernel design and investigating the residual of LESS.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 2 tweets with 105 likes about this paper.