Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
132 tokens/sec
GPT-4o
28 tokens/sec
Gemini 2.5 Pro Pro
42 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Bayesian Deep Learning Via Expectation Maximization and Turbo Deep Approximate Message Passing (2402.07366v2)

Published 12 Feb 2024 in cs.LG and cs.AI

Abstract: Efficient learning and model compression algorithm for deep neural network (DNN) is a key workhorse behind the rise of deep learning (DL). In this work, we propose a message passing based Bayesian deep learning algorithm called EM-TDAMP to avoid the drawbacks of traditional stochastic gradient descent (SGD) based learning algorithms and regularization-based model compression methods. Specifically, we formulate the problem of DNN learning and compression as a sparse Bayesian inference problem, in which group sparse prior is employed to achieve structured model compression. Then, we propose an expectation maximization (EM) framework to estimate posterior distributions for parameters (E-step) and update hyperparameters (M-step), where the E-step is realized by a newly proposed turbo deep approximate message passing (TDAMP) algorithm. We further extend the EM-TDAMP and propose a novel Bayesian federated learning framework, in which and the clients perform TDAMP to efficiently calculate the local posterior distributions based on the local data, and the central server first aggregates the local posterior distributions to update the global posterior distributions and then update hyperparameters based on EM to accelerate convergence. We detail the application of EM-TDAMP to Boston housing price prediction and handwriting recognition, and present extensive numerical results to demonstrate the advantages of EM-TDAMP.

Citations (1)

Summary

  • The paper introduces a Bayesian federated learning framework that treats training and compression as a sparse Bayesian inference problem.
  • It employs the EM-TDAMP algorithm to decompose client-side learning, enabling efficient message passing over group sparse priors.
  • Numerical experiments show faster convergence and improved accuracy on tasks like housing price prediction and handwriting recognition compared to traditional SGD methods.

Bayesian Federated Learning Framework Combining EM-TDAMP for Efficient Training and Compression

Introduction

The surge in the demand for machine learning applications across various domains, coupled with the inherent limitations of traditional training algorithms, introduces novel challenges for federated learning (FL) paradigms. Traditional FL algorithms predominately rely on stochastic gradient descent (SGD) or its variants for client-side model training, which often leads to slow convergence and susceptibility to suboptimal solutions. Moreover, the need for model compression becomes evident as we enter an era of devices with constrained computational capabilities. Addressing these concerns, the paper by Xu et al. introduces a novel approach to federated learning by proposing a Bayesian federated learning (BFL) framework that utilizes Expectation Maximization and Turbo Deep Approximate Message Passing (EM-TDAMP). The framework is tailored to expedite convergence while simultaneously achieving structured model compression.

Bayesian Federated Learning Framework

The proposed BFL framework treats the learning and compression of deep neural network (DNN) models as a sparse Bayesian inference problem. It introduces a group sparse prior to efficiently prune neurons during training. A distinct aspect of the framework is the incorporation of zero-mean Gaussian noise within the likelihood function, serving to regulate the learning process through noise variance. The expectation maximization algorithm is pivotal for updating hyperparameters in both the prior distribution and likelihood function, thus facilitating rapid convergence.

The framework consists of two main components operating in a federated learning setting:

  1. Central Server Algorithm: The central server's role is pivotal in aggregating local updates from clients and updating global hyperparameters. The server computes global posterior distribution approximations and updates hyperparameters based on the expectation maximization algorithm, effectively broadening the scope of convergence acceleration and model compression.
  2. Client-Side Learning (EM-TDAMP Algorithm): At the client level, the paper introduces the EM-TDAMP algorithm, designed to address the high computational complexity challenge in standard message passing algorithms. EM-TDAMP decomposes the learning problem into two main modules: Module A utilizes deep approximate message passing with independent priors derived from Module B, which focuses on message passing over the group sparse prior, thereby enabling efficient learning under the federated setting.

Numerical Results and Analysis

Extensive numerical experiments showcase the superiority of the EM-TDAMP algorithm over traditional SGD-based strategies in terms of convergence speed and inference performance. Specifically, for tasks such as Boston housing price prediction and handwriting recognition, EM-TDAMP demonstrates markedly faster convergence and improved prediction accuracy. The framework's efficacy is highlighted further in scenarios with high model compression ratios, representing a significant leap forward in realizing efficient federated learning on devices with limited computational resources.

Future Directions and Conclusions

The introduction of the Bayesian federated learning framework with EM-TDAMP establishes a new horizon for efficient, distributed learning and model compression. The framework's ability to accelerate convergence and achieve substantial model compression without compromising performance marks a notable advancement in federated learning research. Looking ahead, extending the applicability of this framework to a wider array of DNN architectures and exploring avenues for further optimization present exciting research prospects. In sum, the proposed BFL framework offers a potent solution to the prevailing challenges in federated learning, paving the way for more sophisticated and efficient distributed learning systems.