Aplicación: Visualización de espacio latente de estrellas variables

Datos

Existen muchos tipos de objeto astronómico cuyo brillo varía en el tiempo. Una herramienta fundamental para estudiar estos objetos son las curvas de luz.

Una curva de luz es una serie de tiempo del brillo aparente de un objeto astronómico.

El brillo aparente o flujo se estima a partir de las imágenes de telescopio utilizando técnicas de fotometría. Un mismo objeto es “seguido” por varias noches como muestra el siguiente esquema:

../_images/ZTF.png ../_images/intro-sources.png
../_images/intro-sources-time.png

La colección de mediciones para un mismo objeto se grafica en el tiempo como muestra el siguiente ejemplo:

import sys
import holoviews as hv
hv.extension('bokeh')

sys.path.append('../src/')
from preprocessing import load_ztf_data
from plotting import plot_light_curve

lcs, periods, labels = load_ztf_data()
plot_light_curve(lcs[0])

donde:

  • El eje horizontal corresponde al tiempo, que se expresa en días julianos.

  • El eje vertical corresponde a la magnitud, una medida relativa de brillo aparente (menor magnitud, más brillante).

  • Las barras de error corresponden al error fotométrico, un estimado de la calidad de la medición.

  • El color representa al filtro óptico utilizado. En este caso se utilizaron dos filtros.

De la figura también podemos apreciar las características típicas de una curva de luz:

  • Irregularmente muestreadas: El tiempo entre observaciones no es constante.

  • Multivariadas: Una serie de tiempo por filtro, observadas de forma no simultanea.

  • Heteroscedástica: La varianza del ruido cambia en el tiempo:.

Estrellas variables periódicas

Las curvas de luz son particularmente útiles para estudiar las estrellas variables.

Las estrellas variables son estrellas cuyo brillo varía en el tiempo ya sea de forma regular o estocástica.

Un ejemplo de variabilidad regular son las estrellas pulsantes como las RR Lyrae y las Cefeidas.

from IPython.display import YouTubeVideo
YouTubeVideo('sXJBrRmHPj8')

Estas estrellas pulsan radialmente y de forma regular, se expanden (calientan) y contraen (enfrían) con un periódo estable.

Si conocemos el período \(P\) de la estrella se puede utilizar la siguiente transformación (epoch folding):

\[ \phi_i = \mod(t_i, P)/P \]

para convertir el eje de tiempo en un eje de fase, como muestra el siguiente esquema:

../_images/folding.png

Por ejemplo, la curva de luz que se mostró anteriormente tiene un periodo de:

periods[0]
0.6692088

dias, y si lo utilizamos para “doblar” tenemos:

data_plot = plot_light_curve(lcs[0], periods[0])
data_plot

La curva “doblada” representa de forma clara lo que ocurren durante una fase de la periodicidad de la estrella

Para entrenar un modelo sobre las curvas dobladas primero las interpolaremos a una grilla regular mediante suavizado por kernel.

El resultado del suavizado para la curva anterior se muestra a continuación:

import numpy as np
from preprocessing import kernel_smoothing
from plotting import plot_smoothed

pha_interp = np.linspace(0, 1, num=40)
mag_interp, err_interp, _ = kernel_smoothing(lcs[0], periods[0], pha_interp, align=False, normalize=False)
data_plot * plot_smoothed(pha_interp, mag_interp, err_interp)

Modelo

En este ejemplo práctico utilizaremos un AutoEncoder Variacional (Variational Autoencoder, VAE) (Kingma et al., 2014) para reducir la dimensionalidad de un dataset de curvas de luz de estrellas variables.

El siguiente es un diagrama del modelo:

../_images/VAE.png

donde

  • \(x\) se refiere a los datos observados, en este caso las curvas de luz.

  • \(z\) se refiere a la variable latente de dimensión reducida que queremos inferir.

  • \(g_\phi\) es una red neuronal artificial que llamaremos Codificador.

  • \(f_\theta\) es una red neuronal artificial que llamaremos Decodificador.

Un VAE es un modelo probabilístico generativo donde se consideran los siguientes supuestos:

  • \(p(z) = \mathcal{N}(0, I)\), es decir la distribución a priori de \(z\) es normal estándar.

  • \(p(x|z) = \mathcal{N}(\hat \mu, \hat \sigma^2)\), una verosimilitud normal para \(x\).

  • Se utiliza el decodificador para modelar \(\hat \mu = f_\theta(z)\).

Lo que buscamos es inferir \(z\) a partir de \(x\), es decir el posterior de \(z\):

\[ p(z|x) = \frac{p(x|z) p(z)}{\int p(x|z) p(z) dz}. \]

Como lo anterior es muy difícil de calcular lo reemplazamos por una aproximación variacional. En este caso el posterior variacional es una distribución normal multivariada con covarianza diagonal (sin correlaciones):

\[\begin{split} \begin{split} p(z|x) \approx q_\phi(z|x) &= \mathcal{N}(\mu(x), \sigma(x)^2) \\ &= \mu(x) + \sigma(x) \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \end{split} \end{split}\]

donde el codificador se utiliza para modelar \(\mu_i, \sigma_i = g_\phi(x_i)\) de forma amortizada. Además se utiliza el truco de reparametrización (segunda linea de la ecuación).

Implementación

Implementaremos un VAE para curvas de luz de estrellas periódicas con dos bandas utilizando en flax, considerando lo siguiente:

  • El codificador procesa cada banda por separado y luego las combina en un único espacio latente.

  • El codificador retorna la media y la desviación estándar de la variable latente.

  • La desviación estándar debe ser no-negativa.

  • El decodificador recibe la variable latente y genera las curvas de cada banda.

from typing import Sequence, Callable
import jax
import jax.numpy as jnp
import flax.linen as nn

class Encoder(nn.Module):
    hidden_units: int
    latent_dim: int
    activation: Callable = nn.relu
    
    @nn.compact 
    def __call__(self, x):
        g_0 = self.activation(nn.Dense(self.hidden_units)(x[:, 0, :])) # g-band
        g_1 = self.activation(nn.Dense(self.hidden_units)(x[:, 1, :])) # r-band
        g = self.activation(nn.Dense(self.hidden_units*2)(jnp.concatenate([g_0, g_1], axis=1)))
        g = self.activation(nn.Dense(self.hidden_units)(g))
        z_loc = nn.Dense(self.latent_dim)(g)
        z_scale = nn.softplus(nn.Dense(self.latent_dim)(g))
        return z_loc, z_scale
    
class Decoder(nn.Module):
    output_dim: int
    hidden_units: int
    activation: Callable = nn.relu
        
    @nn.compact
    def __call__(self, z):
        
        f = self.activation(nn.Dense(self.hidden_units)(z))
        f = self.activation(nn.Dense(self.hidden_units*2)(f))
        f_loc0 = self.activation(nn.Dense(self.output_dim)(f))
        f_loc1 = self.activation(nn.Dense(self.output_dim)(f))
        x_loc0 = nn.Dense(self.output_dim)(f_loc0) # g-band
        x_loc1 = nn.Dense(self.output_dim)(f_loc1) # r-band
        return x_loc0, x_loc1

Por conveniencia implementaremos también un módulo que llame a los anteriores y realice el truco re-parametrización del espacio latente:

class VAE(nn.Module):
    hidden_units: int
    latent_dim: int
    output_dim: int
    
    @nn.compact    
    def __call__(self, x, z_rng_key):
        z_loc, z_scale = Encoder(self.hidden_units, self.latent_dim)(x)
        eps = random.normal(z_rng_key, z_scale.shape)
        z = z_loc + eps*z_scale
        x_loc0, x_loc1 = Decoder(self.output_dim, self.hidden_units)(z)
        return x_loc0, x_loc1, z_loc, z_scale

Preparación de datos

A continuación se preparan los datos para entrenar el modelo.

  • El modelo se entrena sobre curvas en fase (dobladas) que han sido interpoladas a una grilla regular mediante suavizado con kernels.

  • Se aplica un reescalamiento de tipo MinMax a las curvas interpoladas.

  • Las curvas en fase se alinean para partir en el mínimo de magnitud (máximo brillo)

import sys
import numpy as np
from sklearn.preprocessing import LabelEncoder  

le = LabelEncoder()
labels_int = le.fit_transform(labels)

pha_interp = np.linspace(0, 1, num=40)
mag_interp = np.zeros(shape=(len(lcs), 2, len(pha_interp)))
err_interp = np.zeros(shape=(len(lcs), 2, len(pha_interp)))

for k, (lc, period) in enumerate(zip(lcs, periods)):
    mag_interp[k], err_interp[k], _ = kernel_smoothing(lc, period, pha_interp)

mags_interp_jax = jnp.array(mag_interp)
errs_interp_jax = jnp.array(err_interp)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Entrenamiento

Para entrenar el modelo utilizaremos el algoritmo de optimización AMSgrad, que es una extensión de gradiente descedente con tasa de aprendizaje adaptiva (Adam).

A continuación se muestra la función chain de la librería optax la cual permite implementar optimizadores customizados. En este caso el optimizador primero satura las gradientes con norma mayor a 10.0, luego aplica escala los gradients con las reglas adaptivas de AMSgrad y finalmente multiplica por la tasa de aprendizaje inicial:

import optax

optimizer = optax.chain(optax.clip_by_global_norm(10.0),  
                        optax.scale_by_amsgrad(),
                        optax.scale(-1e-3))

Luego implementaremos la función de costo de VAE, el Evidence Lower Bound (ELBO) de la aproximación variacional. Matemáticamente esto se define como

\[ \mathcal{L}(\theta, \phi) = \mathbb{E}_{z\sim q_\phi(z|x)} \left [\log p_\theta(x|z) \right ] - D_{KL}\left[ q_\phi(z|x) || p(z) \right] \]

donde, bajo los supuestos considerados, el término de la mano derecha tiene la siguiente solución analítica:

\[ D_\text{KL}\left[q_\phi(z|x) || p(z) \right] = \frac{1}{2}\sum_{j=1}^K \left(\mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1 \right) \]

Para optimizar el VAE se busca maximizar el ELBO en función de los parámetros del codificador y decodificador.

Utilizaremos compilación JIT para acelerar el cálculo de los gradientes del ELBO

def negELBO(params, mag, err, key):
    x_loc0, x_loc1, z_loc, z_scale = model.apply(params, mag, key)
    d0 = 0.5*jnp.sum(jnp.square((x_loc0 - mag[:, 0, :])/err[:, 0, :]), axis=-1)
    d1 = 0.5*jnp.sum(jnp.square((x_loc1 - mag[:, 1, :])/err[:, 1, :]), axis=-1)
    kl_div = 0.5 * jnp.sum(-1 - 2*jnp.log(z_scale) + jnp.square(z_loc) + jnp.square(z_scale), axis=-1)
    return jnp.sum(d0 + d1 + kl_div)

grad_loss_jit = jax.jit(jax.value_and_grad(negELBO, argnums=0))

Finalmente inicializamos el modelo y el optimizador y lanzamos la rutina de entrenamiento.

Se entrena por 300 épocas con minibatches de tamaño 32:

import jax.random as random
from tqdm import tqdm
from train_utils import data_loader

key = random.PRNGKey(12345)
model = VAE(output_dim=40, hidden_units=100, latent_dim=2)
key, key_ = random.split(key)
params = model.init(key, jnp.zeros(shape=(1, 2, 40)), key_)
state = optimizer.init(params)

loss_history = []
for epoch in tqdm(range(300)):
    loss_epoch = 0.0
    key, key_ = random.split(key)
    for bmag, berr in data_loader(key_, mags_interp_jax, errs_interp_jax, 
                                  batch_size=32, shuffle=True):
        key, key_ = random.split(key)
        loss_val, grads = grad_loss_jit(params, bmag, berr, key_) 
        loss_epoch += loss_val.item()
        updates, state = optimizer.update(grads, state, params) 
        params = optax.apply_updates(params, updates)
    loss_history.append(loss_epoch/len(lcs))

hv.Curve(loss_history, 'Epoch', 'negative ELBO').opts(width=500, logy=True)
100%|████████| 300/300 [18:57<00:00,  3.79s/it]

Evaluación y visualizaciones

Primero utilizamos el modelo para inferir las reconstrucciones y variables latentes del dataset completo:

key, key_ = random.split(key)
x_loc0, x_loc1, z_loc, z_scale = model.apply(params, mag_interp, key_)

A continuación se muestra ejemplos de distintos tipos de estrella variable. Las lineas corresponden a las reconstrucciones del modelo y las datos de entrada (curvas de luz interpoladas).

from plotting import plot_reconstruction, plot_latent_space, plot_latent_generation

hv.Layout([plot_reconstruction(x_loc0[idx], x_loc1[idx], pha_interp, 
                               mag_interp[idx], err_interp[idx], labels[idx]) for idx in [0, 2000, 3550, 3385]]).cols(2)

Luego se visualiza el espacio latente bidimensional. Cada color representa un tipo de estrella variable.

Podemos notar que las clases tienden a separarse en el espacio latente. Notar que el entrenamiento fue no supervisado, la clase no se utilizó para ajustar el modelo.

plot_latent_space(z_loc, z_scale, labels_int, le)

VAE es un modelo generativo. Podemos aprovechar esta capacidad para interpolar en el espacio latente y visualizar las formas de curva de luz que el modelo aprendió a reconstruir.

from functools import partial

z1 = jnp.linspace(-1, 1, 9)
z2 = jnp.linspace(-2, 2, 9)

decoder = Decoder(model.output_dim, model.hidden_units)
generate_lc = partial(jax.jit(decoder.apply), {'params': params['params']['Decoder_0']})

plot_latent_generation(z1, z2, pha_interp, generate_lc)

Comentarios:

  • La reducción no-lineal de dimensionalidad nos permite hacer visualizaciones y explorar un dataset cuando no se tienen etiquetas. También puede utilizarse en estrategias de aprendizaje activo (active learning) para etiquetar un dataset de forma eficiente.

  • El espacio latente puede utilizarse también como entrada a un clasificador si se cuenta con algunas etiquetas (semi-supervisado).

  • La capacidad generativa del modelo puede aprovecharse para hacer aumentación de datos y también para interpretar lo que ha aprendido el modelo durante el entrenamiento.

¿Te interesa la astroinformática?

Visita el sitio web del proyecto ALeRCE:

../_images/alerce.png