On Feynman--Kac training of partial Bayesian neural networks (2310.19608v3)
Abstract: Recently, partial Bayesian neural networks (pBNNs), which only consider a subset of the parameters to be stochastic, were shown to perform competitively with full Bayesian neural networks. However, pBNNs are often multi-modal in the latent variable space and thus challenging to approximate with parametric models. To address this problem, we propose an efficient sampling-based training strategy, wherein the training of a pBNN is formulated as simulating a Feynman--Kac model. We then describe variations of sequential Monte Carlo samplers that allow us to simultaneously estimate the parameters and the latent posterior distribution of this model at a tractable computational cost. Using various synthetic and real-world datasets we show that our proposed training scheme outperforms the state of the art in terms of predictive performance.
- The pseudo-marginal approach for efficient Monte Carlo computations. The Annals of Statistics, 37(2):697–725, 2009.
- Automatic differentiation of programs with discrete randomness. In Proceedings of Advances in Neural Information Processing Systems, volume 35, pages 10435–10447. Curran Associates, Inc., 2022.
- Differentiating Metropolis-Hastings to optimize intractable densities. In The 40th International Conference on Machine Learning Workshop: Differentiable Almost Everything, 2023.
- Bradley M. Bell. The iterated Kalman smoother as a Gauss–Newton method. SIAM Journal on Optimization, 4(3):626–636, 1994.
- Retrospective exact simulation of diffusion sample paths with applications. Bernoulli, 12(6):1077–1098, 2006.
- Christopher M. Bishop. Pattern recognition and machine learning. Springer, 2006.
- Variational inference: A review for statisticians. Journal of the American Statistical Association, 112(518):859–877, 2017.
- Weight uncertainty in neural network. In Proceedings of the 32nd International Conference on Machine Learning, volume 37, pages 1613–1622. PMLR, 2015.
- JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
- Blackjax: a sampling library for JAX, 2023. URL http://github.com/blackjax-devs/blackjax.
- Inference in hidden Markov models. Springer Series in Statistics. Springer-Verlag, 2005.
- On diagonal approximations to the extended Kalman filter for online training of Bayesian neural networks. In Proceedings of the 14th Asian Conference on Machine Learning Workshop. OpenReview, 2022.
- Stochastic gradient Hamiltonian Monte Carlo. In Proceedings of the 31st International Conference on Machine Learning, volume 32, pages 1683–1691. PMLR, 2014.
- Nicolas Chopin. A sequential particle filter method for static models. Biometrika, 89(3):539–551, 2002.
- An introduction to sequential Monte Carlo. Springer Series in Statistics. Springer Nature Switzerland, 2020.
- Differentiable particle filtering via entropy-regularized optimal transport. In Proceedings of the 38th International Conference on Machine Learning, volume 139, pages 2100–2111. PMLR, 2021.
- Waste-free sequential Monte Carlo. Journal of the Royal Statistical Society Series B: Statistical Methodology, 84(1):114–148, 2021.
- Bayesian deep learning via subnetwork inference. In Proceedings of the 38th International Conference on Machine Learning, volume 139, pages 2510–2521. PMLR, 2021.
- Sequential Monte Carlo methods to train neural network models. Neural Computation, 12(4):955–993, 2000.
- Sequential Monte Carlo samplers. Journal of the Royal Statistical Society Series B: Statistical Methodology, 68(3):411–436, 2006.
- Particle filters for partially observed diffusions. Journal of the Royal Statistical Society Series B: Statistical Methodology, 70(4):755–777, 2008.
- On the expressiveness of approximate inference in Bayesian neural networks. In Proceedings of Advances in Neural Information Processing Systems, volume 33, pages 15897–15908, 2020.
- Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In Proceedings of The 33rd International Conference on Machine Learning, volume 48, pages 1050–1059. PMLR, 2016.
- A global stochastic optimization particle filter algorithm. Biometrika, 109(4):937–955, 2021.
- Matrices, moments and quadrature with applications. Princeton series in applied mathematics. Princeton University Press, 2010.
- On calibration of modern neural networks. In Proceedings of the 34th International Conference on Machine Learning, volume 70, pages 1321–1330. PMLR, 2017.
- Flax: a neural network library and ecosystem for JAX, 2023. URL http://github.com/google/flax.
- Stochastic variational inference. Journal of Machine Learning Research, 14(40):1303–1347, 2013.
- Averaging weights leads to wider optima and better generalization. In Proceedings of Conference on Uncertainty in Artificial Intelligence, 2018.
- Subspace inference for Bayesian deep learning. In Proceedings of the 35th Uncertainty in Artificial Intelligence Conference, pages 1169–1179. PMLR, 2020.
- What are Bayesian neural network posteriors really like? In Proceedings of the 38th International Conference on Machine Learning, volume 139, pages 4629–4640. PMLR, 2021.
- On nonnegative unbiased estimators. The Annals of Statistics, 43(2):769–784, 2015.
- Unbiased Markov chain Monte Carlo methods with couplings. Journal of the Royal Statistical Society Series B: Statistical Methodology, 82(3):543–600, 2020.
- De-biasing particle filtering for a continuous time hidden Markov model with a Cox process observation model. arXiv preprint arXiv:2206.10478, 2022.
- Particle methods for maximum likelihood estimation in latent variable models. Statistics and Computing, 18(1):47–57, 2008.
- On particle methods for parameter estimation in state-space models. Statistical Science, 30(3):328–351, 2015.
- UCI machine learning repository, Accessed 2023. URL http://archive.ics.uci.edu/ml.
- Being Bayesian, even just a bit, fixes overconfidence in ReLU networks. In Proceedings of the 37th International Conference on Machine Learning, volume 119, pages 5436–5446. PMLR, 2020.
- The MNIST database of handwritten digits, Accessed 2023. URL http://yann.lecun.com/exdb/mnist/.
- On the utility of graphics cards to perform massively parallel simulation of advanced Monte Carlo methods. Journal of Computational and Graphical Statistics, 19(4):769–789, 2010.
- A simple baseline for Bayesian uncertainty in deep learning. In Proceedings of Advances in Neural Information Processing Systems, volume 32, 2019.
- Reinaldo A. Gomes Marques and Geir Storvik. Particle move-reweighting strategies for online inference. Technical report, University of Oslo and Statistics for Innovation Centre, 2013.
- James Martens. New insights and perspectives on the natural gradient method. Journal of Machine Learning Research, 21(146):1–76, 2020.
- Variational sequential Monte Carlo. In Proceedings of the 21th International Conference on Artificial Intelligence and Statistics, volume 84, pages 968–977. PMLR, 2018.
- Radford M. Neal. Annealed importance sampling. Statistics and Computing, 11(2):125–139, 2001.
- Benchmarking the neural linear model for regression. In Proceedings of the 2nd Symposium on Advances in Approximate Bayesian Inference, pages 1–25, 2019.
- Yann Ollivier. Online natural gradient as a Kalman filter. Electronic Journal of Statistics, 12(2):2930–2961, 2018.
- Manfred Opper. A Bayesian approach to on-line learning. In On-line learning in neural networks. Cambridge University Press, 1999.
- Challenges in Markov chain Monte Carlo for Bayesian neural networks. Statistical Science, 37(3):425–442, 2022.
- Scikit-learn: machine learning in Python. Journal of Machine Learning Research, 12:2825–2830, 2011.
- Particle approximations of the score and observed information matrix in state space models with application to parameter estimation. Biometrika, 98(1):65–80, 2011.
- System identification of nonlinear state-space models. Automatica, 47(1):39–49, 2011.
- Do Bayesian neural networks need to be fully stochastic? In Proceedings of the 26th International Conference on Artificial Intelligence and Statistics, volume 206, pages 7694–7722. PMLR, 2023.
- Training multilayer perceptrons with the extended Kalman algorithm. In Proceedings of Advances in Neural Information Processing Systems, volume 1, pages 133–140. Morgan-Kaufmann, 1988.
- Jens Sjölund. A tutorial on parametric variational inference. arXiv preprint arXiv:2301.01236, 2023.
- Bayesian learning via stochastic gradient Langevin dynamics. In Proceedings of the 28th International Conference on Machine Learning, pages 681–688. ACM, 2011.
- Bayesian deep learning and a probabilistic perspective of generalization. In Proceedings of Advances in Neural Information Processing Systems, volume 33, pages 4697–4708. Curran Associates, Inc., 2020.
- AMAGOLD: amortized Metropolis adjustment for efficient stochastic gradient MCMC. In Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics, volume 108, pages 2142–2152. PMLR, 2020.