Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
143 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures (2311.00636v2)

Published 1 Nov 2023 in cs.LG and stat.ML

Abstract: The core components of many modern neural network architectures, such as transformers, convolutional, or graph neural networks, can be expressed as linear layers with $\textit{weight-sharing}$. Kronecker-Factored Approximate Curvature (K-FAC), a second-order optimisation method, has shown promise to speed up neural network training and thereby reduce computational costs. However, there is currently no framework to apply it to generic architectures, specifically ones with linear weight-sharing layers. In this work, we identify two different settings of linear weight-sharing layers which motivate two flavours of K-FAC -- $\textit{expand}$ and $\textit{reduce}$. We show that they are exact for deep linear networks with weight-sharing in their respective setting. Notably, K-FAC-reduce is generally faster than K-FAC-expand, which we leverage to speed up automatic hyperparameter selection via optimising the marginal likelihood for a Wide ResNet. Finally, we observe little difference between these two K-FAC variations when using them to train both a graph neural network and a vision transformer. However, both variations are able to reach a fixed validation metric target in $50$-$75\%$ of the number of steps of a first-order reference run, which translates into a comparable improvement in wall-clock time. This highlights the potential of applying K-FAC to modern neural network architectures.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (67)
  1. Amari, S. Natural gradient works efficiently in learning. Neural computation, 10(2), 1998.
  2. Layer normalization, 2016.
  3. Neural machine translation by jointly learning to align and translate. In ICLR, 2015.
  4. Relational inductive biases, deep learning, and graph networks. arXiv 1806.01261, 2018.
  5. Benzing, F. Gradient descent on neurons and its link to approximate second-order optimization. In ICML, 2022.
  6. Exact natural gradient in deep linear networks and its application to the nonlinear case. In NeurIPS, 2018.
  7. Pangu-weather: A 3d high-resolution model for fast and accurate global weather forecast. arXiv 2211.02556, 2022.
  8. KFAC-JAX, 2022. URL http://github.com/deepmind/kfac-jax.
  9. Practical Gauss-Newton optimisation for deep learning. In ICML, 2017.
  10. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
  11. Language models are few-shot learners. In NeurIPS, 2020.
  12. Entropy-SGD: Biasing gradient descent into wide valleys. In ICLR, 2017.
  13. Benchmarking neural network training algorithms. arXiv 2306.07179, 2023.
  14. Dangel, F. Convolutions through the lens of tensor networks. arXiv 2307.02275, 2023.
  15. BackPACK: Packing more into backprop. In ICLR, 2020.
  16. Laplace redux–effortless Bayesian deep learning. In NeurIPS, 2021.
  17. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.
  18. Adaptive subgradient methods for online learning and stochastic optimization. JMLR, 12(61), 2011.
  19. Practical Quasi-Newton methods for training deep neural networks. In NeurIPS, 2020.
  20. Graves, A. Practical variational inference for neural networks. In NIPS, 2011.
  21. Studying large language model generalization with influence functions. arXiv 2308.03296, 2023.
  22. A Kronecker-factored approximate Fisher matrix for convolution layers. In ICML, 2016.
  23. Shampoo: Preconditioned stochastic tensor optimization. In ICML, 2018.
  24. Heskes, T. On “natural” learning and pruning in multilayered perceptrons. Neural Computation, 12(4), 2000.
  25. Improving neural networks by preventing co-adaptation of feature detectors. arXiv 1207.0580, 2012.
  26. Open graph benchmark: Datasets for machine learning on graphs. In NeurIPS, 2020.
  27. Scalable marginal likelihood estimation for model selection in deep learning. In ICML, 2021.
  28. Invariance learning in deep neural networks with differentiable Laplace approximations. In NeurIPS, 2022.
  29. Optimization of graph neural networks with natural gradient descent. In IEEE BigData, 2020.
  30. Karpathy, A. nanoGPT, 2023. URL https://github.com/karpathy/nanoGPT.
  31. Fast and scalable Bayesian deep learning by weight-perturbation in Adam. In ICML, 2018.
  32. Adam: A method for stochastic optimization. In ICLR, 2015.
  33. Semi-supervised classification with graph convolutional networks. In ICLR, 2016.
  34. Limitations of the empirical Fisher approximation for natural gradient descent. In NeurIPS, 2019.
  35. Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In ICML, 2023.
  36. A comprehensive study of weight sharing in graph networks for 3d human pose estimation. In Computer Vision ECCV, 2020.
  37. Decoupled weight decay regularization. In ICLR, 2019.
  38. MacKay, D. J. Bayesian interpolation. Neural computation, 4(3), 1992.
  39. Martens, J. Deep learning via Hessian-free optimization. In ICML, 2010.
  40. Martens, J. New insights and perspectives on the natural gradient method. JMLR, 21(146), 2014.
  41. Optimizing neural networks with Kronecker-factored approximate curvature. In ICML, 2015.
  42. Kronecker-factored curvature approximations for recurrent neural networks. In ICLR, 2018.
  43. MLCommons. Algorithms Working Group, 2022. https://mlcommons.org/en/groups/research-algorithms/, Last accessed: 20.12.2022.
  44. DART: open-domain structured data record to text generation. In NAACL-HLT, 2021.
  45. PipeFisher: Efficient training of large language models using pipelining and Fisher information matrices. arXiv 2211.14133, 2022.
  46. ASDL: A unified interface for gradient preconditioning in PyTorch. arXiv 2305.04684, 2023.
  47. PyTorch: An imperative style, high-performance deep learning library. In NeurIPS, 2019.
  48. KAISA: an adaptive second-order optimizer framework for deep neural networks. In International Conference for High Performance Computing, Networking, Storage and Analysis (SC21), 2021.
  49. Language models are unsupervised multitask learners. 2019.
  50. Tensor normal training for deep learning models. In NeurIPS, 2021.
  51. A scalable Laplace approximation for neural networks. In ICLR, 2018.
  52. ImageNet large scale visual recognition challenge. IJCV, 115, 2015.
  53. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. In ICLR, 2014.
  54. The graph neural network model. IEEE Transactions on Neural Networks, 20, 2009.
  55. Schraudolph, N. N. Fast curvature matrix-vector products for second-order gradient descent. Neural computation, 14(7), 2002.
  56. Collective classification in network data. AI Magazine, 29(3), 2008.
  57. SKFAC: Training neural networks with faster Kronecker-factored approximate curvature. In CVPR, 2021.
  58. Attention is all you need. In NIPS, 2017.
  59. Wang, Y. Fisher scoring: An interpolation family and its Monte Carlo implementations. Comput. Stat. Data Anal., 54(7), 2010.
  60. Yang, G. Tensor programs i: Wide feedforward or recurrent neural networks of any architecture are gaussian processes. In NeurIPS, 2019.
  61. Sketch-based empirical natural gradient methods for deep learning. Journal of Scientific Computing, 92(3), 2022.
  62. An efficient Fisher matrix approximation method for large-scale neural network optimization. IEEE Transactions on Pattern Analysis and Machine Intelligence, 45(5), 2023.
  63. ADAHESSIAN: an adaptive second order optimizer for machine learning. In AAAI, 2021.
  64. Wide residual networks. In BMVC, 2016.
  65. Noisy natural gradient as variational inference. In ICML, 2018.
  66. Which algorithmic choices matter at which batch sizes? Insights from a noisy quadratic model. In NeurIPS, 2019a.
  67. Fixup initialization: Residual learning without normalization. In ICLR, 2019b.
Citations (15)

Summary

We haven't generated a summary for this paper yet.

X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets