Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
134 tokens/sec
GPT-4o
9 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training (2501.06842v2)

Published 12 Jan 2025 in cs.LG, cs.AI, and cs.CL

Abstract: LLMs have demonstrated exceptional performance across diverse tasks, yet their training remains highly resource-intensive and susceptible to critical challenges such as training instability. A predominant source of this instability stems from gradient and loss spikes, which disrupt the learning process, often leading to costly interventions like checkpoint recovery and experiment restarts, further amplifying inefficiencies. This paper presents a comprehensive investigation into gradient spikes observed during LLM training, revealing their prevalence across multiple architectures and datasets. Our analysis shows that these spikes can be up to $1000\times$ larger than typical gradients, substantially deteriorating model performance. To address this issue, we propose Spike-Aware Adam with Momentum Reset SPAM, a novel optimizer designed to counteract gradient spikes through momentum reset and spike-aware gradient clipping. Extensive experiments, including both pre-training and fine-tuning, demonstrate that SPAM consistently surpasses Adam and its variants across various tasks, including (1) LLM pre-training from 60M to 1B, (2) 4-bit LLM pre-training,(3) reinforcement learning, and (4) Time Series Forecasting. Additionally, SPAM facilitates memory-efficient training by enabling sparse momentum, where only a subset of momentum terms are maintained and updated. When operating under memory constraints, SPAM outperforms state-of-the-art memory-efficient optimizers such as GaLore and Adam-Mini. Our work underscores the importance of mitigating gradient spikes in LLM training and introduces an effective optimization strategy that enhances both training stability and resource efficiency at scale. Code is available at https://github.com/TianjinYellow/SPAM-Optimizer.git

Summary

  • The paper presents SPAM, a modified Adam optimizer that resets momentum when encountering gradient spikes to improve LLM training stability.
  • It provides a detailed regret bound analysis showing how large gradients inflate adaptive estimates and loosen AMSGrad's regret bounds.
  • The study informs practitioners on tuning AMSGrad parameters in the presence of erratic gradients, enhancing real-world training performance.

Regret Bound Analysis for AMSGrad with Gradient Spikes

This paper offers an incisive analysis of how gradient spikes impact the regret bounds of the AMSGrad optimization algorithm, a recognized variant of the Adam optimizer. Building upon the foundational work presented in the cited paper, the core of this paper's contribution is a novel focus on gradient spikes and their influence on the algorithm's regret bounds. Gradient spikes, characterized by occasional large gradient values, are shown to significantly influence the standard regret bound derivation, consequently suggesting that AMSGrad may exhibit increased regret in certain scenarios.

Theoretical Framework and Assumptions

The investigation is set in the context of an online convex optimization problem where, at each iteration, a decision maker selects a point in a compact convex domain. The subgradients of the convex functions involved are assumed to be bounded, with the potential for gradient spikes still considered under these constraints. The AMSGrad variant of Adam is employed due to its enhancement in convergence properties over the standard Adam algorithm in convex settings.

Impact of Gradient Spikes on Regret Bounds

Gradient spikes are demonstrated to have a tangible effect on the regret bounds of AMSGrad. It is shown that large gradients cause the adaptive second moment estimate v^t\hat{v}_t to increase. This inflation in v^t\hat{v}_t leads to looser regret bounds because the terms involving v^T,i1/2\hat{v}_{T,i}^{1/2} and t=1Tgt,i2\sum_{t=1}^T g_{t,i}^2, both of which are part of the regret equation, are augmented by the presence of large gradients.

The derived expression for regret RTR_T is:

RTD2T2α(1β1)i=1dv^T,i1/2+(1+β1)α1+logT2(1β1)(1β2)(1γ)i=1dt=1Tgt,i2R_T \leq \frac{ D^2 \sqrt{T} }{ 2 \alpha (1 - \beta_1) } \sum_{i=1}^d \hat{v}_{T,i}^{1/2} + \frac{ (1 + \beta_1) \alpha \sqrt{1 + \log T} }{ 2 (1 - \beta_1) \sqrt{ (1 - \beta_2)(1 - \gamma) } } \sum_{i=1}^d \sqrt{ \sum_{t=1}^T g_{t,i}^2 }

Here, gradient spikes have a dual effect:

  1. First Term: Large gradients enhance v^T,i1/2\hat{v}_{T,i}^{1/2}, thereby increasing the first component of the bound.
  2. Second Term: Similar inflation occurs in the sum of squared gradients t=1Tgt,i2\sum_{t=1}^T g_{t,i}^2, augmenting the second component.

Methodological Limitations

The paper acknowledges limitations inherent in its analysis:

  • The convexity assumption secures the theoretical framework but limits applicability to nonconvex settings.
  • The gradient bound assumption gtG\| g_t \|_\infty \leq G might not always fit well with the concept of gradient spikes, especially if GG is underestimated to accommodate spikes.
  • An excessively conservative estimate for GG could enlarge the constants in the regret bound, diminishing its practical relevance.

Practical and Theoretical Implications

For researchers and practitioners utilizing AMSGrad in environments susceptible to gradient spikes, this analysis provides a critical perspective on tuning the algorithm to mitigate larger regret bound situations. The insights into how gradient spikes affect learning rates and cumulative gradient metrics warrant consideration when configuring AMSGrad, particularly in real-world scenarios involving irregularly large gradients.

Future Directions

Future research could extend these findings into nonconvex optimization landscapes or investigate adaptive mechanisms within AMSGrad to counteract the loosening of regret bounds posed by gradient spikes. Developing techniques to anticipate or dynamically adjust for gradient spikes may enhance the efficiency and effectiveness of AMSGrad in complex machine learning tasks.

In summary, the paper diligently explores the nuanced impact of gradient spikes on AMSGrad's regret bounds, offering a detailed yet specialized extension to the understanding of adaptive gradient methods in the presence of erratic gradient behavior.

Github Logo Streamline Icon: https://streamlinehq.com