Trained Transformers Learn Linear Models In-Context (2306.09927v3)
Abstract: Attention-based neural networks such as transformers have demonstrated a remarkable ability to exhibit in-context learning (ICL): Given a short prompt sequence of tokens from an unseen task, they can formulate relevant per-token and next-token predictions without any parameter updates. By embedding a sequence of labeled training data and unlabeled test data as a prompt, this allows for transformers to behave like supervised learning algorithms. Indeed, recent work has shown that when training transformer architectures over random instances of linear regression problems, these models' predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of ICL in transformers with a single linear self-attention layer trained by gradient flow on linear regression tasks. We show that despite non-convexity, gradient flow with a suitable random initialization finds a global minimum of the objective function. At this global minimum, when given a test prompt of labeled examples from a new prediction task, the transformer achieves prediction error competitive with the best linear predictor over the test prompt distribution. We additionally characterize the robustness of the trained transformer to a variety of distribution shifts and show that although a number of shifts are tolerated, shifts in the covariate distribution of the prompts are not. Motivated by this, we consider a generalized ICL setting where the covariate distributions can vary across prompts. We show that although gradient flow succeeds at finding a global minimum in this setting, the trained transformer is still brittle under mild covariate shifts. We complement this finding with experiments on large, nonlinear transformer architectures which we show are more robust under covariate shifts.
- “A Mechanism for Sample-Efficient In-Context Learning for Sparse Retrieval Tasks” In Preprint, arXiv:2305.17040, 2023
- “Transformers learn to implement preconditioned gradient descent for in-context learning” In Preprint, arXiv:2306.00297, 2023
- Kabir Ahuja, Madhur Panwar and Navin Goyal “In-Context Learning through the Bayesian Prism” In Preprint, arXiv:2306.04891, 2023
- “A Closer Look at In-Context Learning under Distribution Shifts” In Preprint, arXiv:2305.16704, 2023
- “What learning algorithm is in-context learning? Investigations with linear models” In arXiv preprint arXiv:2211.15661, 2022
- “Exploring Length Generalization in Large Language Models” In Advances in Neural Information Processing Systems (NeurIPS), 2022
- Sanjeev Arora, Nadav Cohen and Elad Hazan “On the optimization of deep networks: Implicit acceleration by overparameterization” In International Conference on Machine Learning, 2018, pp. 244–253
- “Implicit regularization in deep matrix factorization” In Advances in Neural Information Processing Systems 32, 2019
- “On the implicit bias of initialization shape: Beyond infinitesimal mirror descent” In International Conference on Machine Learning, 2021, pp. 468–477
- “Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection” In Preprint, arXiv:2306.04637, 2023
- Mohamed Ali Belabbas “On implicit regularization: Morse functions and applications to matrix factorization” In arXiv preprint arXiv:2001.04264, 2020
- Satwik Bhattamishra, Arkil Patel and Navin Goyal “On the computational power of transformers and its implications in sequence modeling” In arXiv preprint arXiv:2006.09286, 2020
- Yuejie Chi, Yue M Lu and Yuxin Chen “Nonconvex optimization meets low-rank matrix factorization: An overview” In IEEE Transactions on Signal Processing 67.20 IEEE, 2019, pp. 5239–5269
- “Why Can GPT Learn In-Context? Language Models Secretly Perform Gradient Descent as Meta Optimizers” In arXiv preprint arXiv:2212.10559, 2022
- “Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context” In Association for Computational Linguistics (ACL), 2019
- “Universal Transformers”, 2019 arXiv:1807.03819 [cs.CL]
- “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale” In International Conference on Learning Representations (ICLR), 2021
- Simon S Du, Wei Hu and Jason D Lee “Algorithmic regularization in learning deep homogeneous models: Layers are automatically balanced” In Advances in neural information processing systems 31, 2018
- “Inductive biases and variable creation in self-attention mechanisms” In International Conference on Machine Learning, 2022
- “What can transformers learn in-context? a case study of simple function classes” In arXiv preprint arXiv:2208.01066, 2022
- “Implicit regularization in matrix factorization” In Advances in Neural Information Processing Systems 30, 2017
- “In-Context Learning of Large Language Models Explained as Kernel Regression”, 2023 arXiv:2305.12766 [cs.CL]
- Samy Jelassi, Michael Sander and Yuanzhi Li “Vision transformers provably learn spatial structure” In Advances in Neural Information Processing Systems 35, 2022, pp. 37822–37836
- “Understanding incremental learning of gradient descent: A fine-grained analysis of matrix sensing” In arXiv preprint arXiv:2301.11500, 2023
- Diederik P Kingma and Jimmy Ba “Adam: A method for stochastic optimization” In arXiv preprint arXiv:1412.6980, 2014
- “The Closeness of In-Context Learning and Weight Shifting for Softmax Regression” In arXiv preprint arXiv:2304.13276, 2023
- “Transformers as Algorithms: Generalization and Stability in In-context Learning” In arXiv preprint arXiv:2301.07067, 2023
- Yuanzhi Li, Tengyu Ma and Hongyang Zhang “Algorithmic regularization in over-parameterized matrix sensing and neural networks with quadratic activations” In Conference On Learning Theory, 2018, pp. 2–47
- Yuchen Li, Yuanzhi Li and Andrej Risteski “How do transformers learn topic structure: Towards a mechanistic understanding” In arXiv preprint arXiv:2303.04245, 2023
- Zhiyuan Li, Yuping Luo and Kaifeng Lyu “Towards resolving the implicit bias of gradient descent for matrix factorization: Greedy low-rank learning” In arXiv preprint arXiv:2012.09839, 2020
- Valerii Likhosherstov, Krzysztof Choromanski and Adrian Weller “On the expressive power of self-attention matrices” In arXiv preprint arXiv:2106.03764, 2021
- “Transformers Learn Shortcuts to Automata” In International Conference on Learning Representations (ICLR), 2023
- “On a product of positive semidefinite matrices” In Linear algebra and its applications 295.1-3 Elsevier, 1999, pp. 3–6
- “An Isserlis’ theorem for mixed Gaussian variables: Application to the auto-bispectral density” In Journal of Statistical Physics 136 Springer, 2009, pp. 89–102
- “Rethinking the Role of Demonstrations: What Makes In-Context Learning Work?” In arXiv preprint arXiv:2202.12837, 2022
- OpenAI “GPT-4 Technical Report”, 2023 arXiv:2303.08774 [cs.CL]
- “Transformers learn in-context by gradient descent” In arXiv preprint arXiv:2212.07677, 2022
- Jorge Pérez, Javier Marinković and Pablo Barceló “On the turing completeness of modern neural network architectures” In arXiv preprint arXiv:1901.03429, 2019
- Kaare Brandt Petersen and Michael Syskind Pedersen “The matrix cookbook” In Technical University of Denmark 7.15, 2008, pp. 510
- “Improving language understanding by generative pre-training” OpenAI, 2018
- “Language models are unsupervised multitask learners” In OpenAI blog 1.8, 2019, pp. 9
- Mahdi Soltanolkotabi, Dominik Stöger and Changzhi Xie “Implicit Balancing and Regularization: Generalization and Convergence Guarantees for Overparameterized Asymmetric Matrix Sensing” In arXiv preprint arXiv:2303.14244, 2023
- Asher Trockman and J Zico Kolter “Mimetic Initialization of Self-Attention Layers” In arXiv preprint arXiv:2305.09828, 2023
- “Attention is all you need” In Advances in Neural Information Processing Systems 30, 2017
- Xinyi Wang, Wanrong Zhu and William Yang Wang “Large Language Models Are Implicitly Topic Models: Explaining and Finding Good Demonstrations for In-Context Learning” In arXiv preprint arXiv:2301.11916, 2023
- Gian-Carlo Wick “The evaluation of the collision matrix” In Physical review 80.2 APS, 1950, pp. 268
- “Transformers: State-of-the-art natural language processing” In Proceedings of the 2020 conference on empirical methods in natural language processing: system demonstrations, 2020, pp. 38–45
- “An explanation of in-context learning as implicit bayesian inference” In arXiv preprint arXiv:2111.02080, 2021
- “Are transformers universal approximators of sequence-to-sequence functions?” In arXiv preprint arXiv:1912.10077, 2019
- “O (n) connections are expressive enough: Universal approximability of sparse transformers” In Advances in Neural Information Processing Systems 33, 2020, pp. 13783–13794
- “What and How does In-Context Learning Learn? Bayesian Model Averaging, Parameterization, and Generalization” In Preprint, arXiv:2305.19420, 2023