This paper investigates the phenomenon of rapidly increasing gradient norms observed towards the end of long-duration LLM training runs. It posits that this increase is not an inherent property of optimization but rather an unintended consequence of the interaction between weight decay, normalization layers (like LayerNorm or BatchNorm), and learning rate schedules. The authors propose a simple correction to the weight decay mechanism that mitigates this issue and leads to improved training performance.
The core of the analysis builds upon the work of van der Walt and Germishuizen (2017), which showed that for layers followed by normalization, weight decay controls the ratio of the gradient norm to the weight norm at steady state. Specifically, for SGD, this relationship is:
where is the gradient norm, is the weight norm, is the weight decay coefficient, and is the current learning rate.
The paper highlights that when a learning rate schedule (e.g., cosine decay) is used, changes over time. This makes the target steady-state ratio a moving target. As decreases significantly towards the end of training (approaching zero), the theoretical steady-state ratio rapidly increases. While the decreasing learning rate also slows down the optimization process, preventing the actual ratios from perfectly tracking this explosive target, it still leads to a noticeable "tail blow-up" in gradient norms. This behavior is illustrated in Figure 1 of the paper.
The authors also extend this analysis to AdamW, arguing that a similar principle applies, though the derivation is more approximate. They suggest that AdamW effectively balances the layer-wise infinity norms of the weights, which is beneficial because the infinity norm is a key scaling measure for Adagrad-family optimizers. This balancing act is contrasted with the original Adam, where the weight decay term is also normalized by the adaptive learning rates, preventing this desirable balancing and contributing to AdamW's superior performance.
To address the gradient norm increase, the paper proposes a "corrected" weight decay term, , which aims to make the steady-state gradient-to-weight ratio independent of the current learning rate . The proposed correction is:
where is the maximum learning rate used during the training schedule. When this corrected weight decay is applied, the steady-state ratio becomes:
This new target ratio is constant throughout training, depending only on the initial weight decay and the maximum learning rate . This correction is applied only to layers that are followed by normalization. For other layers, standard weight decay is used. The modified optimizers are termed AdamC (Adam with Corrected Weight Decay) and SGDC.
The implementation of AdamC is detailed in Algorithm 1 of the paper:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
Algorithm: Adam with Corrected Weight Decay (AdamC) Input: Initial values x_{1,l} for all layers l Input: learning rate schedule γ_t, maximum γ_max, decay λ, β₁, β₂, ε v_{0,l} = m_{0,l} = 0 For t = 0 to T: For layer l = 0 to L: g_{t,l} ∈ ∂f(x_{t,l}, ζ_{t,l}) // Minibatch gradient m_{t,l} = β₁ m_{t-1,l} + (1-β₁) g_{t,l} // Standard Adam updates v_{t,l} = β₂ v_{t-1,l} + (1-β₂) g_{t,l} ∘ g_{t,l} m̂_{t,l} = m_{t,l} / (1-β₁^t) v̂_{t,l} = v_{t,l} / (1-β₂^t) If Layer l is Normalized: // Corrected weight decay for normalized layers x_{t+1,l} = x_{t,l} - γ_t m̂_{t,l} / (√(v̂_{t,l}+ε)) - (γ_t * λ * (γ_t / γ_max)) x_{t,l} // Simplified: x_{t+1,l} = x_{t,l} - γ_t m̂_{t,l} / (√(v̂_{t,l}+ε)) - (γ_t^2 / γ_max * λ) x_{t,l} // The paper's algorithm shows: γ_t λ_corrected = γ_t * (λ * γ_t / γ_max) which is incorrect. // The paper's text implies the update should be: x_new = x_old - lr_update - lr * λ_corrected * x_old // where λ_corrected = λ * (γ_t / γ_max) // So the weight decay term becomes: γ_t * (λ * γ_t / γ_max) * x_{t,l} // However, the algorithm in the paper (line 12) states: (γ_t^2 / (γ_max * λ)) * x_{t-1,l}. This seems like a typo and should likely be γ_t * (λ * γ_t / γ_max) * x_{t-1,l} or simply (λ * γ_t^2 / γ_max) * x_{t-1,l} assuming the standard AdamW form. // Let's assume the paper's intent for the decay part is: current_lr * effective_wd_coeff * weights // where effective_wd_coeff = λ_original * (current_lr / max_lr) // So, weight_decay_term = γ_t * (λ * γ_t / γ_max) * x_{t,l} // The pseudocode in the paper has a different term highlighted: (γ_t^2 / (γ_max * λ)) * x_{t-1,l}. // This discrepancy is important. The derivation for steady state is: (||g||/||x||) = sqrt(2 * λ_effective / γ_t). // For this to be constant, λ_effective / γ_t must be constant. // If λ_effective = λ_original * (γ_t / γ_max), then λ_effective / γ_t = λ_original / γ_max, which is constant. // The weight decay step in AdamW is: x_t+1 = x_t - γ_t * (Adam_update_term + λ_effective * x_t). // So the decay term is: γ_t * λ_effective * x_t = γ_t * (λ * γ_t / γ_max) * x_t = (λ * γ_t^2 / γ_max) * x_t. // The pseudocode in the paper has γ_t^2 / (γ_max * λ), which inverts λ. This is likely a typo in the algorithm block. // The highlighted term in Algorithm 1 is `(γ_t^2 / (γ_max * λ)) * x_{t-1,l}` // The text states `hat(λ)_t = λ * (γ_t / γ_max)`. // The AdamW update is `x_t+1 = x_t - γ_t * update - γ_t * λ_decoupled * x_t`. // So, if we use `hat(λ)_t` as the decoupled weight decay, the term should be `γ_t * hat(λ)_t * x_t = γ_t * (λ * γ_t / γ_max) * x_t`. // This corresponds to `(λ * γ_t^2 / γ_max) * x_t`. // For the purpose of this summary, I will use the formula derived in the text and assume the pseudocode has a typo. // The corrected term in line 12 of the provided algorithm is `(γ^2_t / (γ_max * λ)) * x_{t-1,l}`. // However, the derivation for the corrected weight decay parameter `hat(λ)_t = λ * (γ_t / γ_max)` and the standard AdamW update form `x_new = x_old - γ_t * step - γ_t * λ_decoupled * x_old` would lead to a weight decay term of `γ_t * (λ * γ_t / γ_max) * x_{t-1,l}`. // Given the paper aims to keep `2 * λ_eff / γ` constant, and `λ_eff` is the term multiplying `x_t` in the decay, // if the SGD update is `x_t+1 = x_t - γ g_t - γ λ_eff x_t`, then `λ_eff = hat(λ)_t` would be used. // The paper's Algorithm 1 states: `x_t = x_{t-1} - γ_t * adam_step - (γ_t^2 / (γ_max * λ)) * x_{t-1,l}`. This means the effective `λ_eff` in `γ_t * λ_eff * x_t` would be `γ_t / (γ_max * λ)`. // This makes `2 * λ_eff / γ_t = 2 * (γ_t / (γ_max * λ)) / γ_t = 2 / (γ_max * λ)`, which is also constant. // So, the algorithm's specific formulation *also* leads to a constant ratio, just a different one. // The critical part for implementation is that the term becomes `(γ_t^2 / (γ_max * λ)) * x_{t-1,l}`. // Re-evaluating based on the provided algorithm's line 12 exactly: // The weight decay applied is `λ'_t * x_{t-1,l}` where `λ'_t = γ_t^2 / (γ_max * λ)`. // This is *not* `γ_t * λ_effective * x_{t-1,l}`. It's a direct subtraction. // This implies the weight decay is treated as `x_{t+1} = x_t - γ_t * (adam_update) - λ_custom_t * x_t` // where `λ_custom_t = (γ_t^2 / (γ_max * λ))`. // Then the steady state would be `||g||/||x|| = sqrt(2 * λ_custom_t / γ_t) = sqrt(2 * (γ_t^2 / (γ_max * λ)) / γ_t) = sqrt(2 * γ_t / (γ_max * λ))`. // This *still depends on γ_t*. // There is a definite inconsistency between the text derivation of `hat(λ)_t` and its use to achieve a constant `sqrt(2λ/γ_max)`, and the exact term highlighted in Algorithm 1. // Let's assume the text is primary: corrected `λ_decoupled = λ * (γ_t / γ_max)`. // Then AdamW step: `x_{t+1} = x_t - γ_t * (m_hat / (sqrt(v_hat) + eps)) - γ_t * (λ * γ_t / γ_max) * x_t` // So the decay term is: `(λ * γ_t^2 / γ_max) * x_t`. // The algorithm in the paper text seems to have `λ` in the denominator of the highlighted term: `(γ_t^2 / (γ_max * λ)) * x_{t-1,l}`. // If this is the case, the effective lambda for the `sqrt(2λ_eff/γ)` formula would be `(γ_t / (γ_max * λ))`. This makes the ratio dependent on `γ_t`. // The most plausible interpretation reconciling the text and the goal is that the *total amount of decay per step* should be scaled. // Standard AdamW decay: `γ_t * λ * x_t`. // Proposed effective weight decay parameter `hat(λ)_t = λ * (γ_t / γ_max)`. // So, corrected AdamW decay is: `γ_t * hat(λ)_t * x_t = γ_t * (λ * γ_t / γ_max) * x_t = (λ * γ_t^2 / γ_max) * x_t`. // This is the term that should be subtracted. x_{t+1,l} = x_{t,l} - γ_t * (m̂_{t,l} / (√(v̂_{t,l}+ε))) - (λ * γ_t^2 / γ_max) * x_{t,l} Else: // Regular AdamW weight decay for non-normalized layers x_{t+1,l} = x_{t,l} - γ_t * (m̂_{t,l} / (√(v̂_{t,l}+ε))) - γ_t * λ * x_{t,l} EndIf EndFor EndFor Return x_{T,l} |
(γ^2_t / (γ_{\max} \lambda)) x_{t-1,l}
appears to be what the authors implemented. If this is the case, the effective decoupled weight decay coefficient would be λ_effective = γ_t / (γ_{\max} \lambda)
. Then the steady state ratio . This is also constant and independent of . So the algorithm, as written, achieves the goal of a constant steady-state, albeit with a different constant than derived from if used as . The key is that the *effective weight decay coefficient that is multiplied by in the steady-state formula becomes . Let's use the formula from the algorithm box for implementation summary.*
The AdamC algorithm (Algorithm 1) proceeds like standard AdamW but modifies the weight decay term for layers followed by normalization.
For each layer l
at timestep t
:
- Compute minibatch gradient .
- Update momentum and variance estimates.
- Compute bias-corrected and .
- Conditional Weight Decay Application:
If layer
l
is a normalized layer (e.g., LayerNorm, BatchNorm): The weight update is:Note: The paper's algorithm has in the denominator of the corrective term, which ensures the target ratio is constant.
Else (layer
l
is not normalized): The weight update uses standard AdamW weight decay:
Experimental Validation:
The paper validates this approach on two main tasks:
- LLM Pre-training: A 120M parameter Llama 3 architecture model trained for 200B tokens on FineWeb-Edu.
- AdamC resulted in significantly lower loss values throughout training compared to AdamW.
- The rapid increase in gradient norm towards the end of training was largely eliminated.
- Weight norms were much more stable with AdamC, whereas AdamW showed rapidly decreasing weight norms.
- For applying the correction, every linear layer (excluding the output layer) was considered normalized.
- ImageNet Classification: A ResNet-50 model trained on ImageNet with SGD with momentum and a cosine learning rate schedule.
- SGDC (SGD with the correction) eliminated the rapid gradient norm increase at the end of training.
- A slow, overall upward trend in gradient norms over time remained, suggesting the theory addresses only one component of gradient norm dynamics.
Practical Implementation Considerations:
- Identifying Normalized Layers: The correction is selectively applied only to weights of layers immediately followed by a normalization operation. For Transformer models, the paper suggests treating all linear layers (except the final output layer) as normalized. This heuristic is crucial.
- Hyperparameters: The method introduces (the maximum learning rate in the schedule) into the weight decay calculation. This value needs to be known. The original weight decay is still a hyperparameter.
- Computational Cost: The modification adds negligible computational overhead, involving a simple scaling of the weight decay term.
- Scope of Applicability: The benefits are most pronounced in long-duration training runs where learning rate schedules cause to drop significantly, and where normalization layers are prevalent (common in modern architectures like Transformers and ResNets).
- Optimizer Choice: The correction can be applied to both SGD (SGDC) and AdamW (AdamC).
Key Takeaways for Practitioners:
- The observed "gradient norm explosion" at the end of long training runs with learning rate schedules is likely an artifact of how weight decay interacts with normalization layers and the schedule, rather than a fundamental optimization issue.
- Applying the proposed corrected weight decay (e.g., AdamC) can lead to more stable training dynamics, lower final loss, and more stable weight norms, particularly in LLM pre-training.
- The correction involves adjusting the weight decay term for normalized layers by a factor proportional to (as per the algorithm block) or by using an effective weight decay in the standard AdamW formulation. The specific formula from Algorithm 1 in the paper is .
- This relatively simple modification can be readily implemented in existing training pipelines.
The paper provides a clear explanation and a practical, theory-motivated solution to a common and perplexing issue in large-scale model training. It also offers insights into why AdamW generally outperforms Adam by linking AdamW's behavior to the balancing of layer-wise norm ratios.