HesScale: Scalable Hessian Diagonals
- HesScale is a computational methodology for efficiently estimating the diagonal of the Hessian matrix in deep neural networks, enabling scalable second-order optimization.
- It employs a recursive backward propagation using exact final layer estimates and a Gauss–Newton variant for piecewise-linear activations to maintain linear complexity.
- Empirical benchmarks demonstrate that HesScale reduces computational overhead compared to Monte Carlo methods while improving convergence in both supervised and reinforcement learning tasks.
HesScale is a computational methodology and family of approximations for efficiently estimating the diagonal of the Hessian matrix in deep learning models. It enables scalable, accurate incorporation of second-order (curvature) information into optimization and reinforcement learning, while maintaining linear computational and memory complexity, similar to standard backpropagation. HesScale refines earlier approaches (notably, Becker and LeCun 1989) by providing a principled backward recursion for diagonal Hessian entries, utilizing exact computation in the last layer when possible, and propagating these diagonals through the network with minimal additional cost (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024).
1. Mathematical Foundations
The Hessian matrix, , encapsulates the local curvature of the loss with respect to parameters . For models with parameters, direct computation and storage of scales as , which is infeasible for modern deep networks. Practical implementations require efficient approximations to the Hessian or its diagonal. The diagonal entries, , provide per-parameter curvature estimates critical for preconditioning updates and adaptive step size selection.
HesScale builds upon the principle that, for a feed-forward network, the Hessian diagonal can be recursively propagated layer by layer if all off-diagonal terms are ignored—following the scalable approximation of BL89. The major advancement is the replacement of the final layer diagonal (for standard losses) with its analytic, exact form (e.g., for softmax-cross entropy, , where is the output probability vector) (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024).
2. HesScale Algorithmic Formulation
Let denote pre-activations, activations, weights, and .
The HesScale diagonal recursion for the pre-activation at layer is
initialized at the output layer () by its closed-form diagonal. For the weights,
A Gauss–Newton variant, "HesScaleGN" (Editor's term), omits the term for piecewise-linear activations (e.g., ReLU), yielding
This iterative scheme matches the per-layer computational structure and cost of standard backpropagation (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024).
3. Computational Complexity and Practical Implementation
HesScale achieves time and space complexity per example, in stark contrast to for the exact Hessian. At each layer, the additional computation beyond backpropagation comprises a matrix-vector product, elementwise squares, and sums with the same dimensions as the usual gradient backward pass. Empirical timings indicate that HesScale (AdaHesScale) incurs approximately the computational cost of Adam, and HesScaleGN (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024). By comparison, Monte Carlo-based unbiased approximations (e.g., AdaHessian) are or more (Elsayed et al., 5 Jun 2024).
The HesScale backward sweep can be implemented as an augmentation to the standard backward pass, requiring only the storage of diagonal Hessian estimates and gradients per parameter. For convolutional networks, the per-parameter complexity remains linear, though with larger constants due to the nature of convolutional layers.
4. Integration with Optimization and Reinforcement Learning
Incorporation into optimization is direct: the diagonal Hessian estimates serve as adaptive preconditioners for Newton-style updates,
Improved stability and generalization are empirically observed when integrated into Adam-like schemes (AdaHesScale), where exponential moving averages of gradients and squared diagonals are used: (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024)
For reinforcement learning, HesScale enables efficient step-size scaling and robust trust-region style updating. The nominal parameter update can be scaled to enforce a quadratic constraint , with minimal additional computational overhead: Empirical evidence indicates significant improvements in stability and insensitivity to learning rate choices in both simulated and real-world RL tasks (Elsayed et al., 5 Jun 2024).
5. Empirical Results and Benchmarking
Extensive experiments compare HesScale, HesScaleGN, BL89, GGN-diagonal, AdaHessian (MC sampling), and first-order methods. Key findings include:
- Approximation accuracy: HesScale achieves the lowest error to the true Hessian diagonal among scalable methods, outperforming BL89 and even high-sample MC methods (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024).
- Supervised classification: On DeepOBS benchmarks (MNIST-MLP, CIFAR-10/100-CNNs), AdaHesScale and AdaHesScaleGN converge faster and reach lower test loss than first-order (Adam, SGD) and stochastic second-order methods (Elsayed et al., 2022).
- Reinforcement learning: In MuJoCo and real-robot environments, AdaHesScale achieves higher final returns and faster learning in several tasks compared to Adam and AdaHessian. Step-size scaling with HesScale leads to robust, tuning-free optimization across wide learning rate ranges (Elsayed et al., 5 Jun 2024).
| Method | Relative Cost (Adam=1) | Diagonal Approx. Error (Normed to HesScale=1.0) |
|---|---|---|
| AdaHesScale | 2.0 | 1.0 |
| AdaHesScaleGN | 1.25 | 1.0 |
| AdaHessian (MC1) | 3.0 | >6.5 |
| BL89 | 1.8 | 1.8 |
6. Extensions, Limitations, and Future Directions
Limitations of HesScale include neglect of all off-diagonal Hessian structure, which may affect performance in settings with strong parameter coupling or highly non-linear architectures. Dependence on analytic second derivatives restricts some activation choices, though the GN variant mitigates this for piecewise-linearities (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024).
Potential extensions:
- Block-diagonal HesScale, propagating small blocks of Hessian diagonals for richer curvature.
- Hybrid stochastic–deterministic diagonals, mixing HesScale recursion with MC-based traces (e.g., Hutchinson).
- Application to natural gradient/trust region methods and as a pruning criterion (e.g., optimal brain surgeon).
- Integration into non-standard architectures (RNNs, transformers, GNNs).
Ongoing research areas involve automatic switching between HesScale and Gauss–Newton variants by layer, optimizing implementations for large-scale convolutions, and theoretical analysis of sample complexity and convergence in highly nonconvex optimization (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024).
7. Significance and Impact
HesScale addresses the fundamental bottleneck of incorporating exact second-order information into large neural network optimization, offering a practical compromise between computational feasibility and approximation fidelity. Its public empirical validation demonstrates superior performance over both traditional and stochastic diagonal approximations, with minimal additional compute in standard learning workflows.
By enabling accurate, scalable Hessian diagonal estimation, HesScale supports improved optimization speed, adaptive step size scaling, and enhanced stability in both supervised and reinforcement learning domains. Its lightweight implementation and extensibility portend adoption in future optimization frameworks and architectures that rely on fine-grained curvature adaptation (Elsayed et al., 2022, Elsayed et al., 5 Jun 2024).