import holoviews as hv
from tqdm.notebook import tqdm
import jax
import jax.numpy as jnp
import jax.random as random
import numpyro.distributions as dists

My first Neural Network with flax

Flax is an neural networks library with focus on flexibility. Flax uses Jax as computing backend and has interoperability with NumPyro

Flax provides flax.linen, an API with primitives to design neural networks. To train the models you would also need an optimizer. State of the art optimizers that can be used with flax models are available in a companion library called optax

import flax.linen as nn
import optax

In what follows we will refer to flax.linen as nn

We will learn how to use these libraries using a polynomial regressor and a multi-layer perceptron

Polynomial (linear) regression

The polynomial regressor is a special case of the linear regressor in which the input is expanded using a polynomial basis

The model is defined as

\[ f_\theta(x) = \langle w , \phi(x) \rangle + b \]


\[ \phi(x) = [x, x^2, x^3, \ldots] \]

To test the model we will use an irregularly-sampled synthetic signal with gaussian noise

key = random.PRNGKey(0)
x = jnp.sort(dists.Uniform(-1, 1).sample(key, sample_shape=(40, 1)))
f = lambda x : x*jnp.sin(5*x)
sub_key, key = random.split(key)
y = f(x) + dists.Normal(0., 0.1).rsample(sub_key, sample_shape=x.shape)
x_test = jnp.linspace(-1.5, 1.5, num=200)[:, jnp.newaxis]

data_plot = hv.Points((x.ravel(), y.ravel()), label='Train data').opts(width=500, size=5, color='black')

A model written in flax is a Python class that inherits from nn.Module

  • The class has to implement a __call__ function that receives the input data and returns the model prediction

  • The class may also implement a __setup__ function to declare the variables and submodules that compose the model. This is the “explicit” way of defining a model in Flax.

  • There is also the “inline” way of defining models, where instead of __setup__ the nn.compact decorator is applied to __call__. In this case we can pass variables as with Python 3.7 dataclasses

See also

See should-i-use-setup-or-nn-compact in the flax documentation for more details

The following is the polynomial regressor written following the “inline” pattern

class PolynomialRegressor(nn.Module):
    degree: int # Dataclass style of declaring arguments

    def __call__(self, x):
        phi = jnp.concatenate([x**(k+1) for k in range(], axis=1)
        return nn.Dense(1, use_bias=True)(phi)

The Dense submodule expects an integer positional parameter \(m\) and implements a Linear transformation \(wx+b\), with parameters \(w\in\mathbb{R}^{n\times m}\) and \(b\in\mathbb{R}^m\), where \(n\) is the dimension of the input

model = PolynomialRegressor(degree=6)

The object created can be initialized using the init method. This requires a PRNG key and the shape of the input minibatch

To evaluate the model given parameter values and input data the apply method is used. Here we use apply to plot the predictions for several randomly initialized models

prior_plots = []
for seed in range(20):
    params = model.init(random.PRNGKey(seed), jnp.zeros(shape=(1, 1)))
    y_test = model.apply(params, x_test)
    prior_plots.append(hv.Curve((x_test, y_test)).opts(alpha=0.1, width=500, color='#30a2da'))

To train the model we need to define a cost function (loss) and an optimizer

In this case we will the use the Mean Square Error as cost and the Adam optimizer. For convenience we write a helper function that recieves parameters, input and target and returns the loss

optimizer = optax.adam(learning_rate=1e-2)
compute_loss = lambda params, x, y : jnp.mean(optax.l2_loss(model.apply(params, x), y))

To train we will use the following optimizer methods:

  • init: Recieves the initialized parameters and return an optimizer state

  • update: Recieves the gradient of the cost function, the optimizer state and the parameters. Returns the delta parameter and updated state


The optimizer state is in charge of mantaining statistics of the optimizer, e.g. moving averages for algorithms like Adagrad, Adam, etc

We will also use the apply_updates helper function from optax to update the parameters given the output of the update optimizer method

Finally we plot the loss as a function of the training epoch and the final test prediction

model = PolynomialRegressor(degree=6)
params = model.init(random.PRNGKey(1234), jnp.zeros(shape=(1,1)))
state = optimizer.init(params)
loss = []
for epoch in tqdm(range(100)):
    loss_val, grads = jax.value_and_grad(compute_loss, argnums=0)(params, x, y)
    updates, state = optimizer.update(grads, state, params)
    params = optax.apply_updates(params, updates)
y_test = model.apply(params, x_test)

loss_plot = hv.Curve((range(len(loss)), loss), 'Epoch', 'MSE')
test_plot = hv.Curve((x_test, y_test), label='Test prediction').opts(ylim=(-1, 1))
loss_plot + hv.Overlay([data_plot, test_plot])

In this particular case we chose polynomials as basis to represent the data, but in the general case

How can we choose a basis in a more general case?

One posibility:

Don’t choose, let the model find the basis

Proposed activity for the polynomial regressor

  1. Change the degree of the basis and describe the results

  2. Increase the noise and repeat the previous step

Concepts to discuss: Complexity, generalization, overfitting, regularization

Multilayer Perceptrons

Artificial neural networks (ANN) are non-linear parametric function approximators built by connecting simple units

These units are simplified models of biological neurons. The artificial neuron is a linear regressor followed by a non-linear activation function

\[ y = g \left( b + \sum_{d=1}^D w_{d} x_d \right) \]

where for example

\[ g(x) = \frac{1}{1 + e^{-x}} \]

which is known as the sigmoid activation.

Feed-forward ANN are organized in layers and each layer has a certain amount of units. The most classical ANN architecture is the MultiLayer Perceptron (MLP)

In the MLP architecture every unit is connected to all units of its previous and next layers. For example a network with four inputs, two outputs and one hidden layer with eight units:


Mathematically, the output of the hidden layer is

\[ h_j = g_1 \left( b_j + \sum_{d=1}^{N_x} w_{jd} x_d \right), j=1, 2, \ldots, N_h \]

and the output layer

\[\begin{split} \begin{align} f_i &= g_2 \left(b_i + \sum_{j=1}^{N_h} w_{ij} h_j \right) \nonumber \\ &= g_2 \left(b_i + \sum_{j=1}^{N_h} w_{ij} g_1 \left( b_j + \sum_{d=1}^{N_x} w_{jd} x_d \right) \right), i = 1, \ldots, N_o \nonumber \end{align} \end{split}\]

where \(g_1\), \(g_2\) may be different (user choice)


The parameter vector \(\theta\) of the network includes the weights \(w\) and biases \(b\) of the linear regressors within the architecture

Flax implementation

Some useful submodules of nn by category are

The following code implements an MLP regressor with an arbitrary number of Dense hidden layers

from typing import Sequence, Callable

class MultiLayerPerceptron(nn.Module):
    hidden_neurons: Sequence[int]
    kernel_init: Callable = nn.initializers.lecun_normal()
    activation: Callable = nn.sigmoid
    def __call__(self, x):
        for k, neurons in enumerate(self.hidden_neurons):
            x = nn.Dense(neurons, kernel_init=self.kernel_init)(x)
            if k != len(self.hidden_neurons) - 1:
                x = self.activation(x)
        return x

Predictive prior distribution of the MLP

Let’s consider a Normal prior with zero mean and \(\sigma=5\) for \(\theta\) and study the space of possible models

A model with one hidden layer

model = MultiLayerPerceptron([100, 1], nn.initializers.normal(stddev=5.))

prior_x = jnp.linspace(-1, 1, num=500)[:, jnp.newaxis]
prior_plots = []
for seed in range(100):
    params = model.init(random.PRNGKey(seed), jnp.zeros(shape=(1,1)))
    prior_plots.append(hv.Curve((prior_x, model.apply(params, prior_x))).opts(width=500, alpha=0.1, color='#30a2da'))

And a model with two hidden layers

model = MultiLayerPerceptron([100, 100, 1], nn.initializers.normal(stddev=5.))

prior_plots = []
for seed in range(100):
    params = model.init(random.PRNGKey(seed), jnp.zeros(shape=(1,1)))
    prior_plots.append(hv.Curve((prior_x, model.apply(params, prior_x))).opts(width=500, alpha=0.1, color='#30a2da'))


By increasing the number of layers and units (capacity) the model becomes more flexible

Proposed activity with the predictive prior of the MLP

  • What happens if you remove the nonlinearity?

  • What happens if you change the nonlinearity?

  • What happens when you change the number of units?

  • What happens when you add even more layers?

Training the Neural Network

Let’s return to our original regression problem using a MLP model

How many hyperbolic tangent basis do we need to fit this data?

model = MultiLayerPerceptron([10, 10, 1], activation=nn.tanh)
optimizer = optax.adam(learning_rate=1e-2)
compute_loss = lambda params, x, y : jnp.mean(optax.l2_loss(model.apply(params, x), y))

params = model.init(random.PRNGKey(0), jnp.zeros(shape=(1,1)))
state = optimizer.init(params)
loss = []
for epoch in tqdm(range(200)):
    loss_val, grads = jax.value_and_grad(compute_loss)(params, x, y)
    updates, state = optimizer.update(grads, state, params)
    params = optax.apply_updates(params, updates)
y_test = model.apply(params, x_test)

loss_plot = hv.Curve((range(len(loss)), loss), 'Epoch', 'MSE')
test_plot = hv.Curve((x_test, y_test), label='Test prediction')
loss_plot + hv.Overlay([data_plot, test_plot])

Proposed activity

  • Explore the solutions using a different number of units

  • Explore the solutions using a different number of layers

In both cases tune the number of training epochs appropriately

Probabilistic interpretation of a MLP

Let’s consider an MLP architecture for multivariate regression

  • \((x^{(n)}, y^{(n)})\) for \(n=1,2,\ldots,N\) are the training tuples with \(x^{(n)}\in \mathbb{R}^D\) and \(y^{(n)} \in \mathbb{R}^O\)

  • \(f_i(\cdot)\) is the i-th output of the model, for \(i=1, 2, \ldots, O\)

  • No activation function in the output layer

  • \(\theta_k\) for \(k=1,2,\ldots, K\) is the collection of trainable parameters (weights and biases)

We fit the parameters by minimizing the Mean Square Error cost function

\[ \min_\theta \sum_{n=1}^N \sum_{i=1}^O \left(y_{i}^{(n)} - f_i(x^{(n)}) \right)^2 \]

This is equivalent to the MLE solution with Normal likelihood assuming known variance

We typically include an L2 regularizer to penalize complexity and improve generalization, in this case the cost function is

\[ \min_\theta \sum_{n=1}^N \sum_{i=1}^O \left(y_{i}^{(n)} - f_i(x^{(n)}) \right)^2 + \lambda \sum_k \theta_k^2 \]

This is equivalent to the MAP solution with Normal likelihood and zero-mean Normal prior


Conventional neural network training yields MLE/MAP point estimates


There is no closed-form solution for this cost functions because of the non-linearities in the model. ANNs are generally trained with iterative methods such as gradient descent

In the case of ANN for classification we arrive to similar conclusions except that

  • sigmoid or softmax activation is used in the output layer

  • cross-entropy cost function is used instead of MSE: Bernoulli/Categorical likelihood

Deep Learning

More complex and flexible models are obtained by increasing the number of hidden layers (depth) and the number of neurons (width)



But the more parameters, the more difficult to train: Overfitting, vanishing gradients, …

Nevertheless, very deep neural network models are the current state of the art in pattern recognition problems. We can train deep neural networks effectively thanks to

  • Having lots of data available

  • The implementation of clever architectures and activations: Convolutional Neural Networks for images, Long-short term memories for time series, Residual connections, ReLU activations, …

  • The applications of regularization schemes: penalty on parameters, data augmentation, dropout, …

  • The use of faster hardware: GPUs and tensor-cores

Why are deep models needed?

The MLP with one hidden layer (shallow network) is a universal approximator


(In theory) we could obtain a shallow network that is as flexible as a deep network


(But) it may require an extremely large number of hidden-layer neurons (even infinite)

In practice you need flexible but also compact models