- The paper shows that wide neural networks’ gradient descent dynamics align with a linearized model derived from the first-order Taylor expansion.
- It establishes that the empirical Neural Tangent Kernel converges deterministically as width increases, enabling precise prediction of learning trajectories.
- The work empirically validates theory across diverse architectures, confirming robust convergence and Gaussian predictive output distributions in practical settings.
Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
The paper "Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent" addresses a fundamental challenge in deep learning: characterizing the training and generalization properties of neural networks. By exploring the dynamics of gradient descent in wide neural networks, the authors show that in the infinite width limit, the training process simplifies significantly, behaving as a linear model.
Theoretical Foundations
The authors build on existing literature, particularly leveraging insights from the Neural Tangent Kernel (NTK) introduced by Jacot et al. To provide a precise mathematical framework, they demonstrate that the gradients of wide neural networks with respect to their parameters become deterministic and follow a Gaussian process when the network width approaches infinity. This finding underpins their main theoretical contributions:
- Parameter Space Dynamics:
- The paper establishes that the gradient descent dynamics of wide neural networks in parameter space are equivalent to those of a linearized model obtained through the first-order Taylor expansion with respect to initial parameters.
- For networks trained with squared loss, the dynamics admit a closed-form solution, simplifying the analysis of learning trajectories significantly.
- Convergence of Empirical NTK:
- The empirical NTK converges to its deterministic counterpart as network width increases, which allows the authors to predict the evolution of network predictions accurately.
- Gaussian Processes from Gradient Descent Training:
- The authors extend the idea that the output of large-width networks during gradient descent training can be described by a Gaussian process (GP).
- They derive explicit time-dependent expressions for this GP, highlighting differences from Bayesian posterior sampling, despite both approaches yielding draws from a Gaussian process.
Practical Implications and Empirical Validation
The theoretical results are empirically validated across multiple architectures, including fully connected networks, convolutional networks, and wide residual networks. Key findings include:
- Parameterization Independence:
- The dynamics described hold for both standard and NTK parameterizations, confirming that the essence of the result stems from increasing width, not specific parameterization.
- In small learning rate regimes, the training dynamics predicted by the linearized model closely match those of the actual network.
- Predictive Output Distributions:
- The predictive distributions of the outputs for networks under gradient descent training remain Gaussian, and the mean and variance dynamics align accurately with theoretical predictions.
- Convergence and Robustness:
- The empirical evaluation shows that as the network width increases, the empirical NTK converges to the theoretical NTK, making the linearized network a reliable approximation for finite-width neural networks of practical size.
Extensions and Future Directions
The authors speculate on extending their framework to other optimizers, architectures, and loss functions. Concretely, they outline how similar theoretical results would hold for:
- Multi-dimensional output settings.
- Networks trained with cross-entropy loss or stochastic gradient descent with momentum.
- Broader classes of neural network architectures such as recurrent neural networks (RNNs) and residual networks.
This research opens avenues for more in-depth analysis of neural networks' generalization capabilities based on NTK properties.
Conclusion
While the findings considerably simplify the understanding of neural network training dynamics for wide networks, they also highlight key differences between gradient descent training and Bayesian inference models. The detailed analysis of NTK dynamics offers new tools to probe the inductive biases introduced by neural network architectures and training procedures. Future research could leverage these insights to design neural networks with better trainability and generalization properties. This work provides a robust theoretical foundation while also demonstrating practical relevance by connecting abstract mathematical results with empirical observations on state-of-the-art architectures.