17. Aumentación de datos con torchvision#

Si tenemos un dataset de imágenes muy pequeño nuestro modelo podría sobreajustarse. Podemos intentar incrementar nuestro dataset usando transformaciones.

Si rotamos, trasladamos o cambiamos el brillo de una imagen obtendremos una nueva imagen “casi siempre” de la misma clase. La librería torchvision tiene funciones implementadas para hacer transformaciones en imágenes:

Cada transformación permite especificar límites, por ejemplo “máximo ángulo de rotación”, “máxima distorsión de brillo”, etc.

Nota

Las transformaciones también sirven para hacer que la red gane “invarianzas”. Por ejemplo, si entrenamos con copias rotadas de nuestras imágenes, la red se volverá invariante a la rotación.

Importante

Las transformaciones que apliquemos no deben cambiar la interpretación de clase. Ejemplos:

  • Si rotas un seis en 180 grados se convierte en un nueve

  • Si cambias demasiado el tono (hue) podrías obtener colores distintos a la realidad (¿perro verde?)

17.1. Transformaciones aleatorias#

La mayoría de las transformaciones están diseñadas para aplicarse sobre imágenes en formato PIL.

Podemos componer varias transformaciones usando torchvision.transforms.Compose.

%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from torchvision import transforms
from PIL import Image

img = Image.open("img/dog.jpg")

my_transform = transforms.Compose([transforms.Resize(200),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.RandomRotation(degrees=30),
                                   transforms.ColorJitter(brightness=0.5, contrast=0.5,
                                                          saturation=0.5, hue=0.0),
                                  ])

display(transforms.Resize(200)(img), 
        my_transform(img))
../../_images/19fa1e237971d3af977c6417a2439e3f2cccb5a5f64f00cb178c9bc9219bb095.png ../../_images/4ee13721cfbde5d6af4fad5bb03a7162a0ce6d58c50caffdff65af9938cf2c72.png

17.2. Entrenando con datos aumentados#

Podemos componer una transformación y añadirla a un Dataset. Luego cuando usamos el DataLoader se generarán imágenes con transformaciones aleatorias de forma automática.

Importante

¡Sólo se aumenta el conjunto de entrenamiento! Los conjuntos de validación y prueba no deben tener aumentación sintética.

A modo de ejemplos aplicaremos una composición de transformaciones al dataset de dígitos manuscritos MNIST. Primero creamos la secuencia de transformaciones para el conjunto de entrenamiento:

mnist_transform = transforms.Compose([transforms.RandomAffine(degrees=30, translate=(0.2, 0.2), 
                                                              scale=(0.5, 1.5), shear=None, 
                                                              interpolation=0, fill=0),
                                      transforms.ColorJitter(brightness=0.5, contrast=0.5, 
                                                             saturation=0.5, hue=0.0),
                                      transforms.ToTensor()
                                     ])

Luego cargamos los datos de entrenamiento y le aplicamos la composición utilizando el argumento transform

from torchvision.datasets import MNIST

mnist_train_data = MNIST(root='~/datasets', train=True, download=True, 
                         transform=mnist_transform)

Si creamos un dataloader a partir de este dataset entonces se aplicarán las transformaciones de forma aleatoria a cada minibatch

from torch.utils.data import DataLoader
train_loader = DataLoader(mnist_train_data, shuffle=False, batch_size=32)

for image, label in train_loader:
    break # El primer minibatch

fig, ax = plt.subplots(4, 8, figsize=(7, 4), tight_layout=True)
for k in range(32):
    i, j = np.unravel_index(k, (4, 8))
    ax[i, j].axis('off')
    ax[i, j].set_title(label[k].numpy())
    ax[i, j].imshow(image[k].numpy()[0, :, :], cmap=plt.cm.Greys_r)
../../_images/c2b0e8dd5919376d62c5e7db420afc1e56d2bf4337e1cec9ffc708ad88f3a22e.png

Por ejemplo si volvemos a pedir el primer minibatch obtenemos

for image, label in train_loader:
    break # El primer minibatch

fig, ax = plt.subplots(4, 8, figsize=(7, 4), tight_layout=True)
for k in range(32):
    i, j = np.unravel_index(k, (4, 8))
    ax[i, j].axis('off')
    ax[i, j].set_title(label[k].numpy())
    ax[i, j].imshow(image[k].numpy()[0, :, :], cmap=plt.cm.Greys_r)
../../_images/916513d12a357e6adc6288d1d9c9369d7c311324d93357875df604e2b982d382.png

Notar que en este caso shuffle=False por lo que los ejemplos originales son los mismos. Sin embargo, las imágenes se ven distintas porque las transformaciones son aleatorias.

Importante

El dataset de entrenamiento es ligeramente distinto en cada época. Es como si tuvieramos un dataset virtualmente más grande y más variado.