Towards Training Without Depth Limits: Batch Normalization Without Gradient Explosion (2310.02012v1)
Abstract: Normalization layers are one of the key building blocks for deep neural networks. Several theoretical studies have shown that batch normalization improves the signal propagation, by avoiding the representations from becoming collinear across the layers. However, results on mean-field theory of batch normalization also conclude that this benefit comes at the expense of exploding gradients in depth. Motivated by these two aspects of batch normalization, in this study we pose the following question: "Can a batch-normalized network keep the optimal signal propagation properties, but avoid exploding gradients?" We answer this question in the affirmative by giving a particular construction of an Multi-Layer Perceptron (MLP) with linear activations and batch-normalization that provably has bounded gradients at any depth. Based on Weingarten calculus, we develop a rigorous and non-asymptotic theory for this constructed MLP that gives a precise characterization of forward signal propagation, while proving that gradients remain bounded for linearly independent input samples, which holds in most practical settings. Inspired by our theory, we also design an activation shaping scheme that empirically achieves the same properties for certain non-linear activations.
- BERT: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
- RoBERTa: A robustly optimized BERT pretraining approach. arXiv preprint arXiv:1907.11692, 2019.
- Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- Larger-scale transformers for multilingual masked language modeling. arXiv preprint arXiv:2105.00572, 2021.
- Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research, 21(1):5485–5551, 2020.
- A ConvNet for the 2020s. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 11976–11986, 2022.
- ConvNeXt V2: Co-designing and scaling convnets with masked autoencoders. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 16133–16142, 2023.
- Cvt: Introducing convolutions to vision transformers. In Proceedings of the IEEE/CVF international conference on computer vision, pages 22–31, 2021.
- Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120, 2013a.
- Rank diminishing in deep neural networks. Advances in Neural Information Processing Systems, 35:33054–33065, 2022.
- Batch normalization provably avoids ranks collapse for randomly initialised deep networks. Advances in Neural Information Processing Systems, 33:18387–18398, 2020.
- Attention is not all you need: Pure attention loses rank doubly exponentially with depth. In International Conference on Machine Learning, pages 2793–2803. PMLR, 2021.
- Signal propagation in Transformers: Theoretical perspectives and the role of rank collapse. Advances in Neural Information Processing Systems, 35:27198–27211, 2022.
- Batch normalization orthogonalizes representations in deep random networks. Advances in Neural Information Processing Systems, 34:4896–4906, 2021.
- Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, pages 448–456. pmlr, 2015.
- A mean field theory of batch normalization. ICLR, 2019.
- On the impact of activation and normalization in obtaining isometric embeddings at initialization. arXiv preprint arXiv:2305.18399, 2023a.
- On bridging the gap between mean field and finite width deep random multilayer perceptron with batch normalization. International Conference on Machine Learning, 2023b.
- Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification. In Proceedings of the IEEE international conference on computer vision, pages 1026–1034, 2015.
- Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics, pages 249–256. JMLR Workshop and Conference Proceedings, 2010.
- Rapid training of deep neural networks without skip connections or normalization layers using deep kernel shaping. arXiv preprint arXiv:2110.01765, 2021.
- Sepp Hochreiter. The vanishing gradient problem during learning recurrent neural nets and problem solutions. International Journal of Uncertainty, Fuzziness and Knowledge-Based Systems, 6(02):107–116, 1998.
- Beyond BatchNorm: towards a unified understanding of normalization in deep learning. Advances in Neural Information Processing Systems, 34:4778–4791, 2021.
- Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
- Self-normalizing neural networks. Advances in neural information processing systems, 30, 2017.
- Resurrecting the sigmoid in deep learning through dynamical isometry: theory and practice. Advances in neural information processing systems, 30, 2017.
- Dynamical isometry and a mean field theory of CNNs: How to train 10,000-layer vanilla convolutional neural networks. In International Conference on Machine Learning, pages 5393–5402. PMLR, 2018.
- All you need is a good init. arXiv preprint arXiv:1511.06422, 2015.
- Unitary evolution recurrent neural networks. In International conference on machine learning, pages 1120–1128. PMLR, 2016.
- A simple way to initialize recurrent networks of rectified linear units. arXiv preprint arXiv:1504.00941, 2015.
- Recurrent orthogonal networks and long-memory tasks. In International Conference on Machine Learning, pages 2034–2042. PMLR, 2016.
- Kenji Fukumizu. Effect of batch learning in multilayer neural networks. Gen, 1(04):1E–03, 1998.
- Learning in linear neural networks: A survey. IEEE Transactions on neural networks, 6(4):837–858, 1995.
- Learning hierarchical categories in deep neural networks. In Proceedings of the Annual Meeting of the Cognitive Science Society, volume 35, 2013b.
- A mathematical theory of semantic development in deep neural networks. Proceedings of the National Academy of Sciences, 116(23):11537–11546, 2019.
- Neural networks and principal component analysis: Learning from examples without local minima. Neural networks, 2(1):53–58, 1989.
- Global optimality conditions for deep neural networks. arXiv preprint arXiv:1707.02444, 2017.
- A convergence analysis of gradient descent for deep linear neural networks. arXiv preprint arXiv:1810.02281, 2018.
- Ohad Shamir. Exponential convergence time of gradient descent for one-dimensional deep linear neural networks. In Conference on Learning Theory, pages 2691–2713. PMLR, 2019.
- Width provably matters in optimization for deep linear neural networks. In International Conference on Machine Learning, pages 1655–1664. PMLR, 2019.
- The neural covariance SDE: Shaped infinite depth-and-width networks at initialization. Advances in Neural Information Processing Systems, 35:10795–10808, 2022.
- Nonlinear random matrix theory for deep learning. Advances in neural information processing systems, 30, 2017.
- Don Weingarten. Asymptotic behavior of group integrals in the limit of infinite rank. Journal of Mathematical Physics, 19(5):999–1001, 1978.
- Benoît Collins. Moments and cumulants of polynomial random variables on unitary groups, the Itzykson-Zuber integral, and free probability. International Mathematics Research Notices, 2003(17):953–982, 2003.
- On polynomial integrals over the orthogonal group. Journal of Combinatorial Theory, Series A, 118(3):778–795, 2011.
- Integration with respect to the haar measure on unitary, orthogonal and symplectic group. Communications in Mathematical Physics, 264(3):773–795, 2006.
- The Weingarten calculus. arXiv preprint arXiv:2109.14890, 2022.
- Understanding batch normalization. Advances in neural information processing systems, 31, 2018.
- Rectified linear units improve restricted boltzmann machines. In Proceedings of the 27th international conference on machine learning (ICML-10), pages 807–814, 2010.
- Gaussian error linear units (GELUs). arXiv preprint arXiv:1606.08415, 2016.
- Rectifier nonlinearities improve neural network acoustic models. In Proc. ICML, volume 30, page 3. Atlanta, GA, 2013.
- Deep learning without shortcuts: Shaping the kernel with tailored rectifiers. arXiv preprint arXiv:2203.08120, 2022.
- Deep transformers without shortcuts: Modifying self-attention for faithful signal propagation. arXiv preprint arXiv:2302.10322, 2023.
- The shaped transformer: Attention models in the infinite depth-and-width limit. arXiv preprint arXiv:2306.17759, 2023.
- On orthogonality and learning recurrent networks with long term dependencies. In International Conference on Machine Learning, pages 3570–3578. PMLR, 2017.
- Cheap orthogonal constraints in neural networks: A simple parametrization of the orthogonal and unitary group. In International Conference on Machine Learning, pages 3794–3803. PMLR, 2019.
- Kronecker recurrent units. In International Conference on Machine Learning, pages 2380–2389. PMLR, 2018.
- Efficient orthogonal parametrisation of recurrent neural networks using Householder reflections. In International Conference on Machine Learning, pages 2401–2409. PMLR, 2017.
- On some properties of orthogonal Weingarten functions. Journal of Mathematical Physics, 50(11), 2009.
- Alexandru Meterez (5 papers)
- Amir Joudaki (7 papers)
- Francesco Orabona (62 papers)
- Alexander Immer (26 papers)
- Gunnar Rätsch (59 papers)
- Hadi Daneshmand (20 papers)