Papers
Topics
Authors
Recent
Search
2000 character limit reached

Modular Hessian Backpropagation

Updated 17 March 2026
  • Modular Hessian Backpropagation is a paradigm that efficiently propagates second-order curvature via modular network components, enabling scalable Hessian approximations.
  • It unifies techniques such as exact, stochastic, and fast diagonal methods to compute local Hessian data, thereby boosting optimization and sensitivity analysis.
  • The framework integrates with autograd tools like PyTorch, facilitating rapid experimentation in large-scale machine learning, neural ODEs, and scientific computing.

Modular Hessian Backpropagation is a computational paradigm for efficiently propagating second-order derivative information (specifically, Hessian matrices or their approximations) through complex computational graphs constructed from modular sub-components such as neural network layers, ODE solvers, or general differentiable modules. This framework enables scalable, automated, and extensible computation of (possibly block-diagonal or stochastic) Hessian or curvature approximations at a per-module level, which can be leveraged for optimization, uncertainty quantification, and sensitivity analysis in large-scale machine learning and scientific computing.

1. Core Principles of Modular Hessian Backpropagation

The central concept is to generalize the backpropagation pattern familiar from first-order automatic differentiation (AD): each computational module is responsible for providing not only a backward routine for propagating gradients, but also a local rule for passing back Hessians (or their block or diagonal approximations). At each node in a computational graph, the module receives downstream curvature information—e.g., the Hessian with respect to its output or the diagonal thereof—and computes the corresponding quantities for its inputs and parameters by means of local, layer-specific matrix operations. The global Hessian or block-diagonal can be constructed by composing these local propagations.

This approach unifies several families of curvature/second-order estimators, including:

2. Mathematical Formulation and Algorithmic Structure

2.1 Exact and Block-Diagonal Hessian Backpropagation

For a scalar objective EE, each module ii with input xx, parameters θ\theta, and output zz implements

Hx=Dz(x)⊤ Hz Dz(x)+∑kHzk(x) δzkH_x = D_z(x)^\top \, H_z \, D_z(x) + \sum_k H_{z_k}(x) \, \delta z_k

Hθ=Dz(θ)⊤ Hz Dz(θ)+∑kHzk(θ) δzkH_\theta = D_z(\theta)^\top \, H_z \, D_z(\theta) + \sum_k H_{z_k}(\theta) \, \delta z_k

where Dz(x)D_z(x) is the Jacobian, HzH_z is the downstream block Hessian, Hzk(x)H_{z_k}(x) is the Hessian of output component kk w.r.t. xx, and δzk\delta z_k is the associated gradient (Dangel et al., 2019).

For Generalized Gauss–Newton (GGN) and Positive-Curvature Hessian (PCH) approximations, the second term is omitted or projected to maintain positive semi-definiteness, leading to block-structured, PSD approximations.

2.2 Stochastic/Unbiased Hessian Estimation—Curvature Propagation

Martens et al.'s Curvature Propagation (CP) algorithm:

  • Injects random i.i.d. vectors viv_i at each node with E[vivi⊤]=IE[v_i v_i^\top]=I.
  • Runs modular recursions:
    • At each module, propagates curvature samples using square-root factors FiF_i of the local Hessian contribution.
    • Output S=Sy1S = S^1_y yields E[SS⊤]=HE[SS^\top]=H.
  • For diagonal estimation, accumulates S[i]2S[i]^2 per entry on the fly.
  • Each rank-1 sample costs ∼2\sim 2 gradient evaluations; variance is strictly lower than that of Pearlmutter’s Hessian-vector method (Martens et al., 2012).

2.3 Efficient Diagonal Approximations—HesScale

HesScale propagates approximate Hessian diagonals via a per-layer recurrence akin to gradient backprop, but under the simplifying assumption of ignoring all off-block/diagonal entries. For layer ll and neuron ii:

∂2L∂al,i2^=σ′(al,i)2∑k∂2L∂al+1,k2^Wl+1,k,i2+σ′′(al,i)∑k∂L∂al+1,kWl+1,k,i\widehat{\frac{\partial^2 L}{\partial a_{l,i}^2}} = \sigma'(a_{l,i})^2 \sum_k \widehat{\frac{\partial^2 L}{\partial a_{l+1,k}^2}} W_{l+1, k, i}^2 + \sigma''(a_{l,i}) \sum_k \frac{\partial L}{\partial a_{l+1,k}} W_{l+1, k, i}

and

∂2L∂Wl,i,j2^=∂2L∂al,i2^hl−1,j2\widehat{\frac{\partial^2 L}{\partial W_{l,i,j}^2}} = \widehat{\frac{\partial^2 L}{\partial a_{l,i}^2}} h_{l-1,j}^2

with O(n)O(n) complexity and empirically high accuracy (Elsayed et al., 2022).

3. Implementation Approaches and Module Patterns

The modular framework allows new module types, such as batch normalization, convolutions, or custom ODE solvers, to be augmented with Hessian rules seamlessly. The required interface per module consists of:

  • forward(inputs; θ\theta) — computes output and caches required tensors.
  • gradback(δ\deltaoutput) — vanilla backward pass for gradients.
  • hessback(HoutH_{out}, δ\deltaoutput) — modular Hessian/curvature backprop.

BackPACK (Dangel et al., 2019) demonstrates this strategy in PyTorch, using per-module hooks to intercept backprop and extract arbitrary curvature statistics (Hessian diagonals, variances, GGN/KFAC factors) with minimal user overhead.

For ODE solvers, modular Hessian backprop employs a two-level adjoint system: the first-order adjoint propagates sensitivities via a backward ODE, while the second-order adjoint propagates the Hessian according to an analogous evolution equation. This approach is integrated as a module in the graph, handling all reverse- and reverse-of-reverse passes (Ciceri et al., 2023).

4. Algorithmic and Computational Properties

Method/Class Memory Scaling Computational Cost Accuracy/Structure
Exact Hessian block O(n2)O(n^2) O(n2)O(n^2) Exact block-diagonal entries
GGN, PCH O(n)O(n)–O(n2)O(n^2) O(n)O(n)–O(n2)O(n^2) PSD, structured approximation
Curvature Propagation O(n)O(n) ∼2×\sim 2 \timesgrad Unbiased, low-variance
HesScale O(n)O(n) O(n)O(n) Inexpensive, diagonal only
Block-Kronecker RNNs O(t2)O(t^2) O(t2)O(t^2) Exact in time-unrolled limit

For diagonal approximations such as HesScale, runtime is comparable to backpropagation, independent of full Hessian size (Elsayed et al., 2022). Curvature Propagation estimates arbitrary Hessian entries with strictly lower variance than Hessian-vector sampling for the same computational budget (Martens et al., 2012).

5. Applications Across Domains

Modular Hessian backpropagation is central to advanced optimization, uncertainty quantification, and scientific machine learning:

  • Optimization: Newton-type and natural gradient methods require access to Hessians or curvature approximations, which modular Hessian backprop can supply at the per-parameter or block level (Elsayed et al., 2022, Dangel et al., 2019).
  • Neural ODEs and scientific computing: Reverse-mode AD architectures for ODE and PDE solvers handle Hessian propagation via custom second-order adjoint solvers registered as modules in the computational graph (Ciceri et al., 2023, Solvik et al., 2024).
  • Score matching and generative modeling: Traces or diagonals of the Hessian appear in density-estimation objectives, efficiently estimated via modular strategies (Martens et al., 2012).
  • Variational data assimilation: Incremental 4D-Var schemes in weather forecasting are re-expressed in modular-differentiable frameworks using Hessian approximation and AD (backprop) rather than explicit tangent linear/adjoint coding, dramatically reducing implementation complexity (Solvik et al., 2024).

6. Empirical Evaluation and Comparative Analysis

CP achieves lower mean squared error in Hessian diagonal estimation than Hessian-vector product methods or deterministic diagonal estimators such as Becker & Le Cun (1988), especially for small sample budgets (Martens et al., 2012).

HesScale outperforms stochastic MC and earlier deterministic diagonal estimators, matching the true Hessian diagonal within small empirical error, and yields consistent improvements over AdaHessian and first-order Adam-type optimizers in wall-clock time-to-accuracy (Elsayed et al., 2022).

In data assimilation and forecasting tasks, modular Hessian backpropagation via autodiff achieves equivalent RMSE to conventional adjoint- and tangent-linear-based 4D-Var (within 1%) but at a fraction of the computational cost, scaling to ∼8,000+ state dimensions (Solvik et al., 2024).

7. Extensibility and Integration

Because modular Hessian backpropagation requires only that each module exposes its Jacobian (and optionally, second-derivative) information, it is naturally composable with autograd frameworks (e.g., PyTorch, JAX) and extension systems such as BackPACK, making it trivial to incorporate new layer types, surrogate physical models, or differentiable solvers (Dangel et al., 2019, Dangel et al., 2019).

Modern Kronecker-factored and block-diagonal schemes, as well as unbiased/sampled estimators, can all be realized within this unifying per-module interface, facilitating rapid experimentation, algorithmic diversity, and application to large-scale, heterogeneous graphs.


References:

(Martens et al., 2012): "Estimating the Hessian by Back-propagating Curvature" (Ciceri et al., 2023): "On backpropagating Hessians through ODEs" (Naumov, 2017): "Feedforward and Recurrent Neural Networks Backward Propagation and Hessian in Matrix Form" (Dangel et al., 2019): "BackPACK: Packing more into backprop" (Elsayed et al., 2022): "HesScale: Scalable Computation of Hessian Diagonals" (Dangel et al., 2019): "Modular Block-diagonal Curvature Approximations for Feedforward Architectures" (Solvik et al., 2024): "4D-Var using Hessian approximation and backpropagation applied to automatically-differentiable numerical and machine learning models"

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Modular Hessian Backpropagation.