Variational Autoencoder

An LVM is defined by the joint density between observed and latent variables

\[ p(\textbf{x}, \textbf{z}) = \prod_{i=1}^N p(\textbf{x}_i|\textbf{z}_i) p(\textbf{z}_i) \]
  • If we use the PCA recipe (Linear mapping, Gaussian likelihood and Gaussian prior) we obtain an analytical Gaussian posterior and evidence

  • If we use a more complex (non-linear) mapping, e.g. a neural network, the posterior and evidence may not be tractable

In the latter case, we can use Variational inference (VI), i.e. we propose an approximate posterior and maximize the ELBO

\[\begin{split} \begin{split} \log p(x) \geq \mathcal{L}(\phi) &= \mathbb{E}_{z\sim q_\phi(z|x)} \left[\log \frac{p(x, z)}{q_\phi(z|x)}\right] \\ &= \int q_\phi(z|x) \log \frac{p(x, z)}{q_\phi(z|x)} dz \end{split} \end{split}\]

to find the best parameters \(\hat \phi\).

Important

In this lesson we review the Variational AutoEncoder (VAE) which combines amortized inference and VI

The VAE is an LVM where deep neural networks are used to model the conditional distributions between latent \(z\) and observed \(x\) variables

It was proposed simultaneously by (Kingma and Welling, ICLR, Dec. 2013) and (Rezende et al, ICML, Jan. 2014) perhaps sparking the revived interest into Deep Learning with Approximate Bayesian Inference that we see today

The difference with a regular autoencoder is that the latent (code) is now a stochastic variable

  • a prior distribution is placed on \(z\): \(p(z)\)

  • a neural network is used to model the parameters of the likelihood: \(p_\theta(x | z)\)

  • a neural network is used to model the parameters of the approximate posterior: \(q_\phi(z|x)\)

Note

The weight and biases of the networks (\(\theta\) and \(\phi\)) are deterministic, i.e. VAE is not a “fully” bayesian neural network

In what follows we will review the assumptions and the key contributions of this work to the field of Bayesian Neural Networks

Assumptions in VAE

VAE assumes a particular prior and approximate posterior:

  • The latent variable has a standard Gaussian prior (the same as in PCA)

  • The approximate posterior is a factorized (diagonal) Gaussian

Mathematically this is

\[ p(z_i) = \mathcal{N}(0, I) \]

and

\[ q(z_i|x_i) = \mathcal{N}(\mu_i, I \sigma_i^2) \]

The likelihood function is chosen depending on the data

  • For continuous data the likelihood is typically set as a spherical Gaussian or diagonal Gaussian

  • For binary data the likelihood is typically set as Bernoulli

Mathetamitcally (for the first case) this is

\[ p(x_i|z_i) = \mathcal{N}(\hat \mu_i, I \hat \sigma_i^2) \]

for \(x_i \in \mathbb{R}^D\).

Details on the VAE training

The VAE is trained by maximizing the ELBO, which in this case is

\[\begin{split} \begin{align} \mathcal{L}(\theta, \phi) &= \mathbb{E}_{z\sim q_\phi(z|x)} \left [\log p_\theta(x|z) + \log p(z) - \log q_\phi(z|x) \right ] \nonumber \\ &= \mathbb{E}_{z\sim q_\phi(z|x)} \left [\log p_\theta(x|z) \right ] - D_{KL}\left[ q_\phi(z|x) || p(z) \right]\nonumber \end{align} \end{split}\]

By maximizing the ELBO we

  • Maximize the log likelihood when sampling from the approximate posterior: Faithfull data reconstructions

  • Minimize the divergence between the approximate posterior and prior: Regularization for the posterior

The ELBO is typically optimized via gradient ascent updates for \(\theta\) and \(\phi\)

\[ \theta_{t+1} = \theta_{t} + \eta \nabla_\theta \mathcal{L}(\theta_{t}, \phi_{t}) \]
\[ \phi_{t+1} = \phi_{t} + \eta \nabla_\phi \mathcal{L}(\theta_{t}, \phi_{t}) \]

Let’s review how to obtain the derivates of the ELBO with respect to \(\theta\) and \(\phi\)

Amortization

In the previous formulation the amount of variational parameters, i.e. \(\mu_i\) and \(\sigma_i\) scales linearly with the number of samples in the dataset. This is impractical for large datasets.

Amortization solves this problem by having a function that maps data to parameters. In the particular case of VAE we have

\[ \mu_i, \sigma_i = g_\phi(x_i) \]

where \(g_\phi(\cdot)\) is the encoder network and

\[ \hat \mu_i, \hat \sigma_i = f_\theta(z_i) \]

where \(f_\theta(\cdot)\) is the decoder network (for a diagonal Gaussian likelihood)

Note

VAE is an example of amortized variational inference

The derivative with respect to \(\theta\)

The only term that depends on \(\theta\) is \(p_\theta(x|z)\), then

\[\begin{split} \begin{split} \nabla_\theta \mathcal{L}(\theta, \phi) &= \nabla_\theta \mathbb{E}_{z\sim q_\phi(z|x)}\left [\log p_\theta(x|z)\right ] \\ &= \mathbb{E}_{z\sim q_\phi(z|x)} \left [\nabla_\theta \log p_\theta(x|z)\right ] \end{split} \end{split}\]

where the gradient can be swapped with the expectation operator

The expectation can be estimated via monte-carlo integration as

\[\begin{split} \begin{split} \mathbb{E}_{z\sim q_\phi(z|x)} \left [\nabla_\theta \log p_\theta(x|z)\right ] &= \int q_\phi(z|x) \nabla_\theta \log p_\theta(x|z) \,dz \\ &\approx \frac{1}{S} \sum_{k=1}^S \nabla_\theta \log p_\theta(x|z^{(k)}), \quad z^{(k)} \sim q_\phi(z|x) \end{split} \end{split}\]

The derivative wrt to \(\phi\)

Let’s write \(h(z) = \log p_\theta(x|z)\), a function of \(z\) that does not depend on \(\phi\). In this case

\[\begin{split} \begin{split} \nabla_\phi \mathcal{L}(\theta, \phi) &= \nabla_\phi \mathbb{E}_{z\sim q_\phi(z|x)}\left [h(z) \right ] \\ &= \nabla_\phi \int q_\phi(z|x) h(z) dz \nonumber \\ &= \int h(z) \nabla_\phi q_\phi(z|x) dz \end{split} \end{split}\]

but we cannot approximate via Monte Carlo integration anymore.

A classical solution to this is the REINFORCE algorithm, which is based on the identity

\[ \nabla_\phi\log q_\phi(z) = \frac{1}{ q_\phi(z)} \nabla_\phi q_\phi(z) \]

then

\[\begin{split} \begin{split} \nabla_\phi \mathbb{E}_{z\sim q_\phi(z|x)}\left [h(z)\right ] &= \int h(z) \nabla_\phi q_\phi(z|x) dz \\ &= \int q_\phi(z) h(z) \nabla_\phi\log q_\phi(z) dz \\ &\approx \frac{1}{S} \sum_{k=1}^S h(z^{(k)}) \nabla_\phi \log q_\phi(z^{(k)}|x) \quad z^{(k)} \sim q_\phi(z|x) \end{split} \end{split}\]

but in practice, training with this estimator is very hard: the REINFORCE estimator is unbiased but it has a very high variance

For the particular case of VAE there is a low variance (and very elegant) alternative

The reparameterization trick

In VAE the latent variable is distributed as

\[ z \sim \mathcal{N}(\mu_\phi(x), \sigma_\phi^2 (x) ) \]

this can be rewriten as first sampling from a standard Gaussian

\[ \epsilon \sim \mathcal{N}(0, I) \]

and then applying a transformation

\[ z = r(\phi, \epsilon) = \mu_\phi (x) + \epsilon \sigma_\phi (x) \]

Note

This is usually referred to as “the reparameterization trick”

Using this the expectation of \(h(z)\) is rewritten as

\[ \mathbb{E}_{z\sim q_\phi(z|x)}\left [h(z) \right ] = \mathbb{E}_{\epsilon\sim \mathcal{N}(0, I)}\left [ h(r(\phi, \epsilon)) \right ] \]

Note

The expectation does not depend on \(\phi\)

Hence the following estimator for its gradient can be used

\[\begin{split} \begin{split} \nabla_\phi \mathbb{E}_{\epsilon\sim \mathcal{N}(0, I)}\left [ h(r(\phi, \epsilon)) \right ] &= \mathbb{E}_{\epsilon\sim \mathcal{N}(0, I)}\left [ h'(r(\phi, \epsilon)) \nabla_\phi r(\phi, \epsilon) \right ] \\ &\approx \frac{1}{S} \sum_{k=1}^S h'(r(\phi, \epsilon^{(k)})) \nabla_\phi r(\phi, \epsilon^{(k)}) \quad \epsilon^{(k)} \sim \mathcal{N}(0,I) \end{split} \end{split}\]

which has a much lower variance than REINFORCE

More variance reduction: Closed-form terms

We have focused on the left hand term of the ELBO

\[ \mathcal{L}(\theta, \phi) = \mathbb{E}_{z\sim q_\phi(z|x)} \left [\log p_\theta(x|z) \right ] - D_{KL}\left[ q_\phi(z|x) || p(z) \right] \]

The right hand term is the KL divergence between two multivariate Gaussian distributions. This has a closed analytical solution

\[ D_\text{KL}\left[q_\phi(z|x) || p(z) \right] = \frac{1}{2}\sum_{j=1}^K \left(\mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1 \right) \]

where \(K\) is the dimensionality of the latent variable

The derivatives of this expression are straighforward and the variance is low

Writing a VAE in Numpyro

Neural networks written in flax can be “registered” into numpyro models using flax_module. This lifts the neural network parameters as numpyro.params

Let’s start by writing the generative model of VAE. The model starts from the latent variable, obtain the logits of the images using the decoder network and then match it with the observed data using a Bernoulli likelihood

import numpyro 
numpyro.set_platform("cpu") 
numpyro.set_host_device_count(2) 
from typing import Sequence, Callable
import jax.numpy as jnp
import flax.linen as nn
import numpyro.distributions as dists
from numpyro.contrib.module import flax_module

class Decoder(nn.Module):
    output_dim: int
    hidden_units: Sequence[int]
    activation: Callable = nn.relu
        
    @nn.compact
    def __call__(self, x):
        for neurons in self.hidden_units:
            x = self.activation(nn.Dense(neurons)(x))
        return nn.Dense(self.output_dim)(x)    
    
def model(batch, hidden_units, latent_dim, sample=False):
    batch_dim, out_dim = jnp.shape(batch)
    decoder = Decoder(hidden_units=hidden_units, output_dim=out_dim)
    decode = flax_module("decoder", decoder, input_shape=(batch_dim, latent_dim))
    with numpyro.plate("batch", size=batch_dim):
        z = numpyro.sample("z", dists.Normal(0, 1).expand([latent_dim]).to_event(1))
        logits = decode(z)
        if not sample:
            x = numpyro.sample("x", dists.BernoulliLogits(logits).to_event(1), obs=batch)
        else:
            x = numpyro.sample("x", dists.BernoulliLogits(logits).to_event(1))

Note

Plates are used to make the model conditionally independent on the batch dimension (leftmost dimension)

The encoder network amortizes the parameters of the approximate posterior (normal). The Encoder model has two outputs, the second being activated to be non-negative

The guide receives a minibatch of data and return a sample from the approximate posterior

class Encoder(nn.Module):
    hidden_units: Sequence[int]
    latent_dim: int
    activation: Callable = nn.relu
    
    @nn.compact 
    def __call__(self, x):
        for neurons in self.hidden_units:            
            x = self.activation(nn.Dense(neurons)(x))
        loc = nn.Dense(self.latent_dim)(x)
        scale = nn.softplus(nn.Dense(self.latent_dim)(x))
        return loc, scale 
    
def guide(batch, hidden_units, latent_dim, sample=False):
    batch_dim, out_dim = jnp.shape(batch)
    encoder = Encoder(hidden_units=hidden_units, latent_dim=latent_dim)
    encode = flax_module("encoder", encoder, input_shape=(batch_dim, out_dim))
    with numpyro.plate("batch", size=batch_dim):
        z_loc, z_std = encode(batch)
        z = numpyro.sample("z", dists.Normal(z_loc, z_std).to_event(1))
        return z_loc, z_std

The optax library provides a function called chain that allows for combining gradient transformations

The following shows how to implement a version of ADAM that clips large gradients for improved training stability

import optax

clipped_adam = optax.chain(optax.clip_by_global_norm(1.0),  
                           optax.scale_by_adam(),
                           optax.scale(-1e-3))

In preparation to train the VAE we create an SVI object using the model, guide, optimizer and cost function. In this case we are also passing a keyword argument (latent_dim)

The svi.update is “jittified” for improved computational efficiency

import jax
from numpyro.infer import SVI

hidden_units = [128, 128]
latent_dim = 2

svi = SVI(model, guide, clipped_adam, 
          numpyro.infer.TraceMeanField_ELBO(num_particles=1), 
          hidden_units=hidden_units, latent_dim=latent_dim)

jit_update = jax.jit(svi.update)

The following trains the VAE using minibatches from the MNIST dataset

import holoviews as hv
import jax.random as random
from tqdm.notebook import tqdm
hv.extension('bokeh')

import os
import sys
sys.path.append(os.path.abspath(os.path.join('../utils')))
from mnist import mnist
from flax_dataloader import data_loader

train_images, train_labels, test_images, test_labels = mnist();
batch_size = 128
nepochs = 100
key = random.PRNGKey(12345)
state = svi.init(key, train_images[:batch_size])
            
loss = []
for epoch in tqdm(range(nepochs)):
    key, key_ = jax.random.split(key)
    train_loader = data_loader(key_, train_images, train_labels, batch_size, shuffle=True)
    minibatch_loss = []
    for batch, labels in train_loader:
        state, loss_val = jit_update(state, batch)
        minibatch_loss.append(loss_val.item())
    
    loss.append(jnp.array(minibatch_loss).mean())
        
hv.Curve((range(nepochs), loss), 'Epoch', 'Loss').opts(width=500, height=200)

The parameter dictionary has the trained weight and biases of the encoder and decoder networks:

svi.get_params(state).keys()
dict_keys(['decoder$params', 'encoder$params'])

Let’s inspect the latent and generative space using the trained models

Note

In VAE, latent projections are not deterministic but distributions (normal)

import numpy as np

encoder = Encoder(hidden_units, latent_dim)
encoder_params = svi.get_params(state)["encoder$params"]
z_loc, z_scale = encoder.apply({'params': encoder_params}, test_images)
z_loc = np.array(z_loc)
z_scale = np.array(z_scale)

The following shows a scatter plot of the test-set data projected to latent space

  • The dots are centered on the mean of the projection

  • The errorbars represent the standard deviation of the projection

from bokeh.palettes import Category10

hv.opts.defaults(hv.opts.ErrorBars(lower_head=None, upper_head=None, show_legend=True))

def plot_digit(digit, color):
    mask = test_labels.argmax(axis=1) == digit
    center = hv.Scatter((z_loc[mask, 0], z_loc[mask, 1])).opts(color=color)
    error_x = hv.ErrorBars((z_loc[mask, 0], z_loc[mask, 1], z_scale[mask, 0]), 
                           horizontal=True).opts(color=color)
    error_y = hv.ErrorBars((z_loc[mask, 0], z_loc[mask, 1], z_scale[mask, 1]), 
                           horizontal=False).opts(color=color)
    return hv.Overlay([center, error_x, error_y])

digits = [plot_digit(digit, color) for digit,color in zip(range(10), Category10[10])]
hv.Overlay(digits).opts(width=550, height=400, legend_position='right')

The VAE is a generative model, this means that one sample does not have one reconstruction but a distribution of reconstructions (Bernoulli)

The following uses Predictive to sample from the model through the guide. The plot shows five test-set examples each with three samples from the generative distribution

predictive = numpyro.infer.Predictive(model, 
                                      guide=svi.guide, 
                                      params=svi.get_params(state), 
                                      return_sites=['x'],
                                      num_samples=3)

xhat = predictive(random.PRNGKey(1), 
                  test_images[:5], 
                  hidden_units=hidden_units, 
                  latent_dim=latent_dim, 
                  sample=True)['x']

hv.opts.defaults(hv.opts.Image(cmap='gray', xaxis=None, yaxis=None))
examples = [hv.Image(example.reshape(28, 28)) for example in test_images[:5]]
reconstructed = []
for k in range(3):
    reconstructed += [hv.Image(example.reshape(28, 28)) for example in np.array(xhat[k])]

hv.Layout(examples + reconstructed).opts(hv.opts.Image(width=100, height=100)).cols(5)

The following computes the mean of the Bernoulli distribution as a function of the latent variable

decoder = Decoder(hidden_units=hidden_units, output_dim=28*28)
decoder_params = svi.get_params(state)["decoder$params"]

res = 30
z = jnp.linspace(-3, 3, num=res)
z1, z2 = jnp.meshgrid(z, z)
zz = jnp.stack((z1.ravel(), z2.ravel()))
p = nn.sigmoid(decoder.apply({'params': decoder_params}, zz.T)).reshape(res, res, 28, 28)
hv.Image(np.array(p.swapaxes(2,1).reshape(28*res, 28*res))).opts(width=600, height=600)