# Gumbel-Softmax

## Why are we interested in Gumbel-Softmax?

Gumbel-Softmax makes categorical sampling *differentiable*. Why is that important? Because sometimes we want to optimize discrete or categorical choices with gradient methods. One of the applications is that we can use it in differentiable neural architecture search when we want to decide the operation on one edge.

The original DARTS paper did not introduce Gumbel-Softmax. To decide which operation should be on one edge, it does the following things:

- During search, the output of each operation is joined (weighted sum) with a group of learnable parameters ($\alpha_1$, $\alpha_2$, …, $\alpha_n$), $n$ is the number of operation choices), like the color lines in the above figure.
- When search is done, a final architecture is derived, meaning that we have to choose an operation for each edge. How do we do that? We use $argmax$ to choose the operation with the largest $\alpha$.

But, doing so causes a “gap”: the original network has a weighted sum of all operations during training, but the derived network loses all of the unchosen operation’s information, maybe the derived network has a different final performance than expected.

“If only we can train *one* network each time during search!” one might say. Indeed, selecting one operation for an edge during search instead of using a weighted sum is a good idea. But, sampling itself is not differentiable, we can’t update the parameters with gradient descent.

Don’t worry, Gumbel-Softmax is coming to rescue.

## Gumbel distribution

First we need to understand what is Gumbel distribution.

Gumbel distribution stands for **Generalized Extreme Value distribution Type-I**. It is used to model the distirbution of the maximum of various distributions. Gumbel distribution is a particular case of the generalized extreme value distribution.

Gumbel distribution has two parameters:

- $\mu$: location
- $\beta$: scale, larger $\beta$ leads to fatter distribution.

The PDF of Gumbel distribution is:

where $z=\frac{x-\mu}{\beta}$. We see it has two exponential functions, so it is also known as the double-exponential distribution.

## What is Gumbel-Softmax?

When we “choose an operation for one edge”, what we are doing is actually drawing a sample from a categorical distribution. A categorical distribution means a random variable can take the value of many discrete categories, with each case’s probability known.

Let’s say, we have a categorical variable $z$ with class probabilities $\pi_1$, $\pi_2$, $\ldots$, $\pi_k$. To draw samples $z$ from the categorical distribution, we use the Gumbel-Max trick:

where $g_1 \ldots g_k$ are i.i.d samples drawn from $Gumbel(0,1)$ distribution. $\arg \max$ is defined as:

as $M$ is the maximum of $f(x)$. The problem here is $\arg \max$ is no way differentiable, so we use softmax instead, which is a *softened* argmax function, making it differentiable. So we get a softened $z$, denoted as $y$, the elements of $y$ are:

This is the Gumbel-Softmax trick.

Here we added a temperature variable $\tau$ to control the *softness* of softmax.

Let me explain. Softmax is just a normalized exponential function. At high temperature, every element $(log(\pi_i) + g_i)$ is divided by a big number, making them all much smaller, so the absolute difference between every element is also smaller, so the distribution is closer to uniform. In contast, at low temperature (smaller than 1), dividing $\tau$ makes the elements bigger, so the difference between elements are also bigger, making the distribtion “sharper”.

We can think of as heating a crystal. Higher temperature melts the crystal, it becomes more soft (closer to uniform distribution). When it cools down, it becomes hardened and sharp.

We often use an annealing schedule with softmax, starting from a high temperature and gradually cooling it down. This is because we want every choice of operator sufficiently trained at early stage, and gradually forms a preference at later stage.

## The Gumbel-Softmax gradient estimator

In some literature we may see the jargon “Gumble-Softmax gradient estimator”. It is well explained in the original paper:

“The Gumbel-Softmax distribution is smooth for $\tau > 0$, and therefore has a well-defined gradient $\frac{\partial y}{\partial \pi}$ with respect to the parameter $\pi$. Thus, by replacing categorical samples with Gumbel-Softmax samples we can use backpropagation to compute gradients. We denote this procedure of replacing non-differentiable categorical smaples with differetiable approximation during training as the Gumbel-Softmax estimator.””

### Straight-through Gumbel-Softmax gradient estimator

“Straight-through” means that only backward gradient propagation uses the differentiable variable, the forward pass still uses categorical variable.