Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
80 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Linear Transformers are Versatile In-Context Learners (2402.14180v2)

Published 21 Feb 2024 in cs.LG

Abstract: Recent research has demonstrated that transformers, particularly linear attention models, implicitly execute gradient-descent-like algorithms on data provided in-context during their forward inference step. However, their capability in handling more complex problems remains unexplored. In this paper, we prove that each layer of a linear transformer maintains a weight vector for an implicit linear regression problem and can be interpreted as performing a variant of preconditioned gradient descent. We also investigate the use of linear transformers in a challenging scenario where the training data is corrupted with different levels of noise. Remarkably, we demonstrate that for this problem linear transformers discover an intricate and highly effective optimization algorithm, surpassing or matching in performance many reasonable baselines. We analyze this algorithm and show that it is a novel approach incorporating momentum and adaptive rescaling based on noise levels. Our findings show that even linear transformers possess the surprising ability to discover sophisticated optimization strategies.

This paper "Linear Transformers are Versatile In-Context Learners" (Vladymyrov et al., 21 Feb 2024 ) investigates the in-context learning capabilities of linear transformers, particularly focusing on their ability to implicitly implement sophisticated optimization algorithms. The core finding is that even simple linear transformers, when trained on data provided within the input sequence, learn to solve the task by executing a variant of gradient descent, and this capability extends to discovering more complex, adaptive algorithms for challenging problems like noisy linear regression with mixed noise levels.

The paper establishes theoretically that any linear transformer layer maintains an implicit linear model of the input data (xi,yi)(x_i, y_i). Specifically, if the input tokens to a layer are (xi,yi)(x_i^{\ell}, y_i^{\ell}), the output tokens (xi+1,yi+1)(x_i^{\ell+1}, y_i^{\ell+1}) can be expressed as a linear transformation of the original input (xi,yi)(x_i, y_i):

xi+1=Mxi+yiux_i^{\ell+1} = M^{\ell} x_i + y_i u^{\ell}

yi+1=ayiw,xiy_i^{\ell+1} = a^{\ell} y_i - \langle w^{\ell}, x_i \rangle

for some matrices MM^{\ell}, vectors u,wu^{\ell}, w^{\ell}, and scalar aa^{\ell}. These parameters (M,u,a,w)(M^{\ell}, u^{\ell}, a^{\ell}, w^{\ell}) are not static weights but are recursively updated based on the layer's attention parameters (Pk,Qk)(P_k^\ell, Q_k^\ell) and the aggregated statistics of the current layer's input tokens (xj,yj)(x_j^\ell, y_j^\ell), such as xj(xj)\sum x_j^\ell (x_j^\ell)^\top, yjxj\sum y_j^\ell x_j^\ell, etc. The prediction for a query xtx_t is derived from the final layer's output token (xtL,ytL)(x_t^L, y_t^L) as ytL-y_t^L.

This implies that the linear transformer's forward pass on a sequence of data points effectively executes an iterative algorithm, where each layer represents a step. The algorithm is learning to find parameters (aL,wL)(a^L, w^L) for a linear model y^t=aLytwL,xt\hat{y}_t = a^L y_t - \langle w^L, x_t \rangle that predicts the query yty_t. Since the query token has yt=0y_t=0, the prediction simplifies to y^t=wL,xt\hat{y}_t = -\langle w^L, x_t \rangle, meaning the network is learning a weight vector wLw^L to predict yty_t from xtx_t.

A key practical consideration explored is the use of diagonal attention matrices (Qk,Pk)(Q_k^\ell, P_k^\ell) for each head. This restriction significantly simplifies the architecture and computation, reducing the complexity from O(N2)O(N^2) to O(N)O(N) with respect to sequence length NN. The paper shows that for diagonal attention, the layer updates can be re-parameterized by four scalar variables: lxx,lxy,lyx,lyyl_{xx}, l_{xy}, l_{yx}, l_{yy}. These variables effectively control the flow of information between the xx and yy components across layers. Despite this simplification, the diagonal linear transformer (Diag) maintains significant power.

The paper relates the learned algorithm to gradient descent. The GDPP\text{GDPP} model [oswald], a restricted diagonal linear transformer, is shown to implement a form of preconditioned gradient descent. For standard least squares problems, the paper proves GDPP\text{GDPP} can achieve high accuracy in O(logκ+loglog1/ϵ)O(\log \kappa + \log \log 1/\epsilon) steps, where κ\kappa is the condition number of the data covariance matrix, suggesting a second-order optimization behavior similar to Newton's method. This indicates that simple linear transformers can learn efficient algorithms for well-defined problems.

The most compelling demonstration of linear transformers' versatility comes from the experiments on noisy linear regression with mixed noise variance. In this setup, each training sequence has data generated with a different noise level στ2\sigma_\tau^2, drawn from a distribution (e.g., uniform U(0,σmax)U(0, \sigma_{max}) or categorical). The optimal solution for a known στ\sigma_\tau is ridge regression wσ2=(Σ+στ2I)1αw^*_{\sigma^2} = (\Sigma + \sigma_\tau^2 I)^{-1}\alpha. However, the noise level is unknown to the model and varies per sequence. The linear transformer must learn an in-context algorithm that adapts to the noise level of the current sequence to make good predictions.

The paper's reverse-engineering of the learned algorithm in the Diag model reveals mechanisms for noise adaptation:

  1. The lyyl_{yy} term can lead to adaptive rescaling of the yy component based on yy norms (λ)(\lambda), which are correlated with noise levels. A negative lyyl_{yy} effectively shrinks predictions more when the noise is higher, aligning with the behavior of ridge regression which shrinks the OLS solution more for higher regularization (analogous to noise).
  2. The lxyl_{xy} term influences the step size of the implicit gradient descent. Analysis suggests that it helps create a step size that depends on the residual variance ri2\sum r_i^2, another quantity correlated with noise. Higher residual variance leads to a smaller effective step size, again consistent with the intuition of effectively 'early stopping' or regularizing more when noise is high.

In experiments, both the Full and Diag linear transformers significantly outperform standard Ridge Regression baselines (ConstRR, AdaRR) and the simpler GDPP\text{GDPP} model on mixed noise variance problems. The Diag model performs comparably to the Full model across various σmax\sigma_{max} values and number of layers (up to 7 layers were tested). This is a crucial finding for practical implementation, as the efficiency of Diag makes it much more suitable for longer sequences and larger models.

For implementation:

  • The model architecture is a stack of linear self-attention layers. For the diagonal variant, the attention calculation needs to be specialized to use only diagonal parameter matrices or their equivalent scalar reparameterization (lxx,lxy,lyx,lyy)(l_{xx}, l_{xy}, l_{yx}, l_{yy}).
  • Training involves minimizing the prediction error yn+1L-y_{n+1}^L for the query token xtx_t using standard optimizers like Adam, on sequences containing multiple (xi,yi)(x_i, y_i) pairs followed by the query (xt,0)(x_t, 0).
  • The number of layers LL is a hyperparameter corresponding to the number of steps in the learned optimization algorithm. Experiments show performance improves with more layers, but significant gains are seen even with 3-5 layers.
  • The comparison between Diag and Full suggests that for tasks requiring adaptive linear estimation, the computational savings of the diagonal architecture come with minimal performance loss. This is highly relevant for scaling these models to larger data contexts.

The paper demonstrates that linear transformers, even with diagonal constraints, can move beyond simple gradient descent to learn sophisticated, adaptive algorithms directly from data presentation. This ability to learn data-dependent optimization strategies in-context holds promise for applications where models need to quickly adapt to varying data characteristics without explicit retraining or complex meta-learning setups. Future work could explore if this phenomenon generalizes to more complex tasks and model architectures.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (40)
  1. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
  2. Transformers learn to implement preconditioned gradient descent for in-context learning. arXiv preprint arXiv:2306.00297, 2023.
  3. What learning algorithm is in-context learning? investigations with linear models. arXiv preprint arXiv:2211.15661, 2022.
  4. In-Context language learning: Architectures and algorithms. arXiv preprint arXiv:2401.12973, 2024.
  5. Palm 2 technical report. arXiv preprint arXiv:2305.10403, 2023.
  6. Transformers as statisticians: Provable in-context learning with in-context algorithm selection. arXiv preprint arXiv:2306.04637, 2023.
  7. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  8. Data distributional properties drive emergent in-context learning in transformers. Advances in Neural Information Processing Systems, 35:18878–18891, 2022.
  9. Transformers implement functional gradient descent to learn non-linear functions in context. arXiv preprint arXiv:2312.06528, 2023.
  10. Comparison of model selection for regression. Neural computation, 15(7):1691–1714, 2003.
  11. Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
  12. Transformers learn higher-order optimization methods for in-context learning: A study with linear models. arXiv preprint arXiv:2310.17086, 2023.
  13. What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583–30598, 2022.
  14. Looped transformers as programmable computers. arXiv preprint arXiv:2301.13196, 2023.
  15. How do transformers learn in-context beyond simple functions? a case study on learning with representations. arXiv preprint arXiv:2310.10616, 2023.
  16. In-context learning creates task vectors. arXiv preprint arXiv:2310.15916, 2023.
  17. In-context convergence of transformers. arXiv preprint arXiv:2310.05249, 2023.
  18. Risks from learned optimization in advanced machine learning systems. arXiv preprint arXiv:1906.01820, 2019.
  19. Mistral 7b. arXiv preprint arXiv:2310.06825, 2023.
  20. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pp. 5156–5165. PMLR, 2020.
  21. In-context learning in large language models learns label relationships but is not conventional learning. arXiv preprint arXiv:2307.12375, 2023.
  22. Transformers as algorithms: Generalization and stability in in-context learning. In International Conference on Machine Learning, pp. 19565–19594. PMLR, 2023.
  23. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. arXiv preprint arXiv:2307.03576, 2023.
  24. Large language models as general pattern machines. arXiv preprint arXiv:2307.04721, 2023.
  25. In-context learning and induction heads. arXiv preprint arXiv:2209.11895, 2022.
  26. Transformers can optimally learn regression mixture models. arXiv preprint arXiv:2311.08362, 2023.
  27. Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning, pp. 9355–9366. PMLR, 2021.
  28. Do pretrained transformers really learn in-context by gradient descent? arXiv preprint arXiv:2310.08540, 2023.
  29. Max-margin token selection in attention mechanism. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  30. Gemini: a family of highly capable multimodal models. arXiv preprint arXiv:2312.11805, 2023.
  31. Scan and snap: Understanding training dynamics and token composition in 1-layer transformer. arXiv preprint arXiv:2305.16380, 2023a.
  32. Joma: Demystifying multilayer transformers via joint dynamics of mlp and attention. In NeurIPS 2023 Workshop on Mathematics of Modern Machine Learning, 2023b.
  33. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  34. Transformers learn in-context by gradient descent. In International Conference on Machine Learning, pp. 35151–35174. PMLR, 2023a.
  35. Uncovering mesa-optimization algorithms in transformers. arXiv preprint arXiv:2309.05858, 2023b.
  36. Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.
  37. Larger language models do in-context learning differently. arXiv preprint arXiv:2303.03846, 2023.
  38. Transformers are uninterpretable with myopic methods: a case study with bounded dyck grammars. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  39. Pretraining data mixtures enable narrow model selection capabilities in transformer models. arXiv preprint arXiv:2311.00871, 2023.
  40. Trained transformers learn linear models in-context. arXiv preprint arXiv:2306.09927, 2023.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (4)
  1. Max Vladymyrov (18 papers)
  2. Johannes von Oswald (21 papers)
  3. Mark Sandler (66 papers)
  4. Rong Ge (92 papers)
Citations (8)