Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
125 tokens/sec
GPT-4o
53 tokens/sec
Gemini 2.5 Pro Pro
42 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent (1902.06720v4)

Published 18 Feb 2019 in stat.ML and cs.LG

Abstract: A longstanding goal in deep learning research has been to precisely characterize training and generalization. However, the often complex loss landscapes of neural networks have made a theory of learning dynamics elusive. In this work, we show that for wide neural networks the learning dynamics simplify considerably and that, in the infinite width limit, they are governed by a linear model obtained from the first-order Taylor expansion of the network around its initial parameters. Furthermore, mirroring the correspondence between wide Bayesian neural networks and Gaussian processes, gradient-based training of wide neural networks with a squared loss produces test set predictions drawn from a Gaussian process with a particular compositional kernel. While these theoretical results are only exact in the infinite width limit, we nevertheless find excellent empirical agreement between the predictions of the original network and those of the linearized version even for finite practically-sized networks. This agreement is robust across different architectures, optimization methods, and loss functions.

Citations (1,018)

Summary

  • The paper shows that wide neural networks’ gradient descent dynamics align with a linearized model derived from the first-order Taylor expansion.
  • It establishes that the empirical Neural Tangent Kernel converges deterministically as width increases, enabling precise prediction of learning trajectories.
  • The work empirically validates theory across diverse architectures, confirming robust convergence and Gaussian predictive output distributions in practical settings.

Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent

The paper "Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent" addresses a fundamental challenge in deep learning: characterizing the training and generalization properties of neural networks. By exploring the dynamics of gradient descent in wide neural networks, the authors show that in the infinite width limit, the training process simplifies significantly, behaving as a linear model.

Theoretical Foundations

The authors build on existing literature, particularly leveraging insights from the Neural Tangent Kernel (NTK) introduced by Jacot et al. To provide a precise mathematical framework, they demonstrate that the gradients of wide neural networks with respect to their parameters become deterministic and follow a Gaussian process when the network width approaches infinity. This finding underpins their main theoretical contributions:

  1. Parameter Space Dynamics:
    • The paper establishes that the gradient descent dynamics of wide neural networks in parameter space are equivalent to those of a linearized model obtained through the first-order Taylor expansion with respect to initial parameters.
    • For networks trained with squared loss, the dynamics admit a closed-form solution, simplifying the analysis of learning trajectories significantly.
  2. Convergence of Empirical NTK:
    • The empirical NTK converges to its deterministic counterpart as network width increases, which allows the authors to predict the evolution of network predictions accurately.
  3. Gaussian Processes from Gradient Descent Training:
    • The authors extend the idea that the output of large-width networks during gradient descent training can be described by a Gaussian process (GP).
    • They derive explicit time-dependent expressions for this GP, highlighting differences from Bayesian posterior sampling, despite both approaches yielding draws from a Gaussian process.

Practical Implications and Empirical Validation

The theoretical results are empirically validated across multiple architectures, including fully connected networks, convolutional networks, and wide residual networks. Key findings include:

  • Parameterization Independence:
    • The dynamics described hold for both standard and NTK parameterizations, confirming that the essence of the result stems from increasing width, not specific parameterization.
    • In small learning rate regimes, the training dynamics predicted by the linearized model closely match those of the actual network.
  • Predictive Output Distributions:
    • The predictive distributions of the outputs for networks under gradient descent training remain Gaussian, and the mean and variance dynamics align accurately with theoretical predictions.
  • Convergence and Robustness:
    • The empirical evaluation shows that as the network width increases, the empirical NTK converges to the theoretical NTK, making the linearized network a reliable approximation for finite-width neural networks of practical size.

Extensions and Future Directions

The authors speculate on extending their framework to other optimizers, architectures, and loss functions. Concretely, they outline how similar theoretical results would hold for:

  • Multi-dimensional output settings.
  • Networks trained with cross-entropy loss or stochastic gradient descent with momentum.
  • Broader classes of neural network architectures such as recurrent neural networks (RNNs) and residual networks.

This research opens avenues for more in-depth analysis of neural networks' generalization capabilities based on NTK properties.

Conclusion

While the findings considerably simplify the understanding of neural network training dynamics for wide networks, they also highlight key differences between gradient descent training and Bayesian inference models. The detailed analysis of NTK dynamics offers new tools to probe the inductive biases introduced by neural network architectures and training procedures. Future research could leverage these insights to design neural networks with better trainability and generalization properties. This work provides a robust theoretical foundation while also demonstrating practical relevance by connecting abstract mathematical results with empirical observations on state-of-the-art architectures.

Github Logo Streamline Icon: https://streamlinehq.com