Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
97 tokens/sec
GPT-4o
53 tokens/sec
Gemini 2.5 Pro Pro
44 tokens/sec
o3 Pro
5 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Categorical Reparameterization with Gumbel-Softmax (1611.01144v5)

Published 3 Nov 2016 in stat.ML and cs.LG

Abstract: Categorical variables are a natural choice for representing discrete structure in the world. However, stochastic neural networks rarely use categorical latent variables due to the inability to backpropagate through samples. In this work, we present an efficient gradient estimator that replaces the non-differentiable sample from a categorical distribution with a differentiable sample from a novel Gumbel-Softmax distribution. This distribution has the essential property that it can be smoothly annealed into a categorical distribution. We show that our Gumbel-Softmax estimator outperforms state-of-the-art gradient estimators on structured output prediction and unsupervised generative modeling tasks with categorical latent variables, and enables large speedups on semi-supervised classification.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (3)
  1. Eric Jang (19 papers)
  2. Shixiang Gu (23 papers)
  3. Ben Poole (46 papers)
Citations (5,031)

Summary

  • The paper introduces the Gumbel-Softmax distribution as a continuous relaxation of the categorical distribution, enabling gradient-based optimization for discrete latent variables.
  • It presents both the Gumbel-Softmax and Straight-Through estimators, which outperform existing single-sample gradient estimators on Bernoulli and categorical tasks.
  • The approach achieves significant training speedups in semi-supervised models and competitive performance in structured prediction and VAE experiments.

The paper introduces the Gumbel-Softmax distribution as a continuous relaxation of the categorical distribution, enabling the use of the reparameterization trick for stochastic neural networks with discrete latent variables. The paper addresses the challenge of training stochastic networks with discrete variables, where backpropagation cannot be directly applied due to non-differentiable sampling layers.

Key contributions include:

  • The introduction of the Gumbel-Softmax distribution, a continuous distribution over the simplex that approximates categorical samples, allowing for gradient computation via the reparameterization trick. The Gumbel-Softmax distribution is parameterized by class probabilities π1,π2,...,πk\pi_1, \pi_2, ..., \pi_k and a temperature parameter τ\tau. Samples yy are generated using the formula:

    yi=exp((log(πi)+gi)/τ)j=1kexp((log(πj)+gj)/τ)y_i = \frac{\text{exp}((\log(\pi_i)+g_i)/\tau)}{\sum_{j=1}^k \text{exp}((\log(\pi_j)+g_j)/\tau)}

    where gig_i are i.i.d samples drawn from Gumbel(0,1)\text{Gumbel}(0,1). The density of the Gumbel-Softmax distribution is given by:

    pπ,τ(y1,...,yk)=Γ(k)τk1(i=1kπi/yiτ)ki=1k(πi/yiτ+1)p_{\pi, \tau}(y_1, ..., y_k) = \Gamma(k)\tau^{k-1}\left(\sum_{i=1}^{k} \pi_i/{y_i^\tau}\right)^{-k} \prod_{i=1}^k\left(\pi_i/y_i^{\tau+1}\right)

    where

    • kk is the number of categories
    • τ\tau is the temperature parameter
    • πi\pi_i is the probability of category ii
    • yiy_i is the ii-th component of the sample from the Gumbel-Softmax distribution
  • Experimental results demonstrating that the Gumbel-Softmax estimator outperforms existing single-sample gradient estimators on both Bernoulli and categorical variables in structured output prediction and unsupervised generative modeling tasks.
  • An efficient approach to training semi-supervised models by using the Gumbel-Softmax estimator to avoid costly marginalization over unobserved categorical latent variables, achieving significant speedups.

The Gumbel-Softmax distribution is derived using the Gumbel-Max trick, where samples zz from a categorical distribution with class probabilities π\pi are obtained as:

z=one_hot(arg maxi[gi+logπi])z = \verb|one_hot|\left(\argmax_{i}{\left[ g_i + \log \pi_i \right]}\right)

where gig_i are i.i.d samples drawn from Gumbel(0,1)\text{Gumbel}(0,1). The softmax function is then used as a continuous, differentiable approximation to arg max\argmax. As the temperature τ\tau approaches 0, samples from the Gumbel-Softmax distribution become one-hot, and the distribution becomes identical to the categorical distribution p(z)p(z).

The paper introduces two estimators: the Gumbel-Softmax estimator, where categorical samples are replaced with Gumbel-Softmax samples and gradients are computed using backpropagation, and the Straight-Through (ST) Gumbel-Softmax estimator, where samples are discretized using arg max\argmax during the forward pass, but the continuous Gumbel-Softmax approximation is used in the backward pass.

The paper compares the proposed Gumbel-Softmax and ST Gumbel-Softmax estimators to other stochastic gradient estimators, including Score-Function (SF), DARN, MuProp, Straight-Through (ST), and Slope-Annealed ST. The estimators are evaluated on structured output prediction and variational training of generative models using the MNIST dataset.

The structured output prediction task involves predicting the lower half of a 28×2828 \times 28 MNIST digit given the top half. The models are trained using an importance-sampled estimate of the likelihood objective. The results show that ST Gumbel-Softmax performs on par with other estimators for Bernoulli variables and outperforms them on categorical variables, while Gumbel-Softmax outperforms other estimators on both Bernoulli and categorical variables.

In the variational autoencoder (VAE) experiments, the objective is to learn a generative model of binary MNIST images. The latent variable is modeled as a single hidden layer with 200 Bernoulli variables or 20 categorical variables. The temperature is annealed during training. The results show that ST Gumbel-Softmax outperforms other estimators for categorical variables, and Gumbel-Softmax outperforms other estimators for both Bernoulli and categorical variables.

The paper also applies the Gumbel-Softmax estimator to semi-supervised classification on the binary MNIST dataset. The original marginalization-based inference approach is compared to single-sample inference with Gumbel-Softmax and ST Gumbel-Softmax. The models are trained on a dataset consisting of 100 labeled examples and 50,000 unlabeled examples. The results demonstrate that Gumbel-Softmax achieves significant speedups in training without compromising generative or classification performance. For instance, training the model with the Gumbel-Softmax estimator is 2×2\times as fast for $10$ classes and 9.9×9.9\times as fast for $100$ classes.

The paper concludes by highlighting the effectiveness of the Gumbel-Softmax distribution and its corresponding estimators for training stochastic neural networks with discrete latent variables, enabling efficient inference and competitive performance on various tasks.

Youtube Logo Streamline Icon: https://streamlinehq.com