Distributed Shampoo Optimizer
- Distributed Shampoo is an optimizer for deep learning that employs block-diagonal preconditioning with Kronecker approximations to efficiently mimic full-matrix AdaGrad.
- It leverages DTensor-based distributed data parallelism to shard preconditioner blocks across GPUs, reducing memory load and computational cost.
- Empirical evaluations on ResNet-50/ImageNet show that Distributed Shampoo improves convergence and accuracy with modest per-step overhead.
Distributed Shampoo is an optimizer for large-scale neural network training that belongs to the AdaGrad family. It employs a block-diagonal preconditioner with each block constructed as a coarse Kronecker product approximation of full-matrix AdaGrad, balancing the trade-offs between memory, computational cost, and statistical effectiveness. Distributed Shampoo leverages advanced PyTorch primitives, specifically the DTensor infrastructure and distributed data-parallelism, enabling efficient multi-GPU training with minimal per-step overhead compared to diagonal adaptive methods, while delivering improved convergence and accuracy, as demonstrated in empirical studies on the ImageNet benchmark using ResNet-50 (Shi et al., 2023).
1. Algorithmic Foundations and Kronecker Preconditioning
Distributed Shampoo positions itself between diagonal AdaGrad and full-matrix AdaGrad by constructing a block-diagonal preconditioner, with each block corresponding to a parameter tensor. For each block (parameter), let . The full-matrix AdaGrad preconditioner accumulates , incurring memory and computation per update—prohibitive for large-scale models.
Shampoo approximates each by constructing Kronecker factors: and uses the approximation
which allows the per-block update
with the global update , and as the block-diagonal concatenation of all . This structure leads to substantial reductions in resource usage, retaining much of the improved conditioning of full-matrix AdaGrad.
A key enhancement is learning-rate grafting: Shampoo inherits the norm of a diagonal method's update (e.g., AdaGrad), rescaling the search direction for each block as
facilitating reuse of established learning-rate schedules.
2. Distributed Data Parallelism and DTensor Sharding
A naïve distributed implementation of Shampoo would replicate all Kronecker state and support expensive matrix inversions on each GPU, incurring remarkable slowdowns (50–75% compared to diagonal methods). Distributed Shampoo instead exploits PyTorch's DTensor interface to shard these preconditioners.
The DTensor-based strategy partitions the set of block preconditioners across GPUs using a greedy load-balancing algorithm. Each GPU (or process-group) manages only its allocated subset, reducing per-GPU memory load by approximately and localizing computation. Each GPU calculates inverse roots for its blocks and applies them to corresponding gradients, accumulating partial search directions.
A single 1D int8 buffer aggregates local search directions across GPUs. The AllGather primitive then broadcasts the complete set to all participants, synchronizing updates efficiently.
To further balance computation and communication, multi-group hierarchies can be created: preconditioners replicate across -sized subgroups, and AllGather operations become localized to these subgroups, minimizing congestion while maintaining synchronous parameter updates.
3. Performance Optimizations and Overhead Assessment
Distributed Shampoo incorporates several system-level optimizations:
- Periodic root-inverse computation: The most expensive computation—matrix fourth-root inverses—is amortized by updating only every steps ("stale roots"), e.g., preserves accuracy for ResNet-50/ImageNet, incurring only 5–8% wall-clock overhead.
- Dimension and block-size heuristics: Tensors with dimension above are either blocked into patches, diagonalized, or fallback to diagonal AdaGrad. Typical parameters are and .
- Fused elementwise operations: PyTorch’s
_foreachkernel is used for optimizations such as -updates and decay. - Guarded eigendecomposition: A retry in double precision mitigates rare decomposition failures.
Measured benchmarks demonstrate that, on 8V100 GPUs, batch-128/GPU, per-step overhead is 8–10% compared to SGD-Nesterov with , dropping below 2% for without significant loss in final accuracy.
4. Resource Complexity: Memory and Computation
A detailed complexity analysis is as follows, for a single parameter block:
| Optimizer | Memory | Per-Step Compute |
|---|---|---|
| Full-matrix AdaGrad | ||
| Diagonal AdaGrad | ||
| Shampoo |
In aggregate, Shampoo's total memory requirement is roughly 4–7 the model parameters, substantially feasible compared to full-matrix AdaGrad, and computational cost grows cubically (matrix roots), but remains well below that of full-matrix alternatives.
5. Empirical Evaluation: ResNet-50/ImageNet Ablations
Comprehensive ablation experiments were conducted on ImageNet ( classes) with ResNet-50 (25.5M parameters) on 8V100 GPUs (batch 1288), cosine decay learning rate schedule with 5-epoch warmup.
- Fixed 90-epoch budget: Shampoo achieves 77.44% Top-1 accuracy compared to 76.85% for SGD-Nesterov, with only +8% wall-clock overhead.
- “Equal-time” comparison: Shampoo, at 60 epochs, matches SGD-Nesterov’s 76.9% accuracy at 90 epochs, yielding 1.35 time savings and requiring 1.5 fewer steps for convergence.
- Learning-rate sensitivity: Shampoo exhibits superior robustness and consistency in accuracy across a 10 range of base learning rates, outperforming SGD-Nesterov in both accuracy and variance.
A representative subset of epoch sweeps:
| Method / Epochs | 40 | 60 | 80 | 90 |
|---|---|---|---|---|
| SGD-Nesterov | 75.2% | 76.1% | 76.6% | 76.9% |
| Shampoo | 76.4% | 77.2% | 77.3% | 77.4% |
In this setup, Shampoo reaches the 90-epoch SGD-Nesterov accuracy after only 60 epochs.
6. Deployment Recommendations
The following guidelines are advised for scalable deployment of Distributed Shampoo:
- Utilize DTensor for state sharding across GPUs (
use_dtensor=True). - Set
max_preconditioner_dim=2048to limit block size, apply blocking and merging as needed. - Update matrix roots every 50 steps (
precondition_frequency=50) for efficiency. - Incorporate learning-rate grafting from SGD or Adam for schedule reuse.
- Enable decoupled weight decay (
use_decoupled_weight_decay=True) and bias correction for regularization. - Combine with momentum/Nesterov acceleration (
momentum=0.9, use_nesterov=True). - Use fused elementwise operations (
_foreach) and guarded eigendecomposition for kernel reliability.
Using these practices, Distributed Shampoo delivers improved convergence speed and variance control over diagonal adaptive methods, with a per-step wall-clock cost in the single-digit percent range, validating its practicality for production-scale distributed neural network training (Shi et al., 2023).