Papers
Topics
Authors
Recent
Search
2000 character limit reached

FMAPLS: Bayesian EM for Label Shift

Updated 30 November 2025
  • The paper introduces a Bayesian EM framework that jointly estimates target class priors and Dirichlet hyperparameters to correct label shift in supervised learning.
  • It leverages a closed-form linear surrogate function for efficient hyperparameter updates, reducing KL divergence by up to 40% in severe imbalance scenarios.
  • Empirical evaluations on datasets like CIFAR100 and ImageNet-LT demonstrate significant accuracy gains and robustness in both batch and online adaptations.

Full Maximum A Posterior Label Shift (FMAPLS) is a Bayesian framework for label-shift correction in supervised learning. Under the label shift assumption—where the class prior distribution varies between source (training) and target (test) domains, but class-conditional likelihoods remain fixed—FMAPLS enables joint and dynamic estimation of both the unknown target priors and the Dirichlet hyperparameters that govern uncertainty over these priors. The method leverages Expectation-Maximization (EM) algorithms in both batch and online variants and introduces a closed-form Linear Surrogate Function (LSF) for efficient hyperparameter updates. Empirical results demonstrate that FMAPLS and its online form outperform previous maximum a posteriori-based label-shift estimators, particularly under severe class imbalance and distributional uncertainty, in terms of Kullback–Leibler divergence and classification accuracy (Hu et al., 23 Nov 2025).

1. Problem Formulation and Generative Model

FMAPLS addresses the canonical label shift scenario with the following structure:

  • Source (training) data: (Xs,Ys)Ps(X_s, Y_s)\sim P_s, where Ps(Y)P_s(Y) is the source class prior πs=(ϵ1,...,ϵK)\pi_s=(\epsilon_1,...,\epsilon_K), and Ps(XY)P_s(X|Y) is the known class-conditional likelihood.
  • Target (test) data: (Xt,Yt)Pt(X_t, Y_t)\sim P_t, assuming Pt(XY)=Ps(XY)P_t(X|Y) = P_s(X|Y) but Pt(Y)=πtπsP_t(Y) = \pi_t\neq\pi_s.
  • Classifier: Trained on PsP_s, provides fj(x)=Ps(Y=jX=x)f_j(x) = P_s(Y=j|X=x). Under label shift, P(Y=jX=x;π)fj(x)πjP(Y=j|X=x; \pi)\propto f_j(x)\pi_j.
  • Bayesian model: Places a Dirichlet prior on π\pi with hyperparameter α=(α1,...,αK)>0\alpha=(\alpha_1,...,\alpha_K)>0:

p(πα)=Dir(π;α)=1B(α)j=1Kπjαj1,B(α)=jΓ(αj)Γ(j=1Kαj)p(\pi|\alpha)=\mathrm{Dir}(\pi;\alpha)=\frac{1}{B(\alpha)}\prod_{j=1}^K \pi_j^{\alpha_j-1},\quad B(\alpha)=\frac{\prod_{j}\Gamma(\alpha_j)}{\Gamma(\sum_{j=1}^K\alpha_j)}

Optionally, a weak prior p(α)p(\alpha) may be included.

Given NN test samples {xi}i=1N\{x_i\}_{i=1}^N, the joint posterior for parameters θ=(π,α)\theta = (\pi, \alpha) is (up to normalization):

p(π,α{xi})p(α)Dir(π;α)i=1Nj=1Kfj(xi)πjp(\pi,\alpha|\{x_i\}) \propto p(\alpha)\,\mathrm{Dir}(\pi;\alpha)\prod_{i=1}^N\sum_{j=1}^K f_j(x_i)\pi_j

The log-posterior (incomplete data) is:

L(π,α)=logp(πα)+logp(α)+i=1Nlog(j=1Kfj(xi)πj)\mathcal{L}(\pi, \alpha) = \log p(\pi|\alpha) + \log p(\alpha) + \sum_{i=1}^N \log\left(\sum_{j=1}^K f_j(x_i)\pi_j\right)

2. Batch EM Algorithm for Joint Estimation

FMAPLS employs a batch Expectation-Maximization (EM) procedure by treating the unknown test labels YiY_i as latent variables:

  • E-step: Computes posterior responsibilities

rij(t)=P(yi=jxi;π(t))=fj(xi)πj(t)k=1Kfk(xi)πk(t)r_{ij}^{(t)} = P(y_i=j|x_i;\pi^{(t)}) = \frac{f_j(x_i)\pi_j^{(t)}}{\sum_{k=1}^K f_k(x_i)\pi_k^{(t)}}

  • M-step: Separately maximizes with respect to π\pi and α\alpha using the expected complete-data log-posterior.

    • Update for π\pi (closed form):

    πj(t+1)=αj(t)1+Rj(t)k=1K(αk(t)1+Rk(t)),Rj(t)=i=1Nrij(t)\pi_j^{(t+1)} = \frac{\alpha_j^{(t)}-1 + R_j^{(t)}}{\sum_{k=1}^K \left(\alpha_k^{(t)}-1 + R_k^{(t)}\right)},\qquad R_j^{(t)} = \sum_{i=1}^N r_{ij}^{(t)} - Update for α\alpha (MAP estimate for Dirichlet):

    α(t+1)=argmaxα>0{logB(α)+j=1K(αj1)logπj(t+1)+logp(α)}\alpha^{(t+1)} = \arg\max_{\alpha>0} \left\{ -\log B(\alpha) + \sum_{j=1}^K (\alpha_j-1)\log\pi_j^{(t+1)} + \log p(\alpha)\right\}

    In standard MAPLS, this subproblem is solved via gradient ascent involving digamma functions, with significant computation if KK is large.

3. Linear Surrogate Function (LSF) Update

To overcome the computational and tuning issues of gradient-based updates for α\alpha, FMAPLS introduces a Linear Surrogate Function (LSF):

  • Key mechanism: Replace the α\alpha-subproblem by enforcing απ\alpha \propto \pi with a large constant cc:

αjc^πj,c^:=c/maxkπk\alpha_j \leftarrow \hat{c}\cdot\pi_j,\quad \hat{c} := c/\max_k\pi_k

where maxjαj=c\max_j \alpha_j = c.

  • Rationale: Direct substitution αj=c^πj\alpha_j = \hat{c}\,\pi_j yields updates that are asymptotically stationary as cc\to\infty (gradient terms O(1/c^)O(1/\hat{c}) vanish), so in practice, a suitably large cc provides accurate approximation without iterative gradients.
  • Computational benefit: The per-iteration cost drops from O(TgradK)O(T_{\text{grad}}\cdot K) (gradient ascent) to O(K)O(K) (LSF closed-form).

4. Online-FMAPLS for Streaming Data

The online-FMAPLS variant enables real-time adaptation to non-stationary or streaming data by employing stochastic approximation of sufficient statistics:

  • Stochastic responsibilities: At time step τ\tau, for incoming xτx^\tau, compute

Bjτ=fj(xτ)πjτkfk(xτ)πkτB_j^\tau = \frac{f_j(x^\tau)\pi_j^\tau}{\sum_k f_k(x^\tau)\pi_k^\tau}

Maintain running statistics SjτS_j^\tau (per-class) and s0τs_0^\tau (total), initialized as Sj0=1S_j^0=1, s00=1s_0^0=1.

  • Online update (with forgetting rate γ\gamma):

s0τ+1=(1γ)s0τ+γ1 Sjτ+1=(1γ)Sjτ+γBjτ\begin{align*} s_0^{\tau+1} &= (1-\gamma)s_0^\tau + \gamma\cdot 1 \ S_j^{\tau+1} &= (1-\gamma)S_j^\tau + \gamma\cdot B_j^\tau \end{align*}

  • M-step: Update

πjτ+1=(αjτ1)+(1γ)+γBjτk[(αkτ1)+(1γ)+γBkτ]\pi_j^{\tau+1} = \frac{(\alpha_j^\tau - 1) + (1-\gamma) + \gamma B_j^\tau}{\sum_k \left[(\alpha_k^\tau - 1) + (1-\gamma) + \gamma B_k^\tau\right]}

and set αjτ+1=c^πjτ+1\alpha_j^{\tau+1} = \hat{c}\,\pi_j^{\tau+1}.

  • Complexity: O(K)O(K) per data sample, enabling scalable, real-time operation.

5. Convergence–Accuracy Trade-Off

Under the LSF regime (αj=c^πj\alpha_j = \hat{c}\pi_j), the step size of the online algorithm is governed by cc:

  • The iterative increment satisfies πjτ+1πjτ=O(1/c^)|\pi_j^{\tau+1} - \pi_j^\tau| = O(1/\hat{c}).
  • Interpretation: Larger cc yields more accurate (less biased) stationary points, but each update becomes smaller, slowing convergence.

A practical implication is that cc must be selected to balance estimation accuracy and adaptation speed: large enough for reliability, but not so large as to impede responsiveness, especially under concept drift or shifting priors.

6. Empirical Performance Evaluation

Extensive experiments were conducted on long-tail variants of CIFAR100 (K=100K=100) and ImageNet-LT (K1000K \approx 1000):

  • Training priors: Long-tail imbalanced, controlled by ρ{0.2,0.1,0.05,0.02}\rho\in\{0.2, 0.1, 0.05, 0.02\}.
  • Test priors: Either shuffled long-tail or Dirichlet-drawn (symmetric αtest{1,1.5,2,2.5,3}\alpha_\text{test}\in\{1,1.5,2,2.5,3\}).
  • Metrics: KL divergence DKL(πtrueπest)D_{\mathrm{KL}}(\pi_{\text{true}}\|\pi_{\text{est}}) and post-shift classification accuracy.

Results, averaged over 100 runs, confirm:

  • FMAPLS reduces KL divergence by up to 40%40\% over MAPLS in settings of severe imbalance (ρ=0.02\rho=0.02) and high prior uncertainty (αtest=1\alpha_{\text{test}}=1).
  • Up to $3$–4%4\% absolute accuracy gains over MAPLS in challenging cases.
  • Online-FMAPLS achieves up to 12%12\% KL reduction over MAPLS, with only $0.5$–1.0%1.0\% relative accuracy drop versus batch FMAPLS.
  • Convergence (measured by KL) stabilizes within $2000$ iterations on CIFAR100 and $10,000$ iterations on ImageNet-LT.
Method Update Complexity KL Reduction vs MAPLS Typical Acc. Gain
FMAPLS+Gradient O(NK+TgradK)O(NK+T_\text{grad}K) up to 40%40\% $3$–4%4\% absolute
FMAPLS+LSF O(NK+K)O(NK+K) up to 40%40\% $3$–4%4\% absolute
Online-FMAPLS O(K)O(K) up to 12%12\% $0.5$–1.0%1.0\% drop

7. Implementation and Practical Guidance

FMAPLS is particularly robust in scenarios with pronounced class imbalance and uncertain or dynamically shifting target priors. The dynamic α\alpha adaptation provides a significant advantage over static-hyperparameter MAPLS approaches.

  • LSF hyperparameter cc should be chosen in the range $10$–$100$; c50c\approx 50–$100$ achieves reliable stationary points with reasonable convergence speed.
  • The forgetting rate γ\gamma for online-FMAPLS should typically fall in [0.1,0.3][0.1, 0.3], with larger values used for more rapid adaptation in highly non-stationary streams.
  • For large NKN \gg K, batch FMAPLS is recommended due to its efficiency and statistical stability; online-FMAPLS is appropriate when NN is small or streaming data is encountered.

FMAPLS offers a Bayesian-EM framework for label-shift correction, accommodating dynamic target priors, with both batch and online variants. Its combination of closed-form surrogate updates and scalable computation makes it suitable for large-scale, imbalanced, or temporally-evolving domains (Hu et al., 23 Nov 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Full Maximum A Posterior Label Shift (FMAPLS).