Riemannian Stochastic Gradient Descent
- Riemannian Stochastic Gradient Descent (RSGD) is an optimization method that generalizes classical SGD to curved spaces by enforcing manifold consistency through tangent-space projections and retractions.
- It incorporates sharpness-aware minimization (RSAM), adjusting local adversarial perturbations within the manifold for improved robustness and convergence.
- Empirical evaluations show that RSGD variants, including Monge SAM, outperform standard SGD and Euclidean SAM in classification, pretraining, and multi-modal alignment tasks.
Riemannian Stochastic Gradient Descent (RSGD) encompasses a family of optimization algorithms that generalize the classical stochastic gradient descent paradigm to settings where the model parameters live on a Riemannian manifold rather than flat Euclidean space. This framework is essential when the feasible set or the natural geometry of the problem is non-Euclidean, as in learning on the Stiefel or Grassmann manifold, or when parameter space symmetry/constraints are most naturally enforced with manifold structure. Recent advances have extended sharpness-aware minimization (SAM)—originally formulated for flat, Euclidean spaces—to the Riemannian context, resulting in Riemannian Sharpness-Aware Minimization (RSAM). Notable modern instantiations include Monge SAM (M-SAM) and geometric approaches tailored for learning on manifolds.
1. Mathematical Foundations of Riemannian SAM
Let be a -dimensional Riemannian manifold, embedding the parameter set of a model with . The loss function, typically sample-averaged,
is defined for on . At a given iterate , the sharpness-aware objective is
where is the tangent space at 0, 1 is the Riemannian exponential map (or a computationally tractable retraction 2), and 3 is the standard Euclidean norm in 4 as a subset of 5 (Truong et al., 2023).
The optimization of this objective requires machinery unique to Riemannian geometry—such as projections onto tangent spaces, manipulation through retraction/exponential maps, and computation of Riemannian gradients and transports.
2. RSGD Algorithmic Structure: Teleportation and Descent
Riemannian SAM (RSAM) generalizes SAM via a two-step process at each iteration 6:
- Inner maximization ("Teleportation" step):
- Compute the Riemannian gradient 7—the projection of the Euclidean gradient onto 8.
- Solve the adversarial ascent direction in the tangent space, typically via
9
where 0 is a metric-adjustment matrix (often 1 or 2), and 3 projects back to the tangent space (Truong et al., 2023). - Teleport the parameters onto the manifold: 4.
- Outer minimization (Riemannian descent):
- Compute the Riemannian gradient at the perturbed point.
- Update by retracting along the negative gradient:
5
Complete pseudocode is specified explicitly in (Truong et al., 2023). This framework accommodates both exact and approximate solutions to the inner maximization; the latter (e.g., direct relaxation using 6 as identity) is 7 faster with negligible loss in empirical accuracy.
3. Geometric and Theoretical Properties
RSAM inherits and augments critical geometric properties relative to its Euclidean counterparts:
- Manifold Consistency: All updates and perturbations are confined to 8, leveraging projection, retraction, and manipulation in tangent spaces.
- Reparametrization Invariance (M-SAM): Monge SAM generalizes SAM by introducing a loss-induced Riemannian metric 9 (the Monge metric). Adversarial directions and steps are computed with respect to this geometry, yielding invariance under smooth reparametrizations: if 0 is a diffeomorphism, the metric and steps transform covariantly, and the constrained step size is preserved under change of variables (Jacobsen et al., 12 Feb 2025).
- Generalization Bound: Under compactness, 1-Lipschitz loss, and controlled retraction error, RSAM yields a generalization bound
2
where 3. The dependence on intrinsic dimension 4 (as opposed to ambient 5 in Euclidean SAM) underpins improved statistical guarantees in manifold-constrained problems (Truong et al., 2023).
- Critical Point Behavior: In RSAM and especially in M-SAM, the method is less prone than Euclidean SAM to become trapped at suboptimal saddle points, due to its step size automodulation. The effective radius
6
endows M-SAM with self-damping properties that make it more robust to hyperparameter choices and gradient magnitude (Jacobsen et al., 12 Feb 2025).
4. Implementation Details and Computational Overhead
All Riemannian SAM-type methods require no explicit computation of Hessians or high-rank matrix inverses. Key operations per iteration include:
- One forward and one backward pass for the base point.
- Computation of Riemannian (projected) gradients and norm/scaling.
- One additional forward-backward for the adversarial direction.
- Retraction/exponential-map computations, and (optionally) tangent-space projections.
The runtime increase relative to standard SAM is marginal: RSAM incurs 7 additional overhead per epoch over Euclidean SAM (both 8 slower than vanilla SGD, due to the double-backprop requirement). All empirical evidence in the literature uses tractable choices of retraction and projection, such as the identity map in unconstrained 9, or projection-based retraction for embedded manifolds (Truong et al., 2023, Jacobsen et al., 12 Feb 2025).
5. Empirical Performance and Benchmarks
Empirical investigations show RSAM and its Monge metric variant outperform standard SGD and Euclidean SAM in multiple settings, particularly:
- Supervised Classification (ResNet50, CIFAR-100): RSAM achieves 77.78% top-1 accuracy compared to 75.04% for SAM and 74.62% for SGD at identical hyperparameters (Truong et al., 2023).
- Contrastive Pretraining (SupCon + RSAM): Linear evaluation after SupCon pretraining and RSAM achieves 81.62% accuracy (vs. 76.73% for SAM and 75.29% for SGD) (Truong et al., 2023).
- Robustness to Hyperparameter Choices: On CIFAR-10, both SAM and M-SAM can escape local minima inaccessible to SGD, but only M-SAM avoids catastrophic divergence with large 0, demonstrating its conservative step size adaptation (Jacobsen et al., 12 Feb 2025).
- Multi-modal Representation Alignment (CLIP Fine-tuning): On WIT/MS-COCO, M-SAM achieved higher mutual-kNN similarity (1) than SAM (2) or SGD/Adam (3), exhibiting less sensitivity to 4 (Jacobsen et al., 12 Feb 2025).
- Ablations: Empirical evaluations demonstrate that approximate inner maximization for adversarial perturbations is nearly as accurate as exact projection, and that choice of metric-adjustment matrix has limited effect on final performance. RSAM is robust to auxiliary constraints, such as orthogonality in autoencoders, where Euclidean regularization strategies struggle (Truong et al., 2023).
6. Limitations and Open Directions
Current limitations of Riemannian SAM-type methods include the following:
- Approximate Inner Maximization: The practical implementations use heuristics for the inner maximization in the tangent space; more accurate manifold-specific solvers (e.g., geodesic searches) remain an open area (Truong et al., 2023).
- Extension to Quotient Manifolds: While existing RSAM treats embedded manifolds, extension to quotient structures (e.g., the Grassmann manifold) with explicit parallel transport between tangent spaces poses technical challenges.
- Second-order Corrections: Incorporating Riemannian analogues of curvature- or sharpness-aware corrections beyond first-order remains unexplored (Truong et al., 2023).
- Empirical Overheads: Although the per-iteration cost is only marginally higher than Euclidean SAM, the overall training cost remains a barrier compared to vanilla SGD, particularly for large-scale applications.
7. Comparison of Methodological Variants
The following table summarizes core distinctions and similarities between Monge SAM, RSAM, and classical SAM, as discussed in the referenced works:
| Method | Geometry | Reparametrization Invariant | Manifold Support |
|---|---|---|---|
| SAM | Euclidean | No | Flat 5 only |
| Monge SAM (M-SAM) | Loss-induced | Yes | Any model, 6 |
| RSAM | Manifold | Yes (by construction) | General 7 |
Monge SAM is a special case of manifold-aware sharpness-aware minimization where the metric is induced by the loss graph in ambient space (Jacobsen et al., 12 Feb 2025), while RSAM is a general framework for optimizing over arbitrary Riemannian manifolds (Truong et al., 2023).