Weighted Fisher divergence for high-dimensional Gaussian variational inference (2503.04246v1)
Abstract: Bayesian inference has many advantages for complex models. However, standard Monte Carlo methods for summarizing the posterior can be computationally demanding, and it is attractive to consider optimization-based variational approximations. Our work considers Gaussian approximations with sparse precision matrices which are tractable to optimize in high-dimensional problems. Although the optimal Gaussian approximation is usually defined as the one closest to the target posterior in Kullback-Leibler divergence, it is useful to consider other divergences when the Gaussian assumption is crude, in order to capture important features of the posterior for a given application. Our work studies the weighted Fisher divergence, which focuses on gradient differences between the target posterior and its approximation, with the Fisher and score-based divergences being special cases. We make three main contributions. First, we compare approximations for weighted Fisher divergences under mean-field assumptions for both Gaussian and non-Gaussian targets with Kullback-Leibler approximations. Second, we go beyond mean-field and consider approximations with sparse precision matrices reflecting posterior conditional independence structure for hierarchical models. Using stochastic gradient descent to enforce sparsity, we develop two approaches to minimize the weighted Fisher divergence, based on the reparametrization trick and a batch approximation of the objective. Finally, we examine the performance of our methods for examples involving logistic regression, generalized linear mixed models and stochastic volatility models.