Papers
Topics
Authors
Recent
Search
2000 character limit reached

Momentum-Aligned Gradient Masking (Magma)

Updated 2 July 2026
  • The paper introduces Magma as a lightweight wrapper for adaptive optimizers that uses stochastic masking modulated by momentum–gradient alignment to enforce curvature-dependent regularization.
  • Magma is an optimization method that applies Bernoulli masking and cosine-similarity scaling to gradient updates, selectively damping steps in high-curvature regions typical of transformers.
  • Empirical evaluations demonstrate that Magma improves perplexity and convergence for large language models while conferring enhanced learning-rate robustness compared to standard optimizers.

Momentum-aligned Gradient Masking (Magma) is an optimization technique developed for training LLMs that operates as a lightweight wrapper around existing dense adaptive optimizers. By introducing stochastic sparsification of parameter updates modulated via momentum–gradient alignment, Magma achieves geometric regularization in high-curvature and heterogeneous-loss landscapes typical for transformers, leading to more stable and performant pre-training. The method was first proposed and analyzed in "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers" (Joo et al., 17 Feb 2026).

1. Algorithmic Structure and Implementation

Magma operates as a parameter-wise masking-and-scaling transformation within any dense adaptive optimizer (e.g., Adam, RMSProp, Muon). At each optimization step, parameters are divided into BB blocks. For each block bb, a Bernoulli mask mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p) is sampled, where pp is the survival probability (default p=0.5p=0.5). The key innovation is to not simply mask updates but to modulate the surviving updates by an alignment score st(b)s_t^{(b)} derived from the cosine similarity between the running momentum estimate μt(b)\mu_t^{(b)} and the instantaneous gradient gt(b)g_t^{(b)}.

Let Δt(b)\Delta_t^{(b)} denote the proposed block-wise update from the base optimizer, commonly of the form Δt(b)=ηt(Vt(b))1/2gt(b)\Delta_t^{(b)} = \eta_t (V_t^{(b)})^{-1/2} g_t^{(b)} for RMSProp/Adam. The final parameter update is computed as

bb0

with the masked and modulated step applied as bb1. Momentum and other optimizer state variables (e.g., moments) remain updated in a dense (unmasked) fashion.

The instantaneous alignment is calculated as

bb2

where bb3 is a temperature hyperparameter (default bb4). This is further smoothed:

bb5

where bb6 (default bb7) is the exponential moving average decay.

2. Theoretical Properties and Regularization Dynamics

Magma induces an explicit form of curvature-dependent regularization through its stochastic masking. For a simplified base optimizer (e.g., SGD), the expectation of the loss after applying bb8 can be written as

bb9

where mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)0 is the mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)1-th block diagonal of the Hessian. Thus, masking introduces a quadratic penalty proportional to the curvature in each blockwise direction, selectively dampening updates in high-curvature regions.

The momentum–gradient alignment further modulates masked updates by prioritizing directions where the current gradient and the historical momentum agree, so that updates with poor alignment are attenuated and updates with strong alignment are preserved. Although this introduces a small bias (i.e., mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)2), empirically this tradeoff results in dramatically increased optimization stability.

A one-step descent bound for constant-step SGD with masking establishes that Magma shrinks the effective block-wise smoothness constants and preserves a quantifiable fraction of the true descent rate, yielding the guarantee

mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)3

with mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)4 a measure of average alignment and mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)5 the reduced noise floor, confirming stability and convergence in highly nonconvex, ill-conditioned problems.

3. Empirical Evaluation and Benchmarking

Extensive pre-training experiments on transformer architectures substantiate the effectiveness of Magma. The key empirical axes are as follows:

  • Llama 2 on C4 dataset:
    • Models: 60M, 130M, 350M, 1B parameters (GPT-style)
    • At 1B parameters, Adam+Magma achieved a validation perplexity of 13.71 vs. 16.35 for Adam; RMSProp+Magma outperformed all baselines, converging stably where plain RMSProp diverged (perplexity 13.19 vs. 14.52 for Muon).
  • MoE Transformer on OpenWebText:
    • 124M parameters, 8 experts.
    • Both Magma+Adam and Magma+Muon produced improved perplexity and training stability compared to respective baselines and C-Adam.
  • Linear attention with synthetic noise:
    • For heavy-tailed gradient noise, Magma consistently yielded lower losses and a reduced robust Hessian condition number than Adam.
  • Quadratic benchmarking:
    • Magma showed accelerated convergence and lower final loss than AdamW under heterogeneous curvature across blocks, yet conferred no improvement for homogeneous (CNN-like) landscapes.

4. Ablations, Sensitivity Analyses, and Practical Guidance

Ablation studies and sensitivity analyses identify practical instantiations and design choices:

  • Masking location and granularity: Applying masking to both attention and MLP blocks yields optimal performance (e.g., for 130M Llama, perplexity 21.65 vs. baseline 22.64). Masking at element, row, column, or block level yields comparable validation perplexity; block-level is chosen for computational simplicity.
  • Hyperparameters: Default mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)6, mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)7, and mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)8 are robust. Performance is stable for mt(b)Bernoulli(p)m_t^{(b)} \sim \mathrm{Bernoulli}(p)9.
  • Momentum updates: Dense (unmasked) momentum is necessary for stable convergence; momentum masking without additional damping destabilizes training.
  • Learning-rate tolerance: Magma confers enhanced learning-rate robustness; Adam+Magma is stable up to pp0, while Adam and C-Adam require narrow schedules pp1.
  • Computational overhead: The additional operations are pp2 masks, pp3 cosine similarities, and pp4 EMA updates per step, negligible in comparison to forward and backward computational cost.

5. Limitations and Open Questions

Limitations of Magma are concentrated around its operational domain and theoretical underpinnings:

  • The momentum-based damping pp5 biases updates, contrasting with unbiased rescaling in SkipUpdate (a Magma variant where pp6). The question of achieving stable but unbiased modulation remains open.
  • The benefits of Magma are largely confined to ill-conditioned, heterogeneous curvature settings such as large transformers. No improvement is observed for architectures with more homogeneous loss landscapes, e.g., ResNet-50 on CIFAR-10.
  • The theoretical guarantees are presently rigorously established only for plain SGD; extending full convergence proofs to adaptive schemes (e.g., Adam, RMSProp) is an area for future work.
  • Integration with other sparse or low-rank optimization techniques, such as GaLore or subspace descent, as well as adaptively varying the masking schedule or survival probability pp7 with task or training phase, constitutes open terrain for further exploration.

6. Summary of Design Choices and Application Scope

Magma is implemented as a drop-in wrapper around modern adaptive optimizers and requires no additional gradient evaluations or optimizer states. Its defaults confer robust improvements for transformer-based LLMs, especially where gradient noise is heavy-tailed and curvature is highly block-heterogeneous. The method emphasizes momentum-driven adaptation of stochastic mask-induced regularization, prioritizing optimizer stability and effective curvature suppression. Its effectiveness in LLM pre-training, combined with negligible computational cost, suggests high practical utility for state-of-the-art generative modeling (Joo et al., 17 Feb 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Momentum-aligned Gradient Masking (Magma).