Skip to main content

The Gumbel-max trick for the Bernoulli distribution

In deep learning, the Gumbel-Max Trick is used to sample from a categorical distribution during a forward pass, while keeping the whole differentiable. It is useful in situations where discrete values should be obtained as some intermediate representation. In such cases, the Gumbel-Max trick enables differentiable sampling (or gradients to flow through sampling operations). In this sense, it is similar to the reparameterization trick used in VAEs, but then for categorical distributions. It is popular to use in, for example, discrete VAEs[1]. I will not fully reintroduce the concepts and maths behind the Gumbel-Max trick here. Instead, refer to this blog for a good introduction.

This blogpost will focus on adapting the Gumbel-max trick for the Bernoulli distribution. Usually, the Gumbel-max trick is paired with the softmax to reparameterize a k-class categorical distribution. For example, the bottleneck space of a discrete autoencoder could contain a categorical “concept” per patch of locally aggregated pixels. Other cases might exist where the internals should be either 0 or 1. Take, for instance, a case where the discretized bottleneck space should encode whether a certain concept is present or not[2].

In neural networks, binary (i.e. Bernoulli) random variables are often modeled with the sigmoid function σ:

p=σ(x)=11+ex=exex+1,

and

1p=1σ(x)=1exex+1=ex+1exex+1=1ex+1.

Both quantities would represent the estimated probability (p) of an event occurring, or not, respectively.

The sigmoid operation can be rewritten as a special case of the two-way softmax. Using 2 unnormalized log-probabilities (i.e. logits) l1 and l2, the two-way softmax can compute the same probabilities:

p=el1el1+el2,

and

1p=el2el1+el2.

These equations allow us to trivially see that the sigmoid function is equivalent to a two-way softmax where the two (unnormalized log-probabilites) are l1=x and l2=0. As such, we can apply the Gumbel-Max trick as one does with the softmax operation, and afterwards rewrite it to logistic form.

The Gumbel-Max trick simply involves adding Gumbel-noise to the logits. Let us first sample two points (g1 and g2) from the Gumbel distribution:

g1,g2Gumbel(0,1)

Then, we add these to our logits l1=x and l2=0 in the sigmoid operation:

Gumbel - σ(x)=ex+g1ex+g1+eg2=1(ex+g1+eg2)e(x+g1)=1e(x+g1g2)=σ(x+g1g2).

As such, applying the reparameterization Gumbel-Max trick for Bernoulli random variables (parameterized using the sigmoid operation) involves adding the difference of two Gumbels. This has been described before by Maddison et al.[3].

The python code for the Gumbel-sigmoid operation is:

def gumbel_sigmoid(logits, tau=1, hard=False):
    gumbels_1 = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
        .exponential_()
        .log()
    )
    gumbels_2 = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
        .exponential_()
        .log()
    )

    y_soft = torch.sigmoid((logits + gumbels_1 - gumbels_2) / tau)

    if hard:
        indices = (y_soft > .5).nonzero(as_tuple=True)
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
        y_hard[indices[0], indices[1]] = 1.0
        ret = y_hard - y_soft.detach() + y_soft
    else:
        ret = y_soft
    return ret

References and footnotes

  1. Jang, Eric, Shixiang Gu, and Ben Poole. “Categorical reparameterization with gumbel-softmax.” arXiv preprint arXiv:1611.01144 (2016). ↩︎

  2. Of course, it is also always possible to output 2 neurons, pairing it with the usual Gumbel-softmax operation. ↩︎

  3. Maddison, Chris J., Andriy Mnih, and Yee Whye Teh. “The concrete distribution: A continuous relaxation of discrete random variables.” arXiv preprint arXiv:1611.00712 (2016). ↩︎