import holoviews as hv
hv.extension('bokeh')
import numpy as np

Matching probabilistic models

Probabilistic Generative Models

Let’s say we have \(N\) i.i.d. observations

\[ \mathcal{D} = \{x_1, x_2, \ldots, x_N\}, \]

which can be viewed as the result of sampling from an unknown distribution

\[ x_i \sim p^*(x) \]

Generative modeling provides a systematic framework to understand the data (observations) by modeling the underlying processes that generated the data

The framework can be summarized in three steps

  1. Propose a probabilistic parametric model \(p_\theta (x)\)

  2. Match the proposed model to \(p^*(x)\)

  3. Repeat the previous steps to compare different models and test assumptions

In this lecture we will focus in the second step: matching distributions

Practical uses of generative models

Generative models can be used for

  • data generation (data augmentation)

  • density estimation

  • representation learning

The latter is specially helpful for subsequent classification/prediction tasks in scenarios with large datasets but few labeled examples

In later chapters we will review the state of the art in generative models: deep generative models

Divergences

To match distributions we first need to assess how similar/dissimilar they are

Divergences or statistical distances are functions used to compare distributions. Some example of divergences are

  • The Kullback-Leibler (KL) divergence (f-divergence family)

  • Renyi’s divergence (\(\alpha\)-divergence family)

  • Wasserstein or Earth-moving distance

Caution

Divergences are not always proper metrics. For example the KL divergence is not symmetric

In what follows we will explore how to use the KL divergence for generative model matching/fitting

Forward KL matching

The forward KL divergence between the underlying distribution and the proposed model is

\[ D_{\text{KL}} \left [ p^*(x) || p_\theta(x) \right] = \mathbb{E}_{x \sim p^*(x)} \left [ \log p^*(x) \right ] - \mathbb{E}_{x \sim p^*(x)} \left [ \log p_\theta(x) \right ] \]

Note

We can’t compute the left-hand term because we can’t evaluate \(p^*(x)\)

Hint

The left-hand term does not depend on \(\theta\), we can ignore it when optimizing the forward KL as a function of \(\theta\)

\[\begin{split} \begin{align} \hat \theta &= \text{arg} \min_\theta D_{\text{KL}} \left [ p^*(x) || p_\theta(x) \right] \nonumber \\ &= \text{arg} \min_\theta - \mathbb{E}_{x \sim p^*(x)} \left [ \log p_\theta(x) \right ] \nonumber \\ &= \text{arg} \max_\theta \mathbb{E}_{x \sim p^*(x)} \left [ \log p_\theta(x) \right ] \end{align} \end{split}\]

If we approximate the expected value with an average over our finite dataset we obtain

\[ \mathbb{E}_{x \sim p^*(x)} \left [ \log p_\theta(x) \right ] \approx \sum_{i=1}^N \log p_\theta(x_i), \]

i.e. the log likelihood of \(\theta\)

Important

Minimizing the forward KL divergence is equivalent to maximizing the log likelihood of the model

Example: Forward KL fitting of a univariate Gaussian model

Considering the following observations

import jax.numpy as jnp
import jax.random as random
import numpyro.distributions as distributions

key = random.PRNGKey(1234)
data = distributions.Normal(5, 2).rsample(key, sample_shape=(200,))

counts, bins = jnp.histogram(data, bins=20, density=True)
hist_data = hv.Histogram((counts, bins),  vdims='Density', label='data').opts(alpha=0.5, width=500)
hist_data

Let’s propose a model \(p_\theta(x) = \mathcal{N}(x|\mu, \sigma^2)\) with parameters \(\theta=(\mu, \sigma)\)

The log likelihood in this case is

\[ \mathcal{L}(\mu, \sigma) = \mathbb{E}_{x \sim p^*(x)} \left [ \log p_\theta(x) \right ] \approx \sum_{i=1}^N \log p_\theta(x_i) = -\frac{N}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^N (x_i - \mu)^2 \]

Let’s maximimize the log likelihood using stochastic gradient descent (SGD), for example we update \(\mu\) with

\[ \mu = \mu - \eta \frac{\partial \mathcal{L}}{\partial \mu}, \]

where \(\eta\) is the learning rate

We will use jax autodiff capabilities for this

import jax.numpy as jnp
from jax import grad

def negative_log_likelihood(mu, sigma, data):
    return -jnp.sum(distributions.Normal(mu, sigma).log_prob(data))
    
def gaussian_pdf(mu, sigma, x):
    return jnp.exp(distributions.Normal(mu, sigma).log_prob(x))

mu, sigma = 0., 10.
lr = 1e-2

params = {}
for epoch in range(50):
    params[epoch] = (mu, sigma)
    grad_mu, grad_sigma = grad(negative_log_likelihood, argnums=(0,1))(mu, sigma, data)
    mu -= lr*grad_mu
    sigma -= lr*grad_sigma   

The evolution of the parameters:

mu = np.array([param[0] for param in params.values()])
sigma = np.array([param[1] for param in params.values()])

mu_curve = hv.Curve((list(params.keys()), mu), label='mu').opts(width=500)
sigma_curve = hv.Curve((list(params.keys()), sigma), label='sigma')
mu_curve * sigma_curve

The final model closely follows the underlying distribution

x_plot = jnp.linspace(data.min(), data.max(), 100)
hv.Curve((x_plot, gaussian_pdf(params[49][0], params[49][1], x_plot)), label='model').opts(color='black') * hist_data

We can use the fitted model to sample new data

fitted_mu, fitted_sigma = params[49]
distributions.Normal(fitted_mu, fitted_sigma).rsample(key, sample_shape=(5,))
DeviceArray([6.9511614, 3.3132024, 5.3313627, 1.5486307, 6.072811 ], dtype=float32)

Example: Forward KL fitting of Gaussian model, misspecified case

What happens if we try our Gaussian model on data that does not follow a Gaussian distribution?

In what follows we will use a mixture of gaussians

key = random.PRNGKey(1234)
p = distributions.Categorical(probs=jnp.array([0.8, 0.2])).sample(key, sample_shape=(200,))
key, subkey = random.split(key)
G1 = distributions.Normal(6., 1.5).rsample(key, sample_shape=(200,))
key, subkey = random.split(key)
G2 = distributions.Normal(-6, 1.5).rsample(key, sample_shape=(200,))
data = jnp.concatenate((G1[p==0], G2[p==1]))
counts, bins = jnp.histogram(data, bins=20, density=True)
hist_data = hv.Histogram((counts, bins),  vdims='Density', label='data').opts(alpha=0.5, width=500)
hist_data

Fitting the model is equivalent

mu, sigma = 0., 10.
lr = 1e-2

params = {}
for epoch in range(60):
    params[epoch] = (mu, sigma)
    grad_mu, grad_sigma = grad(negative_log_likelihood, argnums=(0,1))(mu, sigma, data)
    mu -= lr*grad_mu
    sigma -= lr*grad_sigma    

The evolution of the parameters in this case:

mu = np.array([param[0] for param in params.values()])
sigma = np.array([param[1] for param in params.values()])

mu_curve = hv.Curve((list(params.keys()), mu), label='mu').opts(width=500)
sigma_curve = hv.Curve((list(params.keys()), sigma), label='sigma')
mu_curve * sigma_curve

And the fitted model is

x_plot = jnp.linspace(data.min(), data.max(), 100)
hv.Curve((x_plot, gaussian_pdf(params[49][0], params[49][1], x_plot)), label='model').opts(color='black') * hist_data

Note

\(p_\theta\) spreads out trying to cover all the mass of the \(p^*\)

If we go back to the definition of the KL divergence

\[ D_{\text{KL}} \left [ p^*(x) || p_\theta(x) \right] = \int p^*(x) \log \frac{p^*(x)}{p_\theta(x)} \,dx \]

we have that by definition

  • For \(x\) where \(p^* = 0\) I don’t care what \(p_\theta\) does

  • For \(x\) where \(p^* > 0\), if \(p_\theta \to 0\) then \(D_{\text{KL}} \left [ p^*(x) || p_\theta(x) \right] \to \infty\)

Hence the forward KL will put \(p_\theta\) mass in all places where \(p^* > 0\)

Because of this, the forward KL is often referred to as “mean seeking” and “zero-avoiding”

In a sense, the forward KL favors a diverse but not so realistic model

Reverse KL matching

The reverse KL between our model and the underlying distribution is

\[\begin{split} \begin{align} D_{\text{KL}}\left [ p_\theta(x) || p^*(x) \right] &= \int p_\theta(x) \log \frac{p_\theta(x)}{p^*(x)} \,dx \nonumber \\ &= \int p_\theta(x) \log p_\theta(x) \,dx - \int p_\theta(x) \log p^*(x) \,dx \nonumber \end{align} \end{split}\]

Note

The first term corresponds to the differential entropy of the model

Example: What happens in the misspecified case if we use reverse KL instead of the forward KL?

def log_mog_pdf(x):
    mog_pdf = 0.8*jnp.exp(distributions.Normal(6, 1.5).log_prob(x)) + 0.2*jnp.exp(distributions.Normal(-6, 1.5).log_prob(x))
    return jnp.log(mog_pdf[mog_pdf>0.])

def reverse_kl(mu, sigma, key):
    samples = distributions.Normal(mu, sigma).rsample(key, sample_shape=(10,))    
    return jnp.sum(jnp.log(gaussian_pdf(mu, sigma, samples))) - jnp.sum(log_mog_pdf(samples))

mu, sigma = 0., 10.
lr = 1e-2
key = random.PRNGKey(1234)

params = {}
for epoch in range(501):
    params[epoch] = (mu, sigma)
    key, subkey = random.split(key)
    reverse_kl(mu, sigma, key)
    grad_mu, grad_sigma = grad(reverse_kl, argnums=(0,1))(mu, sigma, key)
    mu -= lr*grad_mu
    sigma -= lr*grad_sigma    

The evolution of the parameters in this case:

mu = np.array([param[0] for param in params.values()])
sigma = np.array([param[1] for param in params.values()])

mu_curve = hv.Curve((list(params.keys()), mu), label='mu').opts(width=500)
sigma_curve = hv.Curve((list(params.keys()), sigma), label='sigma')
mu_curve * sigma_curve

And the fitted model:

x_plot = jnp.linspace(data.min(), data.max(), 100)
hv.Curve((x_plot, gaussian_pdf(params[epoch][0], params[epoch][1], x_plot)), label='model').opts(color='black') * hist_data

What do we need to get a small (minimum) reverse KL?

  • For \(x\) where \(p_\theta = 0\) we don’t need to fit \(p^*\) at all

  • For \(x\) where \(p_\theta > 0\) we need to be as close to \(p^*\) as possible

  • For \(x\) where \(p^* = 0\) we need to have \(p_\theta = 0\)

Because of these, the reverse KL is often referred as “mode seeking” and “zero-forcing”

In a sense, the reverse KL favors a realistic but no so diverse model

Note

Minimizing the reverse KL requires maximizing the entropy of \(p_\theta\), i.e. its mass has to be as spread as possible. This helps avoid the singular solution were all the mass of \(p_\theta\) collapsed to one point of \(p^*\)

The reverse KL example is hacked!

Because I’m using \(p^*\) as if I knew it (second term in reverse_kl)

hv.Curve((x_plot, jnp.exp(log_mog_pdf(x_plot)))).opts(color='k') * hist_data

In general we don’t have \(p^*\) so we can’t evaluate \(x\sim p_\theta(x)\) on it

We will see how Variational Inference overcomes this problem by working with lower bounds of the reverse KL

See also

For more details see Chapter 28 of D. Barber’s book and this article by Colin Raffel: “GAN and Divergence Minimization”