Herramientas: Mi primera red neuronal basada en JAX
Contenido
Herramientas: Mi primera red neuronal basada en JAX¶
Para definir y entrenar redes neuronales artificiales con JAX utilizaremos dos librerías:
Flax es una librería de redes neuronales artificiales cuyo foco es la flexibilidad y que utiliza JAX como backend de cómputo, Además tiene interoperabilidad con la librería de programación probabilística NumPyro.
Flax provee
flax.linen
, una API con primitivas para diseñar redes neuronales (similar atorch.nn
)
optax es una librería de optimización numérica para modelos paramétricos basada en JAX. Esta librería provee los algoritmos basados en gradiente descedente que se ocupan típicamente para entrenar redes neuronales artificiales.
Instalación¶
Instala utilizando pip:
pip install flax
pip install optax
import flax.linen as nn
import optax
import jax
Definiendo un modelo en flax
¶
Un modelo en flax
es una clase de Python que hereda de flax.linen.Module
. Existen dos formas de escribir un modelo: explicita o inline
En la forma explícitala clase que representa el modelo debe implementar:
Un método
__call__
que recibe los datos de entrada y retorna una predicciónUn método
__setup__
que declara las variables y submódulos que componen el modelo
En la forma inline sólo se define __call__
con un decorador nn.compact
. Por ejemplo un regresor logístico
se implementaría como:
class LogisticRegressor(nn.Module):
@nn.compact
def __call__(self, x):
return nn.sigmoid(nn.Dense(1)(x))
LogisticRegressor()
LogisticRegressor()
donde nn.Dense
realiza una transformación lineal de tipo WX + B.
El decorador se hace cargo de registrar los parámetros de los submódulos como
nn.Dense
.
Implementemos ahora el siguiente modelo tipo multi layer perceptron con una capa oculta y sin activación de salida (modelo regresor)
class MLP_singlehidden(nn.Module):
hidden_neurons: int
output_neurons: int
@nn.compact
def __call__(self, x):
z = nn.relu(nn.Dense(self.hidden_neurons)(x))
return nn.Dense(self.output_neurons)(z)
MLP_singlehidden(10, 2)
MLP_singlehidden(
# attributes
hidden_neurons = 10
output_neurons = 2
)
Podemos pasar argumentos al momento de construir el objeto definiendolos dentro de la clase con la notación
nombre_variable : tipo_variable
Nota
flax
implementa clases de tipo dataclass
(introducidas en Python 3.7)
Veamos ahora como se implementaría un modelo MLP con
número arbitrario de capas ocultas
función de activación a elección (por defecto relu)
from typing import Sequence, Callable
class MLP(nn.Module):
neurons_per_layer: Sequence[int]
activation: Callable = nn.relu
@nn.compact
def __call__(self, x):
for k, neurons in enumerate(self.neurons_per_layer):
x = nn.Dense(neurons)(x)
if k != len(self.neurons_per_layer) - 1:
x = self.activation(x)
return x
MLP(neurons_per_layer=[10, 5, 3])
MLP(
# attributes
neurons_per_layer = [10, 5, 3]
activation = relu
)
Otros submódulos y funciones útiles de flax.linen
por categoría:
Métodos init
y apply
¶
Consideremos el último modelo definido:
model = MLP(neurons_per_layer=[10, 10, 1])
Importante
model
guarda la arquitectura del modelo pero no los valores de sus parámetros (peso). En flax
los parámetros se manejan por separado
Para inicializar los pesos del modelo utilizamos el método init
, el cual espera
una llave pseudo aleatoria consumible. Podemos generarla con
jax.random.PRNGKey
un tensor de ejemplo que tenga las dimensiones de nuestros datos
import jax.random as random
import jax.numpy as jnp
key = random.PRNGKey(12345) # 12345 es la semilla para el PRNG
params = model.init(key, jnp.zeros(shape=(1, 1)))
params
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
FrozenDict({
params: {
Dense_0: {
kernel: DeviceArray([[-1.0674667 , -0.06264212, 0.42956862, 1.2786232 ,
-0.74584633, 1.1691928 , 1.4219344 , -0.5895175 ,
0.5620119 , 0.83076704]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: DeviceArray([[ 0.08284923, -0.38116872, -0.12968694, 0.36474282,
-0.42905015, -0.03718745, -0.2860297 , 0.07171737,
0.00070555, -0.10588341],
[-0.20831482, 0.2779922 , 0.05534944, 0.09165573,
0.28945303, -0.10700346, 0.01279914, -0.5825432 ,
0.18495017, 0.03577514],
[ 0.03020477, -0.4052631 , -0.0723141 , 0.3378119 ,
0.00362646, 0.18239951, 0.10729557, 0.6282217 ,
-0.34536844, 0.4048866 ],
[-0.13189079, 0.36574128, -0.18147682, -0.04529037,
-0.22486708, 0.3561906 , 0.28744978, 0.69744307,
-0.33395192, 0.05422174],
[-0.15280393, -0.4834616 , 0.11331128, 0.08707297,
0.24745238, 0.33581656, 0.01574057, -0.38723463,
-0.35440663, 0.62023455],
[ 0.24958734, 0.10895315, 0.31780478, -0.24457307,
0.13650458, 0.14596146, -0.30161887, 0.14169854,
-0.01182529, 0.16038571],
[ 0.13808939, 0.2962448 , 0.16631499, 0.19445856,
0.05240866, -0.6165339 , 0.22212191, -0.4067405 ,
0.6950366 , -0.6181112 ],
[-0.16121888, -0.2954174 , -0.45188877, -0.04252925,
-0.4712695 , -0.05335376, 0.19730896, 0.19199689,
-0.17002471, -0.06398077],
[ 0.14025646, -0.09301604, -0.49561942, 0.16174501,
0.14713125, -0.05796253, -0.24432367, 0.05698485,
0.54087746, 0.12389056],
[ 0.14551008, 0.23458011, 0.34282166, -0.29984108,
0.7048362 , -0.19122204, -0.14576082, -0.48230988,
0.12584084, 0.39400163]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
},
Dense_2: {
kernel: DeviceArray([[-0.4023553 ],
[ 0.17063352],
[-0.20969704],
[ 0.04097234],
[ 0.02392739],
[ 0.30933133],
[ 0.52074623],
[-0.25418895],
[ 0.1303169 ],
[-0.11591556]], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
})
Luego para hacer una inferencia utilizamos el método apply
de model
Este método recibe el diccionario de parámetros y los datos a evaluar:
x_test = jnp.linspace(-1.25, 1.25, num=200)[:, jnp.newaxis]
y_test = model.apply(params, x_test)
x_test.shape, y_test.shape
((200, 1), (200, 1))
Ajustando el modelo con optax
¶
Entrenaremos un regresor basado en el modelo MLP que vimos anteriormente sobre los siguientes datos sintéticos:
key = random.PRNGKey(0)
key, key_ = random.split(key)
x = jnp.sort(random.uniform(key_, minval=-1, maxval=1, shape=(40, 1)))
f = lambda x : x*jnp.sin(5*x)
key, key_ = random.split(key)
y = f(x) + 0.1*random.normal(key_, shape=x.shape)
import holoviews as hv
hv.extension('bokeh')
data_plot = hv.Points((x.ravel(), y.ravel()), label='Train data').opts(width=500, size=5, color='black')
data_plot
Para ajustar el modelo a estos datos utilizaremos la librería de JAX optax
:
En este caso (regresión de funciones) utilizaremos el error medio cuadrático como función de costo:
donde \(f_\theta\) representa el modelo y \(\theta\) la colección de parámetros del mismo. Tradicionalmente, el entrenamiento de redes neuronales artificiales se realiza minimizando la función de costo con el algoritmo de gradiente descendente:
En este ejemplo particular utilizaremos el algoritmo de optimización Adam, un algoritmo de gradiente descedente con paso (tasa de aprendizaje) adaptativa.
La función de costo y el optimizador serían:
optimizer = optax.adam(learning_rate=1e-2)
def loss(params, x, y):
return jnp.mean((model.apply(params, x) - y)**2)
El optimizador tiene los siguientes métodos
init
: Recibe los parámetros iniciales y retorna el “estado” inicial del optimizadorupdate
: Recibe los gradientes de la función de costo, el “estado” actual del optimizador y los parámetros actuales del modelo. Retorna la diferencia entre los parámetros nuevos y actuales, y también el estado actualizado
Nota
El “estado” del optimizador se encarga de mantener las variables particulares de cada optimizador, como por ejemplo los estadísticos de los gradientes en el caso de Adam
Inicializamos el modelo y el optimizador con sus métodos init
respectivos:
key = random.PRNGKey(12345)
model = MLP(neurons_per_layer=[10, 10, 1], activation=nn.tanh)
params = model.init(key, jnp.zeros(shape=(1, 1)))
state = optimizer.init(params)
Ahora lo que necesitamos es calcular los gradientes de la función de costo. Para esto utilizaremos el autodiferenciador de jax a través de la función value_and_grad
.
jax.value_and_grad
recibe una función y retorna una nueva función que recibe los mismos argumentos que la original pero que retorna los gradientes en función de uno de sus argumentos (indicado por argnums
)
Por ejemplo:
grad_loss = jax.value_and_grad(loss, argnums=0)
loss_val, grads = grad_loss(params, x, y)
loss_val
DeviceArray(0.24306464, dtype=float32)
donde grads
es un diccionario con los gradientes de loss
con respecto a params
Podemos ganar considerable desempeño si compilamos grad_loss
utilizando jax.jit
, por ejemplo:
# sin JIT
grad_loss = jax.value_and_grad(loss, argnums=0)
%timeit -r3 -n3 grad_loss(params, x, y)
88.2 ms ± 2.38 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
# con JIT
grad_loss_jit = jax.jit(jax.value_and_grad(loss, argnums=0))
grad_loss_jit(params, x, y) # La primera llamada invoca la compilación
%timeit -r3 -n3 grad_loss_jit(params, x, y)
54.3 µs ± 12.3 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
Utilicemos grad_loss_jit
para entrenar el modelo por 200 épocas como muestra el siguiente bucle.
Luego de calcular los gradientes usamos el método update
del optimizador para calcular los \(\Delta \theta\) de acuerdo al algoritmo seleccionado (adam).
Finalmente ocupamos la función utilitaria apply_updates
para actualizar los parámetros, es decir \(\theta_{t+1} = \theta_t + \Delta \theta\).
from tqdm import tqdm
loss = []
for epoch in tqdm(range(200)):
loss_val, grads = grad_loss_jit(params, x, y)
updates, state = optimizer.update(grads, state, params)
params = optax.apply_updates(params, updates)
loss.append(loss_val.item())
y_test = model.apply(params, x_test)
100%|███████| 200/200 [00:01<00:00, 104.22it/s]
A continuación se muestra la evolución de la función de costo (MSE) y el resultado de la predicción con el regresor aprendido:
loss_plot = hv.Curve((range(len(loss)), loss), 'Epoch', 'MSE')
test_plot = hv.Curve((x_test, y_test), label='Test prediction')#.opts(ylim=(-1, 1))
loss_plot + hv.Overlay([data_plot, test_plot])