ML Research Student at NUS • rish-16.github.io
I’ve recently been reading up on Generative Adversarial Networks (GANs) and I knew that I had to come up with an interesting way to use it. It was then that I noticed all the buzz about the Quick, Draw! dataset published by the Google Brain team for open source tinkering and fooling around. I know, I’m slow…
GAN stands for Generative Adversarial Network. It is a model that is essentially a cop and robber zero-sum game where the robber tries to create fake bank notes in an effort to fully replicate the real ones, while the cop discriminates between the real and fake ones until it becomes harder to guess.
You might have guessed it but this ML model comprises of two major parts: a Generator and a Discriminator.
GANs learn a unique mapping over the training data such that it forms internal representations of the features of the dataset.
Apart from Variational Autoencoders, GANs are the alternative go-to ML models for image generation tasks. They are really difficult to train, requiring lots of data (being rather convolutional) and compute power. However, most of the papers accepted at NIPS last December involved experimentations and applications of GANs and similar versatile generative models, showing us it’s worth and capability.
The Generator is the one spitting out images of new unseen cat faces that are not present in the original dataset. Essentially, it is a feed-forward neural network that transforms random noise into images of a certain size during training.
The generative network learns to map from a latent space to a particular data distribution of interest, while the discriminative network discriminates between instances from the true data distribution and fake instances produced by the generator.
As shown below, we take in random noise (of size 100) and transform it into an image of size 784 that can be reshaped into a 28 by 28 image for visualisation purposes. This is similar to the size of MNIST images. For the Leaky ReLu activation functions, we have used 0.3 and 0.2 as the alpha values.
This part of the model comprises of a feed-forward network that takes in the output of the generator as input and produces a sigmoid probability between 0 and 1 in an attempt to evaluate the given instance being fake or real.
Given an array of size 784 from the generator, the input goes through the discriminator network where a sigmoid layer calculates the probability of the input being real or fake.
We use Dropout to ensure that there is a more uniform learning through training. It prevents the neurones from becoming interdependent on each other to deliver the accurate prediction every layer. It prevents overfitting on the data by preventing the neurones from coming up with complex functions to model the mapping between the inputs and outputs. This is done by dropping a proportion of the neurones in every layer which is why it’s called Dropout.
Also, similar to the generator, we use Leaky ReLu or LReLu (Leaky Rectified Linear Unit) here to prevent the vanishing gradients problem that arises in deep networks. By having the extra line jutting out of the gradient graph, it prevents the gradient or derivatives of the weights from becoming zero which completely slows down the learning process, rendering training ineffective.
We can see below that the Leaky ReLu activation function has a downward function in the third quadrant that prevents the gradient from becoming zero in the case that the inputs to the function are negative.
The loss function contains two parts: the discriminator loss J(D) and the generator loss J(G). Being a min-max game, the sum of these two loss functions should ultimately be zero (this is why J(G) = -J(D)). We can see that the loss function is the log-probability formula that is used in other models such as . Log-probability is also called log loss.
Coming back to the cop-robber analogy, the robber aims to minimise the likelihood of the cop predicting that the printed note is a counterfeit.
Similarly, the generator aims to minimise the log-probability of the discriminator predicting that the generated data is fake or real. To put it another way, the generator tries to minimise the log-probability of the discriminator being correct by improving the quality of the generated images over time during training.
When the generator produces an image, it is evaluated by the discriminator that is trained on the real data distribution. As we have the real/fake labels in the training set, the discriminator loss can be calculated.
The script I’ve written features a convolutional and de-convolutional block that receives the generator inputs and predicts whether it belongs to the original data distribution or is fake. Over time, the quality of the generated images improves and it becomes even harder for the discriminator to evaluate the authenticity of the image.
As the code is too long for a definitive tutorial, I have attached a link to my GitHub project repo. It contains the Cat Faces dataset that I used for this project and the related source code.
There was no straightforward dataset that could be used off the shelf in this case. So, following the instructions given in the Quick, Draw! GitHub repo, I downloaded the numpy file for the cat faces from the cloud storage link provided in the README.md file.
I wrote a function that reads the numpy file cats.npy and creates a training and testing set from it. It is similar to how one handles the MNIST dataset.
The GAN was able to generalise to the features in the training set and was able to reconstruct the noise into a legible image (of sorts).
We can see the faint outline of the cat face in the 100th and 200th epoch which shows us that the generator has nearly achieved the best mapping between the random noise and the features of a cat’s face. This will, however, take a longer time for training if we want to achieve near-perfect results from the generator.
We can see that there is some similarity between the generator outputs and original dataset.
The results were very comprehensive and proved that a GAN can be helpful in creating images from scratch after training for several hours on a standard-issue MacBook Pro CPU.
This journey that I took to make a GAN was amazing! I learnt a lot more about GANs in general and what makes them so versatile to use to create. Projects like these help you discover what makes complex architectures work under the hood.
Original article by Rishabh Anand
Create your free account to unlock your custom reading experience.