ByPE-VAE: Bayesian Pseudocoresets in VAEs
- The paper introduces ByPE-VAE, which replaces a full-data mixture prior with a learned pseudocoreset prior to reduce computational cost while maintaining high performance.
- ByPE-VAE employs an alternating training procedure that dynamically updates both VAE parameters and pseudocoreset elements using efficient gradient-based KL minimization.
- Empirical results on datasets like MNIST, CIFAR-10, and CelebA demonstrate that ByPE-VAE improves density estimation, latent clustering, and training efficiency compared to traditional VAE variants.
ByPE-VAE (Bayesian Pseudocoresets Exemplar Variational Autoencoder) is a deep generative modeling framework that extends the standard variational autoencoder (VAE) architecture by introducing a data-dependent mixture prior based on a small, learned pseudocoreset, rather than the full dataset. By dynamically optimizing both the pseudocoreset elements and their mixture weights to closely match the ideal full-data-based mixture prior in Kullback-Leibler (KL) divergence, ByPE-VAE combines the expressive power of exemplar-based priors with significant computational efficiency and implicit regularization against overfitting. The approach demonstrates improvements across density estimation, unsupervised representation learning, and generative data augmentation tasks compared to traditional VAE variants and other advanced prior formulations (Ai et al., 2021).
1. Model Formulation and Theoretical Foundations
Given a sample from a dataset and latent variable , the standard VAE maximizes the evidence lower bound (ELBO):
where is typically the standard normal prior.
Exemplar VAE [Norouzi et al. 2020] replaces the standard normal prior with a data-dependent mixture:
where each is a Gaussian whose mean and variance are produced by a learned network.
ByPE-VAE introduces a further innovation: it replaces the sum over all data points (computationally prohibitive for large ) by a mixture over a learned, weighted pseudocoreset () with nonnegative weights ():
This pseudocoreset mixture prior is trained to minimize —i.e., to approximate the full data-dependent mixture.
The explicit ELBO optimized in ByPE-VAE is (dropping additive constants):
2. Pseudocoreset Construction and KL Minimization
The pseudocoreset is chosen to minimize the KL divergence between the pseudocoreset mixture prior and the (intractable) full-data mixture prior:
The KL divergence admits a gradient representation via covariances under the posterior:
- Gradient w.r.t. pseudocoreset points :
- Gradient w.r.t. weights :
Specializing to Gaussian , the gradients are efficiently computed by backpropagation through the generator network for means.
In practice, stochastic gradients are formed by sampling and minibatches from , enabling scalable KL minimization even for moderate and large .
3. Alternating Training Procedure
ByPE-VAE employs a two-step alternating optimization:
- VAE Parameter Step: With fixed, update encoder and decoder parameters by maximizing the ELBO above over minibatches.
- Pseudocoreset Step: Every epochs (empirically, suffices), update both the pseudocoreset points and weights using stochastic estimates of the KL gradients as described.
The initialization starts by uniformly sampling pseudopoints from , with equal weights (). Pseudocoreset updates are projected to maintain nonnegativity and preserve the total weight constraint. Overall complexity is dominated by the VAE step (which is amortized for small ) and by the coreset step when is small.
4. Experimental Evaluation and Results
ByPE-VAE is evaluated on density estimation, representation learning, and generative data augmentation over several standard datasets: Dynamic MNIST, Fashion MNIST, CIFAR-10, and CelebA. Architectures involve both MLP and CNN backbones, with latent dimension and (or for CelebA).
Key empirical findings include:
- Density Estimation: On Dynamic MNIST, Fashion MNIST, and CIFAR-10, ByPE-VAE achieves the best test negative log-likelihood (NLL), outperforming VAE+Gaussian, VAE+VampPrior, and Exemplar VAE mixtures.
- Training Efficiency: ByPE-VAE reduces training time per epoch by approximately relative to Exemplar VAE—13.2 s vs. 35.5 s (Dynamic MNIST, vs. components).
- Latent Representations: t-SNE visualization of MNIST embeddings reveals tighter clusters and improved inter-class separation. kNN classification accuracy (on learned codes) is consistently higher under ByPE-VAE for all .
- Data Augmentation: Discriminative test error on permutation-invariant MNIST with augmented samples: ByPE-VAE achieves (posterior sampling), (prior) vs. for Exemplar.
- Ablation Analyses: Performance remains stable as pseudocoreset update interval increases up to 50; ByPE-VAE outperforms Exemplar variants with equal coreset size.
Empirical Comparisons
| Method | Dyn MNIST (NLL) | Fash MNIST (NLL) | CIFAR-10 (NLL) | Training Time/Epoch (Dyn MNIST, s) |
|---|---|---|---|---|
| VAE+Gauss prior | 24.41 | 21.43 | 72.21 | — |
| VAE+VampPrior | 23.65 | 20.87 | 71.97 | 13.0 |
| VAE+Exemplar | 23.83 | 21.00 | 72.55 | 35.5 |
| ByPE-VAE | 23.61 | 20.85 | 71.91 | 13.2 |
This demonstrates that ByPE-VAE matches or exceeds benchmarks on both data efficiency and sample quality (Ai et al., 2021).
5. Analysis of Computational and Statistical Properties
Computational Efficiency: ByPE-VAE's computational savings result from replacing a full sum over exemplars in the prior with a sum over learned pseudopoints. Because pseudocoreset updates are amortized (run every epochs), the additional cost is negligible relative to the overall VAE training cycle.
Regularization and Overfitting Avoidance: Pseudocoreset construction acts as an implicit regularizer. Only a small set of points and their weights are used to assemble the mixture prior, mitigating risks of overfitting or memorization associated with full-dataset-dependent priors.
Optimization and Sensitivity: The KL-based pseudocoreset adaptation is sensitive to hyperparameters , , and , as well as initialization; if is too small or initialization is poor, the coreset may underfit the true data distribution.
6. Limitations and Possible Extensions
While ByPE-VAE achieves substantial improvements, limitations remain:
- Coreset Update Overhead: The requisite stochastic updates for pseudocoreset points and weights still incur nontrivial cost (proportional to positive integers , , ).
- Hyperparameter Sensitivity: Careful selection of , , and learning rates is necessary to avoid degraded performance.
- Potential Underfitting: If the coreset size is too small, the induced prior may not adequately match the full-data mixture, especially for complex data or highly multi-modal latent structures.
Proposed extensions and areas for future work include:
- Amortized pseudocoresets: replacing static pseudopoints with a generative model of .
- Hierarchical pseudocoreset priors: employing clusters within clusters to further enhance prior expressiveness.
- Application to advanced likelihood models: incorporating pseudocoresets into normalizing flow architectures.
- Dynamic coreset adjustment: growing or pruning dynamically during training to maintain model flexibility.
A plausible implication is that pseudocoresets could be adapted for other Bayesian deep learning settings that benefit from expressive, computationally efficient data-dependent priors.
7. Context and Broader Impact in Variational Learning
ByPE-VAE demonstrates that judiciously optimizing a lightweight, learnable pseudocoreset can enable practical approximation of highly flexible, data-dependent mixture priors, expanding the toolkit available for expressive probabilistic modeling without incurring prohibitive costs. Its success highlights the tradeoff between model complexity, inference tractability, and regularization—a central concern for variational learning frameworks. The methodology is broadly applicable across density estimation, unsupervised embedding, and generative modeling, and serves as a foundation for ongoing research into scalable, highly expressive Bayesian generative models (Ai et al., 2021).