Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
119 tokens/sec
GPT-4o
56 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
6 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Gradient Multi-Normalization for Stateless and Scalable LLM Training (2502.06742v1)

Published 10 Feb 2025 in cs.LG and cs.AI

Abstract: Training LLMs typically relies on adaptive optimizers like Adam (Kingma & Ba, 2015) which store additional state information to accelerate convergence but incur significant memory overhead. Recent efforts, such as SWAN (Ma et al., 2024) address this by eliminating the need for optimizer states while achieving performance comparable to Adam via a multi-step preprocessing procedure applied to instantaneous gradients. Motivated by the success of SWAN, we introduce a novel framework for designing stateless optimizers that normalizes stochastic gradients according to multiple norms. To achieve this, we propose a simple alternating scheme to enforce the normalization of gradients w.r.t these norms. We show that our procedure can produce, up to an arbitrary precision, a fixed-point of the problem, and that SWAN is a particular instance of our approach with carefully chosen norms, providing a deeper understanding of its design. However, SWAN's computationally expensive whitening/orthogonalization step limit its practicality for large LMs. Using our principled perspective, we develop of a more efficient, scalable, and practical stateless optimizer. Our algorithm relaxes the properties of SWAN, significantly reducing its computational cost while retaining its memory efficiency, making it applicable to training large-scale models. Experiments on pre-training LLaMA models with up to 1 billion parameters demonstrate a 3X speedup over Adam with significantly reduced memory requirements, outperforming other memory-efficient baselines.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (4)
  1. Meyer Scetbon (22 papers)
  2. Chao Ma (187 papers)
  3. Wenbo Gong (16 papers)
  4. Edward Meeds (15 papers)

Summary

The paper "Gradient Multi-Normalization for Stateless and Scalable LLM Training" (Scetbon et al., 10 Feb 2025 ) introduces a novel framework and a practical optimizer designed to mitigate the substantial memory overhead associated with adaptive optimizers like Adam during the training of LLMs. While adaptive methods typically accelerate convergence, their storage of first and second-moment estimates significantly increases memory requirements, posing a challenge for scaling. The work builds upon recent efforts like SWAN (Abu-Ajamieh et al., 19 Jan 2024 ), which demonstrated the feasibility of stateless optimizers achieving Adam-like performance through gradient preprocessing, but suffered from computational bottlenecks. This paper proposes a general framework for designing stateless optimizers based on simultaneous normalization of gradients with respect to multiple norms and introduces SinkGD, a computationally efficient and scalable optimizer derived from this framework.

Multi-Normalized Gradient Descent (MNGD) Framework

The core theoretical contribution is the Multi-Normalized Gradient Descent (MNGD) framework. Traditional interpretations often view optimizers as performing steepest descent under a specific, potentially adaptive, norm. MNGD generalizes this by seeking an update direction z that maximizes alignment with the current gradient (i.e., maximizing the inner product <∇, z>) while simultaneously satisfying normalization constraints with respect to multiple norms g₁, ..., g\<0xE2>\<0x82>\<0x96>. Formally, the optimization problem is:

argmax <∇, z> subject to gᵢ(z) = 1 for i = 1, ..., K

Solving this multi-constraint optimization problem directly is generally NP-hard. Therefore, the paper proposes a practical iterative algorithm, termed MultiNorm (Algorithm 1), based on alternating projections. MultiNorm initializes with the gradient x ← ∇ and iteratively projects x onto the unit sphere defined by each norm gᵢ for i = 1 to K. This sequence of projections is repeated L times. The projection P\<0xE1>\<0xB5>\<0x8D>(x) for a given norm g seeks the point on the unit sphere of g that is closest to x in terms of inner product maximization. Under specific conditions (specifically for K=2 norms satisfying Assumption 3.3, including the critical assumption that the L2 norm of the projection P\<0xE1>\<0xB5>\<0x8D>₂(x) remains constant across iterations), the paper proves (Theorem 3.6) that the MultiNorm algorithm converges to a fixed point, effectively achieving simultaneous normalization with respect to the chosen norms up to a desired precision.

Relation to SWAN

The MNGD framework provides a new perspective on the SWAN optimizer. The paper demonstrates that SWAN's gradient preprocessing step corresponds precisely to a single iteration (L=1) of the MultiNorm algorithm using two specific norms applied to the gradient matrix (of size m x n):

  1. Row-wise L2-norm: g₁(W) = maxᵢ ||Wᵢ,:||₂ / √n
  2. Spectral norm: g₂(W) = ||W||₂

This interpretation elucidates SWAN's mechanism but also pinpoints its practical limitation: the projection onto the spectral norm unit sphere involves computing the matrix inverse square root, a step with O(m²(m+n)) complexity (via SVD or similar methods). This high computational cost hinders SWAN's applicability to the large matrices encountered in LLM layers.

SinkGD: A Scalable MNGD Instantiation

Leveraging the MNGD framework, the authors introduce Sinkhorn Gradient Descent (SinkGD), a stateless optimizer designed for scalability. SinkGD retains the row-wise L2-norm (g₁) from SWAN but replaces the computationally expensive spectral norm (g₂) with the significantly cheaper column-wise L2-norm:

g₂(W) = maxⱼ ||W:,ⱼ||₂ / √m

The normalization procedure within SinkGD, termed SR-Sinkhorn (Algorithm 3), implements the MultiNorm alternating projection scheme for these row and column L2 norms. Crucially, the paper establishes a direct connection between this alternating row/column L2 normalization process and the classical Sinkhorn algorithm. Specifically, the SR-Sinkhorn iterates, involving alternating multiplications by inverse diagonal matrices of row/column L2 norms, correspond exactly to the square-root iterates of the Sinkhorn algorithm applied to the matrix of squared gradient entries (∇⊙∇). This connection provides theoretical grounding, leveraging the known linear convergence properties of the Sinkhorn algorithm to guarantee efficient convergence of the normalization procedure.

The computational complexity of the SR-Sinkhorn normalization (performing L iterations of row and column normalizations) is O(Lmn), which matches the O(mn) complexity of standard matrix multiplications and element-wise operations in Adam updates for dense layers, assuming L is a small constant (empirically, L=5 is used). This represents a substantial improvement over SWAN's O(m²(m+n)) complexity. Furthermore, SinkGD maintains the memory efficiency of stateless methods, requiring only storage for the gradient itself, akin to SGD.

Implementation and Practical Considerations

Implementing SinkGD involves replacing the optimizer update step in a standard training loop. Instead of computing Adam's moments and applying the Adam update rule, one computes the gradient and then applies the SR-Sinkhorn normalization procedure before performing the weight update w ← w - η * z, where z is the normalized gradient and η is the learning rate.

The SR-Sinkhorn normalization (Algorithm 3) for a gradient matrix X (initially ) can be described pseudocode-wise as:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
function SR_Sinkhorn(X, L):
  # X is the m x n gradient matrix
  # L is the number of iterations
  for iteration from 1 to L:
    # Row normalization
    row_norms = sqrt(sum(X**2, axis=1)) # Compute L2 norm of each row
    row_norms = clip(row_norms, min=epsilon) # Avoid division by zero
    Q_inv = diag(1.0 / row_norms)
    X = sqrt(n) * Q_inv @ X

    # Column normalization
    col_norms = sqrt(sum(X**2, axis=0)) # Compute L2 norm of each column
    col_norms = clip(col_norms, min=epsilon) # Avoid division by zero
    R_inv = diag(1.0 / col_norms)
    X = sqrt(m) * X @ R_inv
  return X

Key implementation details include:

  • Number of Iterations L: Experiments typically use L=5. This small constant ensures the normalization overhead remains low while being sufficient for convergence based on the Sinkhorn connection.
  • Scaling Factors: The √n and √m factors are incorporated during row and column normalization, respectively. Their purpose is to ensure the final normalized gradient matrix z has a Frobenius norm of √(nm). This scaling matches the expected magnitude of gradients processed by Adam (assuming variance normalization), allowing the use of similar learning rate schedules tuned for Adam without significant modification.
  • Numerical Stability: A small epsilon is added before division by row/column norms to prevent division by zero, a standard practice in normalization techniques.
  • Distributed Training: As a stateless optimizer, SinkGD is amenable to standard data-parallel distributed training setups without the complexities associated with managing distributed optimizer states.

Experimental Results and Scalability

The paper presents empirical results primarily from pre-training LLaMA models up to 1.3 billion parameters. Key findings highlight SinkGD's effectiveness:

  • Memory Efficiency: Confirmed the stateless nature, showing significantly reduced memory usage compared to Adam. For a LLaMA-1.3B model, SinkGD required approximately 2.98 GB of optimizer memory, compared to 7.48 GB for Adam (Table 1).
  • Computational Throughput: Raw throughput (tokens/sec) of SinkGD was measured to be comparable to or slightly faster than Adam, indicating the O(mn) normalization overhead is negligible in practice (Figure 1c).
  • Convergence Speed: In terms of wall-clock time to reach a target perplexity, SinkGD demonstrated significantly faster convergence than Adam. The paper reports an effective throughput speedup of approximately 3x over Adam for the 1.3B model training (Table 3), suggesting better sample efficiency.
  • Performance: SinkGD achieved final perplexity scores on par with or slightly better than AdamW and outperformed other memory-efficient baselines like GaLore and Apollo across various model sizes (Figure 1a, Figure 2). Notably, a LLaMA-1B model trained with SinkGD reportedly achieved performance comparable to baseline 7B models, underscoring its efficiency.

These results suggest that SinkGD effectively balances memory efficiency (stateless), computational efficiency (O(mn) per update), and convergence performance, making it a viable alternative to Adam for large-scale LLM training.

Conclusion

The Gradient Multi-Normalization framework offers a principled approach to designing stateless optimizers by enforcing multiple normalization constraints simultaneously. SinkGD, derived from this framework by using computationally tractable row and column L2 norms and leveraging the Sinkhorn algorithm's properties, emerges as a practical and scalable optimizer. Its demonstrated ability to match or exceed Adam's performance while drastically reducing memory overhead and improving effective throughput makes it a compelling method for training increasingly LLMs under memory and computational constraints.