Weight decay induces low-rank attention layers (2410.23819v1)
Abstract: The effect of regularizers such as weight decay when training deep neural networks is not well understood. We study the influence of weight decay as well as $L2$-regularization when training neural network models in which parameter matrices interact multiplicatively. This combination is of particular interest as this parametrization is common in attention layers, the workhorse of transformers. Here, key-query, as well as value-projection parameter matrices, are multiplied directly with each other: $W_KTW_Q$ and $PW_V$. We extend previous results and show on one hand that any local minimum of a $L2$-regularized loss of the form $L(AB\top) + \lambda (|A|2 + |B|2)$ coincides with a minimum of the nuclear norm-regularized loss $L(AB\top) + \lambda|AB\top|_*$, and on the other hand that the 2 losses become identical exponentially quickly during training. We thus complement existing works linking $L2$-regularization with low-rank regularization, and in particular, explain why such regularization on the matrix product affects early stages of training. Based on these theoretical insights, we verify empirically that the key-query and value-projection matrix products $W_KTW_Q, PW_V$ within attention layers, when optimized with weight decay, as usually done in vision tasks and LLMling, indeed induce a significant reduction in the rank of $W_KTW_Q$ and $PW_V$, even in fully online training. We find that, in accordance with existing work, inducing low rank in attention matrix products can damage LLM performance, and observe advantages when decoupling weight decay in attention layers from the rest of the parameters.
- Twan van Laarhoven. L2 regularization versus batch and weight normalization, 2017.
- Understanding deep learning (still) requires rethinking generalization. Commun. ACM, 64(3):107–115, feb 2021. ISSN 0001-0782. doi: 10.1145/3446776. URL https://doi.org/10.1145/3446776.
- Three mechanisms of weight decay regularization. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=B1lz-3Rct7.
- Decoupled weight decay regularization. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=Bkg6RiCqY7.
- On the overlooked pitfalls of weight decay and how to mitigate them: A gradient-norm perspective. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=vnGcubtzR1.
- Why do we need weight decay in modern deep learning?, 2023.
- David J C Mackay. Probable networks and plausible predictions — a review of practical bayesian methods for supervised neural networks. Network: Computation in Neural Systems, 6(3):469–505, 1995.
- A simple weight decay can improve generalization. In Neural Information Processing Systems, 1991. URL https://api.semanticscholar.org/CorpusID:10137788.
- spred: Solving l1 penalty with SGD. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 43407–43422. PMLR, 23–29 Jul 2023.
- Implicit regularization in deep matrix factorization. CoRR, abs/1905.13655, 2019. URL http://arxiv.org/abs/1905.13655.
- Towards resolving the implicit bias of gradient descent for matrix factorization: Greedy low-rank learning. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=AHOs7Sm5H7R.
- Implicit regularization in deep learning may not be explainable by norms, 2020.
- Implicit regularization in matrix factorization, 2017.
- Attention is all you need, 2017.
- Formal algorithms for transformers, 2022.
- Rank, trace-norm and max-norm. In Peter Auer and Ron Meir, editors, Learning Theory, pages 545–560, Berlin, Heidelberg, 2005. Springer Berlin Heidelberg. ISBN 978-3-540-31892-7.
- Ryan J Tibshirani. Equivalences between sparse models and neural networks. Working Notes. URL https://www. stat. cmu. edu/ryantibs/papers/sparsitynn. pdf, 2021.
- Language models are few-shot learners. arXiv preprint arXiv:2005.14165, 2020.
- Llama 2: Open foundation and fine-tuned chat models, 2023.
- An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=YicbFdNTTy.
- Lora: Low-rank adaptation of large language models. CoRR, abs/2106.09685, 2021a. URL https://arxiv.org/abs/2106.09685.
- Visual transformers: Token-based image representation and processing for computer vision, 2020.
- Guaranteed matrix completion via non-convex factorization. IEEE Transactions on Information Theory, 62(11):6535–6579, November 2016. ISSN 1557-9654. doi: 10.1109/tit.2016.2598574. URL http://dx.doi.org/10.1109/TIT.2016.2598574.
- The power of convex relaxation: Near-optimal matrix completion, 2009.
- Low rank regularization: A review. Neural Networks, 136:218–232, 2021b. ISSN 0893-6080. doi: https://doi.org/10.1016/j.neunet.2020.09.021. URL https://www.sciencedirect.com/science/article/pii/S089360802030352X.
- A geometric analysis of neural collapse with unconstrained features, 2021.
- Saddle-to-saddle dynamics in deep linear networks: Small initialization training, symmetry, and sparsity, 2022.
- Arthur Jacot. Implicit bias of large depth networks: a notion of rank for nonlinear functions, 2023.
- Representation costs of linear neural networks: Analysis and design. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 26884–26896. Curran Associates, Inc., 2021. URL https://proceedings.neurips.cc/paper_files/paper/2021/file/e22cb9d6bbb4c290a94e4fff4d68a831-Paper.pdf.
- Characterizing the implicit bias of regularized sgd in rank minimization, 2023.
- Implicit bias of sgd in l2subscript𝑙2l_{2}italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-regularized linear dnns: One-way jumps from high to low rank, 2023.
- Initialization and regularization of factorized neural layers, 2022.
- Low-rank bottleneck in multi-head attention models, 2020.
- The truth is in there: Improving reasoning in language models with layer-selective rank reduction, 2023.
- Exact solutions of a deep linear network. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=X6bp8ri8dV.
- Exact solutions to the nonlinear dynamics of learning in deep linear neural networks, 2013. URL https://arxiv.org/abs/1312.6120.
- The pile: an 800GB dataset of diverse text for language modeling. arXiv preprint arXiv:2101.00027, 2020.
- Training a vision transformer from scratch in less than 24 hours with 1 gpu, 2022.
- Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
- Hungry hungry hippos: Towards language modeling with state space models, 2023.
- Hyena hierarchy: Towards larger convolutional language models, 2023.
- Uncovering mesa-optimization algorithms in transformers, 2023.
- The transient nature of emergent in-context learning in transformers, 2023.
- Paul Lévy. Sur certains processus stochastiques homogènes. Compositio Mathematica, 7:283–339, 1940. URL http://www.numdam.org/item/CM_1940__7__283_0/.
- Free states of the canonical anticommutation relations. Communications in Mathematical Physics, 16(1):1–33, 1970.
- Adam: a method for stochastic optimization. In International Conference on Learning Representations, 2015.
- Digital selection and analogue amplification coexist in a cortex-inspired silicon circuit. Nature, 405(6789):947–951, 2000.