Bayesian Neural Networks with numpyro
Contents
Bayesian Neural Networks with numpyro
¶
Deep Neural Networks are non-linear function approximators and the state of the art in pattern recognition for unstructured data (audio, images, text, video)
But they do have limitations
Very deep models require lots of data to train
Selecting an architecture requires a lot of experimentation
They are poor at representing uncertainty
Note
We can address some of these limitations by going Bayesian
A Bayesian neural network (BNN) places a prior distribution on its parameters. Training the BNN is equivalent to learning the posterior distribution of the parameters given the data.
Important
The uncertainty on the data and the parameters can be propagated to estimate the uncertainty on our predictions
- Aleatoric uncertainty
Uncertainty on the data, generally associated to the noise
- Epistemic uncertainty
Uncertainty associated to the model, either its structure or parameters
Having uncertainty-aware neural networks is speccially useful to
Choose when to use a more simple/complex model (complexity-control)
Make critical decisions, e.g. autonomous cars, cancer diagnosis
A bit of history
A brief timeline of bayesian neural newtworks:
1980’s: Bayes theorem is applied to Neural Networks (John Hopfield and Naftali Tishby)
1990’s: Monte-Carlo and VI for bayesian neural networks was studied extensively by David Mackay and Radford Neal (Also Bishop, Barber, Hinton, Gharamani and many others). Neal shows that Gaussian process are bayesian neural networks with infinite neurons
2011: Alex Graves’ VI for neural networks. Explosion of practical deep bayesian networks
Durk Kingma, Danilo Jimenez Rezende, Shakir Mohamed, José Miguel Hernandez-Lobato
History in video by Zoubin Gharamani at NIPS 2016 and interesting panel discussion on the same conference
A bayesian neural model for regression¶
As an example let’s consider an architecture for univariate regression based on fully connected layers
where \(\theta = (b, w, \hat b, \hat w)\) are the parameters of the neural network
This is implemented as a Python dataclass
with flax
as follows
from typing import Sequence, Callable
from dataclasses import field
import flax.linen as nn
class MultiLayerPerceptron(nn.Module):
hidden_neurons: Sequence[int] = field(default_factory=lambda: [10])
output_neurons: int = 1
activation: Callable = nn.tanh
@nn.compact
def __call__(self, x):
for neurons in self.hidden_neurons:
x = self.activation(nn.Dense(neurons)(x))
return nn.Dense(self.output_neurons)(x)
The first step to make this model bayesian is to choose a prior for the parameters of the network
A very typical, albeit not the best, prior is the diagonal normal:
where \(\Sigma_\theta\) is generally a diagonal covariance (independent prior)
Note
We will discuss more about priors for neural network parameters in the future
The likelihood on the other hand is chosen based on the task, the typical choices are
Gaussian for regression problems
Bernoulli for binary classification problems
Categorical for multi-class classification problems
This bayesian model can be implemented easily with numpyro
. In particular we can place a prior on the parameters of a neural network implemented with flax using the random_flax_module
primitive
This primitive expects
A name (string)
An object that inherits from
nn.Module
A distribution for the prior
A tuple with the size of the expected input
import numpyro
import numpyro.distributions as dists
from numpyro.contrib.module import random_flax_module
numpyro.set_platform("cpu")
numpyro.set_host_device_count(2)
def model(x, y=None, n_hidden=[10], prior_std=1., activation=nn.tanh):
batch_size, data_dim = x.shape
net = MultiLayerPerceptron(n_hidden, activation=activation)
bayesian_net = random_flax_module("net", net,
prior = dists.Normal(0, prior_std),
input_shape=(batch_size, data_dim))
s = numpyro.sample('s', fn=dists.HalfCauchy(scale=0.5))
with numpyro.plate('data', size=batch_size):
f = numpyro.deterministic('f', value=bayesian_net(x)[:, 0])
numpyro.sample('y', fn=dists.Normal(f, s), obs=y)
return f
where additionally, the standard deviation of the likelihood that represents the noise in the data, is modeled as a Half Cauchy distribution
The final step is to obtain the posterior of the model
Note
Even though the likelihood and prior are normal the posterior in this case is not normal because of the nested nonlinearities in the neural network architecture, hence we resort to MCMC or SVI
In the following examples we will use the same synthetic data from the linear regression lecture:
import jax.numpy as jnp
import jax.random as random
import holoviews as hv
hv.extension('bokeh')
key = random.PRNGKey(0)
x = jnp.sort(dists.Uniform(-1, 1).sample(key, sample_shape=(40, 1)))
f = lambda x : x*jnp.sin(5*x)
sub_key, key = random.split(key)
y = f(x) + dists.Normal(0., 0.1).rsample(sub_key, sample_shape=(x.shape))
x_test = jnp.linspace(-1.5, 1.5, num=200)[:, jnp.newaxis]
data_plot = hv.Points((x.ravel(), y.ravel()), label='Train data').opts(width=500, size=5, color='black')
data_plot
“Training” the model using MCMC¶
Samples from the posterior of the neural network can be obtained using MCMC. Then these samples can be used to compute the predictive posterior of the network over new data
In what follows we run two MCMC chains using the NUTS
sampler, a suitable choice given that the neural network is differentiable
from numpyro.infer import MCMC, NUTS
from functools import partial
partial_model = partial(model, n_hidden=[10], prior_std=1., activation=nn.tanh)
sampler = MCMC(sampler=NUTS(partial_model, adapt_step_size=True),
num_chains=2, num_samples=2000, num_warmup=200,
jit_model_args=True)
sampler.run(random.PRNGKey(1234), x, y[:, 0])
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
As a preliminary convergence check the \(\hat r\) and effective number of samples are inspected:
sampler.print_summary(0.9)
mean std median 5.0% 95.0% n_eff r_hat
net/Dense_0.bias[0] -0.01 1.13 -0.02 -1.92 1.80 267.37 1.01
net/Dense_0.bias[1] -0.08 1.13 -0.10 -1.98 1.76 371.68 1.01
net/Dense_0.bias[2] -0.11 1.10 -0.11 -1.98 1.60 415.80 1.00
net/Dense_0.bias[3] 0.01 1.13 -0.00 -1.82 1.92 455.13 1.01
net/Dense_0.bias[4] -0.06 1.06 -0.06 -1.75 1.73 323.52 1.02
net/Dense_0.bias[5] 0.10 1.11 0.13 -1.68 1.98 573.15 1.00
net/Dense_0.bias[6] 0.02 1.08 0.02 -1.80 1.78 623.22 1.00
net/Dense_0.bias[7] 0.13 1.16 0.09 -1.62 2.14 172.98 1.01
net/Dense_0.bias[8] 0.01 1.15 -0.03 -1.79 2.02 270.71 1.02
net/Dense_0.bias[9] 0.19 1.09 0.37 -1.63 1.94 77.59 1.02
net/Dense_0.kernel[0,0] -0.02 1.79 0.01 -3.36 2.86 123.17 1.01
net/Dense_0.kernel[0,1] -0.01 1.78 -0.00 -3.25 2.92 134.30 1.01
net/Dense_0.kernel[0,2] 0.18 1.81 0.19 -2.62 3.61 138.60 1.02
net/Dense_0.kernel[0,3] -0.12 1.71 -0.08 -3.32 2.61 132.57 1.02
net/Dense_0.kernel[0,4] 0.04 1.77 0.02 -2.74 3.53 155.77 1.02
net/Dense_0.kernel[0,5] 0.22 1.69 0.14 -2.29 3.50 193.99 1.01
net/Dense_0.kernel[0,6] -0.01 1.71 0.02 -3.21 2.85 234.77 1.01
net/Dense_0.kernel[0,7] -0.24 1.61 -0.14 -3.16 2.21 147.33 1.01
net/Dense_0.kernel[0,8] -0.23 1.69 -0.17 -3.45 2.35 240.14 1.00
net/Dense_0.kernel[0,9] 0.46 2.03 0.36 -2.43 4.15 80.91 1.01
net/Dense_1.bias[0] -0.85 0.81 -0.87 -2.19 0.41 2826.44 1.00
net/Dense_1.kernel[0,0] 0.06 1.35 0.07 -2.15 2.02 216.45 1.00
net/Dense_1.kernel[1,0] -0.04 1.34 -0.04 -2.15 2.02 192.77 1.02
net/Dense_1.kernel[2,0] -0.03 1.35 -0.02 -2.22 1.97 289.23 1.01
net/Dense_1.kernel[3,0] 0.02 1.35 0.01 -2.08 2.12 255.50 1.02
net/Dense_1.kernel[4,0] -0.05 1.37 -0.08 -2.21 2.07 256.22 1.01
net/Dense_1.kernel[5,0] -0.07 1.32 -0.09 -2.13 2.02 433.07 1.00
net/Dense_1.kernel[6,0] -0.01 1.34 -0.01 -2.07 2.09 384.83 1.00
net/Dense_1.kernel[7,0] 0.21 1.32 0.27 -1.89 2.23 193.57 1.02
net/Dense_1.kernel[8,0] 0.08 1.37 0.08 -2.06 2.18 389.27 1.00
net/Dense_1.kernel[9,0] -0.05 1.42 -0.09 -2.25 2.05 111.25 1.01
s 0.09 0.01 0.09 0.07 0.11 2429.69 1.00
Number of divergences: 3
Let’s inspect some of the traces:
samples = sampler.get_samples(group_by_chain=False)
samples.keys()
def plot_trace_and_distribution(trace: jnp.array, name: str):
trace_plot = hv.Curve((trace), 'Step', name).opts(width=400, height=200)
dist_plot = hv.Distribution((trace), name).opts(width=200, height=200).opts(bandwidth=0.1)
return trace_plot + dist_plot
print(samples['net/Dense_1.bias'].shape)
plot_trace_and_distribution(samples['net/Dense_1.bias'][:, 0], 'Output bias')
(4000, 1)
print(samples['net/Dense_1.kernel'].shape)
plot_trace_and_distribution(samples['net/Dense_1.kernel'][:, 0, 0], 'First output weight')
(4000, 10, 1)
print(samples['net/Dense_0.kernel'].shape)
plot_trace_and_distribution(samples['net/Dense_0.kernel'][:, 0, 0], 'First hidden weight')
(4000, 1, 10)
print(samples['net/Dense_0.bias'].shape)
plot_trace_and_distribution(samples['net/Dense_0.bias'][:, 0], 'First hidden bias')
(4000, 10)
print(samples['s'].shape)
plot_trace_and_distribution(samples['s'], 'Likelihood std')
(4000,)
And the predictive posterior over the test data:
predictive = numpyro.infer.Predictive(partial_model,
posterior_samples=sampler.get_samples(),
return_sites=['f'])
posterior_samples = predictive(random.PRNGKey(1), x_test)
def plot_predictive(posterior_samples):
low, mid, upper = jnp.quantile(posterior_samples['f'], jnp.array([0.025, 0.5, 0.975]), axis=0)
median = hv.Curve((x_test, mid), label='Median').opts(width=500, color='#30a2da')
uncertainty = hv.Area((x_test, low, upper), vdims=['y', 'y2'], label='95% CI').opts(color='#30a2da', alpha=0.3)
return hv.Overlay([median, uncertainty, data_plot]).opts(legend_position='bottom_right')
plot_predictive(posterior_samples)
The effects of the hyperparameters¶
def predictive_mcmc(n_hidden, prior_std, activation):
partial_model = partial(model, n_hidden=n_hidden,
prior_std=prior_std,
activation=activation)
sampler = MCMC(sampler=NUTS(partial_model, adapt_step_size=True),
num_chains=2, num_samples=2000, num_warmup=200,
jit_model_args=True)
sampler.run(random.PRNGKey(1234), x, y[:, 0])
predictive = numpyro.infer.Predictive(partial_model,
posterior_samples=sampler.get_samples(),
return_sites=['f'])
posterior_samples = predictive(random.PRNGKey(1), x_test)
return plot_predictive(posterior_samples)
Standard deviation of the prior too small: Extreme regularization
predictive_mcmc([10], 0.1, nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
Standard deviation of the prior too large: Extreme flexibility
predictive_mcmc([10], 10., nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
Too few hidden units: Not enough flexibility
predictive_mcmc([2], 1., nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
Too many hidden units: More flexibility but no overfitting thanks to the regularization of the prior
predictive_mcmc([100], 1., nn.tanh)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
Different activation function: ReLU
predictive_mcmc([10], 1., nn.relu)
/home/phuijse/.conda/envs/info320/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
Training the model using SVI¶
In SVI we propose an approximate (simple) posterior \(q_\nu(\theta)\) with which we maximize the evidence lower bound (ELBO)
import jax
import optax
from tqdm.notebook import tqdm
from numpyro.infer.autoguide import AutoNormal
clipped_adam = optax.chain(optax.clip_by_global_norm(10.0),
optax.scale_by_adam(),
optax.scale(-1e-2))
def train_svi(model, guide, key, nepochs=10000):
svi = numpyro.infer.SVI(model,
guide,
clipped_adam,
loss=numpyro.infer.TraceMeanField_ELBO(num_particles=10))
state = svi.init(key, x, y[:, 0])
loss_evolution = []
jit_update = jax.jit(svi.update)
for epoch in tqdm(range(nepochs)):
state, loss = jit_update(state, x, y[:, 0])
loss_evolution.append(loss.item())
return svi, state, loss_evolution
partial_model = partial(model, n_hidden=[10], prior_std=10., activation=nn.tanh)
guide = AutoNormal(partial_model, init_scale=1e-3)
svi, state, loss_evolution = train_svi(partial_model, guide, random.PRNGKey(1234))
hv.Curve(loss_evolution, 'Epoch', 'Loss').opts(width=500, logy=True)
Tip
Use a small initial variance for the parameters of the guide (init_scale
argument). This avoids numerical instabilities when training begins
Tip
Another way to decrease variance is to increace num_particles
in the ELBO cost function (number of monte carlo samples). Note that this trades lower variance of the gradients by computational cost.
Then we use \(q_\nu(\theta)\) as our replacement for \(p(\theta|\mathcal{D})\) to calculate the posterior predictive distribution
predictive = numpyro.infer.Predictive(partial_model,
guide=svi.guide,
params=svi.get_params(state),
num_samples=1000,
return_sites=['f'])
posterior_samples = predictive(random.PRNGKey(1), x_test)
plot_predictive(posterior_samples)
A Bayesian neural network for multi-class classification¶
Let’s create a synthetic 2D dataset with 3 classes
import numpy as np
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K,D))
y = np.zeros(N*K, dtype='int')
for j in range(K):
ix = range(N*j,N*(j+1))
r = np.linspace(0.0, 0.5, N) # radius
t = np.linspace(j*4, (j+1)*4, N) + np.random.randn(N)*0.2 # theta
X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
y[ix] = j
hv.Scatter((X[:, 0], X[:, 1], y),
'x0', vdims=['x1', 'y']).opts(color='y', width=400, cmap='tab10')
The following is Bayesian model for multi-class classification using the same neural network architecture as before
A CategoricalLogits
distributions is used for the likelihood
def model_classification(x, y=None, n_hidden=[10], prior_std=1., activation=nn.tanh):
batch_size, data_dim = x.shape
net = MultiLayerPerceptron(n_hidden, output_neurons=3, activation=activation)
bayesian_net = random_flax_module("net", net,
prior=dists.Normal(0, prior_std),
input_shape=(batch_size, data_dim))
with numpyro.plate('data', size=batch_size):
logitp = numpyro.deterministic('p', value=bayesian_net(x))
numpyro.sample('y', fn=dists.CategoricalLogits(logitp), obs=y)
Again we use an automatic diagonal normal guide (no covariance)
partial_model = partial(model_classification, n_hidden=[100], prior_std=10., activation=nn.relu)
guide = AutoNormal(partial_model, init_scale=1e-2)
def train_svi(model, guide, key, nepochs=10000):
svi = numpyro.infer.SVI(model,
guide,
clipped_adam,
loss=numpyro.infer.TraceMeanField_ELBO(num_particles=10))
state = svi.init(key, X, y)
loss_evolution = []
jit_update = jax.jit(svi.update)
for epoch in tqdm(range(nepochs)):
state, loss = jit_update(state, X, y)
loss_evolution.append(loss.item())
return svi, state, loss_evolution
svi, state, loss_evolution = train_svi(partial_model, guide, random.PRNGKey(1234))
hv.Curve(loss_evolution, 'Epoch', 'Loss').opts(width=500)
And use the fitted guide to obtain samples from the predictive distribution
predictive = numpyro.infer.Predictive(partial_model,
guide=svi.guide,
params=svi.get_params(state),
num_samples=100,
return_sites=['y'])
xx1, xx2 = jnp.meshgrid(jnp.arange(-0.5, 0.5, 0.01), jnp.arange(-0.5, 0.5, 0.01))
posterior_samples = predictive(random.PRNGKey(1), jnp.stack((xx1.ravel(), xx2.ravel())).T)
The first three samples from the predictive distribution:
def plot_predictive(prediction, cmap='tab10'):
pred = hv.Image(np.array(prediction.reshape(xx1.shape)).T).opts(width=300, cmap=cmap)
data = hv.Scatter((X[:, 0], X[:, 1])).opts(color='k')
return pred * data
hv.Layout([plot_predictive(posterior_samples['y'][k, :]) for k in range(3)]).cols(3)
We can summarize the categorical predictive distribution. The plot on the left shows the mode of the distribution (most repeated class) and the plot on the right shows the entropy of the distribution
freqs = jnp.stack([np.sum(posterior_samples['y']==k, axis=0) for k in range(3)])/posterior_samples['y'].shape[0]
mode = jnp.argmax(freqs, axis=0)
entropy = -(freqs*jnp.log(freqs+1e-10)).sum(axis=0)
plot_predictive(mode) + plot_predictive(entropy, cmap='Blues')
Note
The higher then entropy the more different the output of the predictive samples (higher uncertainty)
For comparison let’s see the result of an equivalent deterministic neural network
from flax.training import train_state
net = MultiLayerPerceptron([100], output_neurons=3, activation=nn.relu)
nepochs = 10000
dummy_batch = jnp.ones(shape=X.shape)
state = train_state.TrainState.create(apply_fn=net.apply,
params=net.init(random.PRNGKey(12345), dummy_batch)['params'],
tx=clipped_adam)
@jax.jit
def train_step(state, batch, label):
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch)
one_hot = jax.nn.one_hot(label, num_classes=logits.shape[1])
return optax.softmax_cross_entropy(logits, one_hot).mean()
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
return loss, state.apply_gradients(grads=grads)
loss = []
for epoch in tqdm(range(nepochs)):
loss_value, state = train_step(state, X, y)
loss.append(loss_value.item())
hv.Curve((range(nepochs), loss), 'Epoch', 'Loss').opts(width=500, height=200)
If we treat the output of the network as probabilities we can also compute its entropy
Is it the same as before?
freqs = nn.softmax(net.apply({'params': state.params}, jnp.stack((xx1.ravel(), xx2.ravel())).T))
mode = jnp.argmax(freqs, axis=1)
entropy = -(freqs*jnp.log(freqs+1e-10)).sum(axis=1)
plot_predictive(mode) + plot_predictive(entropy, cmap='Blues')
This is related to phenomenon of uncertainty miscalibration in neural networks, i.e. the uncertainty of the predictions tends to be very low even when far from the data
“after (almost) all training samples are correctly classified, crossentropy (neg log likelihood) can be further minimized by increasing the confidence of the predictions”, i.e. reducing the entropy of softmax output
The uncertainty obtained from model averaging (bayesian) and the one derived from the softmax output are not equivalent and should not be confused
Further reading and references on this topic: