Bayesian Neural Networks with numpyro

Deep Neural Networks are non-linear function approximators and the state of the art in pattern recognition for unstructured data (audio, images, text, video)

But they do have limitations

  • Very deep models require lots of data to train

  • Selecting an architecture requires a lot of experimentation

  • They can be easily fooled

  • They are poor at representing uncertainty

Note

We can address some of these limitations by going Bayesian

A Bayesian neural network (BNN) places a prior distribution on its parameters. Training the BNN is equivalent to learning the posterior distribution of the parameters given the data.

Important

The uncertainty on the data and the parameters can be propagated to estimate the uncertainty on our predictions

Aleatoric uncertainty

Uncertainty on the data, generally associated to the noise

Epistemic uncertainty

Uncertainty associated to the model, either its structure or parameters

Having uncertainty-aware neural networks is speccially useful to

  • Choose when to use a more simple/complex model (complexity-control)

  • Make critical decisions, e.g. autonomous cars, cancer diagnosis

A bit of history

A brief timeline of bayesian neural newtworks:

History in video by Zoubin Gharamani at NIPS 2016 and interesting panel discussion on the same conference

A bayesian neural model for regression

As an example let’s consider an architecture for univariate regression based on fully connected layers

\[\begin{split} \begin{split} f_\theta(x) &= \hat b + \sum_{j=1}^{10} \hat w_{j} h_j \\ &= \hat b + \sum_{j=1}^{10} \hat w_{j} \text{tanh} \left( b_j + w_{j} x \right) \end{split} \end{split}\]

where \(\theta = (b, w, \hat b, \hat w)\) are the parameters of the neural network

This is implemented as a Python dataclass with flax as follows

from typing import Sequence, Callable
from dataclasses import field
import flax.linen as nn

class MultiLayerPerceptron(nn.Module):
    hidden_neurons: Sequence[int] = field(default_factory=lambda: [10])
    output_neurons: int = 1
    activation: Callable = nn.tanh
    
    @nn.compact
    def __call__(self, x):
        for neurons in self.hidden_neurons:
            x = self.activation(nn.Dense(neurons)(x)) 
        return nn.Dense(self.output_neurons)(x)

The first step to make this model bayesian is to choose a prior for the parameters of the network

A very typical, albeit not the best, prior is the diagonal normal:

\[ p(\theta) = \mathcal{N}(\theta | 0, \Sigma_\theta) \]

where \(\Sigma_\theta\) is generally a diagonal covariance (independent prior)

Note

We will discuss more about priors for neural network parameters in the future

The likelihood on the other hand is chosen based on the task, the typical choices are

  • Gaussian for regression problems

  • Bernoulli for binary classification problems

  • Categorical for multi-class classification problems

This bayesian model can be implemented easily with numpyro. In particular we can place a prior on the parameters of a neural network implemented with flax using the random_flax_module primitive

This primitive expects

  • A name (string)

  • An object that inherits from nn.Module

  • A distribution for the prior

  • A tuple with the size of the expected input

import numpyro
import numpyro.distributions as dists
from numpyro.contrib.module import random_flax_module

numpyro.set_platform("cpu") 
numpyro.set_host_device_count(2) 

def model(x, y=None, n_hidden=[10], prior_std=1., activation=nn.tanh):
    batch_size, data_dim = x.shape
    net = MultiLayerPerceptron(n_hidden, activation=activation)
    bayesian_net = random_flax_module("net", net,
                                      prior = dists.Normal(0, prior_std),
                                      input_shape=(batch_size, data_dim))
    
    s = numpyro.sample('s', fn=dists.HalfCauchy(scale=0.5))
    with numpyro.plate('data', size=batch_size):
        f = numpyro.deterministic('f', value=bayesian_net(x)[:, 0])
        numpyro.sample('y', fn=dists.Normal(f, s), obs=y)
    return f

where additionally, the standard deviation of the likelihood that represents the noise in the data, is modeled as a Half Cauchy distribution

The final step is to obtain the posterior of the model

\[ p(\theta | \mathcal{D}) = \frac{p(\mathcal{D}|\theta) p(\theta)}{p(\mathcal{D})} = \frac{1}{{p(\mathcal{D})}} \prod_n \mathcal{N}(y^{(n)} | f(x^{(n)}), \sigma^2) \mathcal{N}(\theta | 0, \Sigma_\theta) \]

Note

Even though the likelihood and prior are normal the posterior in this case is not normal because of the nested nonlinearities in the neural network architecture, hence we resort to MCMC or SVI

In the following examples we will use the same synthetic data from the linear regression lecture:

import jax.numpy as jnp
import jax.random as random
import holoviews as hv
hv.extension('bokeh')

key = random.PRNGKey(0)
x = jnp.sort(dists.Uniform(-1, 1).sample(key, sample_shape=(40, 1)))
f = lambda x : x*jnp.sin(5*x)
sub_key, key = random.split(key)
y = f(x) + dists.Normal(0., 0.1).rsample(sub_key, sample_shape=(x.shape))
x_test = jnp.linspace(-1.5, 1.5, num=200)[:, jnp.newaxis]

data_plot = hv.Points((x.ravel(), y.ravel()), label='Train data').opts(width=500, size=5, color='black')
data_plot

“Training” the model using MCMC

Samples from the posterior of the neural network can be obtained using MCMC. Then these samples can be used to compute the predictive posterior of the network over new data

\[ p(\mathbf{y}|\mathbf{x}, \mathcal{D}) = \int p(\mathbf{y}|\mathbf{x}, \theta) p(\theta| \mathcal{D}) \,d\theta \]

In what follows we run two MCMC chains using the NUTS sampler, a suitable choice given that the neural network is differentiable

from numpyro.infer import MCMC, NUTS
from functools import partial

partial_model = partial(model, n_hidden=[10], prior_std=1., activation=nn.tanh)

sampler = MCMC(sampler=NUTS(partial_model, adapt_step_size=True), 
               num_chains=2, num_samples=2000, num_warmup=200, 
               jit_model_args=True)

sampler.run(random.PRNGKey(1234), x, y[:, 0])
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '

As a preliminary convergence check the \(\hat r\) and effective number of samples are inspected:

sampler.print_summary(0.9)
                             mean       std    median      5.0%     95.0%     n_eff     r_hat
    net/Dense_0.bias[0]     -0.01      1.13     -0.02     -1.92      1.80    267.37      1.01
    net/Dense_0.bias[1]     -0.08      1.13     -0.10     -1.98      1.76    371.68      1.01
    net/Dense_0.bias[2]     -0.11      1.10     -0.11     -1.98      1.60    415.80      1.00
    net/Dense_0.bias[3]      0.01      1.13     -0.00     -1.82      1.92    455.13      1.01
    net/Dense_0.bias[4]     -0.06      1.06     -0.06     -1.75      1.73    323.52      1.02
    net/Dense_0.bias[5]      0.10      1.11      0.13     -1.68      1.98    573.15      1.00
    net/Dense_0.bias[6]      0.02      1.08      0.02     -1.80      1.78    623.22      1.00
    net/Dense_0.bias[7]      0.13      1.16      0.09     -1.62      2.14    172.98      1.01
    net/Dense_0.bias[8]      0.01      1.15     -0.03     -1.79      2.02    270.71      1.02
    net/Dense_0.bias[9]      0.19      1.09      0.37     -1.63      1.94     77.59      1.02
net/Dense_0.kernel[0,0]     -0.02      1.79      0.01     -3.36      2.86    123.17      1.01
net/Dense_0.kernel[0,1]     -0.01      1.78     -0.00     -3.25      2.92    134.30      1.01
net/Dense_0.kernel[0,2]      0.18      1.81      0.19     -2.62      3.61    138.60      1.02
net/Dense_0.kernel[0,3]     -0.12      1.71     -0.08     -3.32      2.61    132.57      1.02
net/Dense_0.kernel[0,4]      0.04      1.77      0.02     -2.74      3.53    155.77      1.02
net/Dense_0.kernel[0,5]      0.22      1.69      0.14     -2.29      3.50    193.99      1.01
net/Dense_0.kernel[0,6]     -0.01      1.71      0.02     -3.21      2.85    234.77      1.01
net/Dense_0.kernel[0,7]     -0.24      1.61     -0.14     -3.16      2.21    147.33      1.01
net/Dense_0.kernel[0,8]     -0.23      1.69     -0.17     -3.45      2.35    240.14      1.00
net/Dense_0.kernel[0,9]      0.46      2.03      0.36     -2.43      4.15     80.91      1.01
    net/Dense_1.bias[0]     -0.85      0.81     -0.87     -2.19      0.41   2826.44      1.00
net/Dense_1.kernel[0,0]      0.06      1.35      0.07     -2.15      2.02    216.45      1.00
net/Dense_1.kernel[1,0]     -0.04      1.34     -0.04     -2.15      2.02    192.77      1.02
net/Dense_1.kernel[2,0]     -0.03      1.35     -0.02     -2.22      1.97    289.23      1.01
net/Dense_1.kernel[3,0]      0.02      1.35      0.01     -2.08      2.12    255.50      1.02
net/Dense_1.kernel[4,0]     -0.05      1.37     -0.08     -2.21      2.07    256.22      1.01
net/Dense_1.kernel[5,0]     -0.07      1.32     -0.09     -2.13      2.02    433.07      1.00
net/Dense_1.kernel[6,0]     -0.01      1.34     -0.01     -2.07      2.09    384.83      1.00
net/Dense_1.kernel[7,0]      0.21      1.32      0.27     -1.89      2.23    193.57      1.02
net/Dense_1.kernel[8,0]      0.08      1.37      0.08     -2.06      2.18    389.27      1.00
net/Dense_1.kernel[9,0]     -0.05      1.42     -0.09     -2.25      2.05    111.25      1.01
                      s      0.09      0.01      0.09      0.07      0.11   2429.69      1.00

Number of divergences: 3

Let’s inspect some of the traces:

samples = sampler.get_samples(group_by_chain=False)
samples.keys()

def plot_trace_and_distribution(trace: jnp.array, name: str):
    trace_plot = hv.Curve((trace), 'Step', name).opts(width=400, height=200)
    dist_plot = hv.Distribution((trace), name).opts(width=200, height=200).opts(bandwidth=0.1)
    return trace_plot + dist_plot

print(samples['net/Dense_1.bias'].shape)
plot_trace_and_distribution(samples['net/Dense_1.bias'][:, 0], 'Output bias')
(4000, 1)
print(samples['net/Dense_1.kernel'].shape)
plot_trace_and_distribution(samples['net/Dense_1.kernel'][:, 0, 0], 'First output weight')
(4000, 10, 1)
print(samples['net/Dense_0.kernel'].shape)
plot_trace_and_distribution(samples['net/Dense_0.kernel'][:, 0, 0], 'First hidden weight')
(4000, 1, 10)
print(samples['net/Dense_0.bias'].shape)
plot_trace_and_distribution(samples['net/Dense_0.bias'][:, 0], 'First hidden bias')
(4000, 10)
print(samples['s'].shape)
plot_trace_and_distribution(samples['s'], 'Likelihood std')
(4000,)

And the predictive posterior over the test data:

predictive = numpyro.infer.Predictive(partial_model, 
                                      posterior_samples=sampler.get_samples(), 
                                      return_sites=['f'])
posterior_samples = predictive(random.PRNGKey(1), x_test)

def plot_predictive(posterior_samples):
    low, mid, upper = jnp.quantile(posterior_samples['f'], jnp.array([0.025, 0.5, 0.975]), axis=0)
    median = hv.Curve((x_test, mid), label='Median').opts(width=500, color='#30a2da')
    uncertainty = hv.Area((x_test, low, upper), vdims=['y', 'y2'], label='95% CI').opts(color='#30a2da', alpha=0.3)
    return hv.Overlay([median, uncertainty, data_plot]).opts(legend_position='bottom_right')

plot_predictive(posterior_samples)

The effects of the hyperparameters

def predictive_mcmc(n_hidden, prior_std, activation):
    partial_model = partial(model, n_hidden=n_hidden, 
                            prior_std=prior_std, 
                            activation=activation)

    sampler = MCMC(sampler=NUTS(partial_model, adapt_step_size=True), 
                   num_chains=2, num_samples=2000, num_warmup=200, 
                   jit_model_args=True)

    sampler.run(random.PRNGKey(1234), x, y[:, 0])

    predictive = numpyro.infer.Predictive(partial_model, 
                                          posterior_samples=sampler.get_samples(), 
                                          return_sites=['f'])
    posterior_samples = predictive(random.PRNGKey(1), x_test)
    return plot_predictive(posterior_samples)

Standard deviation of the prior too small: Extreme regularization

predictive_mcmc([10], 0.1, nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '

Standard deviation of the prior too large: Extreme flexibility

predictive_mcmc([10], 10., nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '

Too few hidden units: Not enough flexibility

predictive_mcmc([2], 1., nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '

Too many hidden units: More flexibility but no overfitting thanks to the regularization of the prior

predictive_mcmc([100], 1., nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '

Different activation function: ReLU

predictive_mcmc([10], 1., nn.relu)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '

Training the model using SVI

In SVI we propose an approximate (simple) posterior \(q_\nu(\theta)\) with which we maximize the evidence lower bound (ELBO)

\[ \mathcal{L}(\nu) = \mathbb{E}_{q_\nu(\theta)}[ \log p(\mathcal{D}|\theta)] - \text{KL}[q_\nu(\theta)|p(\theta)] \]
import jax
import optax
from tqdm.notebook import tqdm
from numpyro.infer.autoguide import AutoNormal

clipped_adam = optax.chain(optax.clip_by_global_norm(10.0),  
                           optax.scale_by_adam(),
                           optax.scale(-1e-2))

def train_svi(model, guide, key, nepochs=10000):
    svi = numpyro.infer.SVI(model, 
                            guide, 
                            clipped_adam, 
                            loss=numpyro.infer.TraceMeanField_ELBO(num_particles=10))
    state = svi.init(key, x, y[:, 0])
    
    loss_evolution = []
    jit_update = jax.jit(svi.update)
    for epoch in tqdm(range(nepochs)):
        state, loss = jit_update(state, x, y[:, 0])
        loss_evolution.append(loss.item())
    
    return svi, state, loss_evolution
 
partial_model = partial(model, n_hidden=[10], prior_std=10., activation=nn.tanh)
guide = AutoNormal(partial_model, init_scale=1e-3)

svi, state, loss_evolution = train_svi(partial_model, guide, random.PRNGKey(1234))  
hv.Curve(loss_evolution, 'Epoch', 'Loss').opts(width=500, logy=True)

Tip

Use a small initial variance for the parameters of the guide (init_scale argument). This avoids numerical instabilities when training begins

Tip

Another way to decrease variance is to increace num_particles in the ELBO cost function (number of monte carlo samples). Note that this trades lower variance of the gradients by computational cost.

Then we use \(q_\nu(\theta)\) as our replacement for \(p(\theta|\mathcal{D})\) to calculate the posterior predictive distribution

predictive = numpyro.infer.Predictive(partial_model, 
                                      guide=svi.guide, 
                                      params=svi.get_params(state), 
                                      num_samples=1000,
                                      return_sites=['f'])

posterior_samples = predictive(random.PRNGKey(1), x_test)
plot_predictive(posterior_samples)

A Bayesian neural network for multi-class classification

Let’s create a synthetic 2D dataset with 3 classes

import numpy as np

N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K,D)) 
y = np.zeros(N*K, dtype='int')

for j in range(K):
    ix = range(N*j,N*(j+1))
    r = np.linspace(0.0, 0.5, N) # radius
    t = np.linspace(j*4, (j+1)*4, N) + np.random.randn(N)*0.2 # theta
    X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
    y[ix] = j
    
hv.Scatter((X[:, 0], X[:, 1], y), 
           'x0', vdims=['x1', 'y']).opts(color='y', width=400, cmap='tab10')

The following is Bayesian model for multi-class classification using the same neural network architecture as before

A CategoricalLogits distributions is used for the likelihood

def model_classification(x, y=None, n_hidden=[10], prior_std=1., activation=nn.tanh):
    batch_size, data_dim = x.shape
    net = MultiLayerPerceptron(n_hidden, output_neurons=3, activation=activation)
    bayesian_net = random_flax_module("net", net, 
                                      prior=dists.Normal(0, prior_std),
                                      input_shape=(batch_size, data_dim))
    
    with numpyro.plate('data', size=batch_size):
        logitp = numpyro.deterministic('p', value=bayesian_net(x))
        numpyro.sample('y', fn=dists.CategoricalLogits(logitp), obs=y)

Again we use an automatic diagonal normal guide (no covariance)

partial_model = partial(model_classification, n_hidden=[100], prior_std=10., activation=nn.relu)
guide = AutoNormal(partial_model, init_scale=1e-2)

def train_svi(model, guide, key, nepochs=10000):
    svi = numpyro.infer.SVI(model, 
                            guide, 
                            clipped_adam, 
                            loss=numpyro.infer.TraceMeanField_ELBO(num_particles=10))
    state = svi.init(key, X, y)
    
    loss_evolution = []
    jit_update = jax.jit(svi.update)
    for epoch in tqdm(range(nepochs)):
        state, loss = jit_update(state, X, y)
        loss_evolution.append(loss.item())
    
    return svi, state, loss_evolution

svi, state, loss_evolution = train_svi(partial_model, guide, random.PRNGKey(1234))  
hv.Curve(loss_evolution, 'Epoch', 'Loss').opts(width=500)

And use the fitted guide to obtain samples from the predictive distribution

predictive = numpyro.infer.Predictive(partial_model, 
                                      guide=svi.guide, 
                                      params=svi.get_params(state), 
                                      num_samples=100,
                                      return_sites=['y'])

xx1, xx2 = jnp.meshgrid(jnp.arange(-0.5, 0.5, 0.01), jnp.arange(-0.5, 0.5, 0.01))
posterior_samples = predictive(random.PRNGKey(1), jnp.stack((xx1.ravel(), xx2.ravel())).T)

The first three samples from the predictive distribution:

def plot_predictive(prediction, cmap='tab10'):
    pred = hv.Image(np.array(prediction.reshape(xx1.shape)).T).opts(width=300, cmap=cmap) 
    data = hv.Scatter((X[:, 0], X[:, 1])).opts(color='k')
    return pred * data
                      
hv.Layout([plot_predictive(posterior_samples['y'][k, :]) for k in range(3)]).cols(3)

We can summarize the categorical predictive distribution. The plot on the left shows the mode of the distribution (most repeated class) and the plot on the right shows the entropy of the distribution

freqs = jnp.stack([np.sum(posterior_samples['y']==k, axis=0) for k in range(3)])/posterior_samples['y'].shape[0]
mode = jnp.argmax(freqs, axis=0)
entropy = -(freqs*jnp.log(freqs+1e-10)).sum(axis=0)
plot_predictive(mode) + plot_predictive(entropy, cmap='Blues')

Note

The higher then entropy the more different the output of the predictive samples (higher uncertainty)

For comparison let’s see the result of an equivalent deterministic neural network

from flax.training import train_state

net = MultiLayerPerceptron([100], output_neurons=3, activation=nn.relu)
nepochs = 10000
dummy_batch = jnp.ones(shape=X.shape)
state = train_state.TrainState.create(apply_fn=net.apply, 
                                      params=net.init(random.PRNGKey(12345), dummy_batch)['params'],
                                      tx=clipped_adam)
@jax.jit
def train_step(state, batch, label):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch)
        one_hot = jax.nn.one_hot(label, num_classes=logits.shape[1])
        return optax.softmax_cross_entropy(logits, one_hot).mean()
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return loss, state.apply_gradients(grads=grads)

loss = []
for epoch in tqdm(range(nepochs)):
    loss_value, state = train_step(state, X, y)
    loss.append(loss_value.item())
    
hv.Curve((range(nepochs), loss), 'Epoch', 'Loss').opts(width=500, height=200)

If we treat the output of the network as probabilities we can also compute its entropy

Is it the same as before?

freqs = nn.softmax(net.apply({'params': state.params}, jnp.stack((xx1.ravel(), xx2.ravel())).T))
mode = jnp.argmax(freqs, axis=1)
entropy = -(freqs*jnp.log(freqs+1e-10)).sum(axis=1)
plot_predictive(mode) + plot_predictive(entropy, cmap='Blues')

This is related to phenomenon of uncertainty miscalibration in neural networks, i.e. the uncertainty of the predictions tends to be very low even when far from the data

“after (almost) all training samples are correctly classified, crossentropy (neg log likelihood) can be further minimized by increasing the confidence of the predictions”, i.e. reducing the entropy of softmax output

The uncertainty obtained from model averaging (bayesian) and the one derived from the softmax output are not equivalent and should not be confused

Further reading and references on this topic: