Sharpness-Aware Pretraining
- Sharpness-aware pretraining is a set of optimization techniques that adjust neural network training to find flatter minima by managing loss landscape curvature through adversarial perturbations and learning rate scheduling.
- Methods like SAM, SSAM, GSAM, and X-SAM balance computational cost with enhanced generalization, robustness, and reduced catastrophic forgetting across various architectures and datasets.
- Empirical studies demonstrate that these approaches improve performance in vision and language models, achieving tangible gains in accuracy and model survivability under fine-tuning and quantization.
Sharpness-aware pretraining encompasses a family of optimization approaches that explicitly manage the curvature of the loss landscape during neural network pretraining, with the goal of locating flatter minima and thereby improving generalization, robustness, and downstream model survivability under subsequent updates such as fine-tuning and quantization. Methods include perturbation-based minimization (SAM and its variants SSAM, GSAM, X-SAM), curvature/flatness-guided learning-rate scheduling (blockwise LR), and combined recipes focusing sharpness-reducing steps at specific training phases. Sharpness is commonly formalized via the spectrum of the Hessian of the loss function, particularly the largest eigenvalue, or via surrogate metrics connected to the adversarial loss in parameter space.
1. Theoretical Foundations and Objectives
Sharpness-aware optimization methods are motivated by the empirical observation that flatter minima—i.e., those in which the loss changes slowly with respect to parameter perturbations—lead to improved generalization and robustness. The canonical formulation is the minimax objective: where denotes the loss and is the perturbation radius. This robust optimization seeks parameters where the worst-case loss under small perturbations is minimized, inherently preferring flatter regions of the landscape. The effect is to regularize the dominant Hessian eigenvalue, reducing curvature in the sharpest directions. Analytical results establish a correspondence between adversarial-local loss gap and the largest Hessian eigenvalue at a minimum, connecting these metrics to generalization performance (Zhuang et al., 2022, Duan et al., 15 Jan 2026, Watts et al., 4 May 2026).
2. Core Algorithms: SAM and Major Extensions
Sharpness-Aware Minimization (SAM)
SAM (Zhuang et al., 2022, Mi et al., 2022) operates by alternately perturbing parameters in the direction of steepest ascent (scaled to radius ) and then performing a descent step using the gradient at this adversarially perturbed point:
- Compute
- Set
- Update:
This mechanism yields robust minima but at the expense of near-doubling per-step cost.
Sparse SAM (SSAM)
SSAM reduces the perturbation overhead of full SAM by masking the perturbation direction with a binary mask chosen to reflect parameter importance or flatness. Only a sparsified subset of weights receive adversarial perturbation at each step, using either (a) Fisher-information-based masks or (b) dynamic sparse assignments:
- Fisher: mask coordinates with highest Fisher information (squared gradient magnitude)
- Dynamic: iteratively prune and regrow mask entries, focusing on "flat" directions (Mi et al., 2022)
SSAM achieves up to reduction in ascent-step computational cost with negligible or positive impact on test accuracy; for example, on CIFAR-10 and ImageNet, SSAM at 50% sparsity matches or improves over full SAM.
Surrogate Gap Guided SAM (GSAM)
GSAM introduces the concept of the surrogate gap 0 as a flatness metric, and augments the SAM update by:
- Regular SAM step (gradient at 1)
- Orthogonal ascent in the direction that reduces the surrogate gap without increasing 2
- Update rule: 3 for suitable 4 (Zhuang et al., 2022)
This two-step approach tightens generalization bounds via PAC-Bayesian analysis and yields consistent increases in out-of-distribution performance and transfer learning.
X-SAM: Eigenvector-Aligned Correction
X-SAM addresses limitations of canonical SAM, specifically the scenario where worst-case loss is insensitive to the sharpest curvature direction due to gradient orthogonality. X-SAM explicitly corrects the perturbed gradient by decomposing it along the leading eigenvector 5 of the local Hessian:
- Compute perturbed gradient 6
- Orthogonalize and subtract/attenuate the component along 7 with correction factor 8
- Update: 9 (Duan et al., 15 Jan 2026)
Under mild assumptions, X-SAM provably suppresses the largest eigenvalue of the Hessian more effectively than SAM and accelerates the reduction of sharpness.
3. Empirical Findings and Benchmark Results
The suite of sharpness-aware methods has been evaluated in large-scale vision and LLM pretraining across multiple architectures (ResNet, ViT, GPT-2, LLaMA) and datasets (CIFAR-10/100, ImageNet, OpenWebText, MiniPile). Key empirical results include:
| Method/Setting | Dataset/Model | Baseline Top-1 | SAM | Variant (Best) | Sparsity | Gain |
|---|---|---|---|---|---|---|
| SAM vs SGD | CIFAR-10/ResNet-18 | 96.58% (SGD) | 96.83% | - | 0% | +0.25% |
| SSAM-F/D | CIFAR-10/ResNet-18 | - | 96.83% | 96.84% / 96.74% | 50% | ≈+0.01% |
| X-SAM | CIFAR-10/ResNet-18 | 93.64% (SAM) | 93.64% | 94.71% (X-SAM) | - | +1.07% |
| GSAM (ViT-B/32) | ImageNet-1k | 71.4% (AdamW) | 73.6% | 76.8% (GSAM) | - | +3.2% |
| SAM (OLMo-60M) | StarCoder fine-tune | ΔLoss +0.5 (AdamW) | +0.5 | +0.1 (SAM) | - | —80% forgetting reduction |
Collectively, these methods improve generalization and robustness, with consistent boosts in both in-domain and transfer tasks, as well as improved resilience under downstream fine-tuning, post-training, quantization, and random weight perturbations (Mi et al., 2022, Duan et al., 15 Jan 2026, Zhuang et al., 2022, Watts et al., 4 May 2026).
4. Practical Methodologies and Implementation Details
Implementation of sharpness-aware pretraining algorithms requires careful selection of perturbation radii, correction strength, and update intervals, all of which interact with learning rate schedules and batch normalization. Notable practical guidelines are:
- Typical 0 values: 0.07–0.1 (ImageNet-size), up to 0.6 for ViT; tune via grid search.
- Correction strengths (1): 0.01–0.4 (task-dependent), higher for ViT/Mixer.
- Mask update interval 2 (best for SSAM); Fisher sample size 3 is sufficient for stable masking.
- In large-scale LLM pretraining, blockwise learning rate scaling can be switched on post-warmup and tuned with 4 multipliers per architectural block (4–12, etc.) (Wang et al., 26 Feb 2025).
Where computational cost is critical, interventions such as "SAM-shade" (applying SAM or a variant only during final learning rate decay) recover most benefits at a fraction of the cost (Watts et al., 4 May 2026). Monitoring the dominant Hessian eigenvalue 5 throughout training provides diagnostic control to ensure flatness is achieved.
5. Robustness, Catastrophic Forgetting, and Model Survivability
Sharpness-aware pretraining directly impacts the survivability of pretrained models under subsequent updates. Lower Hessian trace and directional curvature minimize catastrophic forgetting—the loss of pretrained capabilities due to fine-tuning, quantization, or weight perturbations. Empirical results on OLMo models show up to 80% reductions in forgetting across modalities, including 31% less forgetting post MetaMath, and 40% less drop under 4-bit quantization (Watts et al., 4 May 2026). High peak learning rates and shortened annealing phases (inducing the "Edge of Stability" effect) also cap Hessian eigenvalues and yield similar robustness, even if such configurations do not minimize base pretraining loss.
6. Blockwise Sharpness Disparity and Learning-Rate Scheduling
Transformers display persistent blockwise sharpness disparity: embeddings remain the flattest, while final LayerNorm is the sharpest. This observation underpins blockwise learning-rate (LR) schedules, where peak LRs are upscaled in blocks with low sharpness, delivering up to 6 acceleration and lower terminal loss in GPT-2 and LLaMA pretraining tasks (Wang et al., 26 Feb 2025). The LR per block 7 is scaled according to 8, estimated from blockwise average Fisher trace. Integration into memory-efficient Adam variants (e.g., Adam-mini) is direct. This approach is complementary to perturbation-based sharpness methods.
| Block Type | Typical Sharpness | Recommended LR Multiplier |
|---|---|---|
| Embedding | Lowest | 9–12 |
| Norm0 | Low | 1–8 |
| Attention | Moderate | 2–6 |
| FFN | Moderate-high | 3–6 |
| Norm4 | Highest | 5 |
Parameterized tuning is stable across dataset/model variations.
7. Limitations and Spectral Perspectives
Both theoretical analysis and empirical findings demonstrate that standard SAM and related methods may occasionally fail to directly minimize the sharpest curvature when the gradient is nearly orthogonal to the top Hessian eigenvector. Methods such as X-SAM and GSAM address this by explicitly manipulating gradient projections, offering stronger control over the loss landscape's most problematic directions (Zhuang et al., 2022, Duan et al., 15 Jan 2026).
There is no guarantee that minimum adversarial perturbed loss always correlates perfectly with lowest Hessian spectral norm, and sharp minima may present spurious plateaux for poorly chosen perturbation radii or schedules. Attention to spectral metrics and surrogate gap minimization is critical for robust deployment.
Sharpness-aware pretraining, in its various forms, constitutes a rigorously characterized and empirically validated set of methods for improving neural network generalization, robustness, and downstream adaptability by systematically reducing the curvature of the loss landscape through perturbation-based optimization, blockwise learning-rate policies, and spectral corrections (Mi et al., 2022, Zhuang et al., 2022, Watts et al., 4 May 2026, Wang et al., 26 Feb 2025, Duan et al., 15 Jan 2026).