Introducción: Transformaciones básicas de JAX

JAX es un framework reciente enfocado en computación científica cuyos componentes principales son:

  • Autograd: Un sistema para obtener derivadas de todo orden.

  • Accelerated Linear Algebra (XLA): Un compilador enfocado en álgebra lineal.

Ambos aplican para rutinas escritas en Python nátivo o basadas en la librería NumPy. En conjunto permiten diferenciar, vectorizar, paralelizar y compilar a coprocesador GPU/TPU rutinas matemáticas.

JAX es ideal para Machine Learning.

Veremos algunas de sus bondades a continuación.

Instalación

JAX soporta oficialmente Linux y macOS. Hay soporte limitado y no oficial para Windows, pero requiere compilación manual de las fuentes.

Nota

Si tienes SO Windows puedes experimentar con JAX en google colab. En lo que sigue se asume una instalación local de JAX.

Para instalar JAX lo más simple es utilizar el manejador de paquetas conda. Para instalar JAX con soporte GPU, crea un ambiente de conda, actívalo y ejecuta el siguiente comando:

conda install jax cuda-nvcc -c conda-forge -c nvidia

Si sólo vas a usar arquitecturas CPU puedes hacer una instalación más ligera con:

conda install jax -c conda-forge

JAX también puede instalarse con pip (distintos sabores):

pip install "jax[cpu]"
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Una vez instalado podemos verificar los dispositivos compatibles para hacer cómputo con:

import jax
jax.devices()
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

que en este caso indica que sólo está disponible un dispositivo CPU.

JAX arrays y jax.numpy

Consideremos un ndarray con 10000 filas y 2 columnas y una función basada en la API de NumPy que calcula la distancia euclidiana entre todos los pares de filas:

import numpy as np

data = np.random.randn(10000, 2)

def distancia_pares_np(data):
    return np.sqrt(np.sum((data.reshape(-1, 1, 2) - data.reshape(1, -1, 2))**2, axis=-1))

print(type(data))

%timeit -r3 -n1 distancia_pares_np(data)
<class 'numpy.ndarray'>
6.86 s ± 181 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)

Para acelerar esta función con JAX haremos dos cambios:

  • Reemplazar los llamados a numpy por llamados a jax.numpy

  • Expresar los datos de entrada como un arreglo de JAX (DeviceArray)

El segundo es en realidad opcional (pero necesario si queremos usar JIT)

import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

def distancia_pares_jnp(data):
    return jnp.sqrt(jnp.sum((data.reshape(-1, 1, 2) - data.reshape(1, -1, 2))**2, axis=-1))

jax_data = jnp.array(data, dtype=jnp.float64)
print(type(jax_data))

print(np.allclose(distancia_pares_np(data), distancia_pares_jnp(jax_data)))
%timeit -r3 -n3 distancia_pares_jnp(data).block_until_ready()
%timeit -r3 -n3 distancia_pares_jnp(jax_data).block_until_ready()
<class 'jaxlib.xla_extension.DeviceArray'>
True
2.82 s ± 20.1 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
637 ms ± 2.1 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)

El mismo resultado con un orden de magnitud de diferencia en términos de velocidad. Además, la función en JAX es idéntica a la original.

Es muy simple utilizar JAX si ya estamos utilizando NumPy

Nota

Por defecto JAX trabaja con flotantes de 32 bits (single precision) ya que está pensando en arquitecturas GPU y TPU. Se puede activar manualmente la precisión doble como se mostró en el ejemplo anterior.

Los arreglos en JAX se representan con objectos DeviceArray. Este objeto puede vivir en memoria de CPU, GPU y TPU. JAX realizará los cómputos de acuerdo a donde están alojados los arreglos.

Podemos consultar donde está alojado un DeviceArray con:

x = jnp.array([0., 1., 2.])
x.device_buffer.device()
CpuDevice(id=0)

Y podemos transferirlo a otros dispositivos disponibles con device_put:

jax.device_put(x, device=jax.devices()[0])
DeviceArray([0., 1., 2.], dtype=float64)

jax.numpy y DeviceArray siguen de forma bastante directa a numpy y ndarray con algunas excepciones notables, como por ejemplo:

Advertencia

Los arreglos de tipo DeviceArray son inmutables. Una vez creados no pueden modificarse.

Por ejemplo ejecutar

x = jnp.array([0., 1., 2.])
x[0] = 1.

Levantará una excepción de tipo TypeError

Compilación JIT

JAX permite compilación de tipo Just-in-time (JIT) en base a la tecnología XLA. Esto se logra con la transformación jax.jit, que también puede invocarse como decorador @jit.

Por ejemplo si “jitificamos” distancia_pares_jnp:

distancia_pares_jit = jax.jit(distancia_pares_jnp)

print(np.allclose(distancia_pares_np(data), distancia_pares_jit(jax_data)))
%timeit -r3 -n3 distancia_pares_jit(jax_data).block_until_ready()
True
229 ms ± 17.2 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)

Resultado equivalente pero tres veces más rápido que la versión de jax.numpy

Nota

La compilación ocurre al ejecutar por primera vez una función transformada con jax.jit. Internamente JAX realiza un trazado de la función en base a su entrada para verificar si es pura y luego la lleva a un lenguaje intermedio llamado jaxpr. La compilación tiene un overhead importante pero ejecuciones subsecuentes serán mucho más veloces.

Sin embargo, existen ciertas restricciones para utilizar jit.

Advertencia

jit sólo puede aplicarse a funciones que “funcionalmente puras”. Una función pura:

  • Debe entregar siempre el mismo resultado si se entregan las mismas entradas.

  • No debe cambiar su comportamiento en base al valor de su entrada (prohibido los if/else a menos que sean estos)

  • No debe utilizar variables globales o fuera de su scope.

  • No debe tener side-effects (por ejemplo llamadas a I/O).

Se pueden marcar algunas variables como estáticas (no “jitificables”) con el argumento static_argnums de jax.jit

Autodiferenciación

La siguiente transformación que revisaremos es jax.grad. Esta transformación recibe una función y retorna otra función que corresponde a su gradiente. Algunas restricciones de la función original:

  • Debe recibir argumentos flotantes (single o double).

  • Debe retornar un valor escalar.

Veamos un ejemplo, y comparemos el resultado contra la deriviada analítica de la función:

def sigmoid(x): 
    return 1.0/(1.0 + jnp.exp(-x))

diff_sigmoid = jax.grad(sigmoid)

for x in jnp.linspace(-3., 3., 11):
    print(f"x:{x:0.4f}\tdf(x)/dx:{diff_sigmoid(x):0.4f}\t f(x)(1-f(x)):{sigmoid(x)*(1.-sigmoid(x)):0.4f}")
x:-3.0000	df(x)/dx:0.0452	 f(x)(1-f(x)):0.0452
x:-2.4000	df(x)/dx:0.0763	 f(x)(1-f(x)):0.0763
x:-1.8000	df(x)/dx:0.1217	 f(x)(1-f(x)):0.1217
x:-1.2000	df(x)/dx:0.1779	 f(x)(1-f(x)):0.1779
x:-0.6000	df(x)/dx:0.2288	 f(x)(1-f(x)):0.2288
x:0.0000	df(x)/dx:0.2500	 f(x)(1-f(x)):0.2500
x:0.6000	df(x)/dx:0.2288	 f(x)(1-f(x)):0.2288
x:1.2000	df(x)/dx:0.1779	 f(x)(1-f(x)):0.1779
x:1.8000	df(x)/dx:0.1217	 f(x)(1-f(x)):0.1217
x:2.4000	df(x)/dx:0.0763	 f(x)(1-f(x)):0.0763
x:3.0000	df(x)/dx:0.0452	 f(x)(1-f(x)):0.0452

JAX está diseñado para componer sus transformaciones. Por ejemplo es muy simple componer grad con jit:

jit_diff_sigmoid = jax.jit(diff_sigmoid)
print(np.allclose(diff_sigmoid(0.01), jit_diff_sigmoid(0.01)))

%timeit -r10 -n10 diff_sigmoid(0.01).block_until_ready()
%timeit -r10 -n10 jit_diff_sigmoid(0.01).block_until_ready()
True
4.21 ms ± 382 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
The slowest run took 6.81 times longer than the fastest. This could mean that an intermediate result is being cached.
8.66 µs ± 9.09 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

Podemos también componer grad consigo mismo para obtener derivadas de mayor orden:

def poly(x): 
    return x**2 + 2*x + 3.

dpoly = jax.grad(poly)
d2poly = jax.grad(dpoly)
d3poly = jax.grad(d2poly)

for x in jnp.linspace(-3., 3., 11):
    print(f"x:{x:0.4f}\tdf:{dpoly(x):0.4f}\t df2:{d2poly(x):0.4f}\t df3:{d3poly(x):0.4f}")
x:-3.0000	df:-4.0000	 df2:2.0000	 df3:0.0000
x:-2.4000	df:-2.8000	 df2:2.0000	 df3:0.0000
x:-1.8000	df:-1.6000	 df2:2.0000	 df3:0.0000
x:-1.2000	df:-0.4000	 df2:2.0000	 df3:0.0000
x:-0.6000	df:0.8000	 df2:2.0000	 df3:0.0000
x:0.0000	df:2.0000	 df2:2.0000	 df3:0.0000
x:0.6000	df:3.2000	 df2:2.0000	 df3:0.0000
x:1.2000	df:4.4000	 df2:2.0000	 df3:0.0000
x:1.8000	df:5.6000	 df2:2.0000	 df3:0.0000
x:2.4000	df:6.8000	 df2:2.0000	 df3:0.0000
x:3.0000	df:8.0000	 df2:2.0000	 df3:0.0000

Por defecto la diferenciación se hace contra el primer argumento de la función. Para funciones con múltiples argumentos podemos modificar este comportamiento con el argumento argnums de jax.grad.

def f(x, w): 
    return jnp.dot(x,w)

x = jnp.array([-2., 2.])
w = jnp.array([1., -1.])

dfdx = jax.grad(f, argnums=0)
print(dfdx(x, w))

dfdw = jax.grad(f, argnums=1)
print(dfdw(x, w))
[ 1. -1.]
[-2.  2.]

Nota

JAX también tiene transformaciones para calcular jacobianos y hessianos, lo cual vuelve bastante factible la utilización de algoritmos de segundo orden para optimizar.

Vectorización

La última transformación que revisaremos es jax.vmap. Esta transformación recibe una función para valores escalares y retorna otra función que se puede aplicar sobre arreglos, es decir vectoriza automaticamente la función original.

Esta transformación se puede componer con grad y con jit:

def f(x): 
    return 1.0/(1.0 + jnp.exp(-x))

df = jax.grad(f) # Recibe y retorna escalares
vdf = jax.vmap(df) # Recibe y retorna arreglos
vdf_jit = jax.jit(vdf) 

x = jnp.linspace(-5, 5, 1000)
print(np.allclose(vdf(x), vdf_jit(x)))
%timeit -r3 -n3 jnp.array([df(x_) for x_ in x])
%timeit -r3 -n3 vdf(x).block_until_ready()
%timeit -r3 -n3 vdf_jit(x).block_until_ready()
True
4.43 s ± 108 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
5.96 ms ± 241 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
The slowest run took 5.58 times longer than the fastest. This could mean that an intermediate result is being cached.
46.9 µs ± 39.1 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
import holoviews as hv
hv.extension('bokeh')
hv.opts.defaults(hv.opts.Curve(width=400))

hv.Curve((x, f(x)), label='f(x)') * hv.Curve((x, vdf_jit(x)), label='df/dx')

Típicamente utilizaremos esta transformación cuando tengamos una función que debe aplicarse sobre un batch de datos de forma independiente: Single Instruction Multiple Data (SIMD)

Digamos que tenemos un dataset \(x\in \mathbb{R}^{NxD}\) y para cada fila \(i\) necesitamos calcular:

\[ y_i = w_0 + \sum_{j=1}^D w_j x_{ij} \]

con JAX esto sería:

def Dense(x, w): 
    return w[0] + jnp.sum(w[len(x):]*x)

N, D = 5, 3
data = jnp.ones(shape=(N, D))
param = jnp.ones(shape=(D+1)) 

vDense = jax.jit(jax.vmap(Dense, in_axes=[0, None]))

vDense(data, param)
DeviceArray([4., 4., 4., 4., 4.], dtype=float64)

Podemos especificar cual dimensión del tensor de entrada y salida queremos vectorizar con los argumentos in_axes y out_axes de jax.vmap, respectivamente.

Generación de números aleatorios

Existen algunas diferencias importantes entre el generador de números pseudo-aleatorios (PRNG) de NumPy y de JAX. La primera es que las rutins de PRNG de JAX están en jax.random en vez de jax.numpy.random que sería lo esperable.

Lo segundo y más fundamental, es que JAX no mantiene el estado del PRNG de manera global. En su lugar, es el usuario el que debe mantener el estado.

Veamos que significa en la práctica:

np.random.seed(1234) # Configura el estado inicial de forma global
for i in range(3):
    print(np.random.randn()) # Cada llamada actualiza el estado global
0.47143516373249306
-1.1909756947064645
1.4327069684260973

En JAX, las funciones de jax.random esperan una llave:

import jax.random as random

key = random.PRNGKey(1234) # Estado inicial
for i in range(3):
    print(random.normal(key)) # Ocupa la llave, pero la actualiza
0.4395758528880331
0.4395758528880331
0.4395758528880331

Si utilizamos la misma llave obtendremos el mismo resultado. Para no reutilizar una llave podemos utilizar random.split:

import jax.random as random

key = random.PRNGKey(1234)
for i in range(3):
    key, subkey = random.split(key) # Genera dos llaves nuevas
    print(random.normal(subkey)) 
-1.7366326757754018
-0.74630100088392
-0.17782682174856315

random.split recibe como argumento opcional num la cantidad de llaves nuevas a generar

Esta “complicación extra” es fundamental para poder vectorizar eficientemente cómputos que involucren PRNG

Otros detalles sobre JAX

  • Las llamadas en JAX se realizan de forma asíncrona. Es por esto que para medir tiempos forzabamos sincronía con .block_until_ready().

  • jax.numpy está construido en base al módulo de bajo nivel jax.lax, el cual es más estricto pero más eficiente.

  • JAX soporta paralelismo, es decir cómputos que usan más de un nucleo de CPU o más de una GPU. La transformación para implementar paralelismo es jax.pmap.

Reviews y comparativas: