A Hessian-Aware Stochastic Differential Equation for Modelling SGD (2405.18373v3)
Abstract: Continuous-time approximation of Stochastic Gradient Descent (SGD) is a crucial tool to study its escaping behaviors from stationary points. However, existing stochastic differential equation (SDE) models fail to fully capture these behaviors, even for simple quadratic objectives. Built on a novel stochastic backward error analysis framework, we derive the Hessian-Aware Stochastic Modified Equation (HA-SME), an SDE that incorporates Hessian information of the objective function into both its drift and diffusion terms. Our analysis shows that HA-SME achieves the order-best approximation error guarantee among existing SDE models in the literature, while significantly reducing the dependence on the smoothness parameter of the objective. Empirical experiments on neural network-based loss functions further validate this improvement. Further, for quadratic objectives, under mild conditions, HA-SME is proved to be the first SDE model that recovers exactly the SGD dynamics in the distributional sense. Consequently, when the local landscape near a stationary point can be approximated by quadratics, HA-SME provides a more precise characterization of the local escaping behaviors of SGD. With the enhanced approximation guarantee, we further conduct an escape time analysis using HA-SME, showcasing how it can be employed to analytically study the escaping behavior of SGD for general function classes.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Collections
Sign up for free to add this paper to one or more collections.