Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
102 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Infinite Limits of Multi-head Transformer Dynamics (2405.15712v2)

Published 24 May 2024 in stat.ML, cond-mat.dis-nn, and cs.LG

Abstract: In this work, we analyze various scaling limits of the training dynamics of transformer models in the feature learning regime. We identify the set of parameterizations that admit well-defined infinite width and depth limits, allowing the attention layers to update throughout training--a relevant notion of feature learning in these models. We then use tools from dynamical mean field theory (DMFT) to analyze various infinite limits (infinite key/query dimension, infinite heads, and infinite depth) which have different statistical descriptions depending on which infinite limit is taken and how attention layers are scaled. We provide numerical evidence of convergence to the limits and discuss how the parameterization qualitatively influences learned features.

Analysis of the Scaling Limits in Transformer Models

This paper examines the scaling limits inherent to transformer models, with a particular emphasis on their training dynamics in the feature learning regime. The authors meticulously analyze the infinite width and depth limits to discern how different parameterizations influence model behavior during training, using tools from dynamical mean field theory (DMFT) to elucidate these dynamics.

The paper is motivated by the persistent challenge of understanding the behavior of increasingly large transformer models, a class of architectures that has markedly improved performance in fields such as computer vision and LLMing. By identifying parameterizations that maintain stable training dynamics across scales—such as the mean field parameterization (μP)—the authors aim to provide theoretical groundwork that could guide empirical practices.

Contributions

The paper makes several notable contributions:

  1. DMFT Derivation for Transformer Models: The authors derive the DMFT for randomly initialized transformers, particularly focusing on the key/query dimension NN, head count H\mathcal{H}, and depth LL.
  2. Necessary Scaling for Infinite NN Limit: It is analytically demonstrated that the NN \to \infty limit necessitates μP scaling of the key/query inner product by $1/N$. This finding holds even when keys and queries are reparameterized to decrease gradient descent update sizes.
  3. Head Dynamics in NN \to \infty Limit: The paper shows that in the NN \to \infty limit, multi-head self-attention effectively collapses to single-head attention, as all heads follow identical dynamics.
  4. Addressing Head Collapse: To mitigate the collapse of multi-head attention, the authors analyze the H\mathcal{H} \to \infty limit at finite NN. Here, they find that attention dynamics remain distributed across heads, leading to deterministically evolving training dynamics.
  5. Large Depth Limits: The paper further explores the large depth limits of transformers with residual branch scaling, identifying the tension between maintaining initial kernel structure and enabling feature learning within multi-head self-attention (MHSA) and multi-layer perceptron (MLP) blocks.

Implications and Theoretical Insights

Infinite Width and Head Limits

The paper's exploration of the infinite key/query dimension NN \to \infty highlights a critical aspect of transformer training: μP scaling is essential to maintain stable updates. Without this scaling, the dynamical variables would diverge, fundamentally altering the training process.

One particularly intriguing result is the effective collapse of multi-head attention to single-head dynamics as NN grows. This could imply diminishing returns when increasing NN unless counterbalanced by scaling the number of heads independently. The authors thus pivot to investigating the H\mathcal{H} \to \infty limit, arguing that distributed attention dynamics maintain diversity across heads and lead to deterministic training behavior.

Large Depth Limits

For the large depth limit LL \to \infty, the paper identifies two regimes based on the scaling parameter αL\alpha_L. With αL=1\alpha_L = 1, updates to MHSA and MLP blocks persist throughout training, albeit initial representations lose structure. Conversely, αL=12\alpha_L = \frac{1}{2} preserves initial kernel structures but results in frozen weights within residual blocks, reminiscent of shallow feature learning but in the context of deep networks.

Empirical Evidence and Practical Considerations

The authors validate their theoretical models through empirical investigations on CIFAR-5M and the C4 dataset, demonstrating that scaling different model parameters produces varying degrees of performance improvement. Notably, increasing key/query dimensions (NN) under μP scaling improved stability across model scales, while scaling the number of heads (H\mathcal{H}) or layer depth (LL) provided different benefits depending on the specific parameterization.

Conclusion and Future Work

This paper provides a comprehensive analysis of the scaling limits in transformer models using DMFT, presenting important insights into the stable convergence of training dynamics in large-scale settings. Future work could extend these findings to different optimizer settings, particularly Adam, and explore the interplay between model size and training duration to optimize compute resources effectively.

The research underscores the nuanced trade-offs between various scaling parameters, painting a detailed picture that can guide the development of more efficient, scalable transformer architectures in the future.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (36)
  1. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF international conference on computer vision, pages 10012–10022, 2021.
  2. A survey on vision transformer. IEEE transactions on pattern analysis and machine intelligence, 45(1):87–110, 2022.
  3. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  4. Scenic: A jax library for computer vision research and beyond. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 21393–21398, 2022.
  5. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  6. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020.
  7. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  8. Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.
  9. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
  10. Tuning large neural networks via zero-shot hyperparameter transfer. Advances in Neural Information Processing Systems, 34:17084–17097, 2021.
  11. Depthwise hyperparameter transfer in residual networks: Dynamics and scaling limit. In The Twelfth International Conference on Learning Representations, 2024a. URL https://openreview.net/forum?id=KZJehvRKGD.
  12. Feature learning in infinite depth neural networks. In The Twelfth International Conference on Learning Representations, 2023.
  13. Mean-field theory of two-layers neural networks: dimension-free bounds and kernel limit. In Conference on Learning Theory, pages 2388–2464. PMLR, 2019.
  14. Tensor programs iv: Feature learning in infinite-width neural networks. In International Conference on Machine Learning, pages 11727–11737. PMLR, 2021.
  15. Self-consistent dynamical field theory of kernel evolution in wide neural networks. Advances in Neural Information Processing Systems, 35:32240–32256, 2022a.
  16. Feature-learning networks are consistent across widths at realistic scales, 2023.
  17. Dynamics of finite width kernel and prediction fluctuations in mean field neural networks. arXiv preprint arXiv:2304.03408, 2023.
  18. Infinite attention: Nngp and ntk for deep attention networks. In International Conference on Machine Learning, pages 4376–4386. PMLR, 2020.
  19. Effective theory of transformers at initialization, 2023.
  20. Attention is not all you need: Pure attention loses rank doubly exponentially with depth. In International Conference on Machine Learning, pages 2793–2803. PMLR, 2021.
  21. Signal propagation in transformers: Theoretical perspectives and the role of rank collapse. Advances in Neural Information Processing Systems, 35:27198–27211, 2022.
  22. Simplifying transformer blocks. arXiv preprint arXiv:2311.01906, 2023.
  23. Geometric dynamics of signal propagation predict trainability of transformers, 2024.
  24. The shaped transformer: Attention models in the infinite depth-and-width limit. Advances in Neural Information Processing Systems, 36, 2024.
  25. Soufiane Hayou. On the infinite-depth limit of finite-width neural networks. Transactions on Machine Learning Research, 2023. ISSN 2835-8856. URL https://openreview.net/forum?id=RbLsYz1Az9.
  26. Neural signature kernels as infinite-width-depth-limits of controlled resnets. arXiv preprint arXiv:2303.17671, 2023.
  27. On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31, 2018.
  28. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018.
  29. Statistical dynamics of classical systems. Physical Review A, 8(1):423, 1973.
  30. The influence of learning rule on representation dynamics in wide neural networks. arXiv preprint arXiv:2210.02157, 2022b.
  31. The deep bootstrap framework: Good online learners are good offline generalizers. arXiv preprint arXiv:2010.08127, 2020.
  32. A dynamical model of neural scaling laws, 2024b.
  33. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of machine learning research, 21(140):1–67, 2020.
  34. Getting vit in shape: Scaling laws for compute-optimal model design. Advances in Neural Information Processing Systems, 36, 2024.
  35. Adaptive optimization in the ∞\infty∞-width limit. In The Eleventh International Conference on Learning Representations, 2022.
  36. The neural covariance sde: Shaped infinite depth-and-width networks at initialization. Advances in Neural Information Processing Systems, 35:10795–10808, 2022.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (3)
  1. Blake Bordelon (27 papers)
  2. Hamza Tahir Chaudhry (4 papers)
  3. Cengiz Pehlevan (81 papers)
Citations (3)