Modular Hessian Backpropagation
- Modular Hessian Backpropagation is a paradigm that efficiently propagates second-order curvature via modular network components, enabling scalable Hessian approximations.
- It unifies techniques such as exact, stochastic, and fast diagonal methods to compute local Hessian data, thereby boosting optimization and sensitivity analysis.
- The framework integrates with autograd tools like PyTorch, facilitating rapid experimentation in large-scale machine learning, neural ODEs, and scientific computing.
Modular Hessian Backpropagation is a computational paradigm for efficiently propagating second-order derivative information (specifically, Hessian matrices or their approximations) through complex computational graphs constructed from modular sub-components such as neural network layers, ODE solvers, or general differentiable modules. This framework enables scalable, automated, and extensible computation of (possibly block-diagonal or stochastic) Hessian or curvature approximations at a per-module level, which can be leveraged for optimization, uncertainty quantification, and sensitivity analysis in large-scale machine learning and scientific computing.
1. Core Principles of Modular Hessian Backpropagation
The central concept is to generalize the backpropagation pattern familiar from first-order automatic differentiation (AD): each computational module is responsible for providing not only a backward routine for propagating gradients, but also a local rule for passing back Hessians (or their block or diagonal approximations). At each node in a computational graph, the module receives downstream curvature information—e.g., the Hessian with respect to its output or the diagonal thereof—and computes the corresponding quantities for its inputs and parameters by means of local, layer-specific matrix operations. The global Hessian or block-diagonal can be constructed by composing these local propagations.
This approach unifies several families of curvature/second-order estimators, including:
- Exact Hessian block or diagonal computation via local Hessian backpropagation (Dangel et al., 2019)
- Stochastic or unbiased estimators (Curvature Propagation) (Martens et al., 2012)
- Structured/approximate curvature (e.g., Generalized Gauss-Newton, Positive-Curvature Hessian) (Dangel et al., 2019)
- Fast diagonal approximations (e.g., HesScale) (Elsayed et al., 2022)
- Backpropagation through ODEs/computational flow models (Ciceri et al., 2023, Solvik et al., 2024)
- Efficient auto-differentiation tooling (e.g., BackPACK (Dangel et al., 2019))
2. Mathematical Formulation and Algorithmic Structure
2.1 Exact and Block-Diagonal Hessian Backpropagation
For a scalar objective , each module with input , parameters , and output implements
where is the Jacobian, is the downstream block Hessian, is the Hessian of output component w.r.t. , and is the associated gradient (Dangel et al., 2019).
For Generalized Gauss–Newton (GGN) and Positive-Curvature Hessian (PCH) approximations, the second term is omitted or projected to maintain positive semi-definiteness, leading to block-structured, PSD approximations.
2.2 Stochastic/Unbiased Hessian Estimation—Curvature Propagation
Martens et al.'s Curvature Propagation (CP) algorithm:
- Injects random i.i.d. vectors at each node with .
- Runs modular recursions:
- At each module, propagates curvature samples using square-root factors of the local Hessian contribution.
- Output yields .
- For diagonal estimation, accumulates per entry on the fly.
- Each rank-1 sample costs gradient evaluations; variance is strictly lower than that of Pearlmutter’s Hessian-vector method (Martens et al., 2012).
2.3 Efficient Diagonal Approximations—HesScale
HesScale propagates approximate Hessian diagonals via a per-layer recurrence akin to gradient backprop, but under the simplifying assumption of ignoring all off-block/diagonal entries. For layer and neuron :
and
with complexity and empirically high accuracy (Elsayed et al., 2022).
3. Implementation Approaches and Module Patterns
The modular framework allows new module types, such as batch normalization, convolutions, or custom ODE solvers, to be augmented with Hessian rules seamlessly. The required interface per module consists of:
- forward(inputs; ) — computes output and caches required tensors.
- gradback(output) — vanilla backward pass for gradients.
- hessback(, output) — modular Hessian/curvature backprop.
BackPACK (Dangel et al., 2019) demonstrates this strategy in PyTorch, using per-module hooks to intercept backprop and extract arbitrary curvature statistics (Hessian diagonals, variances, GGN/KFAC factors) with minimal user overhead.
For ODE solvers, modular Hessian backprop employs a two-level adjoint system: the first-order adjoint propagates sensitivities via a backward ODE, while the second-order adjoint propagates the Hessian according to an analogous evolution equation. This approach is integrated as a module in the graph, handling all reverse- and reverse-of-reverse passes (Ciceri et al., 2023).
4. Algorithmic and Computational Properties
| Method/Class | Memory Scaling | Computational Cost | Accuracy/Structure |
|---|---|---|---|
| Exact Hessian block | Exact block-diagonal entries | ||
| GGN, PCH | – | – | PSD, structured approximation |
| Curvature Propagation | grad | Unbiased, low-variance | |
| HesScale | Inexpensive, diagonal only | ||
| Block-Kronecker RNNs | Exact in time-unrolled limit |
For diagonal approximations such as HesScale, runtime is comparable to backpropagation, independent of full Hessian size (Elsayed et al., 2022). Curvature Propagation estimates arbitrary Hessian entries with strictly lower variance than Hessian-vector sampling for the same computational budget (Martens et al., 2012).
5. Applications Across Domains
Modular Hessian backpropagation is central to advanced optimization, uncertainty quantification, and scientific machine learning:
- Optimization: Newton-type and natural gradient methods require access to Hessians or curvature approximations, which modular Hessian backprop can supply at the per-parameter or block level (Elsayed et al., 2022, Dangel et al., 2019).
- Neural ODEs and scientific computing: Reverse-mode AD architectures for ODE and PDE solvers handle Hessian propagation via custom second-order adjoint solvers registered as modules in the computational graph (Ciceri et al., 2023, Solvik et al., 2024).
- Score matching and generative modeling: Traces or diagonals of the Hessian appear in density-estimation objectives, efficiently estimated via modular strategies (Martens et al., 2012).
- Variational data assimilation: Incremental 4D-Var schemes in weather forecasting are re-expressed in modular-differentiable frameworks using Hessian approximation and AD (backprop) rather than explicit tangent linear/adjoint coding, dramatically reducing implementation complexity (Solvik et al., 2024).
6. Empirical Evaluation and Comparative Analysis
CP achieves lower mean squared error in Hessian diagonal estimation than Hessian-vector product methods or deterministic diagonal estimators such as Becker & Le Cun (1988), especially for small sample budgets (Martens et al., 2012).
HesScale outperforms stochastic MC and earlier deterministic diagonal estimators, matching the true Hessian diagonal within small empirical error, and yields consistent improvements over AdaHessian and first-order Adam-type optimizers in wall-clock time-to-accuracy (Elsayed et al., 2022).
In data assimilation and forecasting tasks, modular Hessian backpropagation via autodiff achieves equivalent RMSE to conventional adjoint- and tangent-linear-based 4D-Var (within 1%) but at a fraction of the computational cost, scaling to ∼8,000+ state dimensions (Solvik et al., 2024).
7. Extensibility and Integration
Because modular Hessian backpropagation requires only that each module exposes its Jacobian (and optionally, second-derivative) information, it is naturally composable with autograd frameworks (e.g., PyTorch, JAX) and extension systems such as BackPACK, making it trivial to incorporate new layer types, surrogate physical models, or differentiable solvers (Dangel et al., 2019, Dangel et al., 2019).
Modern Kronecker-factored and block-diagonal schemes, as well as unbiased/sampled estimators, can all be realized within this unifying per-module interface, facilitating rapid experimentation, algorithmic diversity, and application to large-scale, heterogeneous graphs.
References:
(Martens et al., 2012): "Estimating the Hessian by Back-propagating Curvature" (Ciceri et al., 2023): "On backpropagating Hessians through ODEs" (Naumov, 2017): "Feedforward and Recurrent Neural Networks Backward Propagation and Hessian in Matrix Form" (Dangel et al., 2019): "BackPACK: Packing more into backprop" (Elsayed et al., 2022): "HesScale: Scalable Computation of Hessian Diagonals" (Dangel et al., 2019): "Modular Block-diagonal Curvature Approximations for Feedforward Architectures" (Solvik et al., 2024): "4D-Var using Hessian approximation and backpropagation applied to automatically-differentiable numerical and machine learning models"