MARS-AdamW Optimizer for Scalable LLM Training
- MARS-AdamW is an optimizer that unifies AdamW momentum preconditioning with STORM-inspired variance reduction to improve convergence in large-scale neural network training.
- The algorithm employs gradient clipping and stabilization techniques to manage variance and ensure robust training across different model scales.
- Empirical evaluations on GPT-2 pretraining show MARS-AdamW significantly reduces tokens and wall-clock time compared to vanilla AdamW while maintaining competitive validation losses.
MARS-AdamW is an optimizer instance derived from the MARS (Make vAriance Reduction Shine) framework, which integrates preconditioned adaptive gradient methods with scalable stochastic variance reduction to improve the efficiency and convergence of large model training. Specifically, MARS-AdamW unifies AdamW-style momentum preconditioning with STORM-inspired variance-reduced gradient estimation and implements explicit stabilization via gradient clipping. Developed to address the observed gap between theoretical advances in variance reduction and their practical adoption in large-scale neural network training, MARS-AdamW demonstrates significant improvements in both token and time efficiency over vanilla AdamW for tasks such as LLM pretraining (Yuan et al., 2024).
1. Mathematical Derivation and Update Steps
MARS-AdamW builds on a variance-reduced preconditioned gradient framework. The core mechanism involves mixing the standard stochastic gradient with a STORM-style recursive momentum correction and then applying AdamW's adaptive moment preconditioning:
- Variance-Reduced Gradient Estimator:
Here, is the parameter vector at iteration , denotes the stochastic mini-batch, modulates the variance reduction intensity (recovering vanilla AdamW at and full STORM at ).
- Gradient Clipping:
- Moment Estimation:
Bias corrections are then applied:
- Preconditioned Parameter Update (with decoupled weight decay):
For , the procedure exactly recovers AdamW. When and , it recovers Adam combined with full STORM recursion.
2. Algorithmic Workflow and Pseudocode
MARS-AdamW's step-by-step execution is summarized below:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
Input: x₀ ∈ ℝᵈ, η_t, λ ≥ 0, Adam β₁, β₂, VR-scale γ_t, clip threshold =1, ε>0
Initialize: m₀ ← 0, v₀ ← 0, x₁ ← x₀
For t = 1, ..., T:
1. Draw mini-batch ξ_t
2. Compute gₜ = ∇f(xₜ, ξₜ)
3. If t>1: δₜ = (β₁/(1−β₁)) · [gₜ − ∇f(xₜ₋₁, ξₜ)]
Else: δₜ ← 0
4. cₜ = gₜ + γ_t · δₜ
5. If ∥cₜ∥₂ > 1: ṡcₜ = cₜ / ∥cₜ∥₂
Else: ṡcₜ = cₜ
6. mₜ = β₁·mₜ₋₁ + (1−β₁)·ṡcₜ
7. vₜ = β₂·vₜ₋₁ + (1−β₂)·(ṡcₜ ⊙ ṡcₜ)
8. ẑmₜ = mₜ / (1−β₁ᵗ), ẑvₜ = vₜ / (1−β₂ᵗ)
9. xₜ₊₁ = xₜ − η_t·[ ẑmₜ / (√ẑvₜ + ε) + λ·xₜ ]
EndFor |
Key distinguishing feature: lines 3–5 integrate STORM variance reduction into AdamW’s preconditioner.
3. Hyperparameters and Selection Properties
MARS-AdamW introduces several hyperparameters, some inherited from AdamW and others governing the VR component. The major hyperparameters and empirical guidelines are:
| Hyperparameter | Role in Optimization | Recommended Setting / Range |
|---|---|---|
| Learning rate (schedule) | Cosine decay with linear warmup; peak : 6e-4 (GPT-2 small), 3e-4 (medium), 2e-4 (large) | |
| Moment-1 EMA | , best | |
| Moment-2 EMA | , default $0.99$ | |
| VR scale | Constant ; $0.025$ robust | |
| Numerical stability | ||
| Weight decay | $0.1$–$0.5$ (model-dependent) | |
| Clipping | Gradient stabilization | -norm threshold |
Batch size (480) and warm-up steps (2k) follow large-model conventions. A constant VR scale is favored over schedule-based variants.
4. Theoretical Properties and Convergence
Under standard assumptions that is -smooth, stochastic gradients are unbiased with bounded variance, and preconditioners are positive definite, MARS-AdamW inherits and extends convergence properties of its constituent algorithms:
- Incremental oracle complexity to find is .
- For , the method recovers the nearly-optimal STORM convergence rate for non-convex smooth optimization, as formalized in Arjevani et al. (2023).
- Full formal proof of convergence for the preconditioned variant is indicated as a direction for future work; empirically, no divergence or instability was observed in large-scale runs (Yuan et al., 2024).
5. Comparative Empirical Evaluation on GPT-2 Pretraining
Performance was assessed on GPT-2 models of varying scales (small: 125M, medium: 355M, large: 770M) using the OpenWebText dataset.
- Token Efficiency: For GPT-2 large, 27B tokens were required by MARS-AdamW to reach validation loss 2.58, while AdamW required 50B tokens. Final validation losses: 2.53 (MARS-AdamW) vs. 2.56 (AdamW).
- Wall-Clock Efficiency: The per-iteration cost of MARS-AdamW is approximately 5–10% higher than AdamW (due to the VR correction), but overall wall-clock time to achieve a given loss is reduced by 50–60%.
- Ablation Studies: Little difference was observed between the exact and approximate VR correction (MARS vs. MARS-AP), suggesting MARS-AP is preferable when computational cost is a concern. A constant outperformed linear scheduling schemes in terms of final validation loss. MARS-AdamW also consistently exceeded MARS-Lion in performance under matched conditions.
| Metric | AdamW | MARS-AdamW |
|---|---|---|
| Tokens to val loss 2.58 (GPT-2 large) | $50$B | $27$B |
| Final validation loss (GPT-2 large) | $2.56$ | $2.53$ |
| Wall-clock speed to given loss | Baseline | $1.5$– faster |
| Hellaswag 0-shot acc. (GPT-2 medium/large) |
- Downstream Transfer: For zero-shot transfer on Hellaswag (after 50B tokens), MARS-AdamW outperformed AdamW for both medium and large GPT-2 models.
6. Practical Considerations and Implementation
MARS-AdamW requires two gradient evaluations per iteration in the default setting (for the VR correction), though the approximate variant (MARS-AP) mitigates this overhead with minimal loss in quality. The optimizer introduces a modest (5–10%) per-iteration computational overhead relative to AdamW, but the net speedup from reduced training steps is substantial for large models. Stability is ensured by -norm gradient clipping and careful hyperparameter tuning. Empirically, the method demonstrated robust convergence and efficiency across all tested LLM-scale configurations (Yuan et al., 2024).