Matching probabilistic models
Contents
import holoviews as hv
hv.extension('bokeh')
import numpy as np
Matching probabilistic models¶
Probabilistic Generative Models¶
Let’s say we have \(N\) i.i.d. observations
which can be viewed as the result of sampling from an unknown distribution
Generative modeling provides a systematic framework to understand the data (observations) by modeling the underlying processes that generated the data
The framework can be summarized in three steps
Propose a probabilistic parametric model \(p_\theta (x)\)
Match the proposed model to \(p^*(x)\)
Repeat the previous steps to compare different models and test assumptions
In this lecture we will focus in the second step: matching distributions
Practical uses of generative models
Generative models can be used for
data generation (data augmentation)
density estimation
representation learning
The latter is specially helpful for subsequent classification/prediction tasks in scenarios with large datasets but few labeled examples
In later chapters we will review the state of the art in generative models: deep generative models
Divergences
To match distributions we first need to assess how similar/dissimilar they are
Divergences or statistical distances are functions used to compare distributions. Some example of divergences are
The Kullback-Leibler (KL) divergence (f-divergence family)
Renyi’s divergence (\(\alpha\)-divergence family)
Wasserstein or Earth-moving distance
Caution
Divergences are not always proper metrics. For example the KL divergence is not symmetric
In what follows we will explore how to use the KL divergence for generative model matching/fitting
Forward KL matching¶
The forward KL divergence between the underlying distribution and the proposed model is
Note
We can’t compute the left-hand term because we can’t evaluate \(p^*(x)\)
Hint
The left-hand term does not depend on \(\theta\), we can ignore it when optimizing the forward KL as a function of \(\theta\)
If we approximate the expected value with an average over our finite dataset we obtain
i.e. the log likelihood of \(\theta\)
Important
Minimizing the forward KL divergence is equivalent to maximizing the log likelihood of the model
Example: Forward KL fitting of a univariate Gaussian model
Considering the following observations
import jax.numpy as jnp
import jax.random as random
import numpyro.distributions as distributions
key = random.PRNGKey(1234)
data = distributions.Normal(5, 2).rsample(key, sample_shape=(200,))
counts, bins = jnp.histogram(data, bins=20, density=True)
hist_data = hv.Histogram((counts, bins), vdims='Density', label='data').opts(alpha=0.5, width=500)
hist_data
Let’s propose a model \(p_\theta(x) = \mathcal{N}(x|\mu, \sigma^2)\) with parameters \(\theta=(\mu, \sigma)\)
The log likelihood in this case is
Let’s maximimize the log likelihood using stochastic gradient descent (SGD), for example we update \(\mu\) with
where \(\eta\) is the learning rate
We will use jax
autodiff capabilities for this
import jax.numpy as jnp
from jax import grad
def negative_log_likelihood(mu, sigma, data):
return -jnp.sum(distributions.Normal(mu, sigma).log_prob(data))
def gaussian_pdf(mu, sigma, x):
return jnp.exp(distributions.Normal(mu, sigma).log_prob(x))
mu, sigma = 0., 10.
lr = 1e-2
params = {}
for epoch in range(50):
params[epoch] = (mu, sigma)
grad_mu, grad_sigma = grad(negative_log_likelihood, argnums=(0,1))(mu, sigma, data)
mu -= lr*grad_mu
sigma -= lr*grad_sigma
The evolution of the parameters:
mu = np.array([param[0] for param in params.values()])
sigma = np.array([param[1] for param in params.values()])
mu_curve = hv.Curve((list(params.keys()), mu), label='mu').opts(width=500)
sigma_curve = hv.Curve((list(params.keys()), sigma), label='sigma')
mu_curve * sigma_curve
The final model closely follows the underlying distribution
x_plot = jnp.linspace(data.min(), data.max(), 100)
hv.Curve((x_plot, gaussian_pdf(params[49][0], params[49][1], x_plot)), label='model').opts(color='black') * hist_data
We can use the fitted model to sample new data
fitted_mu, fitted_sigma = params[49]
distributions.Normal(fitted_mu, fitted_sigma).rsample(key, sample_shape=(5,))
DeviceArray([6.9511614, 3.3132024, 5.3313627, 1.5486307, 6.072811 ], dtype=float32)
Example: Forward KL fitting of Gaussian model, misspecified case
What happens if we try our Gaussian model on data that does not follow a Gaussian distribution?
In what follows we will use a mixture of gaussians
key = random.PRNGKey(1234)
p = distributions.Categorical(probs=jnp.array([0.8, 0.2])).sample(key, sample_shape=(200,))
key, subkey = random.split(key)
G1 = distributions.Normal(6., 1.5).rsample(key, sample_shape=(200,))
key, subkey = random.split(key)
G2 = distributions.Normal(-6, 1.5).rsample(key, sample_shape=(200,))
data = jnp.concatenate((G1[p==0], G2[p==1]))
counts, bins = jnp.histogram(data, bins=20, density=True)
hist_data = hv.Histogram((counts, bins), vdims='Density', label='data').opts(alpha=0.5, width=500)
hist_data
Fitting the model is equivalent
mu, sigma = 0., 10.
lr = 1e-2
params = {}
for epoch in range(60):
params[epoch] = (mu, sigma)
grad_mu, grad_sigma = grad(negative_log_likelihood, argnums=(0,1))(mu, sigma, data)
mu -= lr*grad_mu
sigma -= lr*grad_sigma
The evolution of the parameters in this case:
mu = np.array([param[0] for param in params.values()])
sigma = np.array([param[1] for param in params.values()])
mu_curve = hv.Curve((list(params.keys()), mu), label='mu').opts(width=500)
sigma_curve = hv.Curve((list(params.keys()), sigma), label='sigma')
mu_curve * sigma_curve
And the fitted model is
x_plot = jnp.linspace(data.min(), data.max(), 100)
hv.Curve((x_plot, gaussian_pdf(params[49][0], params[49][1], x_plot)), label='model').opts(color='black') * hist_data
Note
\(p_\theta\) spreads out trying to cover all the mass of the \(p^*\)
If we go back to the definition of the KL divergence
we have that by definition
For \(x\) where \(p^* = 0\) I don’t care what \(p_\theta\) does
For \(x\) where \(p^* > 0\), if \(p_\theta \to 0\) then \(D_{\text{KL}} \left [ p^*(x) || p_\theta(x) \right] \to \infty\)
Hence the forward KL will put \(p_\theta\) mass in all places where \(p^* > 0\)
Because of this, the forward KL is often referred to as “mean seeking” and “zero-avoiding”
In a sense, the forward KL favors a diverse but not so realistic model
Reverse KL matching¶
The reverse KL between our model and the underlying distribution is
Note
The first term corresponds to the differential entropy of the model
Example: What happens in the misspecified case if we use reverse KL instead of the forward KL?
def log_mog_pdf(x):
mog_pdf = 0.8*jnp.exp(distributions.Normal(6, 1.5).log_prob(x)) + 0.2*jnp.exp(distributions.Normal(-6, 1.5).log_prob(x))
return jnp.log(mog_pdf[mog_pdf>0.])
def reverse_kl(mu, sigma, key):
samples = distributions.Normal(mu, sigma).rsample(key, sample_shape=(10,))
return jnp.sum(jnp.log(gaussian_pdf(mu, sigma, samples))) - jnp.sum(log_mog_pdf(samples))
mu, sigma = 0., 10.
lr = 1e-2
key = random.PRNGKey(1234)
params = {}
for epoch in range(501):
params[epoch] = (mu, sigma)
key, subkey = random.split(key)
reverse_kl(mu, sigma, key)
grad_mu, grad_sigma = grad(reverse_kl, argnums=(0,1))(mu, sigma, key)
mu -= lr*grad_mu
sigma -= lr*grad_sigma
The evolution of the parameters in this case:
mu = np.array([param[0] for param in params.values()])
sigma = np.array([param[1] for param in params.values()])
mu_curve = hv.Curve((list(params.keys()), mu), label='mu').opts(width=500)
sigma_curve = hv.Curve((list(params.keys()), sigma), label='sigma')
mu_curve * sigma_curve
And the fitted model:
x_plot = jnp.linspace(data.min(), data.max(), 100)
hv.Curve((x_plot, gaussian_pdf(params[epoch][0], params[epoch][1], x_plot)), label='model').opts(color='black') * hist_data
What do we need to get a small (minimum) reverse KL?
For \(x\) where \(p_\theta = 0\) we don’t need to fit \(p^*\) at all
For \(x\) where \(p_\theta > 0\) we need to be as close to \(p^*\) as possible
For \(x\) where \(p^* = 0\) we need to have \(p_\theta = 0\)
Because of these, the reverse KL is often referred as “mode seeking” and “zero-forcing”
In a sense, the reverse KL favors a realistic but no so diverse model
Note
Minimizing the reverse KL requires maximizing the entropy of \(p_\theta\), i.e. its mass has to be as spread as possible. This helps avoid the singular solution were all the mass of \(p_\theta\) collapsed to one point of \(p^*\)
The reverse KL example is hacked!
Because I’m using \(p^*\) as if I knew it (second term in reverse_kl
)
hv.Curve((x_plot, jnp.exp(log_mog_pdf(x_plot)))).opts(color='k') * hist_data
In general we don’t have \(p^*\) so we can’t evaluate \(x\sim p_\theta(x)\) on it
We will see how Variational Inference overcomes this problem by working with lower bounds of the reverse KL
See also
For more details see Chapter 28 of D. Barber’s book and this article by Colin Raffel: “GAN and Divergence Minimization”