Neural Local Wasserstein Regression
- The paper introduces NLWR, a framework that integrates local Fréchet regression with unbalanced optimal transport and neural generative maps to accurately interpolate complex distributions.
- It employs neural parameterizations to model barycenters as push-forwards from latent distributions, effectively capturing time-evolving phenomena in high-dimensional data.
- Empirical evaluations in genomics and inverse problems demonstrate that NLWR outperforms baselines using metrics like MMD, EMD, and W2, ensuring robustness under noise.
Neural Local Wasserstein Regression (NLWR) designates a class of nonparametric statistical learning frameworks that deploy neural network parameterizations within kernel-localized regression schemes in Wasserstein space. NLWR primarily addresses estimation, interpolation, and trajectory recovery tasks where either covariates, responses, or both are probability measures—commonly arising in high-dimensional systems biology, distributional inverse problems, and uncertainty quantification. The defining innovation is the synthesis of local Fréchet regression principles, unbalanced optimal transport metrics, and generative neural network maps, enabling accurate, robust interpolation of complex time-evolving distributions such as single-cell RNA-seq profiles or stochastic dynamical model parameters.
1. Problem Formulation and Local Fréchet Regression
The canonical NLWR setting considers measures observed at covariate times (e.g., biological times or model states). The estimation target is a “conditional mean” or barycenter measure at any query time , formalized as a local Fréchet regression problem in Wasserstein space: where is the unbalanced squared Wasserstein-2 distance to account for potential mass-variation and outlier-resistance, and are kernel-based weights: with a bandwidth- kernel (e.g., Gaussian). This structure localizes estimation to timepoints near , and is robustified using local moments from Petersen–Müller (2019).
2. Unbalanced Optimal Transport Distance
NLWR leverages the unbalanced optimal transport (UOT) metric for measure comparison. Classical optimal transport demands strict mass conservation between coupled measures; UOT introduces mass-relaxation via penalized marginal divergences. For measures :
where is a coupling, is the marginal on the second space, modulates relaxation, and is a Csiszár divergence, instantiated as KL-divergence via . Mass-relaxed coupling confers robustness against measurement noise and sample-size heterogeneity typical in genomics and stochastic simulation.
3. Neural Generative and Transport Map Parameterization
To operationalize regression over spaces of measures, NLWR models the barycenter (conditional mean measure) as the push-forward of a tractable latent distribution through a generative neural network:
- Generator , where is latent, , architecture: 4-layer fully connected (width 256, ReLU, linear output).
- Transport/Conditional Plan Networks (4 layers, width 196, ReLU), parameterizes the optimal transport map from barycenter to observed .
- Potential Function , same architecture, encodes dual potentials in the semi-dual UOT formulation.
At convergence, , and each approximates the UOT plan from barycenter to . This generative approach circumvents direct measure-space optimization, exploiting the expressivity of neural networks and enabling backpropagation through the UOT objectives.
4. Training Objectives and Optimization Algorithm
Three nested objectives align the model and transport plans:
- Marginal Semi-Dual UOT Objective: For each ,
leading to practical losses:
- Generator Fixed-Point Objective:
where is the generator frozen at previous iteration.
(Optional) Pretraining: can be initialized as a VAE-normalizing flow on pooled data, with reconstruction and KL-divergence regularization.
The optimization alternates several steps of for transport plans, then to update . Complexity scales linearly with the number of observed distributions ; pretraining enhances convergence and mitigates mode-collapse in multimodal settings.
5. Theoretical Guarantees and Convergence Analysis
- Fixed-point monotonicity: For the update , the total Fréchet objective is non-increasing: ; minimizers satisfy .
- Global Convergence: Strict global convergence is impeded by the KL-divergence term breaking weak-continuity; empirically, the neural solver reaches local minima with proper initialization.
- This suggests that while global minima are hard to certify theoretically, empirical performance is robust for well-posed biological tasks or regularized dynamical systems.
6. Empirical Evaluation and Applications
NLWR achieves strong performance in both single-cell genomics and generic high-dimensional inverse problems:
- Datasets:
- Embryoid (human ESC to EB, PCs, cells, 5 timepoints).
- Statefate (mouse HSPC differentiation, days 2/4/6, cells).
- Reprogramming (mouse MEF to iPSC, 39 timepoints, cells).
- Protocols: Leave-one-timepoint-out, predict held-out distribution, metrics: maximum mean discrepancy (MMD), Earth Mover’s Distance (EMD ≡ ), .
- Baselines: MioFlow (dynamic neural ODEs + UOT), midpoint interpolation, nearest-neighbor error, classical OT ().
- Findings:
- NLWR (both OT and UOT variants) outperform all baselines (MMD/EMD/).
- UOT offers enhanced robustness at extrapolation or noisy regimes.
- VAE-NF pretraining critical in multimodal datasets (e.g., COVID-19 lung atlas).
- Interpolated distributions trace smooth, biologically plausible cell transitions.
- Composed UOT transport maps recover cell trajectories better aligned with empirical data geometry.
- Downstream analyses: dynamic time warping and mixed-effects models applied to inferred trajectories uncover interpretable molecular branch markers (e.g., mesodermal/neuroectodermal signatures).
7. Practical Considerations and Deployment
- Computational Complexity: Training scales linearly with ; GPU memory and OT solver efficiency are critical for large-scale genomics.
- Initialization: VAE-NF pretraining mitigates mode-collapse and expedites convergence.
- Robustness: Unbalanced tolerance () should be tuned per dataset character—higher values recover classical OT, lower enforce greater flexibility in handling count noise, dropout, and sampling variation.
- Extensibility: The framework is adaptable for other measures-valued regression—spatiotemporal model parameter reconstruction, high-dimensional image-to-image mappings, and nonparametric time-series analysis. Applications to gene regulation can systematically reveal genetic determinants of cellular differentiation trajectories.
NLWR thus presents a principled, effective fusion of nonparametric kernel smoothing, robust optimal transport, and generative neural modeling for distributional interpolation and regression in biology and beyond. Robustness to sampling variability, flexibility in map parameterization, and empirical accuracy on challenging biological data position NLWR as a reference model for neural distribution-on-distribution regression.