Papers
Topics
Authors
Recent
Search
2000 character limit reached

DeepKENN: Fast W2 Distance Approximation

Updated 9 April 2026
  • The paper presents DeepKENN, a neural architecture that approximates the Wasserstein-2 distance using aggregated multi-layer CNN features.
  • The methodology integrates intermediate feature maps weighted by learnable nonnegative parameters, drastically reducing computational cost compared to OT solvers.
  • Empirical evaluation on MNIST shows a 13% reduction in MSE over a naive baseline, demonstrating its improved metric learning performance.

Deep Kuratowski Embedding Neural Network (DeepKENN) is a neural architecture designed to learn a fast, differentiable surrogate of the Wasserstein-2 distance (W2W_2) for pairs of data distributions. Inspired by the Kuratowski embedding theorem, DeepKENN aggregates information across all intermediate feature maps of a convolutional neural network (CNN) using learnable, nonnegative weights, yielding a metric that approximates the true W2W_2 geometry at significantly reduced computational cost. The approach is positioned as a scalable replacement for computationally expensive optimal transport (OT) solvers in metric learning pipelines, particularly for high-throughput applications on datasets such as MNIST (He, 6 Apr 2026).

1. Mathematical Foundations

DeepKENN is motivated by the Kuratowski–Wojdysławski embedding theorem, which states that any bounded metric space (M,d)(M,d) can be isometrically embedded as a closed subset of the Banach space (M)\ell^\infty(M):

ι ⁣:(M,d)(M) such that ι(x)ι(y)=d(x,y),x,yM.\exists\,\iota\colon (M,d)\hookrightarrow\ell^\infty(M)\ \text{such that}\ \|\iota(x)-\iota(y)\|_{\infty}=d(x,y),\quad\forall\,x,y\in M.

In DeepKENN, the goal is to learn a data-driven map that produces, for any sample pair (x,y)(x,y), a surrogate distance W^2(x,y)\widehat W_2(x,y). This surrogate is defined over the concatenated intermediate feature spaces of a CNN. Letting FϑF_\vartheta denote a CNN with LL layers, and Fϑ(k)(x)RdkF^{(k)}_\vartheta(x)\in\mathbb{R}^{d_k} the feature map at layer W2W_20, DeepKENN’s learned distance is

W2W_21

where each W2W_22 is parameterized via W2W_23, and all W2W_24 are trainable. The network is explicitly trained to approximate the true Wasserstein-2 distance W2W_25 computed by OT solvers:

W2W_26

2. Network Architecture

The DeepKENN structure incorporates a shared encoder backbone—a 5-layer CNN—across all compared architectures:

Layer Operation Output shape Flat Dim
Conv1 1→8 channels, 5×5, ReLU + MaxPool(2) (8,14,14) 1568
Conv2 8→16, 3×3, ReLU + MaxPool(2) (16,7,7) 784
Conv3 16→32, 3×3, ReLU + MaxPool(2) (32,3,3) 288
FC1 288→128, ReLU (128) 128
FC2 128→64 (64) 64

Total flat feature dimension is W2W_27. DeepKENN aggregates the pairwise squared W2W_28 distances between feature vectors at each layer, interpolated by the learnable nonnegative weights W2W_29. To ensure parameter-count parity with alternative models and permit meaningful comparison, both the Naive and DeepKENN models append an extra FC(64(M,d)(M,d)064) + Tanh function head, bringing total parameters to 55,430.

3. Training Methodology

Training involves direct supervised regression on precomputed, exact (M,d)(M,d)1 distances between pairs of MNIST images, where images are normalized to probability measures (sum of pixel intensities equals 1). The dataset comprises 55,000 random image pairs (1,000 per digit-pair); splits are 49,500/2,750/2,750 for train/validation/test, respectively. Ground-truth (M,d)(M,d)2 is computed via the OT solver ot.emd2.

The loss function minimized during training is mean squared error (MSE) between the network output and the true (M,d)(M,d)3:

(M,d)(M,d)4

(M,d)(M,d)5 are enforced to be nonnegative via the softplus activation; no explicit regularizers are applied. Optimization uses Adam (β₁=0.9, β₂=0.999), a learning rate of (M,d)(M,d)6, batch size 256, and default weight initialization (e.g., Kaiming for convolutional layers). Training is conducted for up to 2,000 epochs with the best checkpoint selected by validation MSE.

4. Empirical Evaluation and Computational Complexity

On the MNIST test set, DeepKENN achieves improved metric learning performance relative to a naive, single-layer baseline, while being outperformed by the ODE-KENN variant. Test results on 2,750 sample pairs are:

Model Test MSE ((M,d)(M,d)7) Test MAE ((M,d)(M,d)8) Relative MAE (%)
Naive 4.67 5.249 1.63
DeepKENN 4.07 4.873 1.52
ODE-KENN 3.35 4.406 1.37

DeepKENN yields a 13% reduction in test MSE versus Naive, with ODE-KENN reducing MSE by 18% relative to DeepKENN. ODE-KENN also demonstrates the smallest generalization gap (train/validation MSE), suggesting improved robustness due to implicit ODE smoothness; DeepKENN’s generalization gap is larger with Naive intermediate.

Computationally, forward inference of (M,d)(M,d)9 using DeepKENN or ODE-KENN requires approximately 1–2 (M)\ell^\infty(M)0s per pair on GPU hardware, representing an approximate 5,000(M)\ell^\infty(M)1 speed-up compared to the (M)\ell^\infty(M)25 ms per pair required by the OT solver for exact (M)\ell^\infty(M)3. Total training duration for DeepKENN with standard hyperparameters is about 1 hour for 2,000 epochs on a single GPU.

5. Inference, Deployment, and Properties

The (M)\ell^\infty(M)4 metric is evaluated post-training by extracting feature vectors at each layer for both inputs, computing the pairwise layerwise (M)\ell^\infty(M)5 distances, and aggregating with the trained nonnegative (M)\ell^\infty(M)6. Efficient batching is obtained by stacking pairs along the batch dimension. The following pseudocode represents the core computation:

ι ⁣:(M,d)(M) such that ι(x)ι(y)=d(x,y),x,yM.\exists\,\iota\colon (M,d)\hookrightarrow\ell^\infty(M)\ \text{such that}\ \|\iota(x)-\iota(y)\|_{\infty}=d(x,y),\quad\forall\,x,y\in M.0

Post-training, the backbone and weights are frozen, permitting export as a single, reusable forward function.

(M)\ell^\infty(M)7, as designed, satisfies nonnegativity and the triangle inequality by its construction as an (M)\ell^\infty(M)8 norm within the direct product of intermediate feature spaces. Positive definiteness is generically satisfied (as layer-wise features are not explicitly constrained to be unique for each input), but not imposed as a strict property.

6. Significance and Applications

DeepKENN provides an end-to-end trainable mechanism to approximate Wasserstein-2 distances using deep feature aggregation across CNN layers. By replacing the computational bottleneck of exact OT solvers, it enables scalable deployment in pipelines that depend on repeated pairwise metric computation, such as clustering, retrieval, or generative model evaluation based on Wasserstein distances. The design leveraging the Kuratowski embedding theorem ensures that the learned metric is capable, in principle, of preserving much of the geometric structure inherent to (M)\ell^\infty(M)9, while its architectural simplicity and computational efficiency make it appropriate for integration into modern GPU-accelerated machine learning workflows (He, 6 Apr 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Deep Kuratowski Embedding Neural Network (DeepKENN).