Skip to content

Instantly share code, notes, and snippets.

@dsevero
Last active June 17, 2021 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dsevero/597335355470cac111aa225ceaddee26 to your computer and use it in GitHub Desktop.
Save dsevero/597335355470cac111aa225ceaddee26 to your computer and use it in GitHub Desktop.
Gumbel-Softmax Trick
# We can produce samples from an un-normalized distribution by adding
# iid Gumbel noise together with the argmax operator; which is denoted as the Gumbel-Max Trick.
# However, argmax doesn't produce meaningful gradient signals, so we replace argmax
# by softmax, with a temperature parameter (Gumbel-Softmax Trick).
#
# I didn't invent any of this, all credit goes to the following papers:
# - https://arxiv.org/abs/1611.00712
# - https://arxiv.org/abs/1611.01144
import numpy as np
np.random.seed(1337)
def softmax(p):
p -= p.max(axis=1, keepdims=True)
e = np.exp(p)
return e/e.sum(axis=1, keepdims=True)
n_samples = 100_000
p = np.array([2, 4, 6 ,8]) # un-normalized PMF, shape (4,)
g = -np.log(-np.log(np.random.rand(n_samples, 4)+1e-20)) # iid Gumbel samples, shape (n_samples, 4)
gmt = np.argmax(np.log(p) + g, axis=1) # Gumbel-Max Trick, shape (n_samples,)
gst = softmax((np.log(p) + g)/1e-3) # Gumbel-SoftMax Trick, shape (n_samples, 4), meaningful gradients!
# Verify
_, counts = np.unique(s, return_counts=True)
gst.mean(axis=0), counts/n_samples, p/p.sum()
# Output:
# (array([0.0992882, 0.1995264, 0.2990391, 0.4021463]),
# array([0.09929, 0.19955, 0.29897, 0.40219]),
# array([0.1, 0.2, 0.3, 0.4]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment