Modified Symmetric KL Divergence (MSKL)
- Modified Symmetric KL (MSKL) divergence is a tractable measure that combines forward KL with a learned proxy for reverse KL to address sample-based estimation challenges.
- It employs a normalizing flow as the main model and an energy-based proxy to facilitate joint optimization in a constrained, adaptive framework.
- Empirical results demonstrate robust training with stable convergence, effective mode recovery, and competitive performance in density estimation and image generation.
The Modified Symmetric Kullback-Leibler (MSKL) divergence is a divergence measure designed to enable tractable symmetric training of probabilistic models from samples, by combining forward Kullback-Leibler (KL) divergence and a learned proxy for the intractable reverse KL. MSKL is especially pertinent when the data-generating distribution is only available through samples, making direct computation of the reverse KL with respect to the true data distribution infeasible. Through the introduction of a flexible yet tractable proxy model, MSKL allows for both forward and reverse KL to be optimized jointly in a constrained, adaptive fashion. This framework provides a learned, data-driven alternative to fixed-weight symmetric divergences or adversarial min-max formulations in generative modeling, and has demonstrated broad empirical utility across density estimation, image generation, and simulation-based inference (Ben-Dov et al., 14 Nov 2025).
1. Mathematical Definition and Motivation
Let denote the target data distribution, be the main generative model (parameterized—for example—as a normalizing flow), and be a proxy (auxiliary) model. The classic symmetric divergence to minimize is the Jeffreys divergence: For many applications, particularly in high dimensions, only samples from are accessible, so estimating the reverse KL is generally infeasible.
The MSKL divergence addresses this by introducing a proxy model , resulting in the following divergence: In this construction, is trained both to fit and to serve as a tractable target for the reverse KL term from . As is adaptively aligned to , the second term becomes a faithful surrogate to the intractable .
2. Model Architecture: Main Model and Proxy
The utility of MSKL derives from the complementary parameterizations of the main and proxy models:
- Main Model : Implemented as a normalizing flow, ensuring that and exact i.i.d. sampling are tractable. The primary optimization target is the forward KL divergence to and the reverse KL to .
- Proxy Model : Implemented as an energy-based model, where , allowing greater representational expressivity at the cost of normalization constant estimation. The proxy is optimized both to fit through forward KL and to act as the reverse KL anchor for .
By constraining to remain close to , effectively acts as a practical approximation to .
| Model | Parameterization | Role |
|---|---|---|
| Normalizing Flow (NF) | Approximates and samples; minimizes and | |
| Energy-Based Model (EBM) | Fits and provides target |
3. Constrained Optimization Formulation
The MSKL training objective is formalized using constrained optimization to flexibly balance the contributions of each divergence term. Two principal formulations are introduced:
Proxy-Only Constrained Problem (P-Proxy):
Adaptive (Resilient) Symmetrization (P-DYN):
Introducing slack variables ,
This adaptive approach replaces fixed trade-off weights with learnable slack variables, dynamically adjusting the relative strictness of each divergence constraint throughout optimization.
Dual variables are introduced to form the Lagrangian, resulting in an empirical dual problem: where aggregates the main objective and constraint violations.
4. Training Algorithm
Optimization is performed by primal–dual gradient descent–ascent (GDA), alternating between parameter updates for via gradient descent and multipliers via ascent. The procedure respects non-negativity for slack and dual variables through projection:
1 2 3 4 5 6 7 8 9 10 11 12 |
Input: learning rates α_θ, α_ψ, α_u, α_λ
Initialize θ, ψ, u={u_f,u_r,u_p} ≥ 0, λ={λ_f,λ_r,λ_p,λ_h} ≥ 0
repeat
# 1) Sample minibatch {x_i} from π and {y_j} from p_θ
# For EBMs, estimate log Z of q_ψ by importance sampling:
# y_j ∼ p_θ; log Z ≈ LogSumExp(f_ψ(y_j)-log p_θ(y_j)) - log M
# 2) Compute L(θ,ψ,u,λ) and gradients
θ ← θ − α_θ · ∇_θ L
ψ ← ψ − α_ψ · ∇_ψ L
u ← max(0, u − α_u · ∇_u L)
λ ← max(0, λ + α_λ · ∇_λ L)
until convergence |
5. Theoretical Guarantees and Properties
- Duality Gap Bound: Under mild conditions (e.g., Lipschitz continuity and universal approximation properties of the model classes), the gap between primal and dual optima is bounded:
where measures total-variation approximation error and is finite.
- Gradient Forms:
- employs normalizing flow score gradients and the reverse KL to .
- incorporates samples from and estimates for via importance sampling.
- directly corresponds to constraint violations.
- Optimization Dynamics: Training stabilizes via joint cooperation between and , avoiding the adversarial instability typical of min-max setups as in GANs.
- Convergence: Although the objectives are nonconvex and global convergence guarantees are absent, empirical results indicate robust and stable convergence in diverse settings.
6. Empirical Evaluation and Applications
The MSKL framework has been empirically benchmarked across the following regimes (Ben-Dov et al., 14 Nov 2025):
- Synthetic 2D GMM (40 components): Adaptive MSKL achieves lower and more stable test negative log-likelihood (NLL) than fixed-weight baselines, maintaining partition function normalization () for the EBM proxy.
- Structured 2D Data (rings, moons, grid, spiral): MSKL consistently recovers all modes and attains lower held-out NLL compared to normalizing flow only or fixed-weight symmetric penalties, demonstrating robust mode coverage.
- High-dimensional Image Latents (CelebA, 100-d CAE latent): MSKL achieves Fréchet Inception Distance (FID) comparable to baseline NFs (∼48), indicating preservation of generative sample quality through the NF-EBM collaboration.
- Simulation-Based Inference (two-moons, GMM SBI): Posterior samples become statistically indistinguishable from ground truth (as per C2ST ≈ 0.5), with fewer simulator calls than forward-KL-only baselines.
| Task | Baseline | MSKL Outcome |
|---|---|---|
| 2D GMM (40 comp.) | NF, fixed-weights | Lower NLL, stable normalization |
| Manifold 2D (rings, moons, grid, spiral) | NF, fixed-weights | All modes, lowest held-out NLL |
| CelebA 100-d latent (FID) | NF | Matches baseline FID (∼48) |
| Simulation-Based Inference (SBI) | NF | Statistically indistinguishable posteriors |
Across scenarios, the adaptive MSKL mechanism consistently yields more robust and stable results than fixed penalty weights or adversarial min-max methods, supporting its applicability to density estimation, image generation, and likelihood-free inference.
7. Significance and Practical Considerations
MSKL provides a systematic and tractable methodology for symmetrizing statistical divergences via learned adaptation, circumventing the intractability of reverse KL terms when the data distribution is accessible only through samples. The introduction of a proxy model, parameterized as a flexible EBM and trained in concert with a tractable generative model (NF), enables approximate Jeffreys divergence minimization without adversarial instability or brittle penalty weighting. The primal–dual GDA training procedure further grounds the approach in constrained optimization, offering theoretical guarantees on duality gap under reasonable assumptions and demonstrating practical stability across diverse experimental regimes (Ben-Dov et al., 14 Nov 2025).