JoMA: Demystifying Multilayer Transformers via JOint Dynamics of MLP and Attention (2310.00535v3)
Abstract: We propose Joint MLP/Attention (JoMA) dynamics, a novel mathematical framework to understand the training procedure of multilayer Transformer architectures. This is achieved by integrating out the self-attention layer in Transformers, producing a modified dynamics of MLP layers only. JoMA removes unrealistic assumptions in previous analysis (e.g., lack of residual connection) and predicts that the attention first becomes sparse (to learn salient tokens), then dense (to learn less salient tokens) in the presence of nonlinear activations, while in the linear case, it is consistent with existing works that show attention becomes sparse over time. We leverage JoMA to qualitatively explains how tokens are combined to form hierarchies in multilayer Transformers, when the input tokens are generated by a latent hierarchical generative model. Experiments on models trained from real-world dataset (Wikitext2/Wikitext103) and various pre-trained models (OPT, Pythia) verify our theoretical findings. Code can be found in https://github.com/facebookresearch/luckmatters/tree/yuandong3.
- What learning algorithm is in-context learning? investigations with linear models. arXiv preprint arXiv:2211.15661, 2022.
- A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning, pp. 242–252. PMLR, 2019.
- Exploring length generalization in large language models. arXiv preprint arXiv:2207.04901, 2022.
- A convergence analysis of gradient descent for deep linear neural networks. arXiv preprint arXiv:1810.02281, 2018.
- Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In International Conference on Machine Learning, pp. 322–332. PMLR, 2019.
- Transformers as statisticians: Provable in-context learning with in-context algorithm selection. arXiv preprint arXiv:2306.04637, 2023.
- Hidden progress in deep learning: Sgd learns parities near the computational limit. Advances in Neural Information Processing Systems, 35:21750–21764, 2022.
- Gradient descent with identity initialization efficiently learns positive definite linear transformations by deep residual networks. In International conference on machine learning, pp. 521–530. PMLR, 2018.
- On the ability and limitations of transformers to recognize formal languages. arXiv preprint arXiv:2009.11264, 2020a.
- On the computational power of transformers and its implications in sequence modeling. arXiv preprint arXiv:2006.09286, 2020b.
- Pythia: A suite for analyzing large language models across training and scaling. In International Conference on Machine Learning, pp. 2397–2430. PMLR, 2023.
- Birth of a transformer: A memory viewpoint. arXiv preprint arXiv:2306.00802, 2023.
- Transformers learn through gradual rank increase. arXiv preprint arXiv:2306.07042, 2023.
- Globally optimal gradient descent for a convnet with gaussian inputs. In International conference on machine learning, pp. 605–614. PMLR, 2017.
- On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31, 2018.
- On lazy training in differentiable programming. Advances in neural information processing systems, 32, 2019.
- Universal transformers. arXiv preprint arXiv:1807.03819, 2018.
- Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
- A survey for in-context learning. arXiv preprint arXiv:2301.00234, 2022.
- An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
- Gradient descent learns one-hidden-layer cnn: Don’t be afraid of spurious local minima. In International Conference on Machine Learning, pp. 1339–1348. PMLR, 2018a.
- Gradient descent finds global minima of deep neural networks. In International conference on machine learning, pp. 1675–1685. PMLR, 2019.
- When is a convolutional filter easy to learn? arXiv preprint arXiv:1709.06129, 2017.
- Gradient descent provably optimizes over-parameterized neural networks, 2018b. URL https://arxiv.org/abs/1810.02054.
- Inductive biases and variable creation in self-attention mechanisms. In International Conference on Machine Learning, pp. 5793–5831. PMLR, 2022.
- A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021.
- Modeling from features: a mean-field framework for over-parameterized deep neural networks. In Conference on learning theory, pp. 1887–1936. PMLR, 2021.
- What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583–30598, 2022.
- Learning one convolutional layer with overlapping patches. In International Conference on Machine Learning, pp. 1783–1791. PMLR, 2018.
- Infinite attention: Nngp and ntk for deep attention networks. In International Conference on Machine Learning, pp. 4376–4386. PMLR, 2020.
- Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
- Vision transformers provably learn spatial structure. Advances in Neural Information Processing Systems, 35:37822–37836, 2022.
- A theoretical understanding of shallow vision transformers: Learning, generalization, and sample complexity. In The Eleventh International Conference on Learning Representations, 2023a. URL https://openreview.net/forum?id=jClGv3Qjhb.
- The closeness of in-context learning and weight shifting for softmax regression. arXiv preprint arXiv:2304.13276, 2023b.
- Learning overparameterized neural networks via stochastic gradient descent on structured data. Advances in neural information processing systems, 31, 2018.
- How do transformers learn topic structure: Towards a mechanistic understanding. arXiv preprint arXiv:2303.04245, 2023c.
- On the expressive power of self-attention matrices. arXiv preprint arXiv:2106.03764, 2021.
- Towards understanding the importance of shortcut connections in residual networks. Advances in neural information processing systems, 32, 2019.
- A mean field analysis of deep resnet and beyond: Towards provably optimization via overparameterization from depth. In International Conference on Machine Learning, pp. 6426–6436. PMLR, 2020.
- A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33):E7665–E7671, 2018.
- Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.
- A rigorous framework for the mean field limit of multilayer neural networks. arXiv preprint arXiv:2001.11443, 2020.
- In-context learning and induction heads. arXiv preprint arXiv:2209.11895, 2022.
- OpenAI. Gpt-4 technical report, 2023.
- Toward moderate overparameterization: Global convergence guarantees for training shallow neural networks. IEEE Journal on Selected Areas in Information Theory, 1(1):84–105, 2020.
- On the role of attention in prompt-tuning. arXiv preprint arXiv:2306.03435, 2023.
- Attention is turing complete. The Journal of Machine Learning Research, 22(1):3463–3497, 2021.
- Approximating how single head attention learns. arXiv preprint arXiv:2103.07601, 2021.
- Mahdi Soltanolkotabi. Learning relus via gradient descent. Advances in neural information processing systems, 30, 2017.
- Transformers as support vector machines. arXiv preprint arXiv:2308.16898, 2023a.
- Max-margin token selection in attention mechanism. CoRR, 2023b.
- Yuandong Tian. An analytical formula of population gradient for two-layered relu network and its applications in convergence and critical point analysis. In International conference on machine learning, pp. 3404–3413. PMLR, 2017.
- Yuandong Tian. Understanding the role of nonlinearity in training dynamics of contrastive learning. arXiv preprint arXiv:2206.01342, 2022.
- Yuandong Tian. Understanding the role of nonlinearity in training dynamics of contrastive learning. ICLR, 2023.
- Understanding self-supervised learning with dual deep networks. arXiv preprint arXiv:2010.00578, 2020.
- Scan and snap: Understanding training dynamics and token composition in 1-layer transformer, 2023.
- Attention is all you need. 2017. URL https://arxiv.org/pdf/1706.03762.pdf.
- Transformers learn in-context by gradient descent. arXiv preprint arXiv:2212.07677, 2022.
- Over-parameterization exponentially slows down gradient descent for learning a single neuron. arXiv preprint arXiv:2302.10034, 2023.
- Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer. arXiv preprint arXiv:2203.03466, 2022.
- Self-attention networks can process bounded hierarchical languages. arXiv preprint arXiv:2105.11115, 2021.
- Are transformers universal approximators of sequence-to-sequence functions? arXiv preprint arXiv:1912.10077, 2019.
- Why are adaptive methods good for attention models? Advances in Neural Information Processing Systems, 33:15383–15393, 2020.
- Trained transformers learn linear models in-context. arXiv preprint arXiv:2306.09927, 2023.
- Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
- Do transformers parse while predicting the masked word? arXiv preprint arXiv:2303.08117, 2023.
- Toward understanding the importance of noise in training neural networks. In International Conference on Machine Learning, pp. 7594–7602. PMLR, 2019.
- Gradient descent optimizes over-parameterized deep relu networks. Machine learning, 109:467–492, 2020.