Giống như bộ mã hóa tự động truyền thống, kiến trúc VAE có hai phần: bộ mã hóa và bộ giải mã. Các mô hình AE truyền thống ánh xạ đầu vào vào một vectơ không gian tiềm ẩn và tái tạo lại đầu ra từ vectơ này.
VAE ánh xạ các đầu vào thành một phân phối chuẩn đa biến (bộ mã hóa đưa ra giá trị trung bình và phương sai của từng thứ nguyên tiềm ẩn).
Do bộ mã hóa VAE tạo ra một phân phối nên dữ liệu mới có thể được tạo bằng cách lấy mẫu từ phân phối này và chuyển vectơ tiềm ẩn được lấy mẫu vào bộ giải mã. Lấy mẫu từ phân phối được tạo ra để tạo ra hình ảnh đầu ra có nghĩa là VAE cho phép tạo ra dữ liệu mới tương tự nhưng giống hệt với dữ liệu đầu vào.
Bài viết này tìm hiểu các thành phần của kiến trúc VAE và cung cấp một số cách tạo hình ảnh mới (lấy mẫu) bằng mô hình VAE. Tất cả mã đều có sẵn trong Google Colab .
Bộ mã hóa tự động và Bộ mã hóa tự động biến thể đều có hai phần: bộ mã hóa và bộ giải mã. Mạng nơ-ron mã hóa của AE học cách ánh xạ từng hình ảnh thành một vectơ duy nhất trong không gian tiềm ẩn và bộ giải mã học cách tái tạo lại hình ảnh gốc từ vectơ tiềm ẩn được mã hóa.
Mạng thần kinh mã hóa của các tham số đầu ra VAE xác định phân bố xác suất cho từng chiều của không gian tiềm ẩn (phân phối đa biến). Đối với mỗi đầu vào, bộ mã hóa tạo ra giá trị trung bình và phương sai cho từng chiều của không gian tiềm ẩn.
Giá trị trung bình và phương sai đầu ra được sử dụng để xác định phân phối Gaussian đa biến. Mạng nơ-ron giải mã giống như trong các mô hình AE.
Mục tiêu của việc huấn luyện mô hình VAE là tối đa hóa khả năng tạo ra hình ảnh thực từ các vectơ tiềm ẩn được cung cấp. Trong quá trình huấn luyện, mô hình VAE giảm thiểu hai tổn thất:
Tổn thất tái thiết phổ biến là entropy chéo nhị phân (BCE) và sai số bình phương trung bình (MSE). Trong bài viết này, tôi sẽ sử dụng bộ dữ liệu MNIST cho bản demo. Hình ảnh MNIST chỉ có một kênh và pixel lấy giá trị từ 0 đến 1.
Trong trường hợp này, tổn thất BCE có thể được sử dụng làm tổn thất tái thiết để xử lý các pixel của ảnh MNIST dưới dạng biến ngẫu nhiên nhị phân tuân theo phân bố Bernoulli.
reconstruction_loss = nn.BCELoss(reduction='sum')
Như đã đề cập ở trên - Phân kỳ KL đánh giá sự khác biệt giữa hai phân phối. Lưu ý rằng nó không có tính chất đối xứng về khoảng cách: KL(P‖Q)!=KL(Q‖P).
Hai phân bố cần so sánh là:
không gian tiềm ẩn trước p(z) được giả định là phân phối chuẩn với giá trị trung bình bằng 0 và độ lệch chuẩn là 1 trong mỗi chiều không gian tiềm ẩn N(0, I ) .
Giả định như vậy đơn giản hóa việc tính toán phân kỳ KL và khuyến khích không gian tiềm ẩn tuân theo một phân phối đã biết và có thể quản lý được.
from torch.distributions.kl import kl_divergence def kl_divergence_loss(z_dist): return kl_divergence(z_dist, Normal(torch.zeros_like(z_dist.mean), torch.ones_like(z_dist.stddev)) ).sum(-1).sum()
class Encoder(nn.Module): def __init__(self, im_chan=1, output_chan=32, hidden_dim=16): super(Encoder, self).__init__() self.z_dim = output_chan self.encoder = nn.Sequential( self.init_conv_block(im_chan, hidden_dim), self.init_conv_block(hidden_dim, hidden_dim * 2), # double output_chan for mean and std with [output_chan] size self.init_conv_block(hidden_dim * 2, output_chan * 2, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=0, final_layer=False): layers = [ nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, padding=padding, stride=stride) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] return nn.Sequential(*layers) def forward(self, image): encoder_pred = self.encoder(image) encoding = encoder_pred.view(len(encoder_pred), -1) mean = encoding[:, :self.z_dim] logvar = encoding[:, self.z_dim:] # encoding output representing standard deviation is interpreted as # the logarithm of the variance associated with the normal distribution # take the exponent to convert it to standard deviation return mean, torch.exp(logvar*0.5)
class Decoder(nn.Module): def __init__(self, z_dim=32, im_chan=1, hidden_dim=64): super(Decoder, self).__init__() self.z_dim = z_dim self.decoder = nn.Sequential( self.init_conv_block(z_dim, hidden_dim * 4), self.init_conv_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), self.init_conv_block(hidden_dim * 2, hidden_dim), self.init_conv_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False): layers = [ nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] else: layers += [nn.Sigmoid()] return nn.Sequential(*layers) def forward(self, z): # Ensure the input latent vector z is correctly reshaped for the decoder x = z.view(-1, self.z_dim, 1, 1) # Pass the reshaped input through the decoder network return self.decoder(x)
Để lan truyền ngược qua một mẫu ngẫu nhiên, bạn cần di chuyển các tham số của mẫu ngẫu nhiên ( μ và 𝝈) ra ngoài hàm để cho phép tính toán gradient thông qua các tham số. Bước này còn được gọi là “thủ thuật tham số hóa lại”.
Trong PyTorch, bạn có thể tạo Phân phối chuẩn với đầu ra của bộ mã hóa μ và 𝝈 và lấy mẫu từ nó bằng phương thức rsample() thực hiện thủ thuật tham số hóa lại: nó giống như torch.randn(z_dim) * stddev + mean)
class VAE(nn.Module): def __init__(self, z_dim=32, im_chan=1): super(VAE, self).__init__() self.z_dim = z_dim self.encoder = Encoder(im_chan, z_dim) self.decoder = Decoder(z_dim, im_chan) def forward(self, images): z_dist = Normal(self.encoder(images)) # sample from distribution with reparametarazation trick z = z_dist.rsample() decoding = self.decoder(z) return decoding, z_dist
Tải dữ liệu kiểm tra và đào tạo MNIST.
transform = transforms.Compose([transforms.ToTensor()]) # Download and load the MNIST training data trainset = datasets.MNIST('.', download=True, train=True, transform=transform) train_loader = DataLoader(trainset, batch_size=64, shuffle=True) # Download and load the MNIST test data testset = datasets.MNIST('.', download=True, train=False, transform=transform) test_loader = DataLoader(testset, batch_size=64, shuffle=True)
Tạo vòng lặp đào tạo theo các bước đào tạo VAE được hiển thị trong hình trên.
def train_model(epochs=10, z_dim = 16): model = VAE(z_dim=z_dim).to(device) model_opt = torch.optim.Adam(model.parameters()) for epoch in range(epochs): print(f"Epoch {epoch}") for images, step in tqdm(train_loader): images = images.to(device) model_opt.zero_grad() recon_images, encoding = model(images) loss = reconstruction_loss(recon_images, images)+ kl_divergence_loss(encoding) loss.backward() model_opt.step() show_images_grid(images.cpu(), title=f'Input images') show_images_grid(recon_images.cpu(), title=f'Reconstructed images') return model
z_dim = 8 vae = train_model(epochs=20, z_dim=z_dim)
def visualize_latent_space(model, data_loader, device, method='TSNE', num_samples=10000): model.eval() latents = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): if len(latents) > num_samples: break mu, _ = model.encoder(data.to(device)) latents.append(mu.cpu()) labels.append(label.cpu()) latents = torch.cat(latents, dim=0).numpy() labels = torch.cat(labels, dim=0).numpy() assert method in ['TSNE', 'UMAP'], 'method should be TSNE or UMAP' if method == 'TSNE': tsne = TSNE(n_components=2, verbose=1) tsne_results = tsne.fit_transform(latents) fig = px.scatter(tsne_results, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with TSNE', width=600, height=600) elif method == 'UMAP': reducer = umap.UMAP() embedding = reducer.fit_transform(latents) fig = px.scatter(embedding, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with UMAP', width=600, height=600 ) fig.show()
visualize_latent_space(vae, train_loader, device='cuda' if torch.cuda.is_available() else 'cpu', method='UMAP', num_samples=10000)
Việc lấy mẫu từ Bộ mã hóa tự động biến đổi (VAE) cho phép tạo ra dữ liệu mới tương tự như dữ liệu được thấy trong quá trình đào tạo và đó là khía cạnh độc đáo giúp tách VAE khỏi kiến trúc AE truyền thống.
Có một số cách lấy mẫu từ VAE:
truyền tải các kích thước tiềm ẩn : truyền tải các kích thước tiềm ẩn của phương sai không gian tiềm ẩn VAE của dữ liệu phụ thuộc vào từng thứ nguyên. Việc truyền tải được thực hiện bằng cách cố định tất cả các chiều của vectơ tiềm ẩn ngoại trừ một chiều đã chọn và các giá trị khác nhau của chiều đã chọn trong phạm vi của nó. Một số chiều của không gian tiềm ẩn có thể tương ứng với các thuộc tính cụ thể của dữ liệu (VAE không có cơ chế cụ thể để buộc hành vi đó nhưng nó có thể xảy ra).
Ví dụ, một chiều trong không gian tiềm ẩn có thể kiểm soát biểu hiện cảm xúc của khuôn mặt hoặc hướng của vật thể.
Mỗi phương pháp lấy mẫu cung cấp một cách khác nhau để khám phá và hiểu các thuộc tính dữ liệu được ghi lại bởi không gian tiềm ẩn của VAE.
def posterior_sampling(model, data_loader, n_samples=25): model.eval() images, _ = next(iter(data_loader)) images = images[:n_samples] with torch.no_grad(): _, encoding_dist = model(images.to(device)) input_sample=encoding_dist.sample() recon_images = model.decoder(input_sample) show_images_grid(images, title=f'input samples') show_images_grid(recon_images, title=f'generated posterior samples')
posterior_sampling(vae, train_loader, n_samples=30)
Lấy mẫu sau cho phép tạo ra các mẫu dữ liệu thực tế nhưng có độ biến thiên thấp: dữ liệu đầu ra tương tự như dữ liệu đầu vào.
def prior_sampling(model, z_dim=32, n_samples = 25): model.eval() input_sample=torch.randn(n_samples, z_dim).to(device) with torch.no_grad(): sampled_images = model.decoder(input_sample) show_images_grid(sampled_images, title=f'generated prior samples')
prior_sampling(vae, z_dim, n_samples=40)
Việc lấy mẫu trước với N(0, I ) không phải lúc nào cũng tạo ra dữ liệu hợp lý nhưng có độ biến thiên cao.
Mã hóa trung bình của mỗi lớp có thể được tích lũy từ toàn bộ tập dữ liệu và sau đó được sử dụng cho thế hệ được kiểm soát (thế hệ có điều kiện).
def get_data_predictions(model, data_loader): model.eval() latents_mean = [] latents_std = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): mu, std = model.encoder(data.to(device)) latents_mean.append(mu.cpu()) latents_std.append(std.cpu()) labels.append(label.cpu()) latents_mean = torch.cat(latents_mean, dim=0) latents_std = torch.cat(latents_std, dim=0) labels = torch.cat(labels, dim=0) return latents_mean, latents_std, labels
def get_classes_mean(class_to_idx, labels, latents_mean, latents_std): classes_mean = {} for class_name in train_loader.dataset.class_to_idx: class_id = train_loader.dataset.class_to_idx[class_name] labels_class = labels[labels==class_id] latents_mean_class = latents_mean[labels==class_id] latents_mean_class = latents_mean_class.mean(dim=0, keepdims=True) latents_std_class = latents_std[labels==class_id] latents_std_class = latents_std_class.mean(dim=0, keepdims=True) classes_mean[class_id] = [latents_mean_class, latents_std_class] return classes_mean
latents_mean, latents_stdvar, labels = get_data_predictions(vae, train_loader) classes_mean = get_classes_mean(train_loader.dataset.class_to_idx, labels, latents_mean, latents_stdvar) n_samples = 20 for class_id in classes_mean.keys(): latents_mean_class, latents_stddev_class = classes_mean[class_id] # create normal distribution of the current class class_dist = Normal(latents_mean_class, latents_stddev_class) percentiles = torch.linspace(0.05, 0.95, n_samples) # get samples from different parts of the distribution using icdf # https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.icdf class_z_sample = class_dist.icdf(percentiles[:, None].repeat(1, z_dim)) with torch.no_grad(): # generate image directly from mean class_image_prototype = vae.decoder(latents_mean_class.to(device)) # generate images sampled from Normal(class mean, class std) class_images = vae.decoder(class_z_sample.to(device)) show_image(class_image_prototype[0].cpu(), title=f'Class {class_id} prototype image') show_images_grid(class_images.cpu(), title=f'Class {class_id} images')
Lấy mẫu từ phân phối chuẩn với lớp trung bình μ đảm bảo tạo ra dữ liệu mới từ cùng một lớp.
def linear_interpolation(start, end, steps): # Create a linear path from start to end z = torch.linspace(0, 1, steps)[:, None].to(device) * (end - start) + start # Decode the samples along the path vae.eval() with torch.no_grad(): samples = vae.decoder(z) return samples
start = torch.randn(1, z_dim).to(device) end = torch.randn(1, z_dim).to(device) interpolated_samples = linear_interpolation(start, end, steps = 24) show_images_grid(interpolated_samples, title=f'Linear interpolation between two random latent vectors')
for start_class_id in range(1,10): start = classes_mean[start_class_id][0].to(device) for end_class_id in range(1, 10): if end_class_id == start_class_id: continue end = classes_mean[end_class_id][0].to(device) interpolated_samples = linear_interpolation(start, end, steps = 20) show_images_grid(interpolated_samples, title=f'Linear interpolation between classes {start_class_id} and {end_class_id}')
Mỗi chiều của vectơ tiềm ẩn biểu thị một phân phối chuẩn; phạm vi giá trị của thứ nguyên được kiểm soát bởi giá trị trung bình và phương sai của thứ nguyên. Một cách đơn giản để duyệt qua phạm vi giá trị là sử dụng CDF nghịch đảo (hàm phân phối tích lũy) của phân phối chuẩn.
ICDF nhận giá trị từ 0 đến 1 (biểu thị xác suất) và trả về giá trị từ phân phối. Đối với một xác suất p cho trước, ICDF tạo ra một giá trị p_icdf sao cho xác suất của một biến ngẫu nhiên là <= p_icdf bằng xác suất đã cho p ?”
Nếu bạn có phân phối bình thường, icdf(0,5) sẽ trả về giá trị trung bình của phân phối. icdf(0,95) sẽ trả về giá trị lớn hơn 95% dữ liệu từ bản phân phối.
def latent_space_traversal(model, input_sample, norm_dist, dim_to_traverse, n_samples, latent_dim, device): # Create a range of values to traverse assert input_sample.shape[0] == 1, 'input sample shape should be [1, latent_dim]' # Generate linearly spaced percentiles between 0.05 and 0.95 percentiles = torch.linspace(0.1, 0.9, n_samples) # Get the quantile values corresponding to the percentiles traversed_values = norm_dist.icdf(percentiles[:, None].repeat(1, z_dim)) # Initialize a latent space vector with zeros z = input_sample.repeat(n_samples, 1) # Assign the traversed values to the specified dimension z[:, dim_to_traverse] = traversed_values[:, dim_to_traverse] # Decode the latent vectors with torch.no_grad(): samples = model.decoder(z.to(device)) return samples
for class_id in range(0,10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') for i in range(z_dim): interpolated_samples = latent_space_traversal(vae, mu, norm_dist=Normal(mu, torch.ones_like(mu)), dim_to_traverse=i, n_samples=20, latent_dim=z_dim, device=device) show_images_grid(interpolated_samples, title=f'Class {class_id} dim={i} traversal')
Di chuyển ngang một chiều có thể dẫn đến thay đổi kiểu chữ số hoặc hướng chữ số điều khiển.
def traverse_two_latent_dimensions(model, input_sample, z_dist, n_samples=25, z_dim=16, dim_1=0, dim_2=1, title='plot'): digit_size=28 percentiles = torch.linspace(0.10, 0.9, n_samples) grid_x = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) grid_y = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) figure = np.zeros((digit_size * n_samples, digit_size * n_samples)) z_sample_def = input_sample.clone().detach() # select two dimensions to vary (dim_1 and dim_2) and keep the rest fixed for yi in range(n_samples): for xi in range(n_samples): with torch.no_grad(): z_sample = z_sample_def.clone().detach() z_sample[:, dim_1] = grid_x[xi, dim_1] z_sample[:, dim_2] = grid_y[yi, dim_2] x_decoded = model.decoder(z_sample.to(device)).cpu() digit = x_decoded[0].reshape(digit_size, digit_size) figure[yi * digit_size: (yi + 1) * digit_size, xi * digit_size: (xi + 1) * digit_size] = digit.numpy() plt.figure(figsize=(6, 6)) plt.imshow(figure, cmap='Greys_r') plt.title(title) plt.show()
for class_id in range(10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') traverse_two_latent_dimensions(vae, mu, z_dist=Normal(mu, torch.ones_like(mu)), n_samples=8, z_dim=z_dim, dim_1=3, dim_2=6, title=f'Class {class_id} traversing dimensions {(3, 6)}')
Việc duyệt qua nhiều chiều cùng một lúc cung cấp một cách có thể kiểm soát được để tạo ra dữ liệu có độ biến thiên cao.
Nếu một mô hình VAE được huấn luyện với z_dim =2, thì có thể hiển thị nhiều chữ số 2D từ không gian tiềm ẩn của nó. Để làm điều đó, tôi sẽ sử dụng hàm traverse_two_latent_dimensions với dim_1 =0 và dim_2 =2 .
vae_2d = train_model(epochs=10, z_dim=2)
z_dist = Normal(torch.zeros(1, 2), torch.ones(1, 2)) input_sample = torch.zeros(1, 2) with torch.no_grad(): decoding = vae_2d.decoder(input_sample.to(device)) traverse_two_latent_dimensions(vae_2d, input_sample, z_dist, n_samples=20, dim_1=0, dim_2=1, z_dim=2, title=f'traversing 2D latent space')