Implicit Optimization Bias of Next-Token Prediction in Linear Models (2402.18551v2)
Abstract: We initiate an investigation into the optimization properties of next-token prediction (NTP), the dominant training paradigm for modern LLMs. Specifically, we study the structural properties of the solutions selected by gradient-based optimizers among the many possible minimizers of the NTP objective. By framing NTP as cross-entropy minimization across distinct contexts, each tied with a sparse conditional probability distribution across a finite vocabulary of tokens, we introduce "NTP-separability conditions" that enable reaching the data-entropy lower bound. With this setup, and focusing on linear models with fixed context embeddings, we characterize the optimization bias of gradient descent (GD): Within the data subspace defined by the sparsity patterns of distinct contexts, GD selects parameters that equate the logits' differences of in-support tokens to their log-odds. In the orthogonal subspace, the GD parameters diverge in norm and select the direction that maximizes a margin specific to NTP. These findings extend previous research on implicit bias in one-hot classification to the NTP setting, highlighting key differences and prompting further research into the optimization and generalization properties of NTP, irrespective of the specific architecture used to generate the context embeddings.
- Towards understanding sharpness-aware minimization. In International Conference on Machine Learning, pages 639–668. PMLR, 2022.
- Stochastic mirror descent on overparameterized nonlinear models. IEEE Transactions on Neural Networks and Learning Systems, 33(12):7717–7727, 2021.
- What learning algorithm is in-context learning? investigations with linear models. In The Eleventh International Conference on Learning Representations, 2023.
- Taking on the curse of dimensionality in joint distributions using neural networks. IEEE Transactions on Neural Networks, 11(3):550–557, 2000.
- A neural probabilistic language model. Advances in neural information processing systems, 13, 2000.
- Mikhail Belkin. The necessity of machine learning theory in mitigating ai risk. ACM/JMS Journal of Data Science, 2024.
- On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2021.
- On the implicit geometry of cross-entropy parameterizations for label-imbalanced data. In International Conference on Artificial Intelligence and Statistics, pages 10815–10838. PMLR, 2023.
- Benign overfitting in linear regression. arXiv preprint arXiv:1906.11300, 2019.
- Does data interpolation contradict statistical optimality? arXiv preprint arXiv:1806.09471, 2018.
- Eliciting and learning with soft labels from every annotator. In Proceedings of the AAAI Conference on Human Computation and Crowdsourcing, volume 10, pages 40–52, 2022.
- Benign overfitting in adversarially robust linear classification. In Uncertainty in Artificial Intelligence, pages 313–323. PMLR, 2023.
- Risk bounds for over-parameterized maximum margin classification on sub-gaussian mixtures. Advances in Neural Information Processing Systems, 34:8407–8418, 2021.
- On the implicit bias of adam. arXiv preprint arXiv:2309.00079, 2023.
- Provably learning a multi-head attention layer. arXiv preprint arXiv:2402.04084, 2024.
- When does gradient descent with logistic loss find interpolating two-layer networks? The Journal of Machine Learning Research, 22(1):7135–7182, 2021.
- A convergence analysis of approximate message passing with non-separable functions and applications to multi-class classification. arXiv preprint arXiv:2402.08676, 2024.
- Learning curves for the multi-class teacher–student perceptron. Machine Learning: Science and Technology, 4(1):015019, 2023.
- Thomas M Cover. Geometrical and statistical properties of systems of linear inequalities with applications in pattern recognition. IEEE transactions on electronic computers, pages 326–334, 1965.
- The phase transition for the existence of the maximum likelihood estimate in high-dimensional logistic regression. arXiv preprint arXiv:1804.09753, 2018.
- On the optimization and generalization of multi-head attention. arXiv preprint arXiv:2310.12680, 2023.
- A model of double descent for high-dimensional binary linear classification. Information and Inference: A Journal of the IMA, 11(2):435–495, 2022.
- Fast rates for noisy interpolation require rethinking the effect of inductive bias. In International Conference on Machine Learning, pages 5397–5428. PMLR, 2022.
- Inductive biases and variable creation in self-attention mechanisms. arXiv preprint arXiv:2110.10090, 2021.
- Characterizing implicit bias in terms of optimization geometry. In International Conference on Machine Learning, pages 1832–1841. PMLR, 2018.
- Implicit bias of gradient descent on linear convolutional networks. Advances in Neural Information Processing Systems, 31:9461–9471, 2018.
- Surprises in high-dimensional ridgeless least squares interpolation. arXiv preprint arXiv:1903.08560, 2019.
- Pattern classification. Wiley Hoboken, 2000.
- Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
- Gradient descent follows the regularization path for general losses. In Conference on Learning Theory, pages 2109–2136. PMLR, 2020.
- Precise statistical analysis of classification accuracies for adversarial training. The Annals of Statistics, 50(4):2127–2156, 2022.
- Risk and parameter convergence of logistic regression. arXiv preprint arXiv:1803.07300, 2018.
- Directional convergence and alignment in deep learning. Advances in Neural Information Processing Systems, 33:17176–17186, 2020.
- Characterizing the implicit bias via a primal-dual analysis. In Algorithmic Learning Theory, pages 772–804. PMLR, 2021.
- Label-imbalanced and group-sensitive classification under overparameterization. Advances in Neural Information Processing Systems, 34:18970–18983, 2021.
- Phase transitions for one-vs-one and one-vs-all linear separability in multiclass gaussian mixtures. In ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pages 4020–4024, 2021.
- Uniform convergence of interpolators: Gaussian width, norm bounds and benign overfitting. Advances in Neural Information Processing Systems, 34:20657–20668, 2021.
- On the expressive power of self-attention matrices, 2021.
- Transformers as algorithms: Generalization and stability in in-context learning, 2023.
- Gradient descent maximizes the margin of homogeneous neural networks. In International Conference on Learning Representations, 2020.
- The role of regularization in classification of high-dimensional noisy gaussian mixture. arXiv preprint arXiv:2002.11544, 2020.
- Classification vs regression in overparameterized regimes: Does the loss function matter? arXiv preprint arXiv:2005.08054, 2020.
- The generalization error of max-margin linear classifiers: High-dimensional asymptotics in the overparametrized regime. arXiv preprint arXiv:1911.01544, 2019.
- Convergence of gradient descent on separable data. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 3420–3428. PMLR, 2019.
- In search of the real inductive bias: On the role of implicit regularization in deep learning. arXiv preprint arXiv:1412.6614, 2014.
- OpenAI. Openai: Introducing chatgpt, 2022.
- OpenAI. Gpt-4 technical report, 2023.
- On the role of attention in prompt-tuning. In International Conference of Machine Learning (ICML), 2023.
- Human uncertainty makes classification more robust. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 9617–9626, 2019.
- Implicit bias of sgd for diagonal linear networks: a provable benefit of stochasticity. Advances in Neural Information Processing Systems, 34:29218–29230, 2021.
- Improving language understanding by generative pre-training. OpenAI blog, 2018.
- Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
- Margin maximizing loss functions. In NIPS, 2003.
- A precise analysis of phasemax in phase retrieval. In 2018 IEEE International Symposium on Information Theory (ISIT), pages 976–980. IEEE, 2018.
- Mirror descent maximizes generalized margin and can be implemented efficiently. Advances in Neural Information Processing Systems, 35:31089–31101, 2022.
- A modern maximum-likelihood theory for high-dimensional logistic regression. Proceedings of the National Academy of Sciences, page 201810420, 2019.
- Unraveling attention via convex duality: Analysis and interpretations of vision transformers. International Conference on Machine Learning, 2022.
- Claude Elwood Shannon. A mathematical theory of communication. The Bell system technical journal, 27(3):379–423, 1948.
- Ohad Shamir. The implicit bias of benign overfitting. In Conference on Learning Theory, pages 448–478. PMLR, 2022.
- Ambiguity helps: Classification with disagreements in crowdsourced annotations. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 2194–2202, 2016.
- The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research, 19(1):2822–2878, 2018.
- An investigation of why overparameterization exacerbates spurious correlations. In International Conference on Machine Learning, pages 8346–8356. PMLR, 2020.
- Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 2818–2826, 2016.
- Overparameterization improves robustness to covariate shift in high dimensions. Advances in Neural Information Processing Systems, 34:13883–13897, 2021.
- Multinomial logistic regression: Asymptotic normality on null covariates in high-dimensions. arXiv preprint arXiv:2305.17825, 2023.
- Transformers as support vector machines, 2023.
- Max-margin token selection in attention mechanism, 2023.
- Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
- Asymptotic behavior of adversarial training in binary linear classification. IEEE Transactions on Neural Networks and Learning Systems, 2023.
- Scan and snap: Understanding training dynamics and token composition in 1-layer transformer, 2023.
- Joma: Demystifying multilayer transformers via joint dynamics of mlp and attention. arXiv preprint arXiv:2310.00535, 2023.
- R. Vershynin. Lectures in geometric functional analysis. Unpublished manuscript. Available at http://www-personal. umich. edu/romanv/papers/GFA-book/GFA-book. pdf, 2011.
- Transformers learn in-context by gradient descent. ArXiv, abs/2212.07677, 2022.
- Implicit regularization in relu networks with the square loss. In Conference on Learning Theory, pages 4224–4258. PMLR, 2021.
- Benign overfitting in multiclass classification: All roads lead to interpolation. Advances in Neural Information Processing Systems, 34, 2021.
- Precise asymptotic generalization for multiclass classification with overparameterized linear models. Advances in Neural Information Processing Systems, 36, 2024.
- Ke Wang and Christos Thrampoulidis. Binary classification of gaussian mixtures: Abundance of support vectors, benign overfitting and regularization. arXiv preprint arXiv:2011.09148, 2021.
- Understanding deep learning requires rethinking generalization, 2017.
- mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412, 2017.
- Trained transformers learn linear models in-context, 2023.
- On uniform convergence and low-norm interpolation learning. Advances in Neural Information Processing Systems, 33:6867–6877, 2020.