Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
157 tokens/sec
GPT-4o
43 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

CR-SAM: Curvature Regularized Sharpness-Aware Minimization (2312.13555v2)

Published 21 Dec 2023 in cs.LG and cs.CV

Abstract: The capacity to generalize to future unseen data stands as one of the utmost crucial attributes of deep neural networks. Sharpness-Aware Minimization (SAM) aims to enhance the generalizability by minimizing worst-case loss using one-step gradient ascent as an approximation. However, as training progresses, the non-linearity of the loss landscape increases, rendering one-step gradient ascent less effective. On the other hand, multi-step gradient ascent will incur higher training cost. In this paper, we introduce a normalized Hessian trace to accurately measure the curvature of loss landscape on {\em both} training and test sets. In particular, to counter excessive non-linearity of loss landscape, we propose Curvature Regularized SAM (CR-SAM), integrating the normalized Hessian trace as a SAM regularizer. Additionally, we present an efficient way to compute the trace via finite differences with parallelism. Our theoretical analysis based on PAC-Bayes bounds establishes the regularizer's efficacy in reducing generalization error. Empirical evaluation on CIFAR and ImageNet datasets shows that CR-SAM consistently enhances classification performance for ResNet and Vision Transformer (ViT) models across various datasets. Our code is available at https://github.com/TrustAIoT/CR-SAM.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (54)
  1. On the properties of variational approximations of Gibbs posteriors. Journal of Machine Learning Research, 17(236): 1–41.
  2. Randomized algorithms for estimating the trace of an implicit symmetric positive semi-definite matrix. Journal of the ACM (JACM), 58(2): 1–34.
  3. Sharpness-aware minimization improves language model generalization. arXiv preprint arXiv:2110.08529.
  4. Entropy-sgd: Biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019(12): 124018.
  5. When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations. In International Conference on Learning Representations.
  6. Autoaugment: Learning augmentation policies from data. arXiv preprint arXiv:1805.09501.
  7. Randaugment: Practical automated data augmentation with a reduced search space. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, 702–703.
  8. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, 248–255. Ieee.
  9. Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552.
  10. Sharp minima can generalize for deep nets. In International Conference on Machine Learning, 1019–1028. PMLR.
  11. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
  12. Improving generalization performance using double backpropagation. IEEE transactions on neural networks, 3(6): 991–997.
  13. Efficient Sharpness-aware Minimization for Improved Training of Neural Networks. In International Conference on Learning Representations.
  14. Sharpness-aware training for free. Advances in Neural Information Processing Systems, 35: 23439–23451.
  15. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. arXiv preprint arXiv:1703.11008.
  16. Empirical study of the topology and geometry of deep networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 3762–3770.
  17. Sharpness-aware Minimization for Efficiently Improving Generalization. In International Conference on Learning Representations.
  18. On the Importance of Gradient Norm in PAC-Bayesian Bounds. Advances in Neural Information Processing Systems, 35: 16068–16081.
  19. Deep pyramidal residual networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, 5927–5935.
  20. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778.
  21. The many faces of robustness: A critical analysis of out-of-distribution generalization. In Proceedings of the IEEE/CVF International Conference on Computer Vision, 8340–8349.
  22. Benchmarking Neural Network Robustness to Common Corruptions and Perturbations. In International Conference on Learning Representations.
  23. Flat minima. Neural computation, 9(1): 1–42.
  24. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, 448–456. pmlr.
  25. An Adaptive Policy to Employ Sharpness-Aware Minimization. In The Eleventh International Conference on Learning Representations.
  26. Fantastic Generalization Measures and Where to Find Them. In International Conference on Learning Representations.
  27. On the maximum hessian eigenvalue and generalization. In Proceedings on, 51–65. PMLR.
  28. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of NAACL-HLT, 4171–4186.
  29. On large-batch training for deep learning: Generalization gap and sharp minima. arXiv preprint arXiv:1609.04836.
  30. Learning multiple layers of features from tiny images.
  31. A simple weight decay can improve generalization. Advances in neural information processing systems, 4.
  32. Visualizing the loss landscape of neural nets. Advances in neural information processing systems, 31.
  33. Fisher-rao metric, geometry, and complexity of neural networks. In The 22nd international conference on artificial intelligence and statistics, 888–896. PMLR.
  34. Towards efficient and scalable sharpness-aware minimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 12360–12370.
  35. Random Sharpness-Aware Minimization. In Oh, A. H.; Agarwal, A.; Belgrave, D.; and Cho, K., eds., Advances in Neural Information Processing Systems.
  36. Hessian regularization of deep neural networks: A novel approach based on stochastic estimators of Hessian trace. Neurocomputing, 536: 13–20.
  37. Decoupled Weight Decay Regularization. In International Conference on Learning Representations.
  38. Rethinking parameter counting in deep models: Effective dimensionality revisited. arXiv preprint arXiv:2003.02139.
  39. Make Sharpness-Aware Minimization Stronger: A Sparsified Perturbation Approach. In Oh, A. H.; Agarwal, A.; Belgrave, D.; and Cho, K., eds., Advances in Neural Information Processing Systems.
  40. Robustness via curvature regularization, and vice versa. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 9078–9086.
  41. Exploring generalization in deep learning. Advances in neural information processing systems, 30.
  42. Parallel wavenet: Fast high-fidelity speech synthesis. In International conference on machine learning, 3918–3926. PMLR.
  43. Regularizing Neural Networks by Penalizing Confident Output Distributions.
  44. Robust large margin deep neural networks. IEEE Transactions on Signal Processing, 65(16): 4265–4280.
  45. Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research, 15(1): 1929–1958.
  46. Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, 2818–2826.
  47. Bayesian Deep Learning and a Probabilistic Perspective of Generalization. In Larochelle, H.; Ranzato, M.; Hadsell, R.; Balcan, M.; and Lin, H., eds., Advances in Neural Information Processing Systems, volume 33, 4697–4708. Curran Associates, Inc.
  48. The alignment property of SGD noise and how it helps select flat minima: A stability analysis. Advances in Neural Information Processing Systems, 35: 4680–4693.
  49. When does sgd favor flat minima? a quantitative characterization via linear stability. arXiv preprint arXiv:2207.02628.
  50. Pyhessian: Neural networks through the lens of the hessian. In 2020 IEEE international conference on big data (Big data), 581–590. IEEE.
  51. Wide residual networks. arXiv preprint arXiv:1605.07146.
  52. mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412.
  53. Penalizing gradient norm for efficiently improving generalization in deep learning. In International Conference on Machine Learning, 26982–26992. PMLR.
  54. Surrogate gap minimization improves sharpness-aware training. arXiv preprint arXiv:2203.08065.
Citations (1)

Summary

We haven't generated a summary for this paper yet.