Online Importance Sampling for Stochastic Gradient Optimization (2311.14468v3)
Abstract: Machine learning optimization often depends on stochastic gradient descent, where the precision of gradient estimation is vital for model performance. Gradients are calculated from mini-batches formed by uniformly selecting data samples from the training dataset. However, not all data samples contribute equally to gradient estimation. To address this, various importance sampling strategies have been developed to prioritize more significant samples. Despite these advancements, all current importance sampling methods encounter challenges related to computational efficiency and seamless integration into practical machine learning pipelines. In this work, we propose a practical algorithm that efficiently computes data importance on-the-fly during training, eliminating the need for dataset preprocessing. We also introduce a novel metric based on the derivative of the loss w.r.t. the network output, designed for mini-batch importance sampling. Our metric prioritizes influential data points, thereby enhancing gradient estimation accuracy. We demonstrate the effectiveness of our approach across various applications. We first perform classification and regression tasks to demonstrate improvements in accuracy. Then, we show how our approach can also be used for online data pruning by identifying and discarding data samples that contribute minimally towards the training loss. This significantly reduce training time with negligible loss in the accuracy of the model.
- Variance reduction in sgd by distributed importance sampling. arXiv preprint arXiv:1511.06481, 2015.
- Fast kernel classifiers with online and active learning. Journal of Machine Learning Research, 6(54):1579–1619, 2005. URL http://jmlr.org/papers/v6/bordes05a.html.
- One backward from ten forward, subsampling for large-scale deep learning. arXiv preprint arXiv:2104.13114, 2021.
- An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
- A study of gradient variance in deep learning. arXiv preprint arXiv:2007.04532, 2020.
- Variance-reduced methods for machine learning. Proceedings of the IEEE, 108(11):1968–1983, 2020.
- Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016.
- Richard Zou Horace He. functorch: Jax-like composable function transforms for pytorch. https://github.com/pytorch/functorch, 2021.
- Accelerating stochastic gradient descent using predictive variance reduction. Advances in neural information processing systems, 26, 2013.
- Not all samples are created equal: Deep learning with importance sampling. In Jennifer Dy and Andreas Krause (eds.), Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pp. 2525–2534. PMLR, 10–15 Jul 2018. URL https://proceedings.mlr.press/v80/katharopoulos18a.html.
- Biased importance sampling for deep neural network training. ArXiv, abs/1706.00043, 2017. URL https://api.semanticscholar.org/CorpusID:38367260.
- Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
- Learning multiple layers of features from tiny images. 2009.
- Online batch selection for faster training of neural networks. arXiv preprint arXiv:1511.06343, 2015.
- Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
- Stochastic gradient descent, weighted sampling, and the randomized kaczmarz algorithm. In Z. Ghahramani, M. Welling, C. Cortes, N. Lawrence, and K.Q. Weinberger (eds.), Advances in Neural Information Processing Systems, volume 27. Curran Associates, Inc., 2014. URL https://proceedings.neurips.cc/paper_files/paper/2014/file/f29c21d4897f78948b91f03172341b7b-Paper.pdf.
- Automated flower classification over a large number of classes. In 2008 Sixth Indian conference on computer vision, graphics & image processing, pp. 722–729. IEEE, 2008.
- Pointnet: Deep learning on point sets for 3d classification and segmentation. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 652–660, 2017.
- Low: Training deep neural networks by learning optimal sample weights. Pattern Recognition, 110:107585, 2021.
- Prioritized experience replay. arXiv preprint arXiv:1511.05952, 2015.
- Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014.
- Implicit neural representations with periodic activation functions. Advances in neural information processing systems, 33:7462–7473, 2020.
- Lieven Vandenberghe. The cvxopt linear and quadratic cone program solvers. Online: http://cvxopt. org/documentation/coneprog. pdf, 2010.
- Accelerating deep neural network training with inconsistent stochastic gradient descent. Neural Networks, 93:219–229, 2017.
- 3d shapenets: A deep representation for volumetric shapes. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1912–1920, 2015.
- Determinantal point processes for mini-batch diversification. arXiv preprint arXiv:1705.00607, 2017.
- Active mini-batch sampling using repulsive point processes. In Proceedings of the AAAI conference on Artificial Intelligence, volume 33, pp. 5741–5748, 2019.
- Adaselection: Accelerating deep learning training through data subsampling. arXiv preprint arXiv:2306.10728, 2023.
- Stochastic optimization with importance sampling for regularized loss minimization. In Francis Bach and David Blei (eds.), Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, pp. 1–9, Lille, France, 07–09 Jul 2015. PMLR. URL https://proceedings.mlr.press/v37/zhaoa15.html.