In-Context Learning of a Linear Transformer Block: Benefits of the MLP Component and One-Step GD Initialization (2402.14951v1)
Abstract: We study the \emph{in-context learning} (ICL) ability of a \emph{Linear Transformer Block} (LTB) that combines a linear attention component and a linear multi-layer perceptron (MLP) component. For ICL of linear regression with a Gaussian prior and a \emph{non-zero mean}, we show that LTB can achieve nearly Bayes optimal ICL risk. In contrast, using only linear attention must incur an irreducible additive approximation error. Furthermore, we establish a correspondence between LTB and one-step gradient descent estimators with learnable initialization ($\mathsf{GD}\text{-}\mathbf{\beta}$), in the sense that every $\mathsf{GD}\text{-}\mathbf{\beta}$ estimator can be implemented by an LTB estimator and every optimal LTB estimator that minimizes the in-class ICL risk is effectively a $\mathsf{GD}\text{-}\mathbf{\beta}$ estimator. Finally, we show that $\mathsf{GD}\text{-}\mathbf{\beta}$ estimators can be efficiently optimized with gradient flow, despite a non-convex training objective. Our results reveal that LTB achieves ICL by implementing $\mathsf{GD}\text{-}\mathbf{\beta}$, and they highlight the role of MLP layers in reducing approximation error.
- Transformers learn to implement preconditioned gradient descent for in-context learning. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
- In-context learning through the Bayesian prism. In The Twelfth International Conference on Learning Representations, 2024.
- What learning algorithm is in-context learning? Investigations with linear models. In The Eleventh International Conference on Learning Representations, 2022.
- Transformers as statisticians: Provable in-context learning with in-context algorithm selection. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
- Understanding in-context learning in transformers and LLMs by learning to learn discrete functions. In The Twelfth International Conference on Learning Representations, 2024.
- Language models are few-shot learners. Advances in Neural Information Processing Systems, 33:1877–1901, 2020.
- Meta-learning via language model in-context tuning. arXiv preprint arXiv:2110.07814, 2021.
- Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- Knowledge neurons in pretrained transformers. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp. 8493–8502, 2022.
- Why can GPT learn in-context? Language models secretly perform gradient descent as meta-optimizers. In Findings of the Association for Computational Linguistics: ACL 2023, 2023.
- Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
- Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning, pp. 1126–1135. PMLR, 2017.
- How does representation impact in-context learning: An exploration on a synthetic task. arXiv preprint arXiv:2309.06054, 2023.
- What can transformers learn in-context? A case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583–30598, 2022.
- Transformer feed-forward layers are key-value memories. Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, 2021.
- How do transformers learn in-context beyond simple functions? A case study on learning with representations. In The Twelfth International Conference on Learning Representations, 2024.
- Supervised pretraining can learn in-context reinforcement learning. arXiv preprint arXiv:2306.14892, 2023.
- Transformers as algorithms: Generalization and stability in in-context learning. In International Conference on Machine Learning, pp. 19565–19594. PMLR, 2023.
- Transformers as decision makers: Provable in-context reinforcement learning via supervised pretraining. In The Twelfth International Conference on Learning Representations, 2024.
- Pre-train, prompt, and predict: A systematic survey of prompting methods in natural language processing. ACM Computing Surveys, 55(9):1–35, 2023.
- One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. In The Twelfth International Conference on Learning Representations, 2024.
- On a product of positive semidefinite matrices. Linear algebra and its applications, 295(1-3):3–6, 1999.
- Locating and editing factual associations in GPT. Advances in Neural Information Processing Systems, 35:17359–17372, 2022.
- Metaicl: Learning to learn in context. arXiv preprint arXiv:2110.15943, 2021.
- OpenAI. GPT-4 technical report, 2023.
- Transformers can optimally learn regression mixture models. arXiv preprint arXiv:2311.08362, 2023.
- The matrix cookbook. Technical University of Denmark, 7(15):510, 2008.
- Improving language understanding by generative pre-training, 2018.
- Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
- Pretraining task diversity and the emergence of non-Bayesian in-context learning for regression. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
- Seber, G. A. A matrix handbook for statisticians. John Wiley & Sons, 2008.
- Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
- Benign overfitting in ridge regression. J. Mach. Learn. Res., 24:123–1, 2023.
- Attention is all you need. Advances in Neural Information Processing Systems, 30, 2017.
- Transformers learn in-context by gradient descent. In International Conference on Machine Learning, pp. 35151–35174. PMLR, 2023.
- Huggingface’s transformers: State-of-the-art natural language processing, 2020.
- How many pretraining tasks are needed for in-context learning of linear regression? In The Twelfth International Conference on Learning Representations, 2024.
- Trained transformers learn linear models in-context. arXiv preprint arXiv:2306.09927, 2023.
- Ruiqi Zhang (58 papers)
- Jingfeng Wu (34 papers)
- Peter L. Bartlett (86 papers)