Papers
Topics
Authors
Recent
Search
2000 character limit reached

Global Convergence of Gradient EM for Over-Parameterized Gaussian Mixtures

Published 6 Jun 2025 in cs.LG and stat.ML | (2506.06584v1)

Abstract: Learning Gaussian Mixture Models (GMMs) is a fundamental problem in machine learning, with the Expectation-Maximization (EM) algorithm and its popular variant gradient EM being arguably the most widely used algorithms in practice. In the exact-parameterized setting, where both the ground truth GMM and the learning model have the same number of components $m$, a vast line of work has aimed to establish rigorous recovery guarantees for EM. However, global convergence has only been proven for the case of $m=2$, and EM is known to fail to recover the ground truth when $m\geq 3$. In this paper, we consider the $\textit{over-parameterized}$ setting, where the learning model uses $n>m$ components to fit an $m$-component ground truth GMM. In contrast to the exact-parameterized case, we provide a rigorous global convergence guarantee for gradient EM. Specifically, for any well separated GMMs in general position, we prove that with only mild over-parameterization $n = \Omega(m\log m)$, randomly initialized gradient EM converges globally to the ground truth at a polynomial rate with polynomial samples. Our analysis proceeds in two stages and introduces a suite of novel tools for Gaussian Mixture analysis. We use Hermite polynomials to study the dynamics of gradient EM and employ tensor decomposition to characterize the geometric landscape of the likelihood loss. This is the first global convergence and recovery result for EM or Gradient EM beyond the special case of $m=2$.

Summary

  • The paper establishes global convergence of a gradient EM variant for over-parameterized Gaussian mixtures, proving that extra components are automatically pruned.
  • It utilizes innovative techniques including convex optimization, Hermite polynomial analysis, and tensor decomposition to guarantee parameter recovery.
  • The method achieves an O(1/T^2) convergence rate and extends to finite-sample settings with polynomial sample complexity.

This paper, "Global Convergence of Gradient EM for Over-Parameterized Gaussian Mixtures" (2506.06584), establishes the first global convergence guarantee for a gradient variant of the Expectation-Maximization (EM) algorithm when learning Gaussian Mixture Models (GMMs) with more components than the true underlying distribution. This is a significant theoretical step, as standard EM and Gradient EM are known to fail (get stuck in suboptimal local minima) in the exact-parameterized case (n=mn=m) for m3m \geq 3 components.

The core problem is learning an mm-component ground truth GMM, p(x)=i=1mπiϕ(x;μi,Id)p_*(x) = \sum_{i=1}^m \pi_i^* \phi(x; \mu_i^*, I_d), using a model GMM with nn components, p(x)=i=1nπiϕ(x;μi,Id)p(x) = \sum_{i=1}^n \pi_i \phi(x; \mu_i, I_d), where n>mn > m (over-parameterization). The paper focuses on isotropic covariances (identity matrix IdI_d).

Algorithm and Implementation

The paper analyzes a specific variant of the Gradient EM algorithm, detailed in Algorithm 1 of the paper:

Algorithm 1: Population Gradient-EM with near-optimal weight updates

  1. Input: Stepsize η\eta, iterations TT, target accuracy ϵacc\epsilon_{acc}.
  2. Initialization:
    • Means μi(0)\mu_i^{(0)} are initialized as i.i.d. samples from the observed ground truth data pp_*.
    • Mixing weights πi(0)=1/n\pi_i^{(0)} = 1/n for all i[n]i \in [n].
  3. For t=0t = 0 to TT:
    • Update Mixing Weights $\vpi$:

      $\vpi^{(t+1)} \gets \text{poly}(\epsilon_{acc})\text{-optimal solution of convex subproblem } \arg\min_{\vpi \in \Delta} \mathcal{L}(\vpi, \bm{\mu}^{(t)})$

      (where Δ\Delta is the probability simplex and L\mathcal{L} is the KL divergence loss).

    • Update Means μ\bm{\mu}:

      $\bm{\mu}^{(t+1)} \gets \bm{\mu}^{(t)} + \eta\nabla_{\bm{\mu}} Q(\bm{\vpi}^{(t+1)}, \bm{\mu}^{(t)} |\vpi^{(t+1)}, \bm{\mu}^{(t)})$

      which is equivalent to:

      $\bm{\mu}^{(t+1)} \gets \bm{\mu}^{(t)} - \eta \nabla_{\bm{\mu}} \mathcal{L}(\vpi^{(t+1)}, \bm{\mu}^{(t)})$

  4. Output: $\bm{\mu}^{(T)}, \vpi^{(T)}$

Implementation Considerations for Algorithm 1:

  • Weight Update: The update for $\vpi$ involves solving a convex optimization problem in each iteration to find a near-optimal set of weights given the current means. This can be done using standard convex optimization methods (e.g., projected gradient descent, Frank-Wolfe). The paper remarks this is a limiting case of standard Gradient EM if the learning rate for weights ηπ\eta_\pi is much larger than the learning rate for means ημ\eta_\mu.
  • Mean Update: This is a standard gradient descent step on the KL divergence loss with respect to the means, using the newly updated weights.
  • Loss Function: The analysis relies on the KL divergence $\mathcal{L}(\vpi, \bm{\mu}) = \text{KL}(p_* || p)$. Gradient EM is equivalent to gradient descent on this loss.
  • Random Initialization: Initializing means by sampling from the data is a common heuristic. The theory requires that with high probability, at least one initial model component is close to each true component, which is facilitated by over-parameterization (n>mn > m).

Main Theoretical Result and Practical Implications

The main theorem (Theorem \ref{thm: main}) states that under certain assumptions (well-separated ground truth means, boundedness, non-degeneracy):

  • If the over-parameterization is mild (nmlogm/πminn \gtrsim m\log m / \pi_{\min}^*, simplifying to n=Ω(mlogm)n = \Omega(m \log m) if weights are balanced),
  • and the step size η\eta is polynomially small,
  • then randomly initialized Gradient EM (Algorithm 1) converges globally to the ground truth GMM.
  • The convergence rate for the loss is O(1/T2)\mathcal{O}(1/T^2).
  • Crucially, for components ii that do not converge to a ground-truth mean μ\mu_\ell^*, their corresponding mixing weights πi\pi_i converge to 0. This implies an automatic pruning of redundant components, a desirable practical property.

Practical Implications:

  • Over-parameterization is beneficial: Using more components in the model than suspected in the true data can escape bad local minima, leading to better GMM fitting.
  • No explicit pruning needed: The algorithm naturally zeros out weights of unnecessary components.
  • Sample Complexity: The results extend to a finite-sample setting (Algorithm \ref{alg: online} using NN fresh samples per iteration) with polynomial sample complexity (Theorem \ref{thm: sample complexity}). The total number of samples scales roughly as ϵ5\epsilon^{-5} to achieve a population loss of ϵ\epsilon.

Analytical Techniques and Their Significance

The proof employs a two-stage analysis:

  1. Global Convergence Stage: Shows that the loss L\mathcal{L} decreases below a certain threshold ϵ0=eΘ(Δ2)\epsilon_0 = e^{-\Theta(\Delta^2)} (where Δ\Delta is the minimum separation between true means). This stage relies on showing that the gradient norm μLF{\|\nabla_{\bm{\mu}}\mathcal{L}\|}_F is sufficiently large when the loss is high.
  2. Local Convergence Stage: Shows that once Lϵ0\mathcal{L} \le \epsilon_0, the algorithm converges to an arbitrarily small error ϵ\epsilon. This phase requires a more delicate analysis of the gradient, establishing μLFL3/4/poly(){\|\nabla_{\bm{\mu}}\mathcal{L}\|}_F \gtrsim \mathcal{L}^{3/4}/\text{poly}(\dots).

Novel analytical tools are introduced/adapted for GMM analysis:

  • Hermite Polynomials: Used to analyze the dynamics of Gradient EM, particularly for handling the interaction terms and the dynamic nature of weights. The posterior probabilities ψi(x)\psi_i(x) are expanded using Hermite polynomials, which are orthogonal with respect to the Gaussian measure. This allows for an order-wise decomposition and helps in lower bounding terms like Ex[μ~(x)22]E_{x \sim \ell}[{\tilde{\mu}_\ell(x) }_2^2] (where μ~(x)=iSψi(x)(xμ)\tilde{\mu}_\ell(x) = \sum_{i \in S_\ell} \psi_i(x) (x-\mu_\ell^*)).
  • Tensor Decomposition & Test Functions for Identifiability: To prove that small KL loss implies parameter recovery (identifiability, Theorem \ref{thm: id informal}), the paper connects KL divergence to tensor decomposition.
    • Test functions (specifically, Hermite tensors gk(x)=Hek(x),vkg_k(x) = \langle \text{He}_k(x), v^{\otimes k} \rangle) are used to relate the KL divergence to differences in moments of p(x)p(x) and p(x)p_*(x).
    • The KL divergence is lower-bounded by a term involving the squared difference of expectations: L(Ep[gk(x)]Ep[gk(x)])2\mathcal{L} \gtrsim (E_{p_*}[g_k(x)] - E_p[g_k(x)])^2.
    • This difference in expectations translates to the norm of the difference between moment tensors: iπi(μi)kiπiμik2{\|\sum_i \pi_i^* (\mu_i^*)^{\otimes k} - \sum_i \pi_i \mu_i^{\otimes k}\|}_2.
    • Using a whitening matrix (derived from the second moment of pp_*), this is converted into an orthogonal tensor decomposition problem, whose solution properties (even under perturbation, i.e., small loss L>0\mathcal{L} > 0) guarantee that the model parameters are close to the true parameters (or weights go to zero for redundant components).

Addressing Key Challenges:

  • Dynamic Weights: Unlike analyses that fix weights, this work allows πi\pi_i to change and potentially go to zero. The identifiability proof (via tensor decomposition) handles this by showing collective recovery of component groups.
  • Cross Terms: Interactions between different ground-truth components (e.g., Exj[ψi(x)]E_{x \sim j}[\psi_i(x)] for iS,ji \in S_\ell, j \neq \ell) are managed. In the global phase, these are small (exp(Θ(Δ2))\exp(-\Theta(\Delta^2))). In the local phase where L\mathcal{L} can be smaller than these cross-terms, a more precise Taylor expansion and the identifiability results (showing parameter errors are L\sim \sqrt{\mathcal{L}}) are used to control them.

Assumptions

The theoretical guarantees rely on standard but important assumptions:

  • Non-degeneracy (Assumption \ref{assump: non degeneracy}): The ground truth means {μi}\{\mu_i^*\} are linearly independent in a specific sense (second moment matrix M2M_2^* has rank mm).
  • Boundedness (Assumption \ref{assump: boundedness}): Norms of ground truth means are bounded, Dminμi2DmaxD_{\min} \le {\|\mu_i^*\|}_2 \le D_{\max}.
  • Well-separatedness (Assumption \ref{assump: delta}): Ground truth means are well-separated, Δ=minijμiμj2Cmax{log(dnm/),Dmaxdn,d/ϵacc}\Delta = \min_{i \neq j} {\|\mu_i^* - \mu_j^*\|}_2 \ge C \max\{\sqrt{\log(dnm/\dots)}, \sqrt{D_{\max}\sqrt{dn}}, \sqrt{d/\epsilon_{acc}}\}. This is a strong separation condition, common in GMM theory.

Experimental Validation

The paper includes experiments (Figure \ref{fig: main}) on a 5-GMM (m=5m=5) with models of size n=5,10,15n=5, 10, 15.

  • The over-parameterized models (n=10,15n=10, 15) converge to the true GMM.
  • The exact-parameterized model (n=5n=5) gets stuck in a spurious local minimum, where one model component attempts to fit multiple true components.
  • In over-parameterized cases, redundant components see their weights πi\pi_i driven to zero, supporting the "automatic pruning" claim.

In summary, this paper provides strong theoretical evidence for the practical success of using over-parameterized models with Gradient EM for learning GMMs. It demonstrates that sufficient over-parameterization helps navigate the complex optimization landscape and achieve global convergence, with the algorithm automatically identifying and down-weighting superfluous components. The analytical techniques connecting EM to Hermite polynomials and tensor decompositions are novel in this context and may find applications in analyzing other latent variable models.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 1 tweet with 222 likes about this paper.