Introducción: Transformaciones básicas de JAX
Contenido
Introducción: Transformaciones básicas de JAX¶
JAX es un framework reciente enfocado en computación científica cuyos componentes principales son:
Autograd: Un sistema para obtener derivadas de todo orden.
Accelerated Linear Algebra (XLA): Un compilador enfocado en álgebra lineal.
Ambos aplican para rutinas escritas en Python nátivo o basadas en la librería NumPy. En conjunto permiten diferenciar, vectorizar, paralelizar y compilar a coprocesador GPU/TPU rutinas matemáticas.
JAX es ideal para Machine Learning.
Veremos algunas de sus bondades a continuación.
Instalación¶
JAX soporta oficialmente Linux y macOS. Hay soporte limitado y no oficial para Windows, pero requiere compilación manual de las fuentes.
Nota
Si tienes SO Windows puedes experimentar con JAX en google colab. En lo que sigue se asume una instalación local de JAX.
Para instalar JAX lo más simple es utilizar el manejador de paquetas conda. Para instalar JAX con soporte GPU, crea un ambiente de conda, actívalo y ejecuta el siguiente comando:
conda install jax cuda-nvcc -c conda-forge -c nvidia
Si sólo vas a usar arquitecturas CPU puedes hacer una instalación más ligera con:
conda install jax -c conda-forge
JAX también puede instalarse con pip (distintos sabores):
pip install "jax[cpu]"
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Una vez instalado podemos verificar los dispositivos compatibles para hacer cómputo con:
import jax
jax.devices()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
que en este caso indica que sólo está disponible un dispositivo CPU.
JAX arrays y jax.numpy
¶
Consideremos un ndarray
con 10000 filas y 2 columnas y una función basada en la API de NumPy que calcula la distancia euclidiana entre todos los pares de filas:
import numpy as np
data = np.random.randn(10000, 2)
def distancia_pares_np(data):
return np.sqrt(np.sum((data.reshape(-1, 1, 2) - data.reshape(1, -1, 2))**2, axis=-1))
print(type(data))
%timeit -r3 -n1 distancia_pares_np(data)
<class 'numpy.ndarray'>
6.86 s ± 181 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
Para acelerar esta función con JAX haremos dos cambios:
Reemplazar los llamados a
numpy
por llamados ajax.numpy
Expresar los datos de entrada como un arreglo de JAX (
DeviceArray
)
El segundo es en realidad opcional (pero necesario si queremos usar JIT)
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
def distancia_pares_jnp(data):
return jnp.sqrt(jnp.sum((data.reshape(-1, 1, 2) - data.reshape(1, -1, 2))**2, axis=-1))
jax_data = jnp.array(data, dtype=jnp.float64)
print(type(jax_data))
print(np.allclose(distancia_pares_np(data), distancia_pares_jnp(jax_data)))
%timeit -r3 -n3 distancia_pares_jnp(data).block_until_ready()
%timeit -r3 -n3 distancia_pares_jnp(jax_data).block_until_ready()
<class 'jaxlib.xla_extension.DeviceArray'>
True
2.82 s ± 20.1 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
637 ms ± 2.1 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
El mismo resultado con un orden de magnitud de diferencia en términos de velocidad. Además, la función en JAX es idéntica a la original.
Es muy simple utilizar JAX si ya estamos utilizando NumPy
Nota
Por defecto JAX trabaja con flotantes de 32 bits (single precision) ya que está pensando en arquitecturas GPU y TPU. Se puede activar manualmente la precisión doble como se mostró en el ejemplo anterior.
Los arreglos en JAX se representan con objectos DeviceArray
. Este objeto puede vivir en memoria de CPU, GPU y TPU. JAX realizará los cómputos de acuerdo a donde están alojados los arreglos.
Podemos consultar donde está alojado un DeviceArray
con:
x = jnp.array([0., 1., 2.])
x.device_buffer.device()
CpuDevice(id=0)
Y podemos transferirlo a otros dispositivos disponibles con device_put
:
jax.device_put(x, device=jax.devices()[0])
DeviceArray([0., 1., 2.], dtype=float64)
jax.numpy
y DeviceArray
siguen de forma bastante directa a numpy
y ndarray
con algunas excepciones notables, como por ejemplo:
Advertencia
Los arreglos de tipo DeviceArray
son inmutables. Una vez creados no pueden modificarse.
Por ejemplo ejecutar
x = jnp.array([0., 1., 2.])
x[0] = 1.
Levantará una excepción de tipo TypeError
Compilación JIT¶
JAX permite compilación de tipo Just-in-time (JIT) en base a la tecnología XLA. Esto se logra con la transformación jax.jit
, que también puede invocarse como decorador @jit
.
Por ejemplo si “jitificamos” distancia_pares_jnp
:
distancia_pares_jit = jax.jit(distancia_pares_jnp)
print(np.allclose(distancia_pares_np(data), distancia_pares_jit(jax_data)))
%timeit -r3 -n3 distancia_pares_jit(jax_data).block_until_ready()
True
229 ms ± 17.2 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
Resultado equivalente pero tres veces más rápido que la versión de jax.numpy
Nota
La compilación ocurre al ejecutar por primera vez una función transformada con jax.jit
. Internamente JAX realiza un trazado de la función en base a su entrada para verificar si es pura y luego la lleva a un lenguaje intermedio llamado jaxpr. La compilación tiene un overhead importante pero ejecuciones subsecuentes serán mucho más veloces.
Sin embargo, existen ciertas restricciones para utilizar jit
.
Advertencia
jit
sólo puede aplicarse a funciones que “funcionalmente puras”. Una función pura:
Debe entregar siempre el mismo resultado si se entregan las mismas entradas.
No debe cambiar su comportamiento en base al valor de su entrada (prohibido los if/else a menos que sean estos)
No debe utilizar variables globales o fuera de su scope.
No debe tener side-effects (por ejemplo llamadas a I/O).
Se pueden marcar algunas variables como estáticas (no “jitificables”) con el argumento static_argnums
de jax.jit
Autodiferenciación¶
La siguiente transformación que revisaremos es jax.grad
. Esta transformación recibe una función y retorna otra función que corresponde a su gradiente. Algunas restricciones de la función original:
Debe recibir argumentos flotantes (single o double).
Debe retornar un valor escalar.
Veamos un ejemplo, y comparemos el resultado contra la deriviada analítica de la función:
def sigmoid(x):
return 1.0/(1.0 + jnp.exp(-x))
diff_sigmoid = jax.grad(sigmoid)
for x in jnp.linspace(-3., 3., 11):
print(f"x:{x:0.4f}\tdf(x)/dx:{diff_sigmoid(x):0.4f}\t f(x)(1-f(x)):{sigmoid(x)*(1.-sigmoid(x)):0.4f}")
x:-3.0000 df(x)/dx:0.0452 f(x)(1-f(x)):0.0452
x:-2.4000 df(x)/dx:0.0763 f(x)(1-f(x)):0.0763
x:-1.8000 df(x)/dx:0.1217 f(x)(1-f(x)):0.1217
x:-1.2000 df(x)/dx:0.1779 f(x)(1-f(x)):0.1779
x:-0.6000 df(x)/dx:0.2288 f(x)(1-f(x)):0.2288
x:0.0000 df(x)/dx:0.2500 f(x)(1-f(x)):0.2500
x:0.6000 df(x)/dx:0.2288 f(x)(1-f(x)):0.2288
x:1.2000 df(x)/dx:0.1779 f(x)(1-f(x)):0.1779
x:1.8000 df(x)/dx:0.1217 f(x)(1-f(x)):0.1217
x:2.4000 df(x)/dx:0.0763 f(x)(1-f(x)):0.0763
x:3.0000 df(x)/dx:0.0452 f(x)(1-f(x)):0.0452
JAX está diseñado para componer sus transformaciones. Por ejemplo es muy simple componer grad
con jit
:
jit_diff_sigmoid = jax.jit(diff_sigmoid)
print(np.allclose(diff_sigmoid(0.01), jit_diff_sigmoid(0.01)))
%timeit -r10 -n10 diff_sigmoid(0.01).block_until_ready()
%timeit -r10 -n10 jit_diff_sigmoid(0.01).block_until_ready()
True
4.21 ms ± 382 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
The slowest run took 6.81 times longer than the fastest. This could mean that an intermediate result is being cached.
8.66 µs ± 9.09 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
Podemos también componer grad
consigo mismo para obtener derivadas de mayor orden:
def poly(x):
return x**2 + 2*x + 3.
dpoly = jax.grad(poly)
d2poly = jax.grad(dpoly)
d3poly = jax.grad(d2poly)
for x in jnp.linspace(-3., 3., 11):
print(f"x:{x:0.4f}\tdf:{dpoly(x):0.4f}\t df2:{d2poly(x):0.4f}\t df3:{d3poly(x):0.4f}")
x:-3.0000 df:-4.0000 df2:2.0000 df3:0.0000
x:-2.4000 df:-2.8000 df2:2.0000 df3:0.0000
x:-1.8000 df:-1.6000 df2:2.0000 df3:0.0000
x:-1.2000 df:-0.4000 df2:2.0000 df3:0.0000
x:-0.6000 df:0.8000 df2:2.0000 df3:0.0000
x:0.0000 df:2.0000 df2:2.0000 df3:0.0000
x:0.6000 df:3.2000 df2:2.0000 df3:0.0000
x:1.2000 df:4.4000 df2:2.0000 df3:0.0000
x:1.8000 df:5.6000 df2:2.0000 df3:0.0000
x:2.4000 df:6.8000 df2:2.0000 df3:0.0000
x:3.0000 df:8.0000 df2:2.0000 df3:0.0000
Por defecto la diferenciación se hace contra el primer argumento de la función. Para funciones con múltiples argumentos podemos modificar este comportamiento con el argumento argnums
de jax.grad
.
def f(x, w):
return jnp.dot(x,w)
x = jnp.array([-2., 2.])
w = jnp.array([1., -1.])
dfdx = jax.grad(f, argnums=0)
print(dfdx(x, w))
dfdw = jax.grad(f, argnums=1)
print(dfdw(x, w))
[ 1. -1.]
[-2. 2.]
Nota
JAX también tiene transformaciones para calcular jacobianos y hessianos, lo cual vuelve bastante factible la utilización de algoritmos de segundo orden para optimizar.
Vectorización¶
La última transformación que revisaremos es jax.vmap
. Esta transformación recibe una función para valores escalares y retorna otra función que se puede aplicar sobre arreglos, es decir vectoriza automaticamente la función original.
Esta transformación se puede componer con grad
y con jit
:
def f(x):
return 1.0/(1.0 + jnp.exp(-x))
df = jax.grad(f) # Recibe y retorna escalares
vdf = jax.vmap(df) # Recibe y retorna arreglos
vdf_jit = jax.jit(vdf)
x = jnp.linspace(-5, 5, 1000)
print(np.allclose(vdf(x), vdf_jit(x)))
%timeit -r3 -n3 jnp.array([df(x_) for x_ in x])
%timeit -r3 -n3 vdf(x).block_until_ready()
%timeit -r3 -n3 vdf_jit(x).block_until_ready()
True
4.43 s ± 108 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
5.96 ms ± 241 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
The slowest run took 5.58 times longer than the fastest. This could mean that an intermediate result is being cached.
46.9 µs ± 39.1 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
import holoviews as hv
hv.extension('bokeh')
hv.opts.defaults(hv.opts.Curve(width=400))
hv.Curve((x, f(x)), label='f(x)') * hv.Curve((x, vdf_jit(x)), label='df/dx')
Típicamente utilizaremos esta transformación cuando tengamos una función que debe aplicarse sobre un batch de datos de forma independiente: Single Instruction Multiple Data (SIMD)
Digamos que tenemos un dataset \(x\in \mathbb{R}^{NxD}\) y para cada fila \(i\) necesitamos calcular:
con JAX esto sería:
def Dense(x, w):
return w[0] + jnp.sum(w[len(x):]*x)
N, D = 5, 3
data = jnp.ones(shape=(N, D))
param = jnp.ones(shape=(D+1))
vDense = jax.jit(jax.vmap(Dense, in_axes=[0, None]))
vDense(data, param)
DeviceArray([4., 4., 4., 4., 4.], dtype=float64)
Podemos especificar cual dimensión del tensor de entrada y salida queremos vectorizar con los argumentos in_axes
y out_axes
de jax.vmap
, respectivamente.
Generación de números aleatorios¶
Existen algunas diferencias importantes entre el generador de números pseudo-aleatorios (PRNG) de NumPy y de JAX. La primera es que las rutins de PRNG de JAX están en jax.random
en vez de jax.numpy.random
que sería lo esperable.
Lo segundo y más fundamental, es que JAX no mantiene el estado del PRNG de manera global. En su lugar, es el usuario el que debe mantener el estado.
Veamos que significa en la práctica:
np.random.seed(1234) # Configura el estado inicial de forma global
for i in range(3):
print(np.random.randn()) # Cada llamada actualiza el estado global
0.47143516373249306
-1.1909756947064645
1.4327069684260973
En JAX, las funciones de jax.random
esperan una llave:
import jax.random as random
key = random.PRNGKey(1234) # Estado inicial
for i in range(3):
print(random.normal(key)) # Ocupa la llave, pero la actualiza
0.4395758528880331
0.4395758528880331
0.4395758528880331
Si utilizamos la misma llave obtendremos el mismo resultado. Para no reutilizar una llave podemos utilizar random.split
:
import jax.random as random
key = random.PRNGKey(1234)
for i in range(3):
key, subkey = random.split(key) # Genera dos llaves nuevas
print(random.normal(subkey))
-1.7366326757754018
-0.74630100088392
-0.17782682174856315
random.split
recibe como argumento opcional num
la cantidad de llaves nuevas a generar
Esta “complicación extra” es fundamental para poder vectorizar eficientemente cómputos que involucren PRNG
Otros detalles sobre JAX¶
Las llamadas en JAX se realizan de forma asíncrona. Es por esto que para medir tiempos forzabamos sincronía con
.block_until_ready()
.jax.numpy
está construido en base al módulo de bajo niveljax.lax
, el cual es más estricto pero más eficiente.JAX soporta paralelismo, es decir cómputos que usan más de un nucleo de CPU o más de una GPU. La transformación para implementar paralelismo es
jax.pmap
.
Reviews y comparativas: