पारंपरिक ऑटोएनकोडर के समान, VAE आर्किटेक्चर के दो भाग होते हैं: एक एनकोडर और एक डिकोडर। पारंपरिक एई मॉडल इनपुट को एक अव्यक्त-स्पेस वेक्टर में मैप करते हैं और इस वेक्टर से आउटपुट का पुनर्निर्माण करते हैं।
VAE एक बहुभिन्नरूपी सामान्य वितरण में इनपुट को मैप करता है (एनकोडर प्रत्येक अव्यक्त आयाम के माध्य और विचरण को आउटपुट करता है)।
चूंकि वीएई एनकोडर एक वितरण उत्पन्न करता है, इसलिए इस वितरण से नमूना लेकर और नमूना किए गए अव्यक्त वेक्टर को डिकोडर में पास करके नया डेटा उत्पन्न किया जा सकता है। आउटपुट छवियों को उत्पन्न करने के लिए उत्पादित वितरण से नमूना लेने का मतलब है कि वीएई नए डेटा को उत्पन्न करने की अनुमति देता है जो इनपुट डेटा के समान है, लेकिन समान है।
यह आलेख वीएई वास्तुकला के घटकों की खोज करता है और वीएई मॉडल के साथ नई छवियां (नमूना) उत्पन्न करने के कई तरीके प्रदान करता है। सभी कोड Google Colab पर उपलब्ध हैं।
ऑटोएनकोडर और वेरिएशनल ऑटोएनकोडर दोनों के दो भाग होते हैं: एनकोडर और डिकोडर। एई का एनकोडर न्यूरल नेटवर्क प्रत्येक छवि को अव्यक्त स्थान में एक वेक्टर में मैप करना सीखता है और डिकोडर एन्कोडेड अव्यक्त वेक्टर से मूल छवि को फिर से बनाना सीखता है।
वीएई का एनकोडर तंत्रिका नेटवर्क पैरामीटर आउटपुट करता है जो अव्यक्त स्थान (बहुभिन्नरूपी वितरण) के प्रत्येक आयाम के लिए संभाव्यता वितरण को परिभाषित करता है। प्रत्येक इनपुट के लिए, एनकोडर अव्यक्त स्थान के प्रत्येक आयाम के लिए एक माध्य और एक भिन्नता उत्पन्न करता है।
आउटपुट माध्य और विचरण का उपयोग बहुभिन्नरूपी गाऊसी वितरण को परिभाषित करने के लिए किया जाता है। डिकोडर न्यूरल नेटवर्क AE मॉडल के समान है।
वीएई मॉडल को प्रशिक्षित करने का लक्ष्य प्रदान किए गए अव्यक्त वैक्टर से वास्तविक छवियां उत्पन्न करने की संभावना को अधिकतम करना है। प्रशिक्षण के दौरान, VAE मॉडल दो नुकसानों को कम करता है:
सामान्य पुनर्निर्माण हानियाँ बाइनरी क्रॉस-एन्ट्रॉपी (बीसीई) और माध्य वर्ग त्रुटि (एमएसई) हैं। इस लेख में, मैं डेमो के लिए एमएनआईएसटी डेटासेट का उपयोग करूंगा। एमएनआईएसटी छवियों में केवल एक चैनल होता है, और पिक्सेल 0 और 1 के बीच मान लेते हैं।
इस मामले में, बीसीई हानि का उपयोग एमएनआईएसटी छवियों के पिक्सल को एक द्विआधारी यादृच्छिक चर के रूप में इलाज करने के लिए पुनर्निर्माण हानि के रूप में किया जा सकता है जो बर्नौली वितरण का अनुसरण करता है।
reconstruction_loss = nn.BCELoss(reduction='sum')
जैसा कि ऊपर बताया गया है - केएल विचलन दो वितरणों के बीच अंतर का मूल्यांकन करता है। ध्यान दें कि इसमें दूरी का सममित गुण नहीं है: KL(P‖Q)!=KL(Q‖P)।
जिन दो वितरणों की तुलना करने की आवश्यकता है वे हैं:
p(z) से पहले अव्यक्त स्थान जिसे शून्य के माध्य और प्रत्येक अव्यक्त स्थान आयाम N(0, I ) में एक के मानक विचलन के साथ एक सामान्य वितरण माना जाता है।
ऐसी धारणा केएल विचलन गणना को सरल बनाती है और अव्यक्त स्थान को ज्ञात, प्रबंधनीय वितरण का पालन करने के लिए प्रोत्साहित करती है।
from torch.distributions.kl import kl_divergence def kl_divergence_loss(z_dist): return kl_divergence(z_dist, Normal(torch.zeros_like(z_dist.mean), torch.ones_like(z_dist.stddev)) ).sum(-1).sum()
class Encoder(nn.Module): def __init__(self, im_chan=1, output_chan=32, hidden_dim=16): super(Encoder, self).__init__() self.z_dim = output_chan self.encoder = nn.Sequential( self.init_conv_block(im_chan, hidden_dim), self.init_conv_block(hidden_dim, hidden_dim * 2), # double output_chan for mean and std with [output_chan] size self.init_conv_block(hidden_dim * 2, output_chan * 2, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=0, final_layer=False): layers = [ nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, padding=padding, stride=stride) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] return nn.Sequential(*layers) def forward(self, image): encoder_pred = self.encoder(image) encoding = encoder_pred.view(len(encoder_pred), -1) mean = encoding[:, :self.z_dim] logvar = encoding[:, self.z_dim:] # encoding output representing standard deviation is interpreted as # the logarithm of the variance associated with the normal distribution # take the exponent to convert it to standard deviation return mean, torch.exp(logvar*0.5)
class Decoder(nn.Module): def __init__(self, z_dim=32, im_chan=1, hidden_dim=64): super(Decoder, self).__init__() self.z_dim = z_dim self.decoder = nn.Sequential( self.init_conv_block(z_dim, hidden_dim * 4), self.init_conv_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), self.init_conv_block(hidden_dim * 2, hidden_dim), self.init_conv_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False): layers = [ nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] else: layers += [nn.Sigmoid()] return nn.Sequential(*layers) def forward(self, z): # Ensure the input latent vector z is correctly reshaped for the decoder x = z.view(-1, self.z_dim, 1, 1) # Pass the reshaped input through the decoder network return self.decoder(x)
एक यादृच्छिक नमूने के माध्यम से बैक-प्रोपेगेट करने के लिए आपको मापदंडों के माध्यम से ग्रेडिएंट गणना की अनुमति देने के लिए यादृच्छिक नमूने ( μ और 𝝈) के मापदंडों को फ़ंक्शन के बाहर ले जाना होगा। इस चरण को "पुनरावर्तन चाल" भी कहा जाता है।
PyTorch में, आप एनकोडर के आउटपुट μ और 𝝈 के साथ एक सामान्य वितरण बना सकते हैं और उसमें से rsample() विधि का नमूना ले सकते हैं जो रिपैरामीटराइजेशन ट्रिक को लागू करता है: यह torch.randn(z_dim) * stddev + mean)
class VAE(nn.Module): def __init__(self, z_dim=32, im_chan=1): super(VAE, self).__init__() self.z_dim = z_dim self.encoder = Encoder(im_chan, z_dim) self.decoder = Decoder(z_dim, im_chan) def forward(self, images): z_dist = Normal(self.encoder(images)) # sample from distribution with reparametarazation trick z = z_dist.rsample() decoding = self.decoder(z) return decoding, z_dist
एमएनआईएसटी ट्रेन और परीक्षण डेटा लोड करें।
transform = transforms.Compose([transforms.ToTensor()]) # Download and load the MNIST training data trainset = datasets.MNIST('.', download=True, train=True, transform=transform) train_loader = DataLoader(trainset, batch_size=64, shuffle=True) # Download and load the MNIST test data testset = datasets.MNIST('.', download=True, train=False, transform=transform) test_loader = DataLoader(testset, batch_size=64, shuffle=True)
एक प्रशिक्षण लूप बनाएं जो ऊपर चित्र में दिखाए गए वीएई प्रशिक्षण चरणों का पालन करता हो।
def train_model(epochs=10, z_dim = 16): model = VAE(z_dim=z_dim).to(device) model_opt = torch.optim.Adam(model.parameters()) for epoch in range(epochs): print(f"Epoch {epoch}") for images, step in tqdm(train_loader): images = images.to(device) model_opt.zero_grad() recon_images, encoding = model(images) loss = reconstruction_loss(recon_images, images)+ kl_divergence_loss(encoding) loss.backward() model_opt.step() show_images_grid(images.cpu(), title=f'Input images') show_images_grid(recon_images.cpu(), title=f'Reconstructed images') return model
z_dim = 8 vae = train_model(epochs=20, z_dim=z_dim)
def visualize_latent_space(model, data_loader, device, method='TSNE', num_samples=10000): model.eval() latents = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): if len(latents) > num_samples: break mu, _ = model.encoder(data.to(device)) latents.append(mu.cpu()) labels.append(label.cpu()) latents = torch.cat(latents, dim=0).numpy() labels = torch.cat(labels, dim=0).numpy() assert method in ['TSNE', 'UMAP'], 'method should be TSNE or UMAP' if method == 'TSNE': tsne = TSNE(n_components=2, verbose=1) tsne_results = tsne.fit_transform(latents) fig = px.scatter(tsne_results, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with TSNE', width=600, height=600) elif method == 'UMAP': reducer = umap.UMAP() embedding = reducer.fit_transform(latents) fig = px.scatter(embedding, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with UMAP', width=600, height=600 ) fig.show()
visualize_latent_space(vae, train_loader, device='cuda' if torch.cuda.is_available() else 'cpu', method='UMAP', num_samples=10000)
वेरिएशनल ऑटोएन्कोडर (वीएई) से नमूनाकरण नए डेटा की पीढ़ी को सक्षम बनाता है जो प्रशिक्षण के दौरान देखे गए डेटा के समान है और यह एक अनूठा पहलू है जो वीएई को पारंपरिक एई आर्किटेक्चर से अलग करता है।
VAE से नमूना लेने के कई तरीके हैं:
अव्यक्त आयामों का ट्रैवर्सल : वीएई के अव्यक्त आयामों का ट्रैवर्सिंग डेटा का अव्यक्त स्थान विचरण प्रत्येक आयाम पर निर्भर करता है। ट्रैवर्सल एक चुने हुए आयाम को छोड़कर अव्यक्त वेक्टर के सभी आयामों और उसकी सीमा में चुने गए आयाम के अलग-अलग मूल्यों को तय करके किया जाता है। अव्यक्त स्थान के कुछ आयाम डेटा की विशिष्ट विशेषताओं के अनुरूप हो सकते हैं (वीएई के पास उस व्यवहार को मजबूर करने के लिए विशिष्ट तंत्र नहीं हैं लेकिन ऐसा हो सकता है)।
उदाहरण के लिए, अव्यक्त स्थान में एक आयाम किसी चेहरे की भावनात्मक अभिव्यक्ति या किसी वस्तु के अभिविन्यास को नियंत्रित कर सकता है।
प्रत्येक नमूनाकरण विधि वीएई के अव्यक्त स्थान द्वारा कैप्चर किए गए डेटा गुणों की खोज और समझने का एक अलग तरीका प्रदान करती है।
def posterior_sampling(model, data_loader, n_samples=25): model.eval() images, _ = next(iter(data_loader)) images = images[:n_samples] with torch.no_grad(): _, encoding_dist = model(images.to(device)) input_sample=encoding_dist.sample() recon_images = model.decoder(input_sample) show_images_grid(images, title=f'input samples') show_images_grid(recon_images, title=f'generated posterior samples')
posterior_sampling(vae, train_loader, n_samples=30)
पश्च नमूनाकरण यथार्थवादी डेटा नमूने उत्पन्न करने की अनुमति देता है लेकिन कम परिवर्तनशीलता के साथ: आउटपुट डेटा इनपुट डेटा के समान होता है।
def prior_sampling(model, z_dim=32, n_samples = 25): model.eval() input_sample=torch.randn(n_samples, z_dim).to(device) with torch.no_grad(): sampled_images = model.decoder(input_sample) show_images_grid(sampled_images, title=f'generated prior samples')
prior_sampling(vae, z_dim, n_samples=40)
N(0, I ) के साथ पूर्व नमूनाकरण हमेशा विश्वसनीय डेटा उत्पन्न नहीं करता है लेकिन इसमें उच्च परिवर्तनशीलता होती है।
प्रत्येक वर्ग के माध्य एन्कोडिंग को संपूर्ण डेटासेट से संचित किया जा सकता है और बाद में नियंत्रित (सशर्त पीढ़ी) के लिए उपयोग किया जा सकता है।
def get_data_predictions(model, data_loader): model.eval() latents_mean = [] latents_std = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): mu, std = model.encoder(data.to(device)) latents_mean.append(mu.cpu()) latents_std.append(std.cpu()) labels.append(label.cpu()) latents_mean = torch.cat(latents_mean, dim=0) latents_std = torch.cat(latents_std, dim=0) labels = torch.cat(labels, dim=0) return latents_mean, latents_std, labels
def get_classes_mean(class_to_idx, labels, latents_mean, latents_std): classes_mean = {} for class_name in train_loader.dataset.class_to_idx: class_id = train_loader.dataset.class_to_idx[class_name] labels_class = labels[labels==class_id] latents_mean_class = latents_mean[labels==class_id] latents_mean_class = latents_mean_class.mean(dim=0, keepdims=True) latents_std_class = latents_std[labels==class_id] latents_std_class = latents_std_class.mean(dim=0, keepdims=True) classes_mean[class_id] = [latents_mean_class, latents_std_class] return classes_mean
latents_mean, latents_stdvar, labels = get_data_predictions(vae, train_loader) classes_mean = get_classes_mean(train_loader.dataset.class_to_idx, labels, latents_mean, latents_stdvar) n_samples = 20 for class_id in classes_mean.keys(): latents_mean_class, latents_stddev_class = classes_mean[class_id] # create normal distribution of the current class class_dist = Normal(latents_mean_class, latents_stddev_class) percentiles = torch.linspace(0.05, 0.95, n_samples) # get samples from different parts of the distribution using icdf # https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.icdf class_z_sample = class_dist.icdf(percentiles[:, None].repeat(1, z_dim)) with torch.no_grad(): # generate image directly from mean class_image_prototype = vae.decoder(latents_mean_class.to(device)) # generate images sampled from Normal(class mean, class std) class_images = vae.decoder(class_z_sample.to(device)) show_image(class_image_prototype[0].cpu(), title=f'Class {class_id} prototype image') show_images_grid(class_images.cpu(), title=f'Class {class_id} images')
औसत वर्ग μ के साथ सामान्य वितरण से नमूनाकरण उसी वर्ग से नए डेटा की पीढ़ी की गारंटी देता है।
def linear_interpolation(start, end, steps): # Create a linear path from start to end z = torch.linspace(0, 1, steps)[:, None].to(device) * (end - start) + start # Decode the samples along the path vae.eval() with torch.no_grad(): samples = vae.decoder(z) return samples
start = torch.randn(1, z_dim).to(device) end = torch.randn(1, z_dim).to(device) interpolated_samples = linear_interpolation(start, end, steps = 24) show_images_grid(interpolated_samples, title=f'Linear interpolation between two random latent vectors')
for start_class_id in range(1,10): start = classes_mean[start_class_id][0].to(device) for end_class_id in range(1, 10): if end_class_id == start_class_id: continue end = classes_mean[end_class_id][0].to(device) interpolated_samples = linear_interpolation(start, end, steps = 20) show_images_grid(interpolated_samples, title=f'Linear interpolation between classes {start_class_id} and {end_class_id}')
अव्यक्त वेक्टर का प्रत्येक आयाम एक सामान्य वितरण का प्रतिनिधित्व करता है; आयाम के मानों की सीमा को आयाम के माध्य और विचरण द्वारा नियंत्रित किया जाता है। मूल्यों की सीमा को पार करने का एक सरल तरीका सामान्य वितरण के व्युत्क्रम सीडीएफ (संचयी वितरण फ़ंक्शन) का उपयोग करना होगा।
ICDF 0 और 1 के बीच मान लेता है (संभावना का प्रतिनिधित्व करता है) और वितरण से एक मान लौटाता है। किसी दी गई प्रायिकता p के लिए ICDF एक p_icdf मान इस प्रकार आउटपुट करता है कि एक यादृच्छिक चर के <= p_icdf होने की प्रायिकता दी गई प्रायिकता p के बराबर होती है?
यदि आपके पास सामान्य वितरण है, तो icdf(0.5) को वितरण का माध्य लौटाना चाहिए। icdf(0.95) को वितरण से 95% डेटा से बड़ा मान लौटाना चाहिए।
def latent_space_traversal(model, input_sample, norm_dist, dim_to_traverse, n_samples, latent_dim, device): # Create a range of values to traverse assert input_sample.shape[0] == 1, 'input sample shape should be [1, latent_dim]' # Generate linearly spaced percentiles between 0.05 and 0.95 percentiles = torch.linspace(0.1, 0.9, n_samples) # Get the quantile values corresponding to the percentiles traversed_values = norm_dist.icdf(percentiles[:, None].repeat(1, z_dim)) # Initialize a latent space vector with zeros z = input_sample.repeat(n_samples, 1) # Assign the traversed values to the specified dimension z[:, dim_to_traverse] = traversed_values[:, dim_to_traverse] # Decode the latent vectors with torch.no_grad(): samples = model.decoder(z.to(device)) return samples
for class_id in range(0,10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') for i in range(z_dim): interpolated_samples = latent_space_traversal(vae, mu, norm_dist=Normal(mu, torch.ones_like(mu)), dim_to_traverse=i, n_samples=20, latent_dim=z_dim, device=device) show_images_grid(interpolated_samples, title=f'Class {class_id} dim={i} traversal')
किसी एकल आयाम को पार करने से अंक शैली या नियंत्रण अंक अभिविन्यास में परिवर्तन हो सकता है।
def traverse_two_latent_dimensions(model, input_sample, z_dist, n_samples=25, z_dim=16, dim_1=0, dim_2=1, title='plot'): digit_size=28 percentiles = torch.linspace(0.10, 0.9, n_samples) grid_x = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) grid_y = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) figure = np.zeros((digit_size * n_samples, digit_size * n_samples)) z_sample_def = input_sample.clone().detach() # select two dimensions to vary (dim_1 and dim_2) and keep the rest fixed for yi in range(n_samples): for xi in range(n_samples): with torch.no_grad(): z_sample = z_sample_def.clone().detach() z_sample[:, dim_1] = grid_x[xi, dim_1] z_sample[:, dim_2] = grid_y[yi, dim_2] x_decoded = model.decoder(z_sample.to(device)).cpu() digit = x_decoded[0].reshape(digit_size, digit_size) figure[yi * digit_size: (yi + 1) * digit_size, xi * digit_size: (xi + 1) * digit_size] = digit.numpy() plt.figure(figsize=(6, 6)) plt.imshow(figure, cmap='Greys_r') plt.title(title) plt.show()
for class_id in range(10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') traverse_two_latent_dimensions(vae, mu, z_dist=Normal(mu, torch.ones_like(mu)), n_samples=8, z_dim=z_dim, dim_1=3, dim_2=6, title=f'Class {class_id} traversing dimensions {(3, 6)}')
एक साथ कई आयामों को पार करना उच्च परिवर्तनशीलता के साथ डेटा उत्पन्न करने का एक नियंत्रणीय तरीका प्रदान करता है।
यदि VAE मॉडल को z_dim =2 के साथ प्रशिक्षित किया जाता है, तो इसके अव्यक्त स्थान से अंकों का 2D मैनिफोल्ड प्रदर्शित करना संभव है। ऐसा करने के लिए, मैं dim_1 =0 और dim_2 =2 के साथ travers_two_latent_dimensions फ़ंक्शन का उपयोग करूंगा।
vae_2d = train_model(epochs=10, z_dim=2)
z_dist = Normal(torch.zeros(1, 2), torch.ones(1, 2)) input_sample = torch.zeros(1, 2) with torch.no_grad(): decoding = vae_2d.decoder(input_sample.to(device)) traverse_two_latent_dimensions(vae_2d, input_sample, z_dist, n_samples=20, dim_1=0, dim_2=1, z_dim=2, title=f'traversing 2D latent space')