- The paper introduces the Gumbel-Softmax distribution as a continuous relaxation of the categorical distribution, enabling gradient-based optimization for discrete latent variables.
- It presents both the Gumbel-Softmax and Straight-Through estimators, which outperform existing single-sample gradient estimators on Bernoulli and categorical tasks.
- The approach achieves significant training speedups in semi-supervised models and competitive performance in structured prediction and VAE experiments.
The paper introduces the Gumbel-Softmax distribution as a continuous relaxation of the categorical distribution, enabling the use of the reparameterization trick for stochastic neural networks with discrete latent variables. The paper addresses the challenge of training stochastic networks with discrete variables, where backpropagation cannot be directly applied due to non-differentiable sampling layers.
Key contributions include:
- The introduction of the Gumbel-Softmax distribution, a continuous distribution over the simplex that approximates categorical samples, allowing for gradient computation via the reparameterization trick. The Gumbel-Softmax distribution is parameterized by class probabilities π1,π2,...,πk and a temperature parameter τ. Samples y are generated using the formula:
yi=∑j=1kexp((log(πj)+gj)/τ)exp((log(πi)+gi)/τ)
where gi are i.i.d samples drawn from Gumbel(0,1). The density of the Gumbel-Softmax distribution is given by:
pπ,τ(y1,...,yk)=Γ(k)τk−1(i=1∑kπi/yiτ)−ki=1∏k(πi/yiτ+1)
where
- k is the number of categories
- τ is the temperature parameter
- πi is the probability of category i
- yi is the i-th component of the sample from the Gumbel-Softmax distribution
- Experimental results demonstrating that the Gumbel-Softmax estimator outperforms existing single-sample gradient estimators on both Bernoulli and categorical variables in structured output prediction and unsupervised generative modeling tasks.
- An efficient approach to training semi-supervised models by using the Gumbel-Softmax estimator to avoid costly marginalization over unobserved categorical latent variables, achieving significant speedups.
The Gumbel-Softmax distribution is derived using the Gumbel-Max trick, where samples z from a categorical distribution with class probabilities π are obtained as:
z=one_hot(iargmax[gi+logπi])
where gi are i.i.d samples drawn from Gumbel(0,1). The softmax function is then used as a continuous, differentiable approximation to argmax. As the temperature τ approaches 0, samples from the Gumbel-Softmax distribution become one-hot, and the distribution becomes identical to the categorical distribution p(z).
The paper introduces two estimators: the Gumbel-Softmax estimator, where categorical samples are replaced with Gumbel-Softmax samples and gradients are computed using backpropagation, and the Straight-Through (ST) Gumbel-Softmax estimator, where samples are discretized using argmax during the forward pass, but the continuous Gumbel-Softmax approximation is used in the backward pass.
The paper compares the proposed Gumbel-Softmax and ST Gumbel-Softmax estimators to other stochastic gradient estimators, including Score-Function (SF), DARN, MuProp, Straight-Through (ST), and Slope-Annealed ST. The estimators are evaluated on structured output prediction and variational training of generative models using the MNIST dataset.
The structured output prediction task involves predicting the lower half of a 28×28 MNIST digit given the top half. The models are trained using an importance-sampled estimate of the likelihood objective. The results show that ST Gumbel-Softmax performs on par with other estimators for Bernoulli variables and outperforms them on categorical variables, while Gumbel-Softmax outperforms other estimators on both Bernoulli and categorical variables.
In the variational autoencoder (VAE) experiments, the objective is to learn a generative model of binary MNIST images. The latent variable is modeled as a single hidden layer with 200 Bernoulli variables or 20 categorical variables. The temperature is annealed during training. The results show that ST Gumbel-Softmax outperforms other estimators for categorical variables, and Gumbel-Softmax outperforms other estimators for both Bernoulli and categorical variables.
The paper also applies the Gumbel-Softmax estimator to semi-supervised classification on the binary MNIST dataset. The original marginalization-based inference approach is compared to single-sample inference with Gumbel-Softmax and ST Gumbel-Softmax. The models are trained on a dataset consisting of 100 labeled examples and 50,000 unlabeled examples. The results demonstrate that Gumbel-Softmax achieves significant speedups in training without compromising generative or classification performance. For instance, training the model with the Gumbel-Softmax estimator is 2× as fast for $10$ classes and 9.9× as fast for $100$ classes.
The paper concludes by highlighting the effectiveness of the Gumbel-Softmax distribution and its corresponding estimators for training stochastic neural networks with discrete latent variables, enabling efficient inference and competitive performance on various tasks.