Variational Autoencoder
Contents
Variational Autoencoder¶
An LVM is defined by the joint density between observed and latent variables
If we use the PCA recipe (Linear mapping, Gaussian likelihood and Gaussian prior) we obtain an analytical Gaussian posterior and evidence
If we use a more complex (non-linear) mapping, e.g. a neural network, the posterior and evidence may not be tractable
In the latter case, we can use Variational inference (VI), i.e. we propose an approximate posterior and maximize the ELBO
to find the best parameters \(\hat \phi\).
Important
In this lesson we review the Variational AutoEncoder (VAE) which combines amortized inference and VI
The VAE is an LVM where deep neural networks are used to model the conditional distributions between latent \(z\) and observed \(x\) variables
It was proposed simultaneously by (Kingma and Welling, ICLR, Dec. 2013) and (Rezende et al, ICML, Jan. 2014) perhaps sparking the revived interest into Deep Learning with Approximate Bayesian Inference that we see today
The difference with a regular autoencoder is that the latent (code) is now a stochastic variable
a prior distribution is placed on \(z\): \(p(z)\)
a neural network is used to model the parameters of the likelihood: \(p_\theta(x | z)\)
a neural network is used to model the parameters of the approximate posterior: \(q_\phi(z|x)\)
Note
The weight and biases of the networks (\(\theta\) and \(\phi\)) are deterministic, i.e. VAE is not a “fully” bayesian neural network
In what follows we will review the assumptions and the key contributions of this work to the field of Bayesian Neural Networks
Assumptions in VAE¶
VAE assumes a particular prior and approximate posterior:
The latent variable has a standard Gaussian prior (the same as in PCA)
The approximate posterior is a factorized (diagonal) Gaussian
Mathematically this is
and
The likelihood function is chosen depending on the data
For continuous data the likelihood is typically set as a spherical Gaussian or diagonal Gaussian
For binary data the likelihood is typically set as Bernoulli
Mathetamitcally (for the first case) this is
for \(x_i \in \mathbb{R}^D\).
Details on the VAE training¶
The VAE is trained by maximizing the ELBO, which in this case is
By maximizing the ELBO we
Maximize the log likelihood when sampling from the approximate posterior: Faithfull data reconstructions
Minimize the divergence between the approximate posterior and prior: Regularization for the posterior
The ELBO is typically optimized via gradient ascent updates for \(\theta\) and \(\phi\)
Let’s review how to obtain the derivates of the ELBO with respect to \(\theta\) and \(\phi\)
Amortization¶
In the previous formulation the amount of variational parameters, i.e. \(\mu_i\) and \(\sigma_i\) scales linearly with the number of samples in the dataset. This is impractical for large datasets.
Amortization solves this problem by having a function that maps data to parameters. In the particular case of VAE we have
where \(g_\phi(\cdot)\) is the encoder network and
where \(f_\theta(\cdot)\) is the decoder network (for a diagonal Gaussian likelihood)
Note
VAE is an example of amortized variational inference
The derivative with respect to \(\theta\)¶
The only term that depends on \(\theta\) is \(p_\theta(x|z)\), then
where the gradient can be swapped with the expectation operator
The expectation can be estimated via monte-carlo integration as
The derivative wrt to \(\phi\)¶
Let’s write \(h(z) = \log p_\theta(x|z)\), a function of \(z\) that does not depend on \(\phi\). In this case
but we cannot approximate via Monte Carlo integration anymore.
A classical solution to this is the REINFORCE algorithm, which is based on the identity
then
but in practice, training with this estimator is very hard: the REINFORCE estimator is unbiased but it has a very high variance
For the particular case of VAE there is a low variance (and very elegant) alternative
The reparameterization trick¶
In VAE the latent variable is distributed as
this can be rewriten as first sampling from a standard Gaussian
and then applying a transformation
Note
This is usually referred to as “the reparameterization trick”
Using this the expectation of \(h(z)\) is rewritten as
Note
The expectation does not depend on \(\phi\)
Hence the following estimator for its gradient can be used
which has a much lower variance than REINFORCE
More variance reduction: Closed-form terms
We have focused on the left hand term of the ELBO
The right hand term is the KL divergence between two multivariate Gaussian distributions. This has a closed analytical solution
where \(K\) is the dimensionality of the latent variable
The derivatives of this expression are straighforward and the variance is low
Writing a VAE in Numpyro¶
Neural networks written in flax
can be “registered” into numpyro models using flax_module
. This lifts the neural network parameters as numpyro.params
Let’s start by writing the generative model of VAE. The model starts from the latent variable, obtain the logits of the images using the decoder network and then match it with the observed data using a Bernoulli likelihood
import numpyro
numpyro.set_platform("cpu")
numpyro.set_host_device_count(2)
from typing import Sequence, Callable
import jax.numpy as jnp
import flax.linen as nn
import numpyro.distributions as dists
from numpyro.contrib.module import flax_module
class Decoder(nn.Module):
output_dim: int
hidden_units: Sequence[int]
activation: Callable = nn.relu
@nn.compact
def __call__(self, x):
for neurons in self.hidden_units:
x = self.activation(nn.Dense(neurons)(x))
return nn.Dense(self.output_dim)(x)
def model(batch, hidden_units, latent_dim, sample=False):
batch_dim, out_dim = jnp.shape(batch)
decoder = Decoder(hidden_units=hidden_units, output_dim=out_dim)
decode = flax_module("decoder", decoder, input_shape=(batch_dim, latent_dim))
with numpyro.plate("batch", size=batch_dim):
z = numpyro.sample("z", dists.Normal(0, 1).expand([latent_dim]).to_event(1))
logits = decode(z)
if not sample:
x = numpyro.sample("x", dists.BernoulliLogits(logits).to_event(1), obs=batch)
else:
x = numpyro.sample("x", dists.BernoulliLogits(logits).to_event(1))
Note
Plates are used to make the model conditionally independent on the batch dimension (leftmost dimension)
The encoder network amortizes the parameters of the approximate posterior (normal). The Encoder
model has two outputs, the second being activated to be non-negative
The guide receives a minibatch of data and return a sample from the approximate posterior
class Encoder(nn.Module):
hidden_units: Sequence[int]
latent_dim: int
activation: Callable = nn.relu
@nn.compact
def __call__(self, x):
for neurons in self.hidden_units:
x = self.activation(nn.Dense(neurons)(x))
loc = nn.Dense(self.latent_dim)(x)
scale = nn.softplus(nn.Dense(self.latent_dim)(x))
return loc, scale
def guide(batch, hidden_units, latent_dim, sample=False):
batch_dim, out_dim = jnp.shape(batch)
encoder = Encoder(hidden_units=hidden_units, latent_dim=latent_dim)
encode = flax_module("encoder", encoder, input_shape=(batch_dim, out_dim))
with numpyro.plate("batch", size=batch_dim):
z_loc, z_std = encode(batch)
z = numpyro.sample("z", dists.Normal(z_loc, z_std).to_event(1))
return z_loc, z_std
The optax
library provides a function called chain
that allows for combining gradient transformations
The following shows how to implement a version of ADAM that clips large gradients for improved training stability
import optax
clipped_adam = optax.chain(optax.clip_by_global_norm(1.0),
optax.scale_by_adam(),
optax.scale(-1e-3))
In preparation to train the VAE we create an SVI
object using the model, guide, optimizer and cost function. In this case we are also passing a keyword argument (latent_dim
)
The svi.update
is “jittified” for improved computational efficiency
import jax
from numpyro.infer import SVI
hidden_units = [128, 128]
latent_dim = 2
svi = SVI(model, guide, clipped_adam,
numpyro.infer.TraceMeanField_ELBO(num_particles=1),
hidden_units=hidden_units, latent_dim=latent_dim)
jit_update = jax.jit(svi.update)
The following trains the VAE using minibatches from the MNIST dataset
import holoviews as hv
import jax.random as random
from tqdm.notebook import tqdm
hv.extension('bokeh')
import os
import sys
sys.path.append(os.path.abspath(os.path.join('../utils')))
from mnist import mnist
from flax_dataloader import data_loader
train_images, train_labels, test_images, test_labels = mnist();
batch_size = 128
nepochs = 100
key = random.PRNGKey(12345)
state = svi.init(key, train_images[:batch_size])
loss = []
for epoch in tqdm(range(nepochs)):
key, key_ = jax.random.split(key)
train_loader = data_loader(key_, train_images, train_labels, batch_size, shuffle=True)
minibatch_loss = []
for batch, labels in train_loader:
state, loss_val = jit_update(state, batch)
minibatch_loss.append(loss_val.item())
loss.append(jnp.array(minibatch_loss).mean())
hv.Curve((range(nepochs), loss), 'Epoch', 'Loss').opts(width=500, height=200)
The parameter dictionary has the trained weight and biases of the encoder and decoder networks:
svi.get_params(state).keys()
dict_keys(['decoder$params', 'encoder$params'])
Let’s inspect the latent and generative space using the trained models
Note
In VAE, latent projections are not deterministic but distributions (normal)
import numpy as np
encoder = Encoder(hidden_units, latent_dim)
encoder_params = svi.get_params(state)["encoder$params"]
z_loc, z_scale = encoder.apply({'params': encoder_params}, test_images)
z_loc = np.array(z_loc)
z_scale = np.array(z_scale)
The following shows a scatter plot of the test-set data projected to latent space
The dots are centered on the mean of the projection
The errorbars represent the standard deviation of the projection
from bokeh.palettes import Category10
hv.opts.defaults(hv.opts.ErrorBars(lower_head=None, upper_head=None, show_legend=True))
def plot_digit(digit, color):
mask = test_labels.argmax(axis=1) == digit
center = hv.Scatter((z_loc[mask, 0], z_loc[mask, 1])).opts(color=color)
error_x = hv.ErrorBars((z_loc[mask, 0], z_loc[mask, 1], z_scale[mask, 0]),
horizontal=True).opts(color=color)
error_y = hv.ErrorBars((z_loc[mask, 0], z_loc[mask, 1], z_scale[mask, 1]),
horizontal=False).opts(color=color)
return hv.Overlay([center, error_x, error_y])
digits = [plot_digit(digit, color) for digit,color in zip(range(10), Category10[10])]
hv.Overlay(digits).opts(width=550, height=400, legend_position='right')
The VAE is a generative model, this means that one sample does not have one reconstruction but a distribution of reconstructions (Bernoulli)
The following uses Predictive
to sample from the model through the guide. The plot shows five test-set examples each with three samples from the generative distribution
predictive = numpyro.infer.Predictive(model,
guide=svi.guide,
params=svi.get_params(state),
return_sites=['x'],
num_samples=3)
xhat = predictive(random.PRNGKey(1),
test_images[:5],
hidden_units=hidden_units,
latent_dim=latent_dim,
sample=True)['x']
hv.opts.defaults(hv.opts.Image(cmap='gray', xaxis=None, yaxis=None))
examples = [hv.Image(example.reshape(28, 28)) for example in test_images[:5]]
reconstructed = []
for k in range(3):
reconstructed += [hv.Image(example.reshape(28, 28)) for example in np.array(xhat[k])]
hv.Layout(examples + reconstructed).opts(hv.opts.Image(width=100, height=100)).cols(5)
The following computes the mean of the Bernoulli distribution as a function of the latent variable
decoder = Decoder(hidden_units=hidden_units, output_dim=28*28)
decoder_params = svi.get_params(state)["decoder$params"]
res = 30
z = jnp.linspace(-3, 3, num=res)
z1, z2 = jnp.meshgrid(z, z)
zz = jnp.stack((z1.ravel(), z2.ravel()))
p = nn.sigmoid(decoder.apply({'params': decoder_params}, zz.T)).reshape(res, res, 28, 28)
hv.Image(np.array(p.swapaxes(2,1).reshape(28*res, 28*res))).opts(width=600, height=600)