SharpDRO: Sharpness-Aware Robust Optimization
- SharpDRO is a robust optimization method that integrates sharpness-aware penalties to mitigate overfitting on rare, severely corrupted data.
- It employs a min–max–min structure with per-example loss perturbations to form a flat and reliable loss landscape for worst-case examples.
- Empirical results on CIFAR and ImageNet benchmarks show significant gains in robustness, particularly for high-severity corruptions, outperforming standard DRO.
SharpDRO is a robust optimization method—“Sharpness-aware Distributionally Robust Optimization”—designed to achieve robust generalization on data mixtures where rare, severely corrupted examples (notably photon-limited corruptions) are present. Unlike traditional Distributionally Robust Optimization (DRO), which minimizes the worst-case empirical risk and consequently may produce sharp loss landscapes with poor test generalization, SharpDRO augments the standard DRO formulation with an explicit sharpness minimization penalty concentrated on the hardest (worst-case) distributions or examples. This strategy promotes solutions that not only achieve low risk but are also robust to local perturbations in the parameter space for the most challenging subsets of data (Huang et al., 2023).
1. Formal Problem Statement and Objective
Let denote the parameter space and the per-example loss. Training data are drawn from a mixture of sub-distributions reflecting different corruption severities with . The overall data distribution is .
Traditional DRO formulates robust risk minimization as: where is typically an -divergence ball or a set of mixture distributions.
SharpDRO introduces a sharpness penalty, focusing the sharpness minimization on the worst-case empirical distribution . The objective becomes: where measures the sharpness of the loss landscape locally at for each sample. can be parameterized by weights over sub-distributions (distribution-aware) or per-example out-of-distribution (OOD) scores (distribution-agnostic), leading to: with being the appropriate simplex.
2. Sharpness Definition and Computation
SharpDRO quantifies sharpness per example as the maximum increase in loss under an -bounded parameter perturbation: For smooth and small , first-order approximation yields
Implementationally, sharpness can be computed via two per-batch forward/backward passes, evaluating the loss at and .
3. Optimization Structure and Algorithm
Standard DRO is a min–max problem: SharpDRO extends this to a min–max–min (or min–max–max) structure by penalizing sharpness only for the worst-case distribution: subject to .
The iterative algorithm follows these main steps:
- Max-step (worst-case reweighting): Find or update to focus on hardest (group or example) distributions.
- Min-step (parameter update): Update with respect to risk and sharpness under .
Training Pseudocode (abbreviated)
1 2 3 4 5 6 7 8 9 10 11 12 13 |
for t = 0 ... T-1: # Max-step: update ω if distribution-aware: ω_{t+1} ← argmax_{ω∈Δ} Σ_s ω_s E_{(x,y)∼P_s}[ℒ(θ_t;(x,y))] else: # OOD scoring: ω_i ∝ max f(θ_t;x_i) − max f(θ_t+ε*;x_i) normalize ω to simplex # Min-step: update θ on risk + sharpness L₁ = E_{i in batch}[ω_{t+1,i} ℒ(θ_t;(x_i,y_i))] θ' = θ_t + ρ * ∇θ L₁ / ‖∇θ L₁‖ L₂ = E_{i in batch}[ω_{t+1,i} ℒ(θ';(x_i,y_i))] θ_{t+1} = θ_t − η_θ [∇θ L₁ + (∇θ L₂ − ∇θ L₁)] |
4. Theoretical Properties
Under the following conditions:
- is differentiable and -smooth in and .
- In , satisfies a Polyak–Łojasiewicz (PL) condition with constant .
- Stochastic gradients have bounded variance .
Defining , the SharpDRO training loop converges to an -stationary point: where and is the batch size. Achieving requires .
The proof leverages Danskin’s theorem to show smoothness, and constructs a potential function to analyze joint suboptimality in and descent in .
5. Experimental Methodology
Experiments are conducted using CIFAR-10, CIFAR-100, and ImageNet30. The backbone is Wide ResNet-28-2, trained with SGD (learning rate 0.03, momentum 0.9, weight decay ), 200 epochs, and batch size 128.
Corruptions: For each sample, severity is sampled from , i.e., . Four corruption types (Gaussian noise, JPEG compression, Snow, Shot noise) are applied, following [Hendrycks & Dietterich, ICLR’19]. Clean images correspond to .
Evaluation protocols: Test accuracy is reported per-severity and averaged over all .
Baselines:
- Distribution-aware: ERM, IRM, REx, GroupDRO
- Distribution-agnostic: Just‐Train‐Twice (JTT), EIIL
Hyperparameters: Perturb radius (as in SAM), with learning rates tuned via small validation set.
6. Empirical Findings
- Robustness across severities: On Gaussian-noise CIFAR-10, SharpDRO yields a absolute improvement at over the best DRO baseline, and even on clean data. Similar trends hold on CIFAR-100 and ImageNet30, and across all corruption types.
- Distribution-agnostic performance: SharpDRO with OOD selection surpasses JTT/EIIL by up to at on ImageNet30.
- Ablation studies: Removing data selection (i.e., standard SAM on the full mixture) benefits clean accuracy but significantly reduces performance on highly corrupted data (underperforming GroupDRO). Disabling sharpness minimization (GroupDRO) fails to achieve flat worst-case loss surfaces, consistently underperforming across all corruption severities.
- Sensitivity to hyperparameters: Increasing improves worst-case () accuracy with a slight reduction on clean accuracy, reflecting a trade-off between radius and flatness.
- OOD scoring: The OOD score effectively isolates high-severity samples.
- Training stability: SharpDRO produces the smallest and most uniform gradient norms across all severity levels, evidencing balanced optimization dynamics.
- Computational efficiency: The method adds negligible overhead per epoch compared to SAM, as sharpness is evaluated only for the focused hard subset rather than the entire mixture.
7. Context and Significance
SharpDRO addresses a known limitation of standard DRO by mitigating overfitting to rare, severely corrupted subsets that are prone to sharp and poorly generalizing minima. By integrating sharpness penalization targeted at worst-case distributions or severe examples, SharpDRO enables robust generalization and consistent performance gains in challenging realistic settings involving photon-limited corruptions. The method maintains strong theoretical convergence properties and is empirically validated on several large-scale benchmarks, outperforming existing robust and OOD optimization baselines (Huang et al., 2023).