My first Neural Network with flax
Contents
import holoviews as hv
hv.extension('bokeh')
from tqdm.notebook import tqdm
import jax
import jax.numpy as jnp
import jax.random as random
import numpyro.distributions as dists
My first Neural Network with flax
¶
Flax is an neural networks library with focus on flexibility. Flax uses Jax as computing backend and has interoperability with NumPyro
Flax provides flax.linen
, an API with primitives to design neural networks. To train the models you would also need an optimizer. State of the art optimizers that can be used with flax models are available in a companion library called optax
import flax.linen as nn
import optax
In what follows we will refer to flax.linen
as nn
We will learn how to use these libraries using a polynomial regressor and a multi-layer perceptron
Polynomial (linear) regression¶
The polynomial regressor is a special case of the linear regressor in which the input is expanded using a polynomial basis
The model is defined as
where
To test the model we will use an irregularly-sampled synthetic signal with gaussian noise
key = random.PRNGKey(0)
x = jnp.sort(dists.Uniform(-1, 1).sample(key, sample_shape=(40, 1)))
f = lambda x : x*jnp.sin(5*x)
sub_key, key = random.split(key)
y = f(x) + dists.Normal(0., 0.1).rsample(sub_key, sample_shape=x.shape)
x_test = jnp.linspace(-1.5, 1.5, num=200)[:, jnp.newaxis]
data_plot = hv.Points((x.ravel(), y.ravel()), label='Train data').opts(width=500, size=5, color='black')
data_plot
A model written in flax is a Python class that inherits from nn.Module
The class has to implement a
__call__
function that receives the input data and returns the model predictionThe class may also implement a
__setup__
function to declare the variables and submodules that compose the model. This is the “explicit” way of defining a model in Flax.There is also the “inline” way of defining models, where instead of
__setup__
thenn.compact
decorator is applied to__call__
. In this case we can pass variables as with Python 3.7 dataclasses
See also
See should-i-use-setup-or-nn-compact in the flax documentation for more details
The following is the polynomial regressor written following the “inline” pattern
class PolynomialRegressor(nn.Module):
degree: int # Dataclass style of declaring arguments
@nn.compact
def __call__(self, x):
phi = jnp.concatenate([x**(k+1) for k in range(self.degree)], axis=1)
return nn.Dense(1, use_bias=True)(phi)
The Dense
submodule expects an integer positional parameter \(m\) and implements a Linear transformation \(wx+b\), with parameters \(w\in\mathbb{R}^{n\times m}\) and \(b\in\mathbb{R}^m\), where \(n\) is the dimension of the input
model = PolynomialRegressor(degree=6)
The object created can be initialized using the init
method. This requires a PRNG key and the shape of the input minibatch
To evaluate the model given parameter values and input data the apply
method is used. Here we use apply
to plot the predictions for several randomly initialized models
prior_plots = []
for seed in range(20):
params = model.init(random.PRNGKey(seed), jnp.zeros(shape=(1, 1)))
y_test = model.apply(params, x_test)
prior_plots.append(hv.Curve((x_test, y_test)).opts(alpha=0.1, width=500, color='#30a2da'))
hv.Overlay(prior_plots)
To train the model we need to define a cost function (loss) and an optimizer
In this case we will the use the Mean Square Error as cost and the Adam optimizer. For convenience we write a helper function that recieves parameters, input and target and returns the loss
optimizer = optax.adam(learning_rate=1e-2)
compute_loss = lambda params, x, y : jnp.mean(optax.l2_loss(model.apply(params, x), y))
To train we will use the following optimizer methods:
init
: Recieves the initialized parameters and return an optimizer stateupdate
: Recieves the gradient of the cost function, the optimizer state and the parameters. Returns the delta parameter and updated state
Note
The optimizer state is in charge of mantaining statistics of the optimizer, e.g. moving averages for algorithms like Adagrad, Adam, etc
We will also use the apply_updates
helper function from optax
to update the parameters given the output of the update
optimizer method
Finally we plot the loss as a function of the training epoch and the final test prediction
model = PolynomialRegressor(degree=6)
params = model.init(random.PRNGKey(1234), jnp.zeros(shape=(1,1)))
state = optimizer.init(params)
loss = []
for epoch in tqdm(range(100)):
loss_val, grads = jax.value_and_grad(compute_loss, argnums=0)(params, x, y)
updates, state = optimizer.update(grads, state, params)
params = optax.apply_updates(params, updates)
loss.append(loss_val)
y_test = model.apply(params, x_test)
loss_plot = hv.Curve((range(len(loss)), loss), 'Epoch', 'MSE')
test_plot = hv.Curve((x_test, y_test), label='Test prediction').opts(ylim=(-1, 1))
loss_plot + hv.Overlay([data_plot, test_plot])
In this particular case we chose polynomials as basis to represent the data, but in the general case
How can we choose a basis in a more general case?
One posibility:
Don’t choose, let the model find the basis
Proposed activity for the polynomial regressor
Change the degree of the basis and describe the results
Increase the noise and repeat the previous step
Concepts to discuss: Complexity, generalization, overfitting, regularization
Multilayer Perceptrons¶
Artificial neural networks (ANN) are non-linear parametric function approximators built by connecting simple units
These units are simplified models of biological neurons. The artificial neuron is a linear regressor followed by a non-linear activation function
where for example
which is known as the sigmoid activation.
Feed-forward ANN are organized in layers and each layer has a certain amount of units. The most classical ANN architecture is the MultiLayer Perceptron (MLP)
In the MLP architecture every unit is connected to all units of its previous and next layers. For example a network with four inputs, two outputs and one hidden layer with eight units:
Mathematically, the output of the hidden layer is
and the output layer
where \(g_1\), \(g_2\) may be different (user choice)
Important
The parameter vector \(\theta\) of the network includes the weights \(w\) and biases \(b\) of the linear regressors within the architecture
Flax implementation
Some useful submodules of nn
by category are
The following code implements an MLP regressor with an arbitrary number of Dense
hidden layers
from typing import Sequence, Callable
class MultiLayerPerceptron(nn.Module):
hidden_neurons: Sequence[int]
kernel_init: Callable = nn.initializers.lecun_normal()
activation: Callable = nn.sigmoid
@nn.compact
def __call__(self, x):
for k, neurons in enumerate(self.hidden_neurons):
x = nn.Dense(neurons, kernel_init=self.kernel_init)(x)
if k != len(self.hidden_neurons) - 1:
x = self.activation(x)
return x
Predictive prior distribution of the MLP
Let’s consider a Normal prior with zero mean and \(\sigma=5\) for \(\theta\) and study the space of possible models
A model with one hidden layer
model = MultiLayerPerceptron([100, 1], nn.initializers.normal(stddev=5.))
prior_x = jnp.linspace(-1, 1, num=500)[:, jnp.newaxis]
prior_plots = []
for seed in range(100):
params = model.init(random.PRNGKey(seed), jnp.zeros(shape=(1,1)))
prior_plots.append(hv.Curve((prior_x, model.apply(params, prior_x))).opts(width=500, alpha=0.1, color='#30a2da'))
hv.Overlay(prior_plots)
And a model with two hidden layers
model = MultiLayerPerceptron([100, 100, 1], nn.initializers.normal(stddev=5.))
prior_plots = []
for seed in range(100):
params = model.init(random.PRNGKey(seed), jnp.zeros(shape=(1,1)))
prior_plots.append(hv.Curve((prior_x, model.apply(params, prior_x))).opts(width=500, alpha=0.1, color='#30a2da'))
hv.Overlay(prior_plots)
Note
By increasing the number of layers and units (capacity) the model becomes more flexible
Proposed activity with the predictive prior of the MLP
What happens if you remove the nonlinearity?
What happens if you change the nonlinearity?
What happens when you change the number of units?
What happens when you add even more layers?
Training the Neural Network
Let’s return to our original regression problem using a MLP model
How many hyperbolic tangent basis do we need to fit this data?
model = MultiLayerPerceptron([10, 10, 1], activation=nn.tanh)
optimizer = optax.adam(learning_rate=1e-2)
compute_loss = lambda params, x, y : jnp.mean(optax.l2_loss(model.apply(params, x), y))
params = model.init(random.PRNGKey(0), jnp.zeros(shape=(1,1)))
state = optimizer.init(params)
loss = []
for epoch in tqdm(range(200)):
loss_val, grads = jax.value_and_grad(compute_loss)(params, x, y)
updates, state = optimizer.update(grads, state, params)
params = optax.apply_updates(params, updates)
loss.append(loss_val)
y_test = model.apply(params, x_test)
loss_plot = hv.Curve((range(len(loss)), loss), 'Epoch', 'MSE')
test_plot = hv.Curve((x_test, y_test), label='Test prediction')
loss_plot + hv.Overlay([data_plot, test_plot])
Proposed activity
Explore the solutions using a different number of units
Explore the solutions using a different number of layers
In both cases tune the number of training epochs appropriately
Probabilistic interpretation of a MLP
Let’s consider an MLP architecture for multivariate regression
\((x^{(n)}, y^{(n)})\) for \(n=1,2,\ldots,N\) are the training tuples with \(x^{(n)}\in \mathbb{R}^D\) and \(y^{(n)} \in \mathbb{R}^O\)
\(f_i(\cdot)\) is the i-th output of the model, for \(i=1, 2, \ldots, O\)
No activation function in the output layer
\(\theta_k\) for \(k=1,2,\ldots, K\) is the collection of trainable parameters (weights and biases)
We fit the parameters by minimizing the Mean Square Error cost function
This is equivalent to the MLE solution with Normal likelihood assuming known variance
We typically include an L2 regularizer to penalize complexity and improve generalization, in this case the cost function is
This is equivalent to the MAP solution with Normal likelihood and zero-mean Normal prior
Important
Conventional neural network training yields MLE/MAP point estimates
Note
There is no closed-form solution for this cost functions because of the non-linearities in the model. ANNs are generally trained with iterative methods such as gradient descent
In the case of ANN for classification we arrive to similar conclusions except that
sigmoid or softmax activation is used in the output layer
cross-entropy cost function is used instead of MSE: Bernoulli/Categorical likelihood
Deep Learning¶
More complex and flexible models are obtained by increasing the number of hidden layers (depth) and the number of neurons (width)
Warning
But the more parameters, the more difficult to train: Overfitting, vanishing gradients, …
Nevertheless, very deep neural network models are the current state of the art in pattern recognition problems. We can train deep neural networks effectively thanks to
Having lots of data available
The implementation of clever architectures and activations: Convolutional Neural Networks for images, Long-short term memories for time series, Residual connections, ReLU activations, …
The applications of regularization schemes: penalty on parameters, data augmentation, dropout, …
The use of faster hardware: GPUs and tensor-cores
Why are deep models needed?
The MLP with one hidden layer (shallow network) is a universal approximator
Note
(In theory) we could obtain a shallow network that is as flexible as a deep network
Warning
(But) it may require an extremely large number of hidden-layer neurons (even infinite)
In practice you need flexible but also compact models