CR-SAM: Curvature Regularized Sharpness-Aware Minimization (2312.13555v2)
Abstract: The capacity to generalize to future unseen data stands as one of the utmost crucial attributes of deep neural networks. Sharpness-Aware Minimization (SAM) aims to enhance the generalizability by minimizing worst-case loss using one-step gradient ascent as an approximation. However, as training progresses, the non-linearity of the loss landscape increases, rendering one-step gradient ascent less effective. On the other hand, multi-step gradient ascent will incur higher training cost. In this paper, we introduce a normalized Hessian trace to accurately measure the curvature of loss landscape on {\em both} training and test sets. In particular, to counter excessive non-linearity of loss landscape, we propose Curvature Regularized SAM (CR-SAM), integrating the normalized Hessian trace as a SAM regularizer. Additionally, we present an efficient way to compute the trace via finite differences with parallelism. Our theoretical analysis based on PAC-Bayes bounds establishes the regularizer's efficacy in reducing generalization error. Empirical evaluation on CIFAR and ImageNet datasets shows that CR-SAM consistently enhances classification performance for ResNet and Vision Transformer (ViT) models across various datasets. Our code is available at https://github.com/TrustAIoT/CR-SAM.
- On the properties of variational approximations of Gibbs posteriors. Journal of Machine Learning Research, 17(236): 1–41.
- Randomized algorithms for estimating the trace of an implicit symmetric positive semi-definite matrix. Journal of the ACM (JACM), 58(2): 1–34.
- Sharpness-aware minimization improves language model generalization. arXiv preprint arXiv:2110.08529.
- Entropy-sgd: Biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019(12): 124018.
- When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations. In International Conference on Learning Representations.
- Autoaugment: Learning augmentation policies from data. arXiv preprint arXiv:1805.09501.
- Randaugment: Practical automated data augmentation with a reduced search space. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, 702–703.
- Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, 248–255. Ieee.
- Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552.
- Sharp minima can generalize for deep nets. In International Conference on Machine Learning, 1019–1028. PMLR.
- An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
- Improving generalization performance using double backpropagation. IEEE transactions on neural networks, 3(6): 991–997.
- Efficient Sharpness-aware Minimization for Improved Training of Neural Networks. In International Conference on Learning Representations.
- Sharpness-aware training for free. Advances in Neural Information Processing Systems, 35: 23439–23451.
- Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. arXiv preprint arXiv:1703.11008.
- Empirical study of the topology and geometry of deep networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 3762–3770.
- Sharpness-aware Minimization for Efficiently Improving Generalization. In International Conference on Learning Representations.
- On the Importance of Gradient Norm in PAC-Bayesian Bounds. Advances in Neural Information Processing Systems, 35: 16068–16081.
- Deep pyramidal residual networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, 5927–5935.
- Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778.
- The many faces of robustness: A critical analysis of out-of-distribution generalization. In Proceedings of the IEEE/CVF International Conference on Computer Vision, 8340–8349.
- Benchmarking Neural Network Robustness to Common Corruptions and Perturbations. In International Conference on Learning Representations.
- Flat minima. Neural computation, 9(1): 1–42.
- Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, 448–456. pmlr.
- An Adaptive Policy to Employ Sharpness-Aware Minimization. In The Eleventh International Conference on Learning Representations.
- Fantastic Generalization Measures and Where to Find Them. In International Conference on Learning Representations.
- On the maximum hessian eigenvalue and generalization. In Proceedings on, 51–65. PMLR.
- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of NAACL-HLT, 4171–4186.
- On large-batch training for deep learning: Generalization gap and sharp minima. arXiv preprint arXiv:1609.04836.
- Learning multiple layers of features from tiny images.
- A simple weight decay can improve generalization. Advances in neural information processing systems, 4.
- Visualizing the loss landscape of neural nets. Advances in neural information processing systems, 31.
- Fisher-rao metric, geometry, and complexity of neural networks. In The 22nd international conference on artificial intelligence and statistics, 888–896. PMLR.
- Towards efficient and scalable sharpness-aware minimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 12360–12370.
- Random Sharpness-Aware Minimization. In Oh, A. H.; Agarwal, A.; Belgrave, D.; and Cho, K., eds., Advances in Neural Information Processing Systems.
- Hessian regularization of deep neural networks: A novel approach based on stochastic estimators of Hessian trace. Neurocomputing, 536: 13–20.
- Decoupled Weight Decay Regularization. In International Conference on Learning Representations.
- Rethinking parameter counting in deep models: Effective dimensionality revisited. arXiv preprint arXiv:2003.02139.
- Make Sharpness-Aware Minimization Stronger: A Sparsified Perturbation Approach. In Oh, A. H.; Agarwal, A.; Belgrave, D.; and Cho, K., eds., Advances in Neural Information Processing Systems.
- Robustness via curvature regularization, and vice versa. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 9078–9086.
- Exploring generalization in deep learning. Advances in neural information processing systems, 30.
- Parallel wavenet: Fast high-fidelity speech synthesis. In International conference on machine learning, 3918–3926. PMLR.
- Regularizing Neural Networks by Penalizing Confident Output Distributions.
- Robust large margin deep neural networks. IEEE Transactions on Signal Processing, 65(16): 4265–4280.
- Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research, 15(1): 1929–1958.
- Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, 2818–2826.
- Bayesian Deep Learning and a Probabilistic Perspective of Generalization. In Larochelle, H.; Ranzato, M.; Hadsell, R.; Balcan, M.; and Lin, H., eds., Advances in Neural Information Processing Systems, volume 33, 4697–4708. Curran Associates, Inc.
- The alignment property of SGD noise and how it helps select flat minima: A stability analysis. Advances in Neural Information Processing Systems, 35: 4680–4693.
- When does sgd favor flat minima? a quantitative characterization via linear stability. arXiv preprint arXiv:2207.02628.
- Pyhessian: Neural networks through the lens of the hessian. In 2020 IEEE international conference on big data (Big data), 581–590. IEEE.
- Wide residual networks. arXiv preprint arXiv:1605.07146.
- mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412.
- Penalizing gradient norm for efficiently improving generalization in deep learning. In International Conference on Machine Learning, 26982–26992. PMLR.
- Surrogate gap minimization improves sharpness-aware training. arXiv preprint arXiv:2203.08065.