The paper introduces LLMpresso, a method for extending the context window of pre-trained LLMs while preserving performance on shorter contexts. LLMpresso addresses the out-of-distribution (OOD) issue in rotary positional embeddings (RoPE) by focusing on the hypothesis that higher RoPE dimensions are insufficiently trained, which affects the effectiveness of existing rescaling methods. The method includes a RoPE rescaling algorithm using evolutionary search guided by "needle-driven" perplexity (PPL) and mixed context window training.
The authors identify two major challenges in extending LLM context windows:
- Existing rescaling methods do not achieve the effective target context length
- Performance degradation on the original short context window.
The authors attribute these issues to insufficient training in higher RoPE dimensions, resulting in shorter effective RoPE rotation ranges.
LLMpresso includes the following innovations:
- A RoPE rescaling algorithm that uses evolutionary search to identify critical RoPE dimensions and optimal rescaling factors, guided by a "needle-driven" perplexity evaluation.
- A mixed context window training approach, which fine-tunes model weights to adopt rescaled RoPE for long-context sequences while preserving short-context performance with the original RoPE.
The RoPE is calculated as follows:
$\mathbf{q}_m=f_q(\mathbf{x}_m,m);\quad f_q(\mathbf{x}_m,m)=e^{im\theta}\mathbf{W}_q\mathbf{x}_m\$
$\mathbf{k}_n=f_k(\mathbf{x}_n,n);\quad f_k(\mathbf{x}_n,n)=e^{in\theta}\mathbf{W}_k\mathbf{x}_n\$
- : query representation at position
- : sequence of vectors at position
- : position index
- : function to incorporate position information to the word embeddings and transforms them into query representation
- : imaginary unit
- : per-dimensional rotation angle
- : projection matrices
- : key representation at position
- : function to incorporate position information to the word embeddings and transforms them into key representation
- : position index
- : projection matrices
The attention weights are computed as:
- : query representation at position
- : key representation at position
- : attention head dimension
The per-dimensional rotation angle for is defined as:
- : per-dimensional rotation angle
- : position index
- : per-dimensional rotation angle for
- : a predefined RoPE base value
The corresponding period length can be calculated as:
- : the corresponding period length
- : per-dimensional rotation angle for
The critical dimension can be computed as:
$d_{\text{tcd}=2\lceil \frac{d}{2}\log_{\theta_{base} \frac{L_{\text{train}{2\pi} \rceil$
- : theoretical critical dimension
- : attention head dimension
- : a predefined RoPE base value
- : input sequence length
$\hat\theta_i=\frac{1}{\lambda_i\times{\theta_{base}^{2i/d}}$
- : rescaled per-dimensional rotation angle
- : rescaling factor for the RoPE dimension
- : a predefined RoPE base value
- : attention head dimension
The constraint to avoid OOD is defined as:
- : rescaling factor for the RoPE dimension
- : target context window size
- : pre-trained context window size
- : theoretical critical dimension
The evolutionary search identifies the real critical dimension and the optimal rescaling factors using the following steps:
- Initialize and rescaling factors
- Generate -token documents
- Compute PPL for each candidate by applying the rescaling factors to the LLM and evaluating the input .
The theta base for is updated after mutation, and NTK scaling is applied to rescale factors in the lower group.
The paper presents experiments on LLaMA3-8B and Phi3-mini-3.8B. The models were extended to a 128k context window and mid-trained on 64 A100 GPUs using a 10B-token dataset. Baselines include state-of-the-art RoPE rescaling methods such as YaRN, NTK, and LongRoPE.
The evaluation included:
- Long-context stress tests, including RULER and Needle in a Haystack
- Real-world long-context benchmarks including LOFT, InfiniteBench, and LongBench
- Standard benchmarks within a 4096-token context.
Key results include:
- {LLMpresso} consistently outperformed prior methods on RULER, achieving superior results across all evaluation lengths within the 128k window
- {LLMpresso} achieves near-perfect accuracy across all evaluation lengths within the 128k context window in the Needle in a Haystack test.
- {LLMpresso} consistently improves performance across all benchmarks, demonstrating strong generalization to practical scenarios, on real-world benchmarks
Ablation studies validated:
- The effectiveness of real critical dimension
- The effectiveness of need-PPL guided search
- The effectiveness of mixed context window training
The authors conclude by noting that LLMpresso uses evolutionary search-guided rescaling and mixed context window training to achieve 128k effective context length with just 10B tokens, retaining 97.6\% of the original short-context performance.