Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
125 tokens/sec
GPT-4o
47 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Max-Sliced Wasserstein Distance and its use for GANs (1904.05877v1)

Published 11 Apr 2019 in cs.LG, cs.CV, and stat.ML

Abstract: Generative adversarial nets (GANs) and variational auto-encoders have significantly improved our distribution modeling capabilities, showing promise for dataset augmentation, image-to-image translation and feature learning. However, to model high-dimensional distributions, sequential training and stacked architectures are common, increasing the number of tunable hyper-parameters as well as the training time. Nonetheless, the sample complexity of the distance metrics remains one of the factors affecting GAN training. We first show that the recently proposed sliced Wasserstein distance has compelling sample complexity properties when compared to the Wasserstein distance. To further improve the sliced Wasserstein distance we then analyze its `projection complexity' and develop the max-sliced Wasserstein distance which enjoys compelling sample complexity while reducing projection complexity, albeit necessitating a max estimation. We finally illustrate that the proposed distance trains GANs on high-dimensional images up to a resolution of 256x256 easily.

Citations (189)

Summary

  • The paper introduces max-Sliced Wasserstein Distance (max-SWD), which optimizes projection selection to focus on the most informative differences between high-dimensional distributions.
  • It reduces computational complexity by replacing multiple random projections with a single surrogate optimal direction, enhancing training stability.
  • Empirical evaluations on high-resolution image generation and unsupervised word translation demonstrate improved visual fidelity and efficiency in GAN performance.

Max-Sliced Wasserstein Distance and its Application to GANs

The paper "Max-Sliced Wasserstein Distance and its Use for GANs" addresses the challenges associated with training generative adversarial networks (GANs) on high-dimensional data. GANs, since their inception, have become a pivotal aspect of generative modeling, particularly due to their prowess in applications such as dataset augmentation, image-to-image translation, and feature learning. However, a significant challenge persists in effectively measuring the distance between high-dimensional distributions, a task traditionally handled using metrics such as the Jensen-Shannon divergence or the Wasserstein distance.

Core Contributions

The paper introduces the Max-Sliced Wasserstein Distance (max-SWD), a novel metric that improves upon the sliced Wasserstein distance (SWD) by reducing projection complexity. The SWD evaluates distributional similarity by projecting data onto lower-dimensional subspaces, typically using a random set of projection directions. Although SWD offers advantages in sample complexity, i.e., polynomial complexity compared to the exponential complexity of the Wasserstein distance, its practical effectiveness is diminished by projection complexity necessitating multiple directions to capture distributional differences reliably.

The max-SWD expands this approach by optimizing the choice of projection directions. Instead of averaging over a large number of randomly selected directions, the max-SWD identifies the direction yielding the maximum Wasserstein distance, which implies it identifies the direction where distributions differ the most. This results not only in efficient use of computational resources but also enhances training stability and effectiveness by focusing on the most informative directions.

Theoretical Implications

From a theoretical standpoint, the paper proves that the Max-Sliced Wasserstein Distance satisfies the criteria for being a valid metric. It is also shown to have polynomial sample complexity for certain distribution families, such as Gaussian distributions, similar to the sliced Wasserstein distance.

The research addresses the challenge of estimating the max-SWD through the introduction of a surrogate optimal direction based on discriminator features within the GAN framework. Pretraining this direction using a surrogate loss function ensures that the max-SWD retains generational effectiveness without requiring extensive additional computational cost.

Experimental Evaluation

The practical benefits of max-SWD are illustrated through empirical evaluations on tasks including unsupervised word translation and high-resolution image generation. In unsupervised word translation tasks, max-SWD demonstrates improved performance in aligning word embeddings without parallel corpora, showing competitive retrieval precision in multiple language pairs. This signifies the metric's robust applicability in linguistic domains where quality transformation matrices between language embeddings are critical.

For image generation, the application of max-SWD significantly reduces computational costs compared to SWD by necessitating fewer projections while achieving comparable or superior visual fidelity. On the CelebA-HQ and LSUN Bedrooms datasets at a resolution of 256x256, GANs leveraging max-SWD deliver visually compelling results with enhanced computational efficiency, offering significant reductions in required projection directions relative to traditional implementations.

Concluding Remarks

This paper contributes significantly to the enhancement of GANs by introducing a principled method to minimize distributional distance more effectively in high-dimensional spaces. The proposed max-SWD allows for GANs that are easier to train, computationally efficient, and capable of high-quality generation across diverse data domains, paving the way for further exploration of optimal transport-based metrics in deep generative models. Future research might explore optimizing and extending this metric for other generative frameworks and applications in AI that demand efficient high-dimensional distribution modeling.