Momentum-Aligned Gradient Masking (Magma)
- The paper introduces Magma as a lightweight wrapper for adaptive optimizers that uses stochastic masking modulated by momentum–gradient alignment to enforce curvature-dependent regularization.
- Magma is an optimization method that applies Bernoulli masking and cosine-similarity scaling to gradient updates, selectively damping steps in high-curvature regions typical of transformers.
- Empirical evaluations demonstrate that Magma improves perplexity and convergence for large language models while conferring enhanced learning-rate robustness compared to standard optimizers.
Momentum-aligned Gradient Masking (Magma) is an optimization technique developed for training LLMs that operates as a lightweight wrapper around existing dense adaptive optimizers. By introducing stochastic sparsification of parameter updates modulated via momentum–gradient alignment, Magma achieves geometric regularization in high-curvature and heterogeneous-loss landscapes typical for transformers, leading to more stable and performant pre-training. The method was first proposed and analyzed in "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers" (Joo et al., 17 Feb 2026).
1. Algorithmic Structure and Implementation
Magma operates as a parameter-wise masking-and-scaling transformation within any dense adaptive optimizer (e.g., Adam, RMSProp, Muon). At each optimization step, parameters are divided into blocks. For each block , a Bernoulli mask is sampled, where is the survival probability (default ). The key innovation is to not simply mask updates but to modulate the surviving updates by an alignment score derived from the cosine similarity between the running momentum estimate and the instantaneous gradient .
Let denote the proposed block-wise update from the base optimizer, commonly of the form for RMSProp/Adam. The final parameter update is computed as
0
with the masked and modulated step applied as 1. Momentum and other optimizer state variables (e.g., moments) remain updated in a dense (unmasked) fashion.
The instantaneous alignment is calculated as
2
where 3 is a temperature hyperparameter (default 4). This is further smoothed:
5
where 6 (default 7) is the exponential moving average decay.
2. Theoretical Properties and Regularization Dynamics
Magma induces an explicit form of curvature-dependent regularization through its stochastic masking. For a simplified base optimizer (e.g., SGD), the expectation of the loss after applying 8 can be written as
9
where 0 is the 1-th block diagonal of the Hessian. Thus, masking introduces a quadratic penalty proportional to the curvature in each blockwise direction, selectively dampening updates in high-curvature regions.
The momentum–gradient alignment further modulates masked updates by prioritizing directions where the current gradient and the historical momentum agree, so that updates with poor alignment are attenuated and updates with strong alignment are preserved. Although this introduces a small bias (i.e., 2), empirically this tradeoff results in dramatically increased optimization stability.
A one-step descent bound for constant-step SGD with masking establishes that Magma shrinks the effective block-wise smoothness constants and preserves a quantifiable fraction of the true descent rate, yielding the guarantee
3
with 4 a measure of average alignment and 5 the reduced noise floor, confirming stability and convergence in highly nonconvex, ill-conditioned problems.
3. Empirical Evaluation and Benchmarking
Extensive pre-training experiments on transformer architectures substantiate the effectiveness of Magma. The key empirical axes are as follows:
- Llama 2 on C4 dataset:
- Models: 60M, 130M, 350M, 1B parameters (GPT-style)
- At 1B parameters, Adam+Magma achieved a validation perplexity of 13.71 vs. 16.35 for Adam; RMSProp+Magma outperformed all baselines, converging stably where plain RMSProp diverged (perplexity 13.19 vs. 14.52 for Muon).
- MoE Transformer on OpenWebText:
- 124M parameters, 8 experts.
- Both Magma+Adam and Magma+Muon produced improved perplexity and training stability compared to respective baselines and C-Adam.
- Linear attention with synthetic noise:
- For heavy-tailed gradient noise, Magma consistently yielded lower losses and a reduced robust Hessian condition number than Adam.
- Quadratic benchmarking:
- Magma showed accelerated convergence and lower final loss than AdamW under heterogeneous curvature across blocks, yet conferred no improvement for homogeneous (CNN-like) landscapes.
4. Ablations, Sensitivity Analyses, and Practical Guidance
Ablation studies and sensitivity analyses identify practical instantiations and design choices:
- Masking location and granularity: Applying masking to both attention and MLP blocks yields optimal performance (e.g., for 130M Llama, perplexity 21.65 vs. baseline 22.64). Masking at element, row, column, or block level yields comparable validation perplexity; block-level is chosen for computational simplicity.
- Hyperparameters: Default 6, 7, and 8 are robust. Performance is stable for 9.
- Momentum updates: Dense (unmasked) momentum is necessary for stable convergence; momentum masking without additional damping destabilizes training.
- Learning-rate tolerance: Magma confers enhanced learning-rate robustness; Adam+Magma is stable up to 0, while Adam and C-Adam require narrow schedules 1.
- Computational overhead: The additional operations are 2 masks, 3 cosine similarities, and 4 EMA updates per step, negligible in comparison to forward and backward computational cost.
5. Limitations and Open Questions
Limitations of Magma are concentrated around its operational domain and theoretical underpinnings:
- The momentum-based damping 5 biases updates, contrasting with unbiased rescaling in SkipUpdate (a Magma variant where 6). The question of achieving stable but unbiased modulation remains open.
- The benefits of Magma are largely confined to ill-conditioned, heterogeneous curvature settings such as large transformers. No improvement is observed for architectures with more homogeneous loss landscapes, e.g., ResNet-50 on CIFAR-10.
- The theoretical guarantees are presently rigorously established only for plain SGD; extending full convergence proofs to adaptive schemes (e.g., Adam, RMSProp) is an area for future work.
- Integration with other sparse or low-rank optimization techniques, such as GaLore or subspace descent, as well as adaptively varying the masking schedule or survival probability 7 with task or training phase, constitutes open terrain for further exploration.
6. Summary of Design Choices and Application Scope
Magma is implemented as a drop-in wrapper around modern adaptive optimizers and requires no additional gradient evaluations or optimizer states. Its defaults confer robust improvements for transformer-based LLMs, especially where gradient noise is heavy-tailed and curvature is highly block-heterogeneous. The method emphasizes momentum-driven adaptation of stochastic mask-induced regularization, prioritizing optimizer stability and effective curvature suppression. Its effectiveness in LLM pre-training, combined with negligible computational cost, suggests high practical utility for state-of-the-art generative modeling (Joo et al., 17 Feb 2026).