Wasserstein Projection Estimator (WPE)
- WPE is a projection-based optimal transport technique that leverages the 2-Wasserstein distance to robustly extract lower-dimensional structures from high-dimensional distributions.
- It employs entropy regularization and Sinkhorn iterations to reformulate the nested optimization, ensuring smooth and computationally efficient solutions.
- Riemannian block coordinate descent algorithms significantly reduce complexity, making WPE practical for applications in robust statistics, dimensionality reduction, and large-scale data analysis.
The Wasserstein Projection Estimator (WPE) is a statistical and optimization framework that leverages the Wasserstein distance to perform robust, low-dimensional, or constraint-respecting inference on probability distributions, particularly in high-dimensional settings. WPEs have emerged in diverse formulations—from optimally projecting empirical distributions onto parametric or structural sets, to serving as the computational core of dimensionality-reducing robust statistics, to acting as a key ingredient in statistical tests and optimization tasks. The formalization and efficient computation of WPEs is vital both for overcoming the curse of dimensionality and for extending the reach of optimal transport–based methods in contemporary data sciences.
1. Core Definition and Formulation
Let be probability measures on . For , let be an element of the Stiefel manifold (the set of real matrices with ). Define the orthogonal projection , and denote the push-forward distribution as the law of for . The -dimensional projection-robust Wasserstein distance (PRW) is
This defines the WPE: it selects the -dimensional subspace maximizing the 2-Wasserstein distance between the projected versions of and . In empirical settings with finite-support measures and , the discrete optimization becomes
This is a max–min problem over matrices and transportation plans under marginal constraints , .
2. Regularization and Computational Reformulation
Direct solution is computationally intensive due to the nested max–min structure. Entropy regularization is introduced: yielding the regularized problem
where controls the entropic bias. The regularized formulation facilitates smoothness and efficient computation, notably enabling the application of Sinkhorn iteration for the inner OT problem.
A dual reformulation introduces dual variables such that
and , with the objective function
The problem is reframed as minimizing jointly in , enabling block coordinate descent strategies with closed-form Sinkhorn updates in (rows) and (columns).
3. Riemannian Block Coordinate Descent Algorithm
The RBCD algorithm alternates between Sinkhorn updates for dual variables and a Riemannian gradient step for the projection , exploiting the geometry of . The pseudocode is:
- Initialize .
- Repeat (for each iteration ):
- Compute
- Euclidean gradient
- Project to the tangent space to get
- Retract via QR or polar decomposition:
Stop when stationarity and marginal constraints reach tolerances.
This ensemble yields a scalable algorithm, with per-iteration cost and total arithmetic complexity to find an -stationary point.
4. Convergence Analysis and Complexity
Under mild Lipschitz-smoothness conditions, RBCD ensures sufficient decrease in , and telescoping yields that an -stationary point is obtained in at most
iterations. Setting yields . This is a significant improvement over the previous RGAS method, which has complexity .
5. Practical Implementation Considerations
- Sample Preparation: Given weighted data , form empirical measures with weights (often $1/n$).
- Regularization Parameter: Set to keep entropy bias within . is typical in practice, often tuned by cross-validation.
- Step Size Tuning: Choose in accordance with Lipschitz and retraction constants, as described in the data.
- Initialization: Use random or PCA-based , .
- Iterative Alternation: Alternate Sinkhorn with Riemannian gradient steps, monitoring dual feasibility and stationarity.
In practice, the adaptive variant RABCD, which incorporates per-coordinate learning rates similar to AdaGrad/Adam, improves empirical runtime by beyond RBCD and speedup over RGAS-type algorithms.
6. Numerical Performance and Empirical Evaluation
Empirical studies on both synthetic and real datasets validate efficiency and accuracy:
- Synthetic Fragmented Hypercube ( up to 500, up to 1000): RBCD recovers ground-truth PRW, matching projection quality of RGAS and yielding $2$– faster runtimes for .
- Gaussian Covariances on Low-rank Subspaces: RBCD maintains projection fidelity under moderate noise, matches RGAS in accuracy, and outperforms in speed.
- Real Data: On movie-scripts and Shakespeare plays (word2vec, ), RBCD achieves speedup over RGAS. On MNIST CNN features (), RBCD again accelerates computation while preserving projection-distance quality.
- Adaptive Variant: RABCD empirically yields further $10$– speedup and $2$– improvement over RGAS/RABCD in large-scale settings.
7. Significance and Applications
WPEs implemented via the (adaptive) Riemannian block coordinate descent represent the first family of algorithms for the projection-robust Wasserstein distance with complexity—substantially improving computational feasibility for large-scale, high-dimensional optimal transport tasks. The formulation is general and encompasses applications in two-sample testing, robust dimensionality reduction, and large-scale data comparison, where identification of maximally separating linear projections under the Wasserstein metric is crucial. The architecture is readily extensible and compatible with GPU acceleration (due to intensive Sinkhorn calls) and can be parallelized over independent random initializations for global optimization. Empirical validation demonstrates that these methods retain the statistical power of projection-based OT distances while rendering computations tractable at scales unattainable by earlier approaches.
Table: Algorithm Comparison and Complexity
| Algorithm | Per-iteration cost | Iteration bound | Arithmetic complexity |
|---|---|---|---|
| RBCD | |||
| RGAS |
The table highlights the substantial improvement of RBCD over RGAS in terms of polynomial complexity in , with empirical speed-ups ranging from $2$ to and robust performance on both synthetic and real high-dimensional datasets (Huang et al., 2020).
Sponsored by Paperpile, the PDF & BibTeX manager trusted by top AI labs.
Get 30 days free