Papers
Topics
Authors
Recent
2000 character limit reached

Neural Local Wasserstein Regression

Updated 17 November 2025
  • The paper introduces NLWR, a framework that integrates local Fréchet regression with unbalanced optimal transport and neural generative maps to accurately interpolate complex distributions.
  • It employs neural parameterizations to model barycenters as push-forwards from latent distributions, effectively capturing time-evolving phenomena in high-dimensional data.
  • Empirical evaluations in genomics and inverse problems demonstrate that NLWR outperforms baselines using metrics like MMD, EMD, and W2, ensuring robustness under noise.

Neural Local Wasserstein Regression (NLWR) designates a class of nonparametric statistical learning frameworks that deploy neural network parameterizations within kernel-localized regression schemes in Wasserstein space. NLWR primarily addresses estimation, interpolation, and trajectory recovery tasks where either covariates, responses, or both are probability measures—commonly arising in high-dimensional systems biology, distributional inverse problems, and uncertainty quantification. The defining innovation is the synthesis of local Fréchet regression principles, unbalanced optimal transport metrics, and generative neural network maps, enabling accurate, robust interpolation of complex time-evolving distributions such as single-cell RNA-seq profiles or stochastic dynamical model parameters.

1. Problem Formulation and Local Fréchet Regression

The canonical NLWR setting considers measures {νi}i=1N\{\nu_i\}_{i=1}^N observed at covariate times {ti}\{t_i\} (e.g., biological times or model states). The estimation target is a “conditional mean” or barycenter measure ν(t)\nu(t) at any query time tt, formalized as a local Fréchet regression problem in Wasserstein space: ν(t)=argminμP2(Rd)Vt(μ)=argminμi=1Nαi(t)W2,ub2(μ,νi),\nu(t) = \arg\min_{\mu \in \mathcal{P}_2(\mathbb{R}^d)} V_t(\mu) = \arg\min_\mu \sum_{i=1}^N \alpha_i(t)\, W_{2,ub}^2(\mu, \nu_i), where W2,ubW_{2,ub} is the unbalanced squared Wasserstein-2 distance to account for potential mass-variation and outlier-resistance, and αi(t)\alpha_i(t) are kernel-based weights: αi(t)=si(t,h)j=1Nsj(t,h),si(t,h)=1NKh(tit)[μ^2μ^1(tit)],\alpha_i(t) = \frac{s_i(t,h)}{\sum_{j=1}^N s_j(t,h)}, \quad s_i(t,h) = \frac{1}{N} K_h(t_i - t) [\hat{\mu}_2 - \hat{\mu}_1 (t_i-t)], with KhK_h a bandwidth-hh kernel (e.g., Gaussian). This structure localizes estimation to timepoints near tt, and is robustified using local moments from Petersen–Müller (2019).

2. Unbalanced Optimal Transport Distance

NLWR leverages the unbalanced optimal transport (UOT) metric for measure comparison. Classical optimal transport demands strict mass conservation between coupled measures; UOT introduces mass-relaxation via penalized marginal divergences. For measures μ,ν\mu, \nu:

W2,ub2(μ,ν)=infπ:π1=μ12xy2dπ(x,y)+τDψ(π2ν),W_{2,ub}^2 (\mu, \nu) = \inf_{\pi: \pi_1 = \mu} \int \frac{1}{2} \| x - y \|^2\, d\pi(x,y) + \tau\, D_{\psi}(\pi_2 \Vert \nu),

where π\pi is a coupling, π2\pi_2 is the marginal on the second space, τ>0\tau > 0 modulates relaxation, and Dψ(pq)D_\psi(p\|q) is a Csiszár divergence, instantiated as KL-divergence via ψ(t)=tlogtt+1\psi(t) = t\log t-t+1. Mass-relaxed coupling confers robustness against measurement noise and sample-size heterogeneity typical in genomics and stochastic simulation.

3. Neural Generative and Transport Map Parameterization

To operationalize regression over spaces of measures, NLWR models the barycenter (conditional mean measure) as the push-forward of a tractable latent distribution through a generative neural network:

  • Generator Gξ:ZRdG_\xi: \mathcal{Z} \to \mathbb{R}^d, where Z\mathcal{Z} is latent, zN(0,I)z \sim \mathcal{N}(0, I), architecture: 4-layer fully connected (width 256, ReLU, linear output).
  • Transport/Conditional Plan Networks Tθi:RdRdT_{\theta_i}: \mathbb{R}^d \to \mathbb{R}^d (4 layers, width 196, ReLU), parameterizes the optimal transport map from barycenter to observed νi\nu_i.
  • Potential Function vωi:RdRv_{\omega_i}: \mathbb{R}^d \to \mathbb{R}, same architecture, encodes dual potentials in the semi-dual UOT formulation.

At convergence, (Gξ)#N(0,I)ν(t)(G_\xi)_\#\mathcal{N}(0, I) \approx \nu(t), and each TθiT_{\theta_i} approximates the UOT plan from barycenter to νi\nu_i. This generative approach circumvents direct measure-space optimization, exploiting the expressivity of neural networks and enabling backpropagation through the UOT objectives.

4. Training Objectives and Optimization Algorithm

Three nested objectives align the model and transport plans:

  • Marginal Semi-Dual UOT Objective: For each ii,

maxvωiminTθi[Exμ{12xTθi(x)2vωi(Tθi(x))}Eyνiψτ(vωi(y))],\max_{v_{\omega_i}} \min_{T_{\theta_i}} \left[ \mathbb{E}_{x \sim \mu} \left\{\frac{1}{2}\|x - T_{\theta_i}(x)\|^2 - v_{\omega_i}(T_{\theta_i}(x))\right\} - \mathbb{E}_{y \sim \nu_i} \psi_\tau^*(-v_{\omega_i}(y)) \right],

leading to practical losses: LT=1NiExG#μz[12xTθi(x)2vωi(Tθi(x))],\mathcal{L}_T = \frac{1}{N} \sum_i \mathbb{E}_{x \sim G_\#\mu_z} [\frac{1}{2}\|x - T_{\theta_i}(x)\|^2 - v_{\omega_i}(T_{\theta_i}(x))],

Lv=1Ni[Exvωi(Tθi(x))+Eyνiψτ(vωi(y))].\mathcal{L}_v = \frac{1}{N} \sum_i [\mathbb{E}_x\, v_{\omega_i}(T_{\theta_i}(x)) + \mathbb{E}_{y \sim \nu_i} \psi_\tau^*(-v_{\omega_i}(y))].

  • Generator Fixed-Point Objective:

LG=EzμzGξ(z)i=1NαiTθi(Gξ0(z))2,\mathcal{L}_G = \mathbb{E}_{z \sim \mu_z} \left\|G_\xi(z) - \sum_{i=1}^N \alpha_i T_{\theta_i}(G_{\xi_0}(z))\right\|^2,

where ξ0\xi_0 is the generator frozen at previous iteration.

(Optional) Pretraining: GξG_\xi can be initialized as a VAE-normalizing flow on pooled data, with reconstruction and KL-divergence regularization.

The optimization alternates several steps of LT,Lv\mathcal{L}_T, \mathcal{L}_v for transport plans, then LG\mathcal{L}_G to update GξG_\xi. Complexity scales linearly with the number of observed distributions NN; pretraining enhances convergence and mitigates mode-collapse in multimodal settings.

5. Theoretical Guarantees and Convergence Analysis

  • Fixed-point monotonicity: For the update μn+1=Tˉ(μn)\mu_{n+1} = \bar{T}(\mu_n), the total Fréchet objective V(μ)=iαiW2,ub2(μ,νi)V(\mu) = \sum_i \alpha_i W_{2,ub}^2(\mu, \nu_i) is non-increasing: V(μn+1)V(μn)V(\mu_{n+1}) \leq V(\mu_n); minimizers satisfy Tˉ(μ)=μ\bar{T}(\mu^*) = \mu^*.
  • Global Convergence: Strict global convergence is impeded by the KL-divergence term breaking weak-continuity; empirically, the neural solver reaches local minima with proper initialization.
  • This suggests that while global minima are hard to certify theoretically, empirical performance is robust for well-posed biological tasks or regularized dynamical systems.

6. Empirical Evaluation and Applications

NLWR achieves strong performance in both single-cell genomics and generic high-dimensional inverse problems:

  • Datasets:
    • Embryoid (human ESC to EB, d=20d=20 PCs, 16,800\sim16,800 cells, 5 timepoints).
    • Statefate (mouse HSPC differentiation, days 2/4/6, 130,000\sim130,000 cells).
    • Reprogramming (mouse MEF to iPSC, 39 timepoints, 259,000\sim259,000 cells).
  • Protocols: Leave-one-timepoint-out, predict held-out distribution, metrics: maximum mean discrepancy (MMD), Earth Mover’s Distance (EMD ≡ W1W_1), W2W_2.
  • Baselines: MioFlow (dynamic neural ODEs + UOT), midpoint interpolation, nearest-neighbor error, classical OT (τ\tau\to\infty).
  • Findings:
    • NLWR (both OT and UOT variants) outperform all baselines (MMD/EMD/W2W_2).
    • UOT offers enhanced robustness at extrapolation or noisy regimes.
    • VAE-NF pretraining critical in multimodal datasets (e.g., COVID-19 lung atlas).
    • Interpolated distributions trace smooth, biologically plausible cell transitions.
    • Composed UOT transport maps recover cell trajectories better aligned with empirical data geometry.
    • Downstream analyses: dynamic time warping and mixed-effects models applied to inferred trajectories uncover interpretable molecular branch markers (e.g., mesodermal/neuroectodermal signatures).

7. Practical Considerations and Deployment

  • Computational Complexity: Training scales linearly with NN; GPU memory and OT solver efficiency are critical for large-scale genomics.
  • Initialization: VAE-NF pretraining mitigates mode-collapse and expedites convergence.
  • Robustness: Unbalanced tolerance (τ\tau) should be tuned per dataset character—higher values recover classical OT, lower enforce greater flexibility in handling count noise, dropout, and sampling variation.
  • Extensibility: The framework is adaptable for other measures-valued regression—spatiotemporal model parameter reconstruction, high-dimensional image-to-image mappings, and nonparametric time-series analysis. Applications to gene regulation can systematically reveal genetic determinants of cellular differentiation trajectories.

NLWR thus presents a principled, effective fusion of nonparametric kernel smoothing, robust optimal transport, and generative neural modeling for distributional interpolation and regression in biology and beyond. Robustness to sampling variability, flexibility in map parameterization, and empirical accuracy on challenging biological data position NLWR as a reference model for neural distribution-on-distribution regression.

Whiteboard

Follow Topic

Get notified by email when new papers are published related to Neural Local Wasserstein Regression.