- The paper demonstrates that the maximal update learning rate remains width-independent for most layers except the first and last.
- It establishes that in deep ReLU MLPs the optimal learning rate scales as L^-3/2, necessitating smaller rates for deeper networks to ensure stability.
- The study employs mean-field initialization to guarantee independent pre-activations, providing actionable insights for designing effective deep network training regimes.
The paper "Depth Dependence of μP Learning Rates in ReLU MLPs" investigates the maximal update learning rate in fully connected ReLU neural networks with varying depths and fixed widths. The paper aims to understand how this learning rate, which ensures the mean squared change in pre-activations remains bounded after one gradient descent step, scales with the network's depth (L) and width (n).
Key points of the paper include:
- Network Architecture and Initialization:
- The focus is on randomly initialized, fully connected multi-layer perceptrons (MLPs) using ReLU activations.
- These networks are equipped with a mean-field weight initialization, ensuring that the hidden layer pre-activations are initially independent and identically distributed.
- Learning Rate Analysis:
- The authors explore the maximal update (μP) learning rate, denoted as the largest learning rate ensuring that the mean squared change in pre-activations after one gradient descent step remains uniformly bounded for large n and L.
- Following the work of Yang et al., the paper confirms that this maximal update learning rate does not depend on the width (n) for all network layers except the first and last ones. This result implies that for most layers, as long as the network is wide enough, the learning rate can be the same without stability concerns.
- Depth Dependence:
- A significant finding of this paper is the non-trivial dependence of the maximal update learning rate on the network's depth.
- Specifically, the paper derives that this learning rate scales as L−3/2. This relationship implies the learning rate must decrease as the network's depth increases to maintain stability in the learning dynamics.
- This scaling law provides insight into how deep architectures need to adjust learning rates compared to shallower ones, ideally aiding the design of more effective training regimes for deep networks.
- Implications and Applications:
- Understanding the depth dependence of learning rates is crucial for training very deep neural networks efficiently.
- By deriving this scaling law, the paper contributes to the broader effort of optimizing gradient-based training algorithms for deep networks, potentially benefiting various applications in machine learning where deep architectures are prominent.
In conclusion, this paper offers valuable insights into how the largest stable learning rate for gradient descent in ReLU MLPs is influenced by network depth. The finding that the maximal update learning rate scales as L−3/2 provides a clear framework for adjusting learning rates in deep networks, impacting both theoretical understanding and practical implementations.