A Stable Whitening Optimizer for Efficient Neural Network Training
(2506.07254v2)
Published 8 Jun 2025 in cs.LG
Abstract: In this work, we take an experimentally grounded look at neural network optimization. Building on the Shampoo family of algorithms, we identify and alleviate three key issues, resulting in the proposed SPlus method. First, we find that naive Shampoo is prone to divergence when matrix-inverses are cached for long periods. We introduce an alternate bounded update combining a historical eigenbasis with instantaneous normalization, resulting in across-the-board stability and significantly lower computational requirements. Second, we adapt a shape-aware scaling to enable learning rate transfer across network width. Third, we find that high learning rates result in large parameter noise, and propose a simple iterate-averaging scheme which unblocks faster learning. To properly confirm these findings, we introduce a pointed Transformer training benchmark, considering three objectives (LLMling, image classification, and diffusion modelling) across different stages of training. On average, SPlus is able to reach the validation performance of Adam within 44% of the gradient steps and 62% of the wallclock time.
Summary
The paper introduces SPlus, an optimizer that stabilizes training by using instant-sign normalization to avoid divergence and significantly reduce eigendecomposition frequency.
It employs symmetric shape-aware scaling to maintain consistent learning rate transfer across different layer dimensions, simplifying hyperparameter tuning.
Additionally, iterate averaging is applied to reduce parameter noise, enabling SPlus to achieve validation performance with only 44% of gradient steps and 62% of wall-clock time compared to Adam.
This paper introduces SPlus, a novel optimization algorithm for training neural networks, particularly Transformers, designed to be more efficient in terms of gradient steps and wall-clock time compared to existing methods like Adam and its predecessor, Shampoo. SPlus builds upon the Shampoo family of optimizers, which approximate a whitening metric from historical gradients, and addresses three key shortcomings of naive Shampoo.
The core contributions and mechanisms of SPlus are:
Instant-Sign Normalization for Stability:
The paper identifies that naive Shampoo often diverges, especially with high learning rates or when matrix inverses (used for preconditioning) are cached for extended periods (e.g., >25 steps). Shampoo's update can be expressed using an eigen-decomposition of gradient covariance matrices L=E[GGT] and R=E[GTG]:
UShampoo=(QLΛL−1/2QLT)G(QRΛR−1/2QRT),
where Q are eigenvectors and Λ are eigenvalues. Divergence can occur if incoming gradients align with an eigenbasis direction having a small historical magnitude (small eigenvalue), leading to an amplified update.
SPlus replaces this with an "instant-sign normalization." It retains the historical eigenbasis (QL,QR) but ignores the historical magnitudes (ΛL,ΛR). Instead, it normalizes instantaneously using the sign function after projecting the gradient G (or momentum-averaged gradient Gˉ) into the eigenbasis:
USPlus=QLTsign(QLGˉQR)QRT.
This update is inherently bounded (∣∣U∣∣spectral≤∣∣U∣∣frob=nm), preventing divergence. This stability allows SPlus to cache the eigenbasis (QL,QR) for much longer intervals (e.g., >100 steps), significantly reducing the computational overhead of frequent eigendecompositions and thus improving wall-clock time.
Symmetric Shape-Aware Scaling for Learning Rate Transfer:
To ensure that the optimal learning rate remains consistent across different network widths, SPlus introduces a per-layer scaling factor. This is crucial for easier hyperparameter tuning. For a dense layer of shape m×n, the update U (derived from instant-sign normalization) is scaled by $2 / (m+n)$:
U←U⋅m+n2.
This symmetric scaling (treating m and n similarly) was found empirically to outperform "spectral" scaling (e.g., $1/m$) proposed in other works, especially in Transformer MLP blocks, while still enabling learning rate transfer.
Iterate Averaging to Reduce Parameter Noise:
Whitening-based optimizers, due to their normalized updates, can suffer from parameter noise, where parameters oscillate around their optimal values, especially at high learning rates. To mitigate this without sacrificing learning speed by reducing the learning rate, SPlus employs iterate averaging. It maintains two sets of parameters:
"Live" parameters (θ′): These are updated aggressively using the SPlus update rule and a potentially high learning rate. Gradients are computed using these parameters: G=∇θ′L(θ′,x).
"Slow" parameters (θ): These are an exponential moving average (EMA) of the live parameters: θ←(1−β3)θ+β3θ′.
During training, updates are applied to θ′, while the model used for evaluation (and potentially for saving checkpoints) is θ. This averaging smooths out the oscillations in θ′, revealing the "true" learning progress. The EMA decay rate β3 (e.g., 0.999) is a new hyperparameter but is reported to be relatively insensitive.
Algorithm Outline (SPlus changes from Shampoo highlighted):
For each layer gradient G∈Rm×n:
Compute gradient: G=∇θL(θ′,x) (using live parameters θ′).
Compute eigendecomposition: QL,ΛL←eigh(L), QR,ΛR←eigh(R). (These are the cached Q matrices).
Compute SPlus preconditioned update:
Uraw←QLTsign(QLGˉQR)QRT.
U←Uraw⋅m+n2 (shape-aware scaling).
Update live parameters: θ′←θ′−αU (where α is learning rate).
Update slow (EMA) parameters: θ←(1−β3)θ+β3θ′.
Experimental Evaluation:
SPlus was evaluated on Transformer architectures across three tasks: autoregressive LLMing (LLM on OpenWebText), latent diffusion modeling (DiT for ImageNet generation), and image classification (ViT on ImageNet). Training was performed starting from different checkpoints (initialization, 10k steps, 50k steps of Adam pre-training) to assess robustness.
Results:
SPlus, on average, reached the validation performance of Adam using approximately 44% of the gradient steps and 62% of the wall-clock time. It consistently outperformed other optimizers like naive Shampoo (which often diverged), Schedule-Free Adam, Sophia, SOAP, PSGD, and Muon across the tested tasks and training stages. The stability improvements were significant, allowing SPlus to use an inversion frequency of 100 steps, compared to Shampoo needing 10 steps to avoid divergence.
Practical Implementation and Usage:
Code: Single-file JAX and PyTorch implementations are provided at https://github.com/kvfrans/splus.
Application: SPlus updates are applied to 2D dense layers. For other parameters (e.g., LayerNorm scales, embeddings, output heads), a simpler update (sign of momentum with a fixed constant scaling factor like 0.001) is used.
Learning Rate:
A suggested conversion from a tuned Adam learning rate: splus_lr=adam_lr⋅network_width⋅2.
A general starting point: splus_lr=0.2.
Hyperparameters:
β1 (momentum), β2 (covariance EMA), weight decay can often be kept from Adam settings.
SPlus-specific hyperparameters like EMA rate for iterate averaging (β3, default 0.999), inversion frequency (N, default 100), and non-standard parameter scaling constant are reported to be relatively insensitive.
Parameter Handling: A key consideration is that gradients are computed using the "live" parameters (θ′), while evaluations and checkpoints typically use the "slow" EMA parameters (θ). Helper functions are provided in the implementations to manage this.
Memory: SPlus requires approximately 60% more memory per dense layer than Adam, for storing live parameters, slow parameters, momentum, Kronecker factors (L,R), and cached eigenvectors (QL,QR).
Computational Cost: The main additional costs are:
Matrix multiplications for rotating gradients into and out of the eigenbasis (per step, generally negligible).
Eigendecomposition of L and R matrices (every N steps, more significant but amortized). This step can be parallelized across devices in distributed settings by sharding the matrices.
Limitations:
Increased memory and computational cost compared to Adam.
The paper's experiments are focused on Transformer architectures. Performance on other architectures is an open question.
SPlus offers a promising approach to accelerate Transformer training by improving optimizer stability and efficiency, allowing for more aggressive learning strategies with reduced computational bottlenecks from preconditioning.