Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
184 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Depth Dependence of $μ$P Learning Rates in ReLU MLPs (2305.07810v1)

Published 13 May 2023 in cs.LG and stat.ML

Abstract: In this short note we consider random fully connected ReLU networks of width $n$ and depth $L$ equipped with a mean-field weight initialization. Our purpose is to study the dependence on $n$ and $L$ of the maximal update ($\mu$P) learning rate, the largest learning rate for which the mean squared change in pre-activations after one step of gradient descent remains uniformly bounded at large $n,L$. As in prior work on $\mu$P of Yang et. al., we find that this maximal update learning rate is independent of $n$ for all but the first and last layer weights. However, we find that it has a non-trivial dependence of $L$, scaling like $L{-3/2}.$

Citations (6)

Summary

  • The paper demonstrates that the maximal update learning rate remains width-independent for most layers except the first and last.
  • It establishes that in deep ReLU MLPs the optimal learning rate scales as L^-3/2, necessitating smaller rates for deeper networks to ensure stability.
  • The study employs mean-field initialization to guarantee independent pre-activations, providing actionable insights for designing effective deep network training regimes.

The paper "Depth Dependence of μμP Learning Rates in ReLU MLPs" investigates the maximal update learning rate in fully connected ReLU neural networks with varying depths and fixed widths. The paper aims to understand how this learning rate, which ensures the mean squared change in pre-activations remains bounded after one gradient descent step, scales with the network's depth (LL) and width (nn).

Key points of the paper include:

  1. Network Architecture and Initialization:
    • The focus is on randomly initialized, fully connected multi-layer perceptrons (MLPs) using ReLU activations.
    • These networks are equipped with a mean-field weight initialization, ensuring that the hidden layer pre-activations are initially independent and identically distributed.
  2. Learning Rate Analysis:
    • The authors explore the maximal update (μμP) learning rate, denoted as the largest learning rate ensuring that the mean squared change in pre-activations after one gradient descent step remains uniformly bounded for large nn and LL.
    • Following the work of Yang et al., the paper confirms that this maximal update learning rate does not depend on the width (nn) for all network layers except the first and last ones. This result implies that for most layers, as long as the network is wide enough, the learning rate can be the same without stability concerns.
  3. Depth Dependence:
    • A significant finding of this paper is the non-trivial dependence of the maximal update learning rate on the network's depth.
    • Specifically, the paper derives that this learning rate scales as L3/2L^{-3/2}. This relationship implies the learning rate must decrease as the network's depth increases to maintain stability in the learning dynamics.
    • This scaling law provides insight into how deep architectures need to adjust learning rates compared to shallower ones, ideally aiding the design of more effective training regimes for deep networks.
  4. Implications and Applications:
    • Understanding the depth dependence of learning rates is crucial for training very deep neural networks efficiently.
    • By deriving this scaling law, the paper contributes to the broader effort of optimizing gradient-based training algorithms for deep networks, potentially benefiting various applications in machine learning where deep architectures are prominent.

In conclusion, this paper offers valuable insights into how the largest stable learning rate for gradient descent in ReLU MLPs is influenced by network depth. The finding that the maximal update learning rate scales as L3/2L^{-3/2} provides a clear framework for adjusting learning rates in deep networks, impacting both theoretical understanding and practical implementations.