Exact Conversion of In-Context Learning to Model Weights in Linearized-Attention Transformers
(2406.02847v2)
Published 5 Jun 2024 in cs.LG and stat.ML
Abstract: In-Context Learning (ICL) has been a powerful emergent property of LLMs that has attracted increasing attention in recent years. In contrast to regular gradient-based learning, ICL is highly interpretable and does not require parameter updates. In this paper, we show that, for linearized transformer networks, ICL can be made explicit and permanent through the inclusion of bias terms. We mathematically demonstrate the equivalence between a model with ICL demonstration prompts and the same model with the additional bias terms. Our algorithm (ICLCA) allows for exact conversion in an inexpensive manner. Existing methods are not exact and require expensive parameter updates. We demonstrate the efficacy of our approach through experiments that show the exact incorporation of ICL tokens into a linear transformer. We further suggest how our method can be adapted to achieve cheap approximate conversion of ICL tokens, even in regular transformer networks that are not linearized. Our experiments on GPT-2 show that, even though the conversion is only approximate, the model still gains valuable context from the included bias terms.
This paper addresses the limitations of In-Context Learning (ICL) in LLMs, specifically that ICL information is temporary and requires computationally expensive re-prompting for each inference. Existing methods to make ICL permanent, such as context distillation, rely on costly gradient-based fine-tuning and lack theoretical guarantees. The authors propose a novel approach that leverages the structure of linearized attention mechanisms to achieve exact conversion of ICL prompts into model weights.
The core insight is that in linearized attention models, the contribution of past tokens to the attention output can be factored into a "Key-Value matrix" which is a sum over previous token representations. When ICL tokens are prepended to the input sequence, they modify this Key-Value matrix for subsequent input tokens. The paper demonstrates that, for linearized attention layers satisfying certain assumptions (autoregressive property, relative positional encoding like RoPE, and a separation property for the normalization term), the effect of appending ICL tokens can be precisely captured by adding bias terms to the linearized attention layer.
The proposed method, called ICL Conversion Algorithm (ICLCA), involves calculating the specific contribution of the ICL tokens to both the Key-Value sum and the normalization factor in each linearized attention layer. These contributions are then added as new bias terms (bKV′ and bD′) to the existing biases (bKV and bD). The updated biases bKV′=RΘ,−MdKbKV+j=1∑MRΘ,j−MdKϕ(Kj′)Vj′T and bD′=bD+D2∗(X′) (where X′ is the ICL prompt, M is its length, ϕ is the kernel feature map, K′ and V′ are key and value projections of X′, RΘ,mdK is the RoPE matrix, and D2∗ is a function related to the normalizing term) exactly replicate the effect of having the ICL prompt X′ prepended to the input X.
The ICLCA is computationally inexpensive. It involves a single pass through the model with the ICL prompt to calculate the necessary terms and then a linear update to the bias parameters. The added bias terms are a small fraction (around 1%) of the total model parameters, making storage efficient. This contrasts sharply with fine-tuning methods that require many gradient steps and potentially large datasets. The conversion is exact, meaning the model with the updated biases produces the same output as the original model given the input prompt without the ICL prefix.
For standard transformers with softmax attention, the authors propose an approximate conversion method (ICLAA). This method first approximates the softmax attention using a kernel function (similar to methods used in linear transformers like Performers (Choromanski et al., 2020)). It then applies the logic of ICLCA to this approximated linearized attention to calculate bias terms corresponding to the ICL prompt. These biases are then introduced back into the original softmax attention mechanism in an approximate manner. The accuracy of this approximate conversion depends on how well the chosen kernel approximates the softmax function.
The paper provides experimental validation:
Exact Conversion Verification: Experiments on linear attention transformers with RoPE of various sizes (up to 1.98B parameters) demonstrate that the converted model's logits match the original model's logits with ICL prefix exactly, up to numerical rounding errors (relative errors are on the order of 10−6).
Induction Head Task: On a synthetic task designed to test ICL, a converted linear attention model achieved 99.95% in-context accuracy, exactly matching the performance of the original model with the ICL prompt and significantly outperforming the original model without the prompt (2.23% accuracy). Training a model architecture with the proposed bias terms from scratch also showed faster convergence.
Approximate Conversion (GPT-2): Experiments on a pretrained GPT-2 model (with softmax attention) showed that the approximate conversion method reduced the relative error of output logits between the original model with ICL and the converted model without ICL from 16.56% to 9.17%. Generated text examples qualitatively suggest that the converted model gains valuable context from the added biases.
From an implementation perspective, ICLCA provides a direct method to permanently inject context into linear attention models by simply computing and adding specific bias terms per attention layer (and potentially per head in multi-head attention). Algorithm 1 provides the step-by-step process. For ICLAA on softmax models (Algorithm 2), the key challenge lies in choosing an appropriate kernel approximation ϕ that works well for the specific pretrained model. Once ϕ is chosen, the bias calculation follows a similar pattern based on the ICL tokens and the model's learned weights.
The paper highlights the potential for this method to save computation during inference by eliminating the need to repeatedly process long ICL prompts. It also suggests applications in ICL-guided fine-tuning, where the bias conversion could serve as a cheap initialization step. Furthermore, the connection of linearized transformers to RNNs, where ICL conversion is equivalent to modifying RNN hidden state initialization, opens possibilities for extending this approach to emerging RNN-based architectures like Mamba (Gu et al., 2023).