paint-brush
Cómo muestrear desde el espacio latente con el codificador automático variacionalpor@owlgrey
7,717 lecturas
7,717 lecturas

Cómo muestrear desde el espacio latente con el codificador automático variacional

por Dmitrii Matveichev 17m2024/02/29
Read on Terminal Reader

Demasiado Largo; Para Leer

A diferencia de los modelos AE tradicionales, los codificadores automáticos variacionales (VAE) asignan entradas a una distribución normal multivariada, lo que permite la generación de datos novedosos a través de varios métodos de muestreo. Los métodos de muestreo tratados en este artículo son muestreo posterior, muestreo previo, interpolación entre dos vectores y recorrido de dimensión latente.
featured image - Cómo muestrear desde el espacio latente con el codificador automático variacional
Dmitrii Matveichev  HackerNoon profile picture

Al igual que los codificadores automáticos tradicionales, la arquitectura VAE tiene dos partes: un codificador y un decodificador. Los modelos AE tradicionales asignan entradas a un vector de espacio latente y reconstruyen la salida de este vector.


VAE asigna entradas a una distribución normal multivariada (el codificador genera la media y la varianza de cada dimensión latente).


Dado que el codificador VAE produce una distribución, los nuevos datos se pueden generar tomando muestras de esta distribución y pasando el vector latente muestreado al decodificador. El muestreo de la distribución producida para generar imágenes de salida significa que VAE permite generar datos novedosos que son similares, pero idénticos, a los datos de entrada.


Este artículo explora los componentes de la arquitectura VAE y proporciona varias formas de generar nuevas imágenes (muestreo) con modelos VAE. Todo el código está disponible en Google Colab .

1 Implementación del modelo VAE


El modelo AE se entrena minimizando la pérdida de reconstrucción (por ejemplo, BCE o MSE)


Los codificadores automáticos y los codificadores automáticos variacionales tienen dos partes: codificador y decodificador. La red neuronal codificadora de AE aprende a mapear cada imagen en un único vector en el espacio latente y el decodificador aprende a reconstruir la imagen original a partir del vector latente codificado.


El modelo VAE se entrena minimizando la pérdida de reconstrucción y la pérdida de divergencia KL


La red neuronal codificadora de VAE genera parámetros que definen una distribución de probabilidad para cada dimensión del espacio latente (distribución multivariada). Para cada entrada, el codificador produce una media y una varianza para cada dimensión del espacio latente.


La media y la varianza de salida se utilizan para definir una distribución gaussiana multivariada. La red neuronal del decodificador es la misma que en los modelos AE.

1.1 Pérdidas VAE

El objetivo de entrenar un modelo VAE es maximizar la probabilidad de generar imágenes reales a partir de vectores latentes proporcionados. Durante el entrenamiento, el modelo VAE minimiza dos pérdidas:


  • Pérdida de reconstrucción : la diferencia entre las imágenes de entrada y la salida del decodificador.


  • Pérdida de divergencia de Kullback-Leibler (Divergencia KL, una distancia estadística entre dos distribuciones de probabilidad): la distancia entre la distribución de probabilidad de la salida del codificador y una distribución previa (una distribución normal estándar), que ayuda a regularizar el espacio latente.

1.2 Pérdida de reconstrucción

Las pérdidas de reconstrucción comunes son la entropía cruzada binaria (BCE) y el error cuadrático medio (MSE). En este artículo, utilizaré el conjunto de datos MNIST para la demostración. Las imágenes MNIST tienen un solo canal y los píxeles toman valores entre 0 y 1.


En este caso, la pérdida BCE se puede utilizar como pérdida de reconstrucción para tratar los píxeles de imágenes MNIST como una variable aleatoria binaria que sigue la distribución de Bernoulli.

 reconstruction_loss = nn.BCELoss(reduction='sum')

1.3 Divergencia Kullback-Leibler

Como se mencionó anteriormente, la divergencia KL evalúa la diferencia entre dos distribuciones. Tenga en cuenta que no tiene una propiedad simétrica de distancia: KL(P‖Q)!=KL(Q‖P).


Las dos distribuciones que deben compararse son:

  • el espacio latente de la salida del codificador dadas las imágenes de entrada x: q(z|x)


  • espacio latente anterior p(z) que se supone que es una distribución normal con una media de cero y una desviación estándar de uno en cada dimensión del espacio latente N(0, I ) .


    Tal suposición simplifica el cálculo de la divergencia de KL y fomenta que el espacio latente siga una distribución conocida y manejable.

 from torch.distributions.kl import kl_divergence def kl_divergence_loss(z_dist): return kl_divergence(z_dist, Normal(torch.zeros_like(z_dist.mean), torch.ones_like(z_dist.stddev)) ).sum(-1).sum()

1.4 Codificador

 class Encoder(nn.Module): def __init__(self, im_chan=1, output_chan=32, hidden_dim=16): super(Encoder, self).__init__() self.z_dim = output_chan self.encoder = nn.Sequential( self.init_conv_block(im_chan, hidden_dim), self.init_conv_block(hidden_dim, hidden_dim * 2), # double output_chan for mean and std with [output_chan] size self.init_conv_block(hidden_dim * 2, output_chan * 2, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=0, final_layer=False): layers = [ nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, padding=padding, stride=stride) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] return nn.Sequential(*layers) def forward(self, image): encoder_pred = self.encoder(image) encoding = encoder_pred.view(len(encoder_pred), -1) mean = encoding[:, :self.z_dim] logvar = encoding[:, self.z_dim:] # encoding output representing standard deviation is interpreted as # the logarithm of the variance associated with the normal distribution # take the exponent to convert it to standard deviation return mean, torch.exp(logvar*0.5)

1.5 Decodificador

 class Decoder(nn.Module): def __init__(self, z_dim=32, im_chan=1, hidden_dim=64): super(Decoder, self).__init__() self.z_dim = z_dim self.decoder = nn.Sequential( self.init_conv_block(z_dim, hidden_dim * 4), self.init_conv_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), self.init_conv_block(hidden_dim * 2, hidden_dim), self.init_conv_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False): layers = [ nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] else: layers += [nn.Sigmoid()] return nn.Sequential(*layers) def forward(self, z): # Ensure the input latent vector z is correctly reshaped for the decoder x = z.view(-1, self.z_dim, 1, 1) # Pass the reshaped input through the decoder network return self.decoder(x)

1.6 Modelo VAE

Para propagar hacia atrás a través de una muestra aleatoria, debe mover los parámetros de la muestra aleatoria ( μ y 𝝈) fuera de la función para permitir el cálculo del gradiente a través de los parámetros. Este paso también se denomina "truco de reparametrización".


En PyTorch, puede crear una distribución normal con la salida μ y 𝝈 del codificador y muestrearla con el método rsample() que implementa el truco de reparametrización: es lo mismo que torch.randn(z_dim) * stddev + mean)

 class VAE(nn.Module): def __init__(self, z_dim=32, im_chan=1): super(VAE, self).__init__() self.z_dim = z_dim self.encoder = Encoder(im_chan, z_dim) self.decoder = Decoder(z_dim, im_chan) def forward(self, images): z_dist = Normal(self.encoder(images)) # sample from distribution with reparametarazation trick z = z_dist.rsample() decoding = self.decoder(z) return decoding, z_dist

1.7 Entrenamiento de un VAE

Cargue el tren MNIST y los datos de prueba.

 transform = transforms.Compose([transforms.ToTensor()]) # Download and load the MNIST training data trainset = datasets.MNIST('.', download=True, train=True, transform=transform) train_loader = DataLoader(trainset, batch_size=64, shuffle=True) # Download and load the MNIST test data testset = datasets.MNIST('.', download=True, train=False, transform=transform) test_loader = DataLoader(testset, batch_size=64, shuffle=True) 

Pasos del entrenamiento VAE

Cree un bucle de entrenamiento que siga los pasos de entrenamiento de VAE visualizados en la figura anterior.

 def train_model(epochs=10, z_dim = 16): model = VAE(z_dim=z_dim).to(device) model_opt = torch.optim.Adam(model.parameters()) for epoch in range(epochs): print(f"Epoch {epoch}") for images, step in tqdm(train_loader): images = images.to(device) model_opt.zero_grad() recon_images, encoding = model(images) loss = reconstruction_loss(recon_images, images)+ kl_divergence_loss(encoding) loss.backward() model_opt.step() show_images_grid(images.cpu(), title=f'Input images') show_images_grid(recon_images.cpu(), title=f'Reconstructed images') return model
 z_dim = 8 vae = train_model(epochs=20, z_dim=z_dim) 

1.8 Visualizar el espacio latente

 def visualize_latent_space(model, data_loader, device, method='TSNE', num_samples=10000): model.eval() latents = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): if len(latents) > num_samples: break mu, _ = model.encoder(data.to(device)) latents.append(mu.cpu()) labels.append(label.cpu()) latents = torch.cat(latents, dim=0).numpy() labels = torch.cat(labels, dim=0).numpy() assert method in ['TSNE', 'UMAP'], 'method should be TSNE or UMAP' if method == 'TSNE': tsne = TSNE(n_components=2, verbose=1) tsne_results = tsne.fit_transform(latents) fig = px.scatter(tsne_results, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with TSNE', width=600, height=600) elif method == 'UMAP': reducer = umap.UMAP() embedding = reducer.fit_transform(latents) fig = px.scatter(embedding, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with UMAP', width=600, height=600 ) fig.show()
 visualize_latent_space(vae, train_loader, device='cuda' if torch.cuda.is_available() else 'cpu', method='UMAP', num_samples=10000) 

Espacio latente del modelo VAE visualizado con UMAP

2 Muestreo con VAE

El muestreo de un codificador automático variacional (VAE) permite la generación de nuevos datos similares a los vistos durante el entrenamiento y es un aspecto único que separa al VAE de la arquitectura AE tradicional.


Hay varias formas de tomar muestras de un VAE:

  • muestreo posterior: muestreo de la distribución posterior dada una entrada proporcionada.


  • muestreo previo: muestreo del espacio latente asumiendo una distribución multivariada normal estándar. Esto es posible debido a la suposición (utilizada durante el entrenamiento VAE) de que las variables latentes se distribuyen normalmente. Este método no permite generar datos con propiedades específicas (por ejemplo, generar datos de una clase específica).


  • interpolación : la interpolación entre dos puntos en el espacio latente puede revelar cómo los cambios en la variable del espacio latente corresponden a cambios en los datos generados.


  • recorrido de dimensiones latentes : atravesar dimensiones latentes de VAE la variación del espacio latente de los datos depende de cada dimensión. El recorrido se realiza fijando todas las dimensiones del vector latente excepto una dimensión elegida y variando los valores de la dimensión elegida en su rango. Algunas dimensiones del espacio latente pueden corresponder a atributos específicos de los datos (VAE no tiene mecanismos específicos para forzar ese comportamiento pero puede suceder).


    Por ejemplo, una dimensión en el espacio latente puede controlar la expresión emocional de un rostro o la orientación de un objeto.


Cada método de muestreo proporciona una forma diferente de explorar y comprender las propiedades de los datos capturados por el espacio latente de VAE.

2.1 Muestreo posterior (a partir de una imagen de entrada determinada)

El codificador genera una distribución (μ_x y 𝝈_x de distribución normal) en el espacio latente. Muestreo de la distribución normal N(μ_x, _x) y pasar el vector muestreado al decodificador genera una imagen similar a la imagen de entrada dada.

 def posterior_sampling(model, data_loader, n_samples=25): model.eval() images, _ = next(iter(data_loader)) images = images[:n_samples] with torch.no_grad(): _, encoding_dist = model(images.to(device)) input_sample=encoding_dist.sample() recon_images = model.decoder(input_sample) show_images_grid(images, title=f'input samples') show_images_grid(recon_images, title=f'generated posterior samples')
 posterior_sampling(vae, train_loader, n_samples=30)

El muestreo posterior permite generar muestras de datos realistas pero con baja variabilidad: los datos de salida son similares a los datos de entrada.

2.2 Muestreo previo (a partir de un vector espacial latente aleatorio)

Tomar muestras de la distribución y pasar el vector muestreado al decodificador permite generar nuevos datos.

 def prior_sampling(model, z_dim=32, n_samples = 25): model.eval() input_sample=torch.randn(n_samples, z_dim).to(device) with torch.no_grad(): sampled_images = model.decoder(input_sample) show_images_grid(sampled_images, title=f'generated prior samples')
 prior_sampling(vae, z_dim, n_samples=40)

El muestreo previo con N(0, I ) no siempre genera datos plausibles pero tiene una alta variabilidad.

2.3 Muestreo de los centros de clase

Las codificaciones medias de cada clase se pueden acumular a partir de todo el conjunto de datos y luego usarse para una generación controlada (condicional).

 def get_data_predictions(model, data_loader): model.eval() latents_mean = [] latents_std = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): mu, std = model.encoder(data.to(device)) latents_mean.append(mu.cpu()) latents_std.append(std.cpu()) labels.append(label.cpu()) latents_mean = torch.cat(latents_mean, dim=0) latents_std = torch.cat(latents_std, dim=0) labels = torch.cat(labels, dim=0) return latents_mean, latents_std, labels
 def get_classes_mean(class_to_idx, labels, latents_mean, latents_std): classes_mean = {} for class_name in train_loader.dataset.class_to_idx: class_id = train_loader.dataset.class_to_idx[class_name] labels_class = labels[labels==class_id] latents_mean_class = latents_mean[labels==class_id] latents_mean_class = latents_mean_class.mean(dim=0, keepdims=True) latents_std_class = latents_std[labels==class_id] latents_std_class = latents_std_class.mean(dim=0, keepdims=True) classes_mean[class_id] = [latents_mean_class, latents_std_class] return classes_mean
 latents_mean, latents_stdvar, labels = get_data_predictions(vae, train_loader) classes_mean = get_classes_mean(train_loader.dataset.class_to_idx, labels, latents_mean, latents_stdvar) n_samples = 20 for class_id in classes_mean.keys(): latents_mean_class, latents_stddev_class = classes_mean[class_id] # create normal distribution of the current class class_dist = Normal(latents_mean_class, latents_stddev_class) percentiles = torch.linspace(0.05, 0.95, n_samples) # get samples from different parts of the distribution using icdf # https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.icdf class_z_sample = class_dist.icdf(percentiles[:, None].repeat(1, z_dim)) with torch.no_grad(): # generate image directly from mean class_image_prototype = vae.decoder(latents_mean_class.to(device)) # generate images sampled from Normal(class mean, class std) class_images = vae.decoder(class_z_sample.to(device)) show_image(class_image_prototype[0].cpu(), title=f'Class {class_id} prototype image') show_images_grid(class_images.cpu(), title=f'Class {class_id} images')

El muestreo de una distribución normal con clase μ promediada garantiza la generación de nuevos datos de la misma clase.

imagen generada desde el centro de clase 3

imagen generada desde el centro de clase 4

Los valores de percentil alto y bajo utilizados en icdf dan como resultado una alta variación de datos

2.4 Interpolación

 def linear_interpolation(start, end, steps): # Create a linear path from start to end z = torch.linspace(0, 1, steps)[:, None].to(device) * (end - start) + start # Decode the samples along the path vae.eval() with torch.no_grad(): samples = vae.decoder(z) return samples

2.4.1 Interpolación entre dos vectores latentes aleatorios

 start = torch.randn(1, z_dim).to(device) end = torch.randn(1, z_dim).to(device) interpolated_samples = linear_interpolation(start, end, steps = 24) show_images_grid(interpolated_samples, title=f'Linear interpolation between two random latent vectors') 

2.4.2 Interpolación entre dos centros de clase

 for start_class_id in range(1,10): start = classes_mean[start_class_id][0].to(device) for end_class_id in range(1, 10): if end_class_id == start_class_id: continue end = classes_mean[end_class_id][0].to(device) interpolated_samples = linear_interpolation(start, end, steps = 20) show_images_grid(interpolated_samples, title=f'Linear interpolation between classes {start_class_id} and {end_class_id}') 

2.5 Recorrido del espacio latente

Cada dimensión del vector latente representa una distribución normal; el rango de valores de la dimensión está controlado por la media y la varianza de la dimensión. Una forma sencilla de recorrer el rango de valores sería utilizar CDF (funciones de distribución acumulativa) inversas de la distribución normal.


ICDF toma un valor entre 0 y 1 (que representa la probabilidad) y devuelve un valor de la distribución. Para una probabilidad dada p, ICDF genera un valor p_icdf tal que la probabilidad de que una variable aleatoria sea <= p_icdf es igual a la probabilidad dada p .


Si tiene una distribución normal, icdf(0.5) debería devolver la media de la distribución. icdf(0.95) debería devolver un valor mayor que el 95% de los datos de la distribución.

Visualización de CDF y valores devueltos por ICDF dadas probabilidades 0,025, 0,5, 0,975

2.5.1 Recorrido del espacio latente unidimensional

 def latent_space_traversal(model, input_sample, norm_dist, dim_to_traverse, n_samples, latent_dim, device): # Create a range of values to traverse assert input_sample.shape[0] == 1, 'input sample shape should be [1, latent_dim]' # Generate linearly spaced percentiles between 0.05 and 0.95 percentiles = torch.linspace(0.1, 0.9, n_samples) # Get the quantile values corresponding to the percentiles traversed_values = norm_dist.icdf(percentiles[:, None].repeat(1, z_dim)) # Initialize a latent space vector with zeros z = input_sample.repeat(n_samples, 1) # Assign the traversed values to the specified dimension z[:, dim_to_traverse] = traversed_values[:, dim_to_traverse] # Decode the latent vectors with torch.no_grad(): samples = model.decoder(z.to(device)) return samples
 for class_id in range(0,10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') for i in range(z_dim): interpolated_samples = latent_space_traversal(vae, mu, norm_dist=Normal(mu, torch.ones_like(mu)), dim_to_traverse=i, n_samples=20, latent_dim=z_dim, device=device) show_images_grid(interpolated_samples, title=f'Class {class_id} dim={i} traversal')

Atravesar una única dimensión puede provocar un cambio en el estilo de los dígitos o en la orientación de los dígitos de control.

2.5.3 Recorrido del espacio latente en dos dimensiones

 def traverse_two_latent_dimensions(model, input_sample, z_dist, n_samples=25, z_dim=16, dim_1=0, dim_2=1, title='plot'): digit_size=28 percentiles = torch.linspace(0.10, 0.9, n_samples) grid_x = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) grid_y = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) figure = np.zeros((digit_size * n_samples, digit_size * n_samples)) z_sample_def = input_sample.clone().detach() # select two dimensions to vary (dim_1 and dim_2) and keep the rest fixed for yi in range(n_samples): for xi in range(n_samples): with torch.no_grad(): z_sample = z_sample_def.clone().detach() z_sample[:, dim_1] = grid_x[xi, dim_1] z_sample[:, dim_2] = grid_y[yi, dim_2] x_decoded = model.decoder(z_sample.to(device)).cpu() digit = x_decoded[0].reshape(digit_size, digit_size) figure[yi * digit_size: (yi + 1) * digit_size, xi * digit_size: (xi + 1) * digit_size] = digit.numpy() plt.figure(figsize=(6, 6)) plt.imshow(figure, cmap='Greys_r') plt.title(title) plt.show()
 for class_id in range(10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') traverse_two_latent_dimensions(vae, mu, z_dist=Normal(mu, torch.ones_like(mu)), n_samples=8, z_dim=z_dim, dim_1=3, dim_2=6, title=f'Class {class_id} traversing dimensions {(3, 6)}')

Atravesar varias dimensiones a la vez proporciona una forma controlable de generar datos con alta variabilidad.

Bonificación 2.6: variedad 2D de dígitos del espacio latente

Si se entrena un modelo VAE con z_dim =2, es posible mostrar una variedad 2D de dígitos desde su espacio latente. Para hacer eso, usaré la función traverse_two_latent_dimensions con dim_1 =0 y dim_2 =2 .

 vae_2d = train_model(epochs=10, z_dim=2)
 z_dist = Normal(torch.zeros(1, 2), torch.ones(1, 2)) input_sample = torch.zeros(1, 2) with torch.no_grad(): decoding = vae_2d.decoder(input_sample.to(device)) traverse_two_latent_dimensions(vae_2d, input_sample, z_dist, n_samples=20, dim_1=0, dim_2=1, z_dim=2, title=f'traversing 2D latent space') 

espacio latente 2D