- The paper presents uniform convergence of transformer hidden states and gradients at a rate of O(L⁻¹ + L⁻¹/³H⁻¹/²), independent of input token count.
- It introduces a mean-field analysis using McKean–Vlasov ODEs and leverages martingale techniques and Lions’ derivatives to ensure global Lipschitz continuity of the flow maps.
- A blockwise AdamW variant is shown to remove embedding dimension dependency, enhancing scalability and generalization for deep, attention-only transformer architectures.
Introduction
This work presents a rigorous theoretical treatment of the uniform scaling limits of transformers trained with the AdamW optimizer. The focus is on understanding the behavior of attention-only (unmasked) transformers as both depth (L) and the number of attention heads (H) become large, with an explicit analysis of hidden-state and gradient (adjoint) dynamics under AdamW training. A key novelty is the uniformity of the main convergence results: the bounds established are independent of the number of input tokens and, under an appropriate variant of the optimizer, independent of the token embedding dimension as well. This addresses a significant gap in the existing landscape of deep learning theory, particularly for transformer architectures.
Problem Setting and Mathematical Framework
The authors formulate the hidden-state evolution of deep, wide transformers as an interacting particle system (IPS) governed by a McKean–Vlasov ODE (MVODE) in the large-L, large-H limit. Each attention head's parameters are initialized i.i.d., typically within compact support (e.g., Xavier uniform).
Given the unmasked self-attention mechanism, the update for the i-th token in layer r is:
xr+1i=xri+L1Γ(xri,mrN,νrH),
where mrN is the empirical measure of token embeddings, and νrH the empirical measure of head parameters. They generalize this to a mean-field description with measure-valued mappings and utilize advanced tools from probability (Wasserstein distances) and calculus on measure spaces (Lions' derivative).
Crucially, the work analyzes the joint dynamics of forward (hidden states) and backward (adjoint variables, i.e., gradients via backpropagation) passes, essential for understanding optimization in deep networks.
Main Results
Convergence Theorem
Theorem [Uniform Convergence]:
Under compactly supported i.i.d. parameter initialization and AdamW training (with blockwise or standard variants), the difference between the discrete sequence of layerwise states and the continuum ODE model vanishes as L,H→∞ at rate:
H0
uniformly over initial conditions and, importantly, independently of the number of input tokens H1.
When using a blockwise AdamW variant [xie2024adam], the constants in the uniform convergence bound become independent of the embedding dimension H2.
This is in contrast with prior results, where uniform convergence bounds typically scale with the size of the input set, e.g., with bounds growing like H3 for H4 data points as in [chaintron2026resnets]. Here, by leveraging concentration of measure without covering arguments, the constants avoid such dependency.
Analytical Techniques
- AdamW and Parameter Compactness: AdamW's decoupled weight decay ensures (via Lemma 1) that parameter measures remain within a compact set, controlling the growth of Lipschitz constants for the dynamics throughout training—a crucial guarantee for deep networks under cubic gradients induced by self-attention.
- Flow Maps with Lions’ Derivatives: The mapping from initial condition/measure to solution is proven to be globally Lipschitz using Lions’ (measure) derivatives, with explicit locally uniform bounds.
- Martingale Techniques for Concentration: Uniform bounds on the discrepancy are obtained via martingale inequalities and Rademacher symmetrization, sidestepping traditional entropy integrals or covering arguments.
- Blockwise AdamW: The variant introduced in [xie2024adam] further sharpens the bounds, removing embedding dimension dependence.
- Explicit Computation of Measure Derivatives: The paper provides closed-form expressions for the Lions' derivative of the attention kernel and its higher-order derivatives, enabling sharp local uniformity and control of regularity properties.
- Generalization to Loss Functions: The analysis accommodates general loss functions defined on measures, with Lipschitz continuity of the Lions' derivative, making the result applicable to broad objective classes.
Numerical and Theoretical Implications
The established rate, H5, while slightly suboptimal compared to previous non-uniform analyses (which achieve H6 in some cases), offers uniformity over the input domain, a critical property for generalization guarantees and out-of-distribution robustness.
Key implications:
- Generalizability: The uniformity with respect to initial condition means the result applies not only to seen data but to unobserved inputs as well.
- Scalability: By making error bounds dimension-agnostic (with blockwise AdamW), the results remain relevant as models and inputs scale up.
- Practicality: Analysis under AdamW, the optimizer of choice for large transformers in practice (including LLMs and Diffusion Transformers), ensures that the theory is aligned with state-of-the-art empirical practice.
- ResNet Mean-Field Limits: Extensions and improvements over [avelin2020neuralodesdeeplimit], [ding2], [chaintron2026resnets], with removal of parameter dependence across layers at initialization and extension to general i.i.d. initialization.
- Transformers and Self-Attention: Builds on recent mean-field analyses for transformers ([rigollet2025meanfielddynamicstransformers], [gao2024global]), but is the first to analytically handle AdamW and derive uniform scaling limits independent of both H7 and H8.
- Optimizer Theory: The result leverages the recent characterization of AdamW as performing constrained optimization with explicit compactness guarantees ([pmlr-v235-xie24e]).
Open Problems and Future Directions
The authors note a remaining gap: the order of the convergence rate (H9 vs L0) under some regimes could potentially be closed, possibly by extending arguments used in non-uniform convergence settings. Also, the current results are restricted to finite training horizon; extending these uniformly in training time is an open avenue, potentially requiring refined properties of the adjoint variable near optimum (as in [marion2024implicit]).
Extending the analysis to include (i) stochastic depth (SDE scaling regimes), (ii) attention with causal masking (as in autoregressive LMs), and (iii) architectures with normalization layers, are highlighted as important future milestones.
The techniques developed may also impact the theoretical understanding of the generalization, robustness, and optimization landscape in large transformer models, informing both architecture design and scaling policy in practical AI systems.
Conclusion
This paper delivers the first uniform-in-input, token-count, and embedding-dimension convergence guarantees for deep transformers trained with AdamW, establishing the continuous-time mean-field limit as a robust theoretical tool in the analysis of modern large-scale attention architectures. The results advance the mathematical foundations of deep learning theory, with practical implications for scalable and reliable transformer optimization.
References: