Papers
Topics
Authors
Recent
Search
2000 character limit reached

Sharpness-Aware Pretraining

Updated 9 May 2026
  • Sharpness-aware pretraining is a set of optimization techniques that adjust neural network training to find flatter minima by managing loss landscape curvature through adversarial perturbations and learning rate scheduling.
  • Methods like SAM, SSAM, GSAM, and X-SAM balance computational cost with enhanced generalization, robustness, and reduced catastrophic forgetting across various architectures and datasets.
  • Empirical studies demonstrate that these approaches improve performance in vision and language models, achieving tangible gains in accuracy and model survivability under fine-tuning and quantization.

Sharpness-aware pretraining encompasses a family of optimization approaches that explicitly manage the curvature of the loss landscape during neural network pretraining, with the goal of locating flatter minima and thereby improving generalization, robustness, and downstream model survivability under subsequent updates such as fine-tuning and quantization. Methods include perturbation-based minimization (SAM and its variants SSAM, GSAM, X-SAM), curvature/flatness-guided learning-rate scheduling (blockwise LR), and combined recipes focusing sharpness-reducing steps at specific training phases. Sharpness is commonly formalized via the spectrum of the Hessian of the loss function, particularly the largest eigenvalue, or via surrogate metrics connected to the adversarial loss in parameter space.

1. Theoretical Foundations and Objectives

Sharpness-aware optimization methods are motivated by the empirical observation that flatter minima—i.e., those in which the loss changes slowly with respect to parameter perturbations—lead to improved generalization and robustness. The canonical formulation is the minimax objective: minwmaxϵ2ρL(w+ϵ)\min_w \max_{\|\epsilon\|_2\leq\rho} L(w+\epsilon) where LL denotes the loss and ρ\rho is the perturbation radius. This robust optimization seeks parameters ww where the worst-case loss under small perturbations is minimized, inherently preferring flatter regions of the landscape. The effect is to regularize the dominant Hessian eigenvalue, reducing curvature in the sharpest directions. Analytical results establish a correspondence between adversarial-local loss gap and the largest Hessian eigenvalue at a minimum, connecting these metrics to generalization performance (Zhuang et al., 2022, Duan et al., 15 Jan 2026, Watts et al., 4 May 2026).

2. Core Algorithms: SAM and Major Extensions

Sharpness-Aware Minimization (SAM)

SAM (Zhuang et al., 2022, Mi et al., 2022) operates by alternately perturbing parameters in the direction of steepest ascent (scaled to radius ρ\rho) and then performing a descent step using the gradient at this adversarially perturbed point:

  • Compute ϵ=ρL(w)/L(w)2\epsilon^* = \rho\, \nabla L(w)/\|\nabla L(w)\|_2
  • Set w~=w+ϵ\tilde w = w + \epsilon^*
  • Update: wwηL(w~)w \gets w - \eta\nabla L(\tilde w)

This mechanism yields robust minima but at the expense of near-doubling per-step cost.

Sparse SAM (SSAM)

SSAM reduces the perturbation overhead of full SAM by masking the perturbation direction with a binary mask mm chosen to reflect parameter importance or flatness. Only a sparsified subset of weights receive adversarial perturbation at each step, using either (a) Fisher-information-based masks or (b) dynamic sparse assignments:

  • Fisher: mask coordinates with highest Fisher information (squared gradient magnitude)
  • Dynamic: iteratively prune and regrow mask entries, focusing on "flat" directions (Mi et al., 2022)

SSAM achieves up to 2×2\times reduction in ascent-step computational cost with negligible or positive impact on test accuracy; for example, on CIFAR-10 and ImageNet, SSAM at 50% sparsity matches or improves over full SAM.

Surrogate Gap Guided SAM (GSAM)

GSAM introduces the concept of the surrogate gap LL0 as a flatness metric, and augments the SAM update by:

  • Regular SAM step (gradient at LL1)
  • Orthogonal ascent in the direction that reduces the surrogate gap without increasing LL2
  • Update rule: LL3 for suitable LL4 (Zhuang et al., 2022)

This two-step approach tightens generalization bounds via PAC-Bayesian analysis and yields consistent increases in out-of-distribution performance and transfer learning.

X-SAM: Eigenvector-Aligned Correction

X-SAM addresses limitations of canonical SAM, specifically the scenario where worst-case loss is insensitive to the sharpest curvature direction due to gradient orthogonality. X-SAM explicitly corrects the perturbed gradient by decomposing it along the leading eigenvector LL5 of the local Hessian:

  • Compute perturbed gradient LL6
  • Orthogonalize and subtract/attenuate the component along LL7 with correction factor LL8
  • Update: LL9 (Duan et al., 15 Jan 2026)

Under mild assumptions, X-SAM provably suppresses the largest eigenvalue of the Hessian more effectively than SAM and accelerates the reduction of sharpness.

3. Empirical Findings and Benchmark Results

The suite of sharpness-aware methods has been evaluated in large-scale vision and LLM pretraining across multiple architectures (ResNet, ViT, GPT-2, LLaMA) and datasets (CIFAR-10/100, ImageNet, OpenWebText, MiniPile). Key empirical results include:

Method/Setting Dataset/Model Baseline Top-1 SAM Variant (Best) Sparsity Gain
SAM vs SGD CIFAR-10/ResNet-18 96.58% (SGD) 96.83% - 0% +0.25%
SSAM-F/D CIFAR-10/ResNet-18 - 96.83% 96.84% / 96.74% 50% ≈+0.01%
X-SAM CIFAR-10/ResNet-18 93.64% (SAM) 93.64% 94.71% (X-SAM) - +1.07%
GSAM (ViT-B/32) ImageNet-1k 71.4% (AdamW) 73.6% 76.8% (GSAM) - +3.2%
SAM (OLMo-60M) StarCoder fine-tune ΔLoss +0.5 (AdamW) +0.5 +0.1 (SAM) - —80% forgetting reduction

Collectively, these methods improve generalization and robustness, with consistent boosts in both in-domain and transfer tasks, as well as improved resilience under downstream fine-tuning, post-training, quantization, and random weight perturbations (Mi et al., 2022, Duan et al., 15 Jan 2026, Zhuang et al., 2022, Watts et al., 4 May 2026).

4. Practical Methodologies and Implementation Details

Implementation of sharpness-aware pretraining algorithms requires careful selection of perturbation radii, correction strength, and update intervals, all of which interact with learning rate schedules and batch normalization. Notable practical guidelines are:

  • Typical ρ\rho0 values: 0.07–0.1 (ImageNet-size), up to 0.6 for ViT; tune via grid search.
  • Correction strengths (ρ\rho1): 0.01–0.4 (task-dependent), higher for ViT/Mixer.
  • Mask update interval ρ\rho2 (best for SSAM); Fisher sample size ρ\rho3 is sufficient for stable masking.
  • In large-scale LLM pretraining, blockwise learning rate scaling can be switched on post-warmup and tuned with 4 multipliers per architectural block (ρ\rho4–12, etc.) (Wang et al., 26 Feb 2025).

Where computational cost is critical, interventions such as "SAM-shade" (applying SAM or a variant only during final learning rate decay) recover most benefits at a fraction of the cost (Watts et al., 4 May 2026). Monitoring the dominant Hessian eigenvalue ρ\rho5 throughout training provides diagnostic control to ensure flatness is achieved.

5. Robustness, Catastrophic Forgetting, and Model Survivability

Sharpness-aware pretraining directly impacts the survivability of pretrained models under subsequent updates. Lower Hessian trace and directional curvature minimize catastrophic forgetting—the loss of pretrained capabilities due to fine-tuning, quantization, or weight perturbations. Empirical results on OLMo models show up to 80% reductions in forgetting across modalities, including 31% less forgetting post MetaMath, and 40% less drop under 4-bit quantization (Watts et al., 4 May 2026). High peak learning rates and shortened annealing phases (inducing the "Edge of Stability" effect) also cap Hessian eigenvalues and yield similar robustness, even if such configurations do not minimize base pretraining loss.

6. Blockwise Sharpness Disparity and Learning-Rate Scheduling

Transformers display persistent blockwise sharpness disparity: embeddings remain the flattest, while final LayerNorm is the sharpest. This observation underpins blockwise learning-rate (LR) schedules, where peak LRs are upscaled in blocks with low sharpness, delivering up to ρ\rho6 acceleration and lower terminal loss in GPT-2 and LLaMA pretraining tasks (Wang et al., 26 Feb 2025). The LR per block ρ\rho7 is scaled according to ρ\rho8, estimated from blockwise average Fisher trace. Integration into memory-efficient Adam variants (e.g., Adam-mini) is direct. This approach is complementary to perturbation-based sharpness methods.

Block Type Typical Sharpness Recommended LR Multiplier
Embedding Lowest ρ\rho9–12
Normww0 Low ww1–8
Attention Moderate ww2–6
FFN Moderate-high ww3–6
Normww4 Highest ww5

Parameterized tuning is stable across dataset/model variations.

7. Limitations and Spectral Perspectives

Both theoretical analysis and empirical findings demonstrate that standard SAM and related methods may occasionally fail to directly minimize the sharpest curvature when the gradient is nearly orthogonal to the top Hessian eigenvector. Methods such as X-SAM and GSAM address this by explicitly manipulating gradient projections, offering stronger control over the loss landscape's most problematic directions (Zhuang et al., 2022, Duan et al., 15 Jan 2026).

There is no guarantee that minimum adversarial perturbed loss always correlates perfectly with lowest Hessian spectral norm, and sharp minima may present spurious plateaux for poorly chosen perturbation radii or schedules. Attention to spectral metrics and surrogate gap minimization is critical for robust deployment.


Sharpness-aware pretraining, in its various forms, constitutes a rigorously characterized and empirically validated set of methods for improving neural network generalization, robustness, and downstream adaptability by systematically reducing the curvature of the loss landscape through perturbation-based optimization, blockwise learning-rate policies, and spectral corrections (Mi et al., 2022, Zhuang et al., 2022, Watts et al., 4 May 2026, Wang et al., 26 Feb 2025, Duan et al., 15 Jan 2026).

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Sharpness-Aware Pretraining.