Herramientas: Mi primera red neuronal basada en JAX

Para definir y entrenar redes neuronales artificiales con JAX utilizaremos dos librerías:

Flax es una librería de redes neuronales artificiales cuyo foco es la flexibilidad y que utiliza JAX como backend de cómputo, Además tiene interoperabilidad con la librería de programación probabilística NumPyro.

Flax provee flax.linen, una API con primitivas para diseñar redes neuronales (similar a torch.nn)

optax es una librería de optimización numérica para modelos paramétricos basada en JAX. Esta librería provee los algoritmos basados en gradiente descedente que se ocupan típicamente para entrenar redes neuronales artificiales.

Instalación

Instala utilizando pip:

pip install flax
pip install optax
import flax.linen as nn
import optax
import jax

Definiendo un modelo en flax

Un modelo en flax es una clase de Python que hereda de flax.linen.Module. Existen dos formas de escribir un modelo: explicita o inline

En la forma explícitala clase que representa el modelo debe implementar:

  • Un método __call__ que recibe los datos de entrada y retorna una predicción

  • Un método __setup__ que declara las variables y submódulos que componen el modelo

En la forma inline sólo se define __call__ con un decorador nn.compact. Por ejemplo un regresor logístico

\[ y = \text{sigmoid}\left(\sum_j w_j x_j + b \right) \]

se implementaría como:

class LogisticRegressor(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        return nn.sigmoid(nn.Dense(1)(x))
    
LogisticRegressor()
LogisticRegressor()

donde nn.Dense realiza una transformación lineal de tipo WX + B.

El decorador se hace cargo de registrar los parámetros de los submódulos como nn.Dense.

Implementemos ahora el siguiente modelo tipo multi layer perceptron con una capa oculta y sin activación de salida (modelo regresor)

\[ y = \left(\sum_j w_j \text{ReLU}\left( \sum_i w_{ij} x_i + b_i\right) + b \right) \]
class MLP_singlehidden(nn.Module):
    
    hidden_neurons: int
    output_neurons: int
        
    @nn.compact
    def __call__(self, x):
        z = nn.relu(nn.Dense(self.hidden_neurons)(x))
        return nn.Dense(self.output_neurons)(z)
   
MLP_singlehidden(10, 2)
MLP_singlehidden(
    # attributes
    hidden_neurons = 10
    output_neurons = 2
)

Podemos pasar argumentos al momento de construir el objeto definiendolos dentro de la clase con la notación

nombre_variable : tipo_variable

Nota

flax implementa clases de tipo dataclass (introducidas en Python 3.7)

Veamos ahora como se implementaría un modelo MLP con

  • número arbitrario de capas ocultas

  • función de activación a elección (por defecto relu)

from typing import Sequence, Callable

class MLP(nn.Module):
    
    neurons_per_layer: Sequence[int]
    activation: Callable = nn.relu
    
    @nn.compact
    def __call__(self, x):
        for k, neurons in enumerate(self.neurons_per_layer):
            x = nn.Dense(neurons)(x)
            if k != len(self.neurons_per_layer) - 1:
                x = self.activation(x)
        return x
    
MLP(neurons_per_layer=[10, 5, 3])
MLP(
    # attributes
    neurons_per_layer = [10, 5, 3]
    activation = relu
)

Otros submódulos y funciones útiles de flax.linen por categoría:

Métodos init y apply

Consideremos el último modelo definido:

model = MLP(neurons_per_layer=[10, 10, 1])

Importante

model guarda la arquitectura del modelo pero no los valores de sus parámetros (peso). En flax los parámetros se manejan por separado

Para inicializar los pesos del modelo utilizamos el método init, el cual espera

  • una llave pseudo aleatoria consumible. Podemos generarla con jax.random.PRNGKey

  • un tensor de ejemplo que tenga las dimensiones de nuestros datos

import jax.random as random
import jax.numpy as jnp

key = random.PRNGKey(12345) # 12345 es la semilla para el PRNG
params = model.init(key, jnp.zeros(shape=(1, 1)))
params
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
FrozenDict({
    params: {
        Dense_0: {
            kernel: DeviceArray([[-1.0674667 , -0.06264212,  0.42956862,  1.2786232 ,
                          -0.74584633,  1.1691928 ,  1.4219344 , -0.5895175 ,
                           0.5620119 ,  0.83076704]], dtype=float32),
            bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
        },
        Dense_1: {
            kernel: DeviceArray([[ 0.08284923, -0.38116872, -0.12968694,  0.36474282,
                          -0.42905015, -0.03718745, -0.2860297 ,  0.07171737,
                           0.00070555, -0.10588341],
                         [-0.20831482,  0.2779922 ,  0.05534944,  0.09165573,
                           0.28945303, -0.10700346,  0.01279914, -0.5825432 ,
                           0.18495017,  0.03577514],
                         [ 0.03020477, -0.4052631 , -0.0723141 ,  0.3378119 ,
                           0.00362646,  0.18239951,  0.10729557,  0.6282217 ,
                          -0.34536844,  0.4048866 ],
                         [-0.13189079,  0.36574128, -0.18147682, -0.04529037,
                          -0.22486708,  0.3561906 ,  0.28744978,  0.69744307,
                          -0.33395192,  0.05422174],
                         [-0.15280393, -0.4834616 ,  0.11331128,  0.08707297,
                           0.24745238,  0.33581656,  0.01574057, -0.38723463,
                          -0.35440663,  0.62023455],
                         [ 0.24958734,  0.10895315,  0.31780478, -0.24457307,
                           0.13650458,  0.14596146, -0.30161887,  0.14169854,
                          -0.01182529,  0.16038571],
                         [ 0.13808939,  0.2962448 ,  0.16631499,  0.19445856,
                           0.05240866, -0.6165339 ,  0.22212191, -0.4067405 ,
                           0.6950366 , -0.6181112 ],
                         [-0.16121888, -0.2954174 , -0.45188877, -0.04252925,
                          -0.4712695 , -0.05335376,  0.19730896,  0.19199689,
                          -0.17002471, -0.06398077],
                         [ 0.14025646, -0.09301604, -0.49561942,  0.16174501,
                           0.14713125, -0.05796253, -0.24432367,  0.05698485,
                           0.54087746,  0.12389056],
                         [ 0.14551008,  0.23458011,  0.34282166, -0.29984108,
                           0.7048362 , -0.19122204, -0.14576082, -0.48230988,
                           0.12584084,  0.39400163]], dtype=float32),
            bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
        },
        Dense_2: {
            kernel: DeviceArray([[-0.4023553 ],
                         [ 0.17063352],
                         [-0.20969704],
                         [ 0.04097234],
                         [ 0.02392739],
                         [ 0.30933133],
                         [ 0.52074623],
                         [-0.25418895],
                         [ 0.1303169 ],
                         [-0.11591556]], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

Luego para hacer una inferencia utilizamos el método apply de model

Este método recibe el diccionario de parámetros y los datos a evaluar:

x_test = jnp.linspace(-1.25, 1.25, num=200)[:, jnp.newaxis]
y_test = model.apply(params, x_test)

x_test.shape, y_test.shape
((200, 1), (200, 1))

Ajustando el modelo con optax

Entrenaremos un regresor basado en el modelo MLP que vimos anteriormente sobre los siguientes datos sintéticos:

key = random.PRNGKey(0)
key, key_ = random.split(key)

x = jnp.sort(random.uniform(key_, minval=-1, maxval=1, shape=(40, 1)))
f = lambda x : x*jnp.sin(5*x)

key, key_ = random.split(key)
y = f(x) + 0.1*random.normal(key_, shape=x.shape)
import holoviews as hv
hv.extension('bokeh')

data_plot = hv.Points((x.ravel(), y.ravel()), label='Train data').opts(width=500, size=5, color='black')
data_plot

Para ajustar el modelo a estos datos utilizaremos la librería de JAX optax:

En este caso (regresión de funciones) utilizaremos el error medio cuadrático como función de costo:

\[ L = \frac{1}{N} \sum_{i=1}^N (y_i - f_\theta(x_i))^2 \]

donde \(f_\theta\) representa el modelo y \(\theta\) la colección de parámetros del mismo. Tradicionalmente, el entrenamiento de redes neuronales artificiales se realiza minimizando la función de costo con el algoritmo de gradiente descendente:

\[ \theta_{t+1} = \theta_t - \eta \frac{dL}{d\theta} \]

En este ejemplo particular utilizaremos el algoritmo de optimización Adam, un algoritmo de gradiente descedente con paso (tasa de aprendizaje) adaptativa.

La función de costo y el optimizador serían:

optimizer = optax.adam(learning_rate=1e-2)

def loss(params, x, y):
    return jnp.mean((model.apply(params, x) - y)**2)

El optimizador tiene los siguientes métodos

  • init: Recibe los parámetros iniciales y retorna el “estado” inicial del optimizador

  • update: Recibe los gradientes de la función de costo, el “estado” actual del optimizador y los parámetros actuales del modelo. Retorna la diferencia entre los parámetros nuevos y actuales, y también el estado actualizado

Nota

El “estado” del optimizador se encarga de mantener las variables particulares de cada optimizador, como por ejemplo los estadísticos de los gradientes en el caso de Adam

Inicializamos el modelo y el optimizador con sus métodos init respectivos:

key = random.PRNGKey(12345)
model = MLP(neurons_per_layer=[10, 10, 1], activation=nn.tanh)
params = model.init(key, jnp.zeros(shape=(1, 1)))
state = optimizer.init(params)

Ahora lo que necesitamos es calcular los gradientes de la función de costo. Para esto utilizaremos el autodiferenciador de jax a través de la función value_and_grad.

jax.value_and_grad recibe una función y retorna una nueva función que recibe los mismos argumentos que la original pero que retorna los gradientes en función de uno de sus argumentos (indicado por argnums)

Por ejemplo:

grad_loss = jax.value_and_grad(loss, argnums=0)
loss_val, grads = grad_loss(params, x, y)
loss_val
DeviceArray(0.24306464, dtype=float32)

donde grads es un diccionario con los gradientes de loss con respecto a params

Podemos ganar considerable desempeño si compilamos grad_loss utilizando jax.jit, por ejemplo:

# sin JIT
grad_loss = jax.value_and_grad(loss, argnums=0)
%timeit -r3 -n3 grad_loss(params, x, y)
88.2 ms ± 2.38 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
# con JIT
grad_loss_jit = jax.jit(jax.value_and_grad(loss, argnums=0))
grad_loss_jit(params, x, y) # La primera llamada invoca la compilación
%timeit -r3 -n3 grad_loss_jit(params, x, y)
54.3 µs ± 12.3 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)

Utilicemos grad_loss_jit para entrenar el modelo por 200 épocas como muestra el siguiente bucle.

Luego de calcular los gradientes usamos el método update del optimizador para calcular los \(\Delta \theta\) de acuerdo al algoritmo seleccionado (adam).

Finalmente ocupamos la función utilitaria apply_updates para actualizar los parámetros, es decir \(\theta_{t+1} = \theta_t + \Delta \theta\).

from tqdm import tqdm

loss = []
for epoch in tqdm(range(200)):
    loss_val, grads = grad_loss_jit(params, x, y) 
    updates, state = optimizer.update(grads, state, params) 
    params = optax.apply_updates(params, updates)
    loss.append(loss_val.item())
    
y_test = model.apply(params, x_test)
100%|███████| 200/200 [00:01<00:00, 104.22it/s]

A continuación se muestra la evolución de la función de costo (MSE) y el resultado de la predicción con el regresor aprendido:

loss_plot = hv.Curve((range(len(loss)), loss), 'Epoch', 'MSE')
test_plot = hv.Curve((x_test, y_test), label='Test prediction')#.opts(ylim=(-1, 1))
loss_plot + hv.Overlay([data_plot, test_plot])