DeepKENN: Fast W2 Distance Approximation
- The paper presents DeepKENN, a neural architecture that approximates the Wasserstein-2 distance using aggregated multi-layer CNN features.
- The methodology integrates intermediate feature maps weighted by learnable nonnegative parameters, drastically reducing computational cost compared to OT solvers.
- Empirical evaluation on MNIST shows a 13% reduction in MSE over a naive baseline, demonstrating its improved metric learning performance.
Deep Kuratowski Embedding Neural Network (DeepKENN) is a neural architecture designed to learn a fast, differentiable surrogate of the Wasserstein-2 distance () for pairs of data distributions. Inspired by the Kuratowski embedding theorem, DeepKENN aggregates information across all intermediate feature maps of a convolutional neural network (CNN) using learnable, nonnegative weights, yielding a metric that approximates the true geometry at significantly reduced computational cost. The approach is positioned as a scalable replacement for computationally expensive optimal transport (OT) solvers in metric learning pipelines, particularly for high-throughput applications on datasets such as MNIST (He, 6 Apr 2026).
1. Mathematical Foundations
DeepKENN is motivated by the Kuratowski–Wojdysławski embedding theorem, which states that any bounded metric space can be isometrically embedded as a closed subset of the Banach space :
In DeepKENN, the goal is to learn a data-driven map that produces, for any sample pair , a surrogate distance . This surrogate is defined over the concatenated intermediate feature spaces of a CNN. Letting denote a CNN with layers, and the feature map at layer 0, DeepKENN’s learned distance is
1
where each 2 is parameterized via 3, and all 4 are trainable. The network is explicitly trained to approximate the true Wasserstein-2 distance 5 computed by OT solvers:
6
2. Network Architecture
The DeepKENN structure incorporates a shared encoder backbone—a 5-layer CNN—across all compared architectures:
| Layer | Operation | Output shape | Flat Dim |
|---|---|---|---|
| Conv1 | 1→8 channels, 5×5, ReLU + MaxPool(2) | (8,14,14) | 1568 |
| Conv2 | 8→16, 3×3, ReLU + MaxPool(2) | (16,7,7) | 784 |
| Conv3 | 16→32, 3×3, ReLU + MaxPool(2) | (32,3,3) | 288 |
| FC1 | 288→128, ReLU | (128) | 128 |
| FC2 | 128→64 | (64) | 64 |
Total flat feature dimension is 7. DeepKENN aggregates the pairwise squared 8 distances between feature vectors at each layer, interpolated by the learnable nonnegative weights 9. To ensure parameter-count parity with alternative models and permit meaningful comparison, both the Naive and DeepKENN models append an extra FC(64064) + Tanh function head, bringing total parameters to 55,430.
3. Training Methodology
Training involves direct supervised regression on precomputed, exact 1 distances between pairs of MNIST images, where images are normalized to probability measures (sum of pixel intensities equals 1). The dataset comprises 55,000 random image pairs (1,000 per digit-pair); splits are 49,500/2,750/2,750 for train/validation/test, respectively. Ground-truth 2 is computed via the OT solver ot.emd2.
The loss function minimized during training is mean squared error (MSE) between the network output and the true 3:
4
5 are enforced to be nonnegative via the softplus activation; no explicit regularizers are applied. Optimization uses Adam (β₁=0.9, β₂=0.999), a learning rate of 6, batch size 256, and default weight initialization (e.g., Kaiming for convolutional layers). Training is conducted for up to 2,000 epochs with the best checkpoint selected by validation MSE.
4. Empirical Evaluation and Computational Complexity
On the MNIST test set, DeepKENN achieves improved metric learning performance relative to a naive, single-layer baseline, while being outperformed by the ODE-KENN variant. Test results on 2,750 sample pairs are:
| Model | Test MSE (7) | Test MAE (8) | Relative MAE (%) |
|---|---|---|---|
| Naive | 4.67 | 5.249 | 1.63 |
| DeepKENN | 4.07 | 4.873 | 1.52 |
| ODE-KENN | 3.35 | 4.406 | 1.37 |
DeepKENN yields a 13% reduction in test MSE versus Naive, with ODE-KENN reducing MSE by 18% relative to DeepKENN. ODE-KENN also demonstrates the smallest generalization gap (train/validation MSE), suggesting improved robustness due to implicit ODE smoothness; DeepKENN’s generalization gap is larger with Naive intermediate.
Computationally, forward inference of 9 using DeepKENN or ODE-KENN requires approximately 1–2 0s per pair on GPU hardware, representing an approximate 5,0001 speed-up compared to the 25 ms per pair required by the OT solver for exact 3. Total training duration for DeepKENN with standard hyperparameters is about 1 hour for 2,000 epochs on a single GPU.
5. Inference, Deployment, and Properties
The 4 metric is evaluated post-training by extracting feature vectors at each layer for both inputs, computing the pairwise layerwise 5 distances, and aggregating with the trained nonnegative 6. Efficient batching is obtained by stacking pairs along the batch dimension. The following pseudocode represents the core computation:
0
Post-training, the backbone and weights are frozen, permitting export as a single, reusable forward function.
7, as designed, satisfies nonnegativity and the triangle inequality by its construction as an 8 norm within the direct product of intermediate feature spaces. Positive definiteness is generically satisfied (as layer-wise features are not explicitly constrained to be unique for each input), but not imposed as a strict property.
6. Significance and Applications
DeepKENN provides an end-to-end trainable mechanism to approximate Wasserstein-2 distances using deep feature aggregation across CNN layers. By replacing the computational bottleneck of exact OT solvers, it enables scalable deployment in pipelines that depend on repeated pairwise metric computation, such as clustering, retrieval, or generative model evaluation based on Wasserstein distances. The design leveraging the Kuratowski embedding theorem ensures that the learned metric is capable, in principle, of preserving much of the geometric structure inherent to 9, while its architectural simplicity and computational efficiency make it appropriate for integration into modern GPU-accelerated machine learning workflows (He, 6 Apr 2026).