SA-SGLD: Adaptive MCMC Sampling
- SA-SGLD is an adaptive MCMC method that employs time-rescaling to adjust stepsizes based on local gradient norms for unbiased posterior sampling.
- It enhances stability and mixing in high-dimensional models, effectively managing varying curvature in Bayesian neural networks.
- The algorithm guarantees ergodicity with controlled discretization bias, offering a robust alternative to conventional SGLD and pSGLD.
SA-SGLD (Stochastic Adams-Stochastic Gradient Langevin Dynamics) refers to an adaptive, time-rescaled Markov chain Monte Carlo (MCMC) algorithm designed for efficient Bayesian posterior sampling in high-dimensional parameter spaces such as Bayesian neural networks (BNNs). SA-SGLD leverages the methodology of time rescaling, adaptively modulating the discretization stepsize based on the local geometry of the posterior landscape as measured by stochastic gradients. This mechanism enables robust, unbiased approximation of the correct invariant measure, addressing challenges of stability, mixing, and tuning complexity inherent to classical SGLD and preconditioned variants (Rajpal et al., 11 Nov 2025).
1. Foundations and Motivation
Stochastic Gradient Langevin Dynamics (SGLD) is a method for sampling from posteriors of the form , where comprises both negative log-likelihood and log-prior terms. The overdamped Langevin diffusion is
with invariant density proportional to . In practice, SGLD applies an Euler–Maruyama discretization:
where and is a mini-batch stochastic gradient. Most applications forgo vanishing stepsizes in favor of a fixed , risking a tradeoff between poor mixing in flat regions and instability in high-curvature regions.
Preconditioned SGLD (pSGLD) seeks to address curvature sensitivity by using a diagonal metric (often RMSprop-style), replacing with . However, omitting the necessary divergence correction term in high dimensions breaks detailed balance and induces bias—an issue insurmountable at scale due to the term's cost.
2. SA-SGLD: Time Rescaling and Algorithmic Design
SA-SGLD is derived via a Sundman-type time-rescaling mechanism wherein the stepsize adapts according to a monitored scalar function—conventionally the instantaneous squared norm of the local mini-batch stochastic gradient,
An auxiliary variable (“clock”) tracks a running average of :
where is a base time step and controls memory decay. The overall time rescaling is governed by a user-specified bounded, Lipschitz map (e.g., ):
This leads to the SA-SGLD update:
Because the time-rescaled SDE is a reparameterization, the invariant measure remains unchanged, eliminating the bias incurred by naive adaptation schemes.
3. Theoretical Guarantees
Under standard assumptions—, Lipschitz , dissipativity, unbiased stochastic gradients with bounded variance, and bounded, Lipschitz —SA-SGLD enjoys uniform moment bounds for the chain and ergodicity to the correct invariant measure. Key properties include:
- Uniform moment stability: There exists so that for , .
- Ergodicity and bias: Weighted averages with respect to the variable stepsizes converge almost surely, with the discretization bias for test functions , and no additional bias is introduced by the adaptation mechanism itself.
This distinguishes SA-SGLD from pSGLD and similar variants, whose failure to correct the metric change introduces persistent bias in the samples (Rajpal et al., 11 Nov 2025).
4. Algorithmic Workflow
The SA-SGLD algorithm proceeds as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
for n = 0,1,2,... # 1. Stochastic mini-batch gradient G_n ≈ ∇U(θ_n) # 2. Monitor g_n = ||G_n||^2 + δ # 3. Update exponential moving average ζ_{n+1} = ρ ζ_n + (1–ρ) (g_n/α), 𝜌 = exp(–αh) # 4. Adaptive stepsize Δt = ψ(ζ_{n+1}) · h # 5. Gaussian noise ξ ∼ N(0, I) # 6. Parameter update θ_{n+1} = θ_n – Δt·G_n + sqrt(2β^{-1}Δt)·ξ end |
The computational overhead relative to SGLD is negligible: only one additional running scalar , one norm, and a few arithmetic operations per iteration.
5. Empirical Performance
Performance is evaluated on both synthetic high-curvature potentials and practical BNNs:
- 2D toy examples (Müller–Brown, Star potential): SA-SGLD dynamically shrinks stepsizes in narrow, high-curvature regions, enabling sampling to cross barriers and traverse “funnels” that stall fixed-ε SGLD.
- BNNs on MNIST:
- Architecture: Fully-connected, 784–1200–1200–10, trained for 200 epochs (100 burn-in), reporting NLL, test accuracy, and expected calibration error (ECE).
- Under Gaussian prior, SA-SGLD matches SGLD performance; under the sharper Horseshoe prior, SA-SGLD yields lower NLL, higher accuracy, and better calibration.
- Stability is retained for large base stepsizes in SA-SGLD, with SGLD diverging beyond its stability threshold.
| Prior | Method | NLL (↓) | Accuracy (↑) | ECE (↓) |
|---|---|---|---|---|
| Gaussian | SGLD | 0.192±0.005 | 95.23%±0.06 | 5.72%±0.27 |
| SA-SGLD | 0.193±0.004 | 95.25%±0.03 | 5.76%±0.22 | |
| Horseshoe | SGLD | 0.086±0.004 | 98.03%±0.04 | 3.64%±0.11 |
| SA-SGLD | 0.080±0.003 | 98.12%±0.03 | 3.49%±0.09 |
SA-SGLD improves on SGLD in posterior quality, mixing, and stability under sharp priors (Rajpal et al., 11 Nov 2025).
6. Practical Guidance and Implementation Details
Recommendations for hyperparameters include:
- Base step : As large as feasible in flat regions, typically $0.1–1.0$.
- Monitor function: , with –.
- Exponential decay : Set (medium memory).
- Rescaling map : Common choice is with , , .
- Overhead: Minimal, dominated by backpropagation cost in gradient computation.
A plausible implication is that SA-SGLD presents a scalable and robust approach for large-scale BNN posterior sampling, particularly beneficial when the geometry of the posterior exhibits regions of disparate curvature.
7. Related Methodologies and Distinctions
In the literature, "SA" (stochastic approximation) also refers to temporal averaging in Langevin Monte Carlo for multi-armed bandit problems, as exemplified by TS-SA. In that context, stochastic approximation smooths iterates to achieve improved posterior sampling and regret bounds in the non-stationary setting (Wang et al., 6 Oct 2025).
The core distinction is that SA-SGLD specifically denotes the Sundman-adapted SGLD via time rescaling for posterior sampling in continuous parameter spaces, as constructed in Leimkuhler, Lohmann, and Whalley (2025) (Rajpal et al., 11 Nov 2025). In contrast, stochastic approximation in the TS-SA framework is applied to averaging in bandit decision processes. Both exploit adaptivity, but the formalism and objectives differ. Coordination of stepsize to posterior geometry through time rescaling, as in SA-SGLD, addresses long-standing instability and bias issues in stochastic gradient MCMC.
For SGMCMC in high dimensions, SA-SGLD supplies a provably correct, tune-free, and computationally efficient alternative to existing adaptive schemes.