The Gumbel-max trick for the Bernoulli distribution
In deep learning, the Gumbel-Max Trick is used to sample from a categorical distribution during a forward pass, while keeping the whole differentiable. It is useful in situations where discrete values should be obtained as some intermediate representation. In such cases, the Gumbel-Max trick enables differentiable sampling (or gradients to flow through sampling operations). In this sense, it is similar to the reparameterization trick used in VAEs, but then for categorical distributions. It is popular to use in, for example, discrete VAEs[1]. I will not fully reintroduce the concepts and maths behind the Gumbel-Max trick here. Instead, refer to this blog for a good introduction.
This blogpost will focus on adapting the Gumbel-max trick for the Bernoulli distribution.
Usually, the Gumbel-max trick is paired with the softmax to reparameterize a
In neural networks, binary (i.e. Bernoulli) random variables are often modeled with the sigmoid function
and
Both quantities would represent the estimated probability (
The sigmoid operation can be rewritten as a special case of the two-way softmax.
Using 2 unnormalized log-probabilities (i.e. logits)
and
These equations allow us to trivially see that the sigmoid function is equivalent to a two-way softmax where the two (unnormalized log-probabilites) are
The Gumbel-Max trick simply involves adding Gumbel-noise to the logits.
Let us first sample two points (
Then, we add these to our logits
As such, applying the reparameterization Gumbel-Max trick for Bernoulli random variables (parameterized using the sigmoid operation) involves adding the difference of two Gumbels. This has been described before by Maddison et al.[3].
The python code for the Gumbel-sigmoid operation is:
def gumbel_sigmoid(logits, tau=1, hard=False):
gumbels_1 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
gumbels_2 = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
.exponential_()
.log()
)
y_soft = torch.sigmoid((logits + gumbels_1 - gumbels_2) / tau)
if hard:
indices = (y_soft > .5).nonzero(as_tuple=True)
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
y_hard[indices[0], indices[1]] = 1.0
ret = y_hard - y_soft.detach() + y_soft
else:
ret = y_soft
return ret
References and footnotes
Jang, Eric, Shixiang Gu, and Ben Poole. “Categorical reparameterization with gumbel-softmax.” arXiv preprint arXiv:1611.01144 (2016). ↩︎
Of course, it is also always possible to output 2 neurons, pairing it with the usual Gumbel-softmax operation. ↩︎
Maddison, Chris J., Andriy Mnih, and Yee Whye Teh. “The concrete distribution: A continuous relaxation of discrete random variables.” arXiv preprint arXiv:1611.00712 (2016). ↩︎