Wasserstein Metric-based Dataset Distillation
- WMDD is a dataset distillation method that leverages the Wasserstein metric from optimal transport to closely match the geometric structure of true data distributions.
- It employs both feature-space and latent-space formulations, using barycenter computation and quantization techniques to optimize synthetic sample generation.
- Extensive experiments on high-resolution datasets showcase WMDD's superior sample efficiency and accuracy compared to traditional distillation approaches.
Wasserstein Metric-based Dataset Distillation (WMDD) refers to a family of dataset distillation methods that employ objectives derived from optimal transport theory, primarily leveraging the Wasserstein metric and its generalizations to directly match the geometric structure of the true data distribution in either feature or latent space. These approaches seek to construct compact synthetic datasets on which a model, when trained, exhibits efficacy comparable to training on the full dataset, with a primary focus on high-fidelity distribution matching and sample efficiency across challenging domains such as high-resolution image recognition.
1. Mathematical Foundations: Optimal Transport and the Wasserstein Metric
The central mathematical underpinning of WMDD is the -Wasserstein distance , which quantifies the minimum “effort” required to transport mass from one probability measure to another over a metric space. For two distributions on with cost , the definition is:
where denotes all couplings with marginals and (Liu et al., 2023). This distance reflects the intrinsic geometry of data distributions, unlike maximum mean discrepancy or other non-geometric divergences.
In multi-class or latent-embedding scenarios, the concept of the Wasserstein barycenter arises: given distributions , the barycenter minimizes the average . Explicitly:
This centroid-like object enables synthetic features to approximate the global arrangement of true data (Liu et al., 2023).
Extensions to higher-order structures, as in probability distributions over distributions, have led to the Wasserstein-over-Wasserstein (WoW) metric, operating on the space and defined as:
enabling analysis and flows on datasets described as mixtures of class-conditional measures (Bonet et al., 9 Jun 2025).
2. Loss Formulations and Optimization in WMDD
WMDD loss functions distinctively incorporate the Wasserstein geometry. The archetypal feature-space approach (Liu et al., 2023) considers a synthetic set per class , with barycenter features , and minimizes:
Here, is a feature extractor and are per-class, per-layer BatchNorm statistics, ensuring both global and intra-class variation are matched.
In the latent-quantization formulation (Tan et al., 13 Jan 2025), the distilled measure in latent space approximates the true measure via optimal quantization:
with the minimizer aligning support points to minimize . The distilled images are then generated by a decoder, pushing the quantized latent measure forward to image space.
Alternative functionals, such as the Maximum Mean Discrepancy lifted to the space of measures via Sliced-Wasserstein kernels, enable WMDD as a Wasserstein-over-Wasserstein gradient flow within spaces of random measures, supporting finer control over both intra-class and inter-class alignment (Bonet et al., 9 Jun 2025).
3. Algorithmic Procedures for WMDD
A generalized algorithmic archetype for WMDD methods includes:
- Feature or Latent Extraction: Obtain embeddings for all real data, either as features from pretrained classifiers (Liu et al., 2023) or as latent codes from autoencoders/diffusion encoders (Tan et al., 13 Jan 2025).
- Wasserstein Barycenter or Quantization: Compute classwise barycenter features or perform k-means/CLVQ in latent space to define atoms of the synthetic distribution.
- Synthetic Sample Optimization:
- In feature-based WMDD, initialize synthetic images and update them by gradient descent to minimize the barycenter-alignment and BatchNorm losses, propagating gradients through the feature extractor (Liu et al., 2023).
- In latent-space WMDD, generate synthetic images by decoding quantized latent points; weights are determined by cluster assignments (Tan et al., 13 Jan 2025).
- Gradient Flow Over Measure Spaces: In WoW-based methods, update atomic representatives of each class's measure by backpropagating discrete Wasserstein or MMD-SW gradients (Bonet et al., 9 Jun 2025).
- Student Training: Train learners solely on the synthetic set, optionally using soft labels from a teacher (Tan et al., 13 Jan 2025).
Efficient implementation relies on specialized solvers for OT (e.g., entropic Sinkhorn (Liu et al., 2023)), cluster quantization, or automatic differentiation for measure operations in WoW flows (Bonet et al., 9 Jun 2025).
4. Theoretical Guarantees and Consistency
WMDD establishes rigorous convergence and approximation guarantees under several frameworks:
- Finite Support Wasserstein Approximation: Optimal quantization in latent space is equivalent to minimizing for measure supported on atoms. Consistency results show that the induced empirical risk and gradient statistics on the distilled set converge to those of the full distribution at rate (Tan et al., 13 Jan 2025).
- Diffusion-based Propagation: Under latent diffusion priors, the pushforward of the quantized latent measure contracts errors under reverse SDE dynamics, maintaining alignment with the target distribution over the generative process (Tan et al., 13 Jan 2025).
- Gradient Flows in Random-Measure Spaces: The WoW framework provides a formal differential structure and justifies that time-discretized flows decrease the objective (MMD with Sliced-Wasserstein kernels), producing synthetic sets that align with the true mixture (Bonet et al., 9 Jun 2025).
These theoretical structures distinguish Wasserstein-based formulations from earlier MMD-based or purely heuristic distillation schemes.
5. Experimental Results and Empirical Comparison
Empirical validation across multiple high-resolution datasets demonstrates the efficacy and efficiency of WMDD:
- ImageNette, TinyImageNet, ImageNet-1K: WMDD consistently outperforms MMD-based (e.g., DM), trajectory-matching, and previous state-of-the-art methods (MTT, SRe²L) both in low (1–10 IPC) and moderate (50–100 IPC) synthetic data regimes (Liu et al., 2023, Tan et al., 13 Jan 2025).
- Quantitative Benchmarks: For example, on ImageNet-1K, 50 IPC yields top-1 accuracy of 57.6% for WMDD versus 52.8% for SRe²L, with similar trends on other datasets (Liu et al., 2023).
- Ablations: Both barycenter (feature) alignment and per-class BatchNorm matching are critical; omitting either sharply reduces downstream student accuracy (Liu et al., 2023).
- WoW Flows and Domain Adaptation: On toy and domain adaptation benchmarks (MNIST/Fashion-MNIST, SVHN/CIFAR10), Wasserstein-over-Wasserstein flows using Riesz-Sliced-Wasserstein kernels form semantically correct clusters efficiently, outperforming prior OTDD and product-MMD methods (Bonet et al., 9 Jun 2025).
- Complexity: All variants achieve linear scaling in (real data size), constant or sublinear dependence on (number of synthetic samples), and no require backpropagation through the teacher network or decoder (Liu et al., 2023, Tan et al., 13 Jan 2025).
| Dataset & Setting | Method | Top-1 Accuracy (%) |
|---|---|---|
| ImageNet-1K, 50 IPC | SRe²L | 52.8 |
| WMDD | 57.6 | |
| ImageNette, 10 IPC | SRe²L | 54.2 |
| WMDD | 64.8 |
6. Practical Considerations, Scalability, and Extensions
WMDD methods are designed for scalability and practical deployment on large-scale problems:
- Computation: Free-support barycenter computation or latent k-means is fast and GPU-batchable; overall pipeline feasible for full ImageNet with 2000–5000 distillation steps (Liu et al., 2023, Tan et al., 13 Jan 2025).
- Memory: WMDD's complexity is dominated by synthetic set size (number of support points), not by size of original data, as backpropagation through entire real dataset is avoided, and OT/quantization steps are memory-efficient (Tan et al., 13 Jan 2025).
- Feature Extractor/Autoencoder Choice: The geometric quality and semantic richness of distilled data depend on the feature or latent encoder. Pretrained ResNet-18 features are common, but deeper or more discriminative models can further improve the coverage of barycenter features or latent codes (Liu et al., 2023).
- BatchNorm Statistics: Matching per-class, not global, BN statistics preserves intra-class diversity and prevents collapse of synthetic distributions (Liu et al., 2023).
- Kernel Choices in WoW: Sliced-Wasserstein (SW) kernels inject geometric sensitivity into MMD objectives, offering a balance between computational tractability and OT faithfulness (Bonet et al., 9 Jun 2025).
7. Extensions and Connections to Related Methods
WMDD unifies and generalizes several dataset distillation and distribution-matching paradigms:
- Disentangled Methods: Existing distribution-matching techniques, such as D4M, may be interpreted as uniform-weight specializations of WMDD with explicit optimal quantization or barycenter objectives (Tan et al., 13 Jan 2025).
- Gradient Flows and Random Measures: The introduction of gradient flows in opens paths to nonparametric, fully measure-theoretic approaches to dataset condensation, with applications in transfer learning, domain adaptation, and generative modeling (Bonet et al., 9 Jun 2025).
- Theoretical Linkages: The pushforward quantization consistency analysis links WMDD to optimal transport theory, quantization rates, and diffusion priors, distinguishing it from maximum-mean-discrepancy baselines (Tan et al., 13 Jan 2025).
- Empirical Performance: WMDD has achieved state-of-the-art results in high-resolution, multi-class, and cross-architecture distillation with robust generalization to deep and transformer-based learners (Liu et al., 2023).
Wasserstein Metric-based Dataset Distillation thus constitutes a rigorous, geometrically principled, and computationally efficient framework for dataset condensation, setting benchmarks for both theoretical understanding and empirical performance across modern machine learning tasks.