Wasserstein Distance Joint Estimation
- WDJE is a framework that employs the Wasserstein distance to jointly estimate and align structured distributions across domains.
- It leverages a minimax adversarial training approach with feature extractors, predictors, and joint critics to minimize domain risk gaps and ensure statistical convergence.
- Empirical benchmarks on tasks such as MNIST↔SVHN and Office-Caltech demonstrate WDJE’s superior accuracy and robust performance guarantees.
Wasserstein Distance-based Joint Estimation (WDJE) is a class of statistical and machine learning methodologies that use the Wasserstein (optimal transport) distance as a core metric for joint estimation problems. WDJE provides rigorous, distribution-aware approaches for aligning, transferring, or estimating model parameters and representations in domains where the relationships among structured distributions—especially those on product spaces—are critical. The framework generalizes from domain adaptation and invariance in deep networks, to algebraic statistical models, to robust estimation under distributional uncertainty (Andeol et al., 2021, Çelik et al., 2020, Su, 2019, Prabhat et al., 2024).
1. Mathematical Definition of Joint Wasserstein Distance
WDJE is centered around the -Wasserstein distance () between joint distributions, typically defined on product spaces. Let denote the input/observation space, a representation (feature) space, and an output (label) space. For two domains (e.g., ), with distributions and on , the $1$-Wasserstein distance is given by
0
or, via the dual (Kantorovich–Rubinstein) formulation,
1
where 2 ranges over all real-valued 1-Lipschitz functions on 3 (Andeol et al., 2021).
For algebraic models (e.g., independence models), WDJE seeks the projection of an observed joint distribution 4 onto a statistical model 5 in Wasserstein distance, leading to solutions that are piecewise algebraic in 6 (Çelik et al., 2020).
2. Key Theoretical Results and Risk Bounds
WDJE establishes explicit upper bounds connecting joint Wasserstein distance and standard supervised learning losses. The main result states: 7 where 8 and 9 are convex mixtures incorporating available labeled and unlabeled (imputed) data; 0, 1 denote the joint laws induced by predictor functions.
For regression,
2
For classification,
3
This formalism enables WDJE to control domain gaps through adversarial alignment (minimizing 4 via a dual critic), plus weighted empirical losses (Andeol et al., 2021).
Statistical convergence guarantees are available: per-domain risk gaps are bounded by 5 under Lipschitz and compactness conditions (Andeol et al., 2021).
3. Algorithmic Frameworks and Architectures
WDJE is implemented using a minimax adversarial framework combining feature extractors, predictors, and joint or marginal critics. For domain alignment settings:
- Feature encoder 6 (NN-parameterized)
- Predictor/classifier 7 (MLP/linear head)
- Domain critic 8 (constrained to be 1-Lipschitz, e.g., via spectral normalization)
The loss integrates a Wasserstein dual objective
9
with supervised loss terms, yielding the full min-max objective
0
Alternating optimization updates the critic for adversarial alignment and feature/predictor parameters for supervision (Andeol et al., 2021). Semi-supervised regularizers (entropy minimization, VAT) can be incorporated to further exploit unlabeled data.
For discrete settings (independence models), WDJE reduces to a structured minimax problem over the probability simplex and the polyhedral Lipschitz polytope, leading to piecewise-algebraic solution maps with tractable KKT conditions on small models (Çelik et al., 2020).
4. Empirical Performance and Benchmarks
WDJE has been validated on diverse domain adaptation and representation learning benchmarks:
| Dataset | Architecture | Domains | Label Count/domain | Notable Results |
|---|---|---|---|---|
| MNIST ↔ SVHN | Conv-Large | 2 | 1000–3000 | Highest min-accuracy, tightest 1 gap; WDJE+VAT~94.3% (SVHN target) |
| Office-Caltech | ResNet-18 (DeCAF6) | 4 | 200 | WDJE improves min-accuracy across all four domains |
| PACS | ResNet-18 | 4 | 500 | Joint 2 critic improves worst-case and often average performance |
Ablations indicate that a joint 3 critic dominates marginal Z-only critics in reducing both Wasserstein distance and risk disparities. Visualization (UMAP, relevance propagation) confirms domain-invariant feature clustering while maintaining discriminative ability (Andeol et al., 2021).
5. Relationship to Other Methodologies
WDJE generalizes and strengthens classical domain adaptation via the following distinctions:
- DANN/WDGRL: Enforce adversarial alignment only in the marginal feature (Z) space; do not provide explicit bounds on the joint space or per-domain risk gap; critic ignores 4 (Andeol et al., 2021).
- JDOT: Solves joint optimal transport in primal but is computationally expensive and not easily scalable to high-dimensional deep nets.
- Semi-supervised DA/DG: Use pseudo-labeling or meta-learning, often lack formal joint 5 or performance gap control.
WDJE thus provides the first GAN-style, scalable approach with explicit statistical guarantees for joint Wasserstein minimization between domains. In algebraic statistics, WDJE yields piecewise-algebraic estimators of independence models, enabling explicit description of error regions and algebraic degrees (Çelik et al., 2020).
6. Extensions and Applications
WDJE's theoretical formalism offers principled avenues for future research and application:
- Domain generalization: Control over the convex hull of seen domains can induce generalization to unseen domains.
- Self-supervised/contrastive objectives: Integration with non-supervised learning objectives is possible to enhance representation invariance.
- Fairness/group-robustness: Alignment of group-wise performance or statistical parity can be directly tied to joint Wasserstein metrics.
- Multi-task/structured output: Extension to structured prediction tasks, e.g., aligning 6 across domains to enforce task-consistent invariance.
Additionally, the WDJE model is well-suited for algebraic and combinatorial models where minimization over model spaces (e.g., Segre/Veronese varieties) requires piecewise-algebraic optimization, and may be extended to handle dependence or mixture models beyond independence (Çelik et al., 2020).
7. Computational and Theoretical Aspects in Algebraic and Discrete Models
For discrete 7-variable scenarios, WDJE formalizes projection onto independence models in Wasserstein metric as a non-convex, piecewise-algebraic optimization:
8
where 9 is typically an algebraic variety (e.g., Segre). The solution's combinatorial region is determined by the Lipschitz polytope's faces, and its algebraic complexity is described by the model's polar degrees. Efficient symbolic and numerical algorithms (face enumeration, branch-and-cut, Gröbner basis computation) have been designed for small to moderate 0 (Çelik et al., 2020).
Empirical analysis reveals sharp face counts, type distributions, and degree patterns in tetrahedral and higher-dimensional models, with clear open questions regarding efficiency, sharpness, and connections to MLE/KL projections under noise.
References:
- "Learning Domain Invariant Representations by Joint Wasserstein Distance Minimization" (Andeol et al., 2021)
- "Wasserstein Distance to Independence Models" (Çelik et al., 2020)
- "Wasserstein Distance Guided Cross-Domain Learning" (Su, 2019)
- "Optimal State Estimation in the Presence of Non-Gaussian Uncertainty via Wasserstein Distance Minimization" (Prabhat et al., 2024)