Momentum-Aligned Gradient Masking (Magma)
- The paper presents Magma, an adaptive optimizer enhancement that uses momentum-gradient alignment to mask parameter updates, reducing validation perplexity by over 19% in LLM pre-training.
- Magma employs block-wise Bernoulli masking with rescaling to inject curvature-dependent regularization, steering optimization toward flatter regions and improved convergence.
- Empirical evaluations on LLaMA-style transformers confirm Magma’s efficiency, achieving faster convergence and lower error floors with minimal additional computational overhead.
Momentum-Aligned Gradient Masking (Magma) is an optimization scheme for large-scale neural network training, particularly designed for LLMs. Magma augments adaptive optimizers by randomly masking parameter updates in a manner guided by the alignment between momentum and instantaneous gradients. This method introduces implicit geometric regularization, enhances training stability and generalization, and maintains computational efficiency suitable for large-scale transformer architectures (Joo et al., 17 Feb 2026).
1. Motivation: Curvature-Dependent Regularization via Masked Updates
Conventional adaptive optimizers such as RMSProp and Adam employ dense preconditioning to mitigate curvature heterogeneity by adjusting learning rates per parameter or block. Magma’s foundation is the observation that randomly masking a subset of parameter updates at each step—while still updating momentum and second-moment statistics densely—can yield systematic improvements in optimizer behavior and model generalization. This masking can be formalized by partitioning parameters into blocks and, at each iteration, sampling independent Bernoulli masks . Masked blocks forgo updates, and surviving updates are rescaled for unbiasedness: .
A second-order Taylor analysis demonstrates that this random masking injects an additional penalty (with the block Hessian), explicitly regularizing updates in directions of high curvature. This steers optimization toward flatter regions, functioning analogously to sharpness-aware methods but derived from stochastic masking rather than explicit regularizers (Joo et al., 17 Feb 2026).
2. SkipUpdate and the Masked RMSProp Baseline
The precursor to Magma, termed “SkipUpdate,” applies masking uniformly across blocks, irrespective of block-specific signals. In the baseline masked RMSProp, block-wise statistics are computed as follows:
- First moment (momentum):
- Second moment:
- Preconditioned update: After masking and rescaling as above, the first moment is preserved, while the added curvature-regularization persists.
While effective, SkipUpdate does not differentiate between blocks according to their optimization state or noisiness, potentially discarding informative updates and retaining erratic ones.
3. Momentum-Gradient Alignment and Adaptive Masking
Magma extends SkipUpdate by modulating the masking probability based on the alignment between the current momentum estimate and the instantaneous gradient for each block. The alignment quality is measured by their cosine similarity:
High (positive) values indicate that current gradients reinforce accumulated momentum, signaling reliable descent directions; low or negative values suggest stochasticity or gradient oscillations.
4. Magma Update Rule and Algorithm
For each block and iteration , Magma computes an alignment-based score:
- Alignment score: , with as a temperature hyperparameter.
- Exponential smoothing: , typically with .
A Bernoulli mask is then sampled. The update applied is:
and the parameters are updated via .
This process preferentially transmits updates with strong momentum-gradient alignment, adaptively suppressing those likely to be noisy or unproductive. Magma’s computational burden is minimal, requiring additional dot-product and scalar operations per step (Joo et al., 17 Feb 2026).
5. Mechanisms Behind Improved Optimization
Analysis reveals that Magma, through alignment-aware masking, amplifies optimizer selectivity:
- Progress is focused on blocks with coherent, low-noise gradient trajectories.
- Noisy or high-curvature blocks, identified via poor alignment, are down-weighted or skipped. This mechanism reduces curvature-weighted noise, smooths the effective loss landscape, and enlarges the domain for stable learning-rate schedules. Empirical evidence demonstrates accelerated convergence and lower error floors in transformer architectures with high parameter heterogeneity.
6. Empirical Performance in LLM Pre-Training
Magma’s efficacy has been evaluated on LLaMA-style transformers using the C4 corpus across model scales from 60M to 1B parameters. Integrations with Adam, LaProp, and RMSProp were compared to C-Adam, SGG, Adafactor, APOLLO, SOAP, and matrix-prefactorized Muon. For 1B-parameter models, RMSProp+Magma achieved validation perplexity of 13.19, representing reductions of over 19% compared to Adam (16.35) and over 9% compared to Muon (14.52). These gains are substantial relative to established adaptive optimizers and are consistent across smaller model regimes.
Magma also demonstrated robust performance enhancements in mixture-of-experts (Nano MoE) pre-training and synthetic heavy-tailed, heterogeneous quadratic benchmarks, exhibiting superior stability and final accuracy (Joo et al., 17 Feb 2026).
7. Practical Considerations and Computational Overhead
Magma is implemented as a wrapper around existing adaptive optimizers. The additional computational cost is restricted to extra inner-products, a scalar sigmoid evaluation, an exponential moving average, and a Bernoulli sample per block per iteration. Since (block count) is much less than the overall parameter count in modern LLMs, the overhead is negligible in practice. Magma requires no extra gradient computations or significant changes to data handling, supporting straightforward integration into large-scale training workflows.
Momentum-Aligned Gradient Masking (Magma) is positioned as a robust, theoretically motivated, and empirically validated drop-in enhancement to adaptive optimization for large-scale neural network models (Joo et al., 17 Feb 2026).