The purpose of this post is to implement and understand Google Deepmind’s paper . The code is based on the work of , who in his original code was able to achieve the implementation in only 158 lines of Python code. DRAW: A Recurrent Neural Network For Image Generation Eric Jang Let’s begin by explaining what does DRAW stands for… Deep Recurrent Attentive Writer (DRAW) is a neural network architecture for image generation. DRAW networks combine a novel spatial attention mechanism that mimics the foveation of the human eye, with a sequential variational auto-encoding framework that allows for the iterative construction of complex images. The system substantially improves on the state of the art for generative models on MNIST, and, when trained on the Street View House Numbers dataset, it generates images that cannot be distinguished from real data with the naked eye. The core of the DRAW architecture is a pair of recurrent neural networks: an encoder network that compresses the real images presented during training, and a decoder that reconstitutes images after receiving codes. The combined system is trained end-to-end with stochastic gradient descent, where the loss function is a variational upper bound on the log-likelihood of the data. DRAW Architecture DRAW Network is similar to other variational auto-encoders, it contains an network that determines a distribution over latent codes that capture salient information about the input data and a network receives samples from the code distribution and uses them to condition its own distribution over images. encoder decoder 3 Key Differences Between DRAW and Auto-Encoders Both, the encoder and decoder are recurrent networks in DRAW.Decoder’s output are added successively to the distribution in order to generate the data, instead of generating this the distribution in single steps.A dynamically updated attention mechanism is used to restrict both the input region observed by the encoder, and the output region modified by the decoder. In simple terms, the network decides at each time-step “where to read” and “where to write” as well as “what to write”. During generation, a sample is drawn from a prior and passed through the feedforward decoder network to compute the probability of the input given the sample. Left: Conventional Variational Auto-Encoder. z P(z) P(x|z) During inference the input x is passed to the encoder network, producing an approximate posterior over latent variables. During training, is sampled from and then used to compute the total description length , which is minimized with . Q(z|x) z Q(z|x) KL ( Q (Z|x)∣∣ P(Z)−log(P(x|z)) stochastic gradient descent At each time-step a sample from the prior is passed to the recurrent decoder network, which then modifies part of the canvas matrix. The final canvas matrix is used to compute . Right: DRAW Network. z_t P(z_t) cT P(x|z_1:T) During inference the input is read at every time-step and the result is passed to the encoder RNN. at the previous time-step specify where to read. The output of the encoder RNN is used to compute the approximate posterior over the latent variables at that time-step. The RNNs Loss Function The final canvas matrix is used to parametrize a model of the input data. If the input is binary, the natural choice for is a with means given by . The reconstruction loss is defined as the negative log probability of under : cT D(X | cT) D Bernoulli distribution σ(cT) Lx x D The latent loss for a sequence of latent distributions is defined as the summed Kullback-Leibler divergence of some latent prior from P(Z_t) Note that this loss depends upon the latent samples drawn from z_t which depend in turn on the input x. If the latent distribution is a diagonal Gaussian with μt, σt where: a simple choice for is a standard Gaussian with mean zero and standard deviation one, in which case the equation becomes: P(Z_t) The total loss for the network is the expectation of the sum of the reconstruction and latent losses: L Which we optimize using a single sample of for each stochastic gradient descent step. z can be interpreted as the number of nats required to transmit the latent sample sequence to the decoder from the prior, and (if is discrete) is the number of nats required for the decoder to reconstruct given . The total loss is therefore equivalent to the expected compression of the data by the decoder and prior. L^z z_1:T x L^x x z_1:T Improving Images As Eric Jang mentions on , it’s easier to ask our neural network to merely “improve the image” rather than “finish the image in one shot”. Human artists work by iterating on their canvas, and infer from their drawing what to fix and what to paint next. his post Improving an image or progressive refinement is simply breaking up our joint distribution over and over again, resulting in a chain of latent variables to a new observed variable distribution . P(C) C1,C2,…CT−1 P(CT) The trick is to sample from the iterative refinement distribution several times rather than straight-up sampling from . P(Ct|Ct−1) P(C) In the DRAW model, is the same distribution for all , so we can compactly represent this as the following recurrence relation (if not, then we have a instead of a ) P(Ct|Ct−1) t Markov Chain recurrent network The DRAW model applied Imagine you are trying to encode an image of the number 8. Every handwritten number is drawn differently, while some portions may be thicker others can be longer. Without attention, the encoder would be forced to try and capture all these small variations at the same time. But…what about if the encoder could choose a small crop of the image on every frame and examine each portion of the number one at a time? That would make the work more easy, right? The same logic applies for generating the number. The attention unit will determine where to draw the next portion of the number 8 -or any other-, while the latent vector passed will determine if the decoder generates a thicker area or a thinner area. Basically, if we think of the latent code in a as a vector that represents the entire image, the latent codes in DRAW can be thought of as vectors that represent a pen stroke. Eventually, a sequence of these vectors creates a recreation of the original image. VAE (variational auto-encoder) Ok, But how does it really work? In a recurrent VAE model, the encoder takes in the entire input image at every single timestep. In DRAW we need to focus in the attention gate between the two of them, so the encoder only receives the portion of our image that the network deems is important at that timestep. That first attention gate is called the attention. “read” The “read” attention consists in two parts: Choosing the important portionCropping the image and forget about other parts Choosing the important portion of an image In order to determine which part of the image to focus on, we need some sort of observation to make a decision based on. In DRAW, we use the previous timestep’s decoder hidden state. Using a simple fully-connected layer, we can map the hidden state to three parameters that represent our square crop: center x, center y, and the scale. Cropping the image Now, instead of encoding the entire image, we crop it so only a small part of the image is encoded. This code is then passed through the system, and decoded back into a small patch. We now arrive to the second part of our attention gate, the attention, which have the same setup as the “read” section, except that the “write” attention gate uses the current decoder instead of the previous timestep’s decoder. “write” Wait…is that really done in practice? While describing the attention mechanism as a crop makes sense intuitively, in practice, a different method is used. The model structure described above is still accurate, but a matrix of instead of a crop is used. gaussian filters In DRAW, we take an array of gaussian filters, each with their centers spaced apart evenly. Show me the money…or the code instead We will use Eric Jang’s code as a base but we will clean it up a bit and comment it in order to make it more easy to understand tensorflow tf tensorflow.examples.tutorials mnist tensorflow.examples.tutorials.mnist input_data numpy np scipy.misc os # first we import our libraries import as from import from import import as import import Eric provide us with some great functions that will help us build our “read” and “write” attention gates as well as a function to filter the initial state that we will use below, but first, we will need add new functions that will allow us to create a dense layer as well as merge the images and save them into our local machine for our updated code. tf.variable_scope(scope ): matrix = tf.get_variable( , [inputFeatures, outputFeatures], tf.float32, tf.random_normal_initializer(stddev= )) bias = tf.get_variable( , [outputFeatures], initializer=tf.constant_initializer( )) with_w: tf.matmul(x, matrix) + bias, matrix, bias : tf.matmul(x, matrix) + bias h, w = images.shape[ ], images.shape[ ] img = np.zeros((h * size[ ], w * size[ ])) idx, image enumerate(images): i = idx % size[ ] j = idx / size[ ] img[j*h:j*h+h, i*w:i*w+w] = image img scipy.misc.toimage(img, cmin= , cmax= ).save(name) # fully-conected layer : def dense (x, inputFeatures, outputFeatures, scope=None, with_w=False) with or "Linear" "Matrix" 0.02 "bias" 0.0 if return else return # merge images : def merge (images, size) 1 2 0 1 for in 1 1 return # save image on local machine : def ims (name, img) # print img[:10][:10] 0 1 Let’s now put the code all together for the sake of completion. self.mnist = input_data.read_data_sets( , one_hot= ) self.n_samples = self.mnist.train.num_examples self.img_size = self.attention_n = self.n_hidden = self.n_z = self.sequence_length = self.batch_size = self.share_parameters = self.images = tf.placeholder(tf.float32, [ , ]) self.e = tf.random_normal((self.batch_size, self.n_z), mean= , stddev= ) self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple= ) self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple= ) self.cs = [ ] * self.sequence_length self.mu, self.logsigma, self.sigma = [ ] * self.sequence_length, [ ] * self.sequence_length, [ ] * self.sequence_length h_dec_prev = tf.zeros((self.batch_size, self.n_hidden)) enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32) dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32) x = self.images t range(self.sequence_length): c_prev = tf.zeros((self.batch_size, self.img_size** )) t == self.cs[t ] x_hat = x - tf.sigmoid(c_prev) r = self.read_basic(x,x_hat,h_dec_prev) r.get_shape() self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat( , [r, h_dec_prev])) z = self.sampleQ(self.mu[t],self.sigma[t]) z.get_shape() h_dec, dec_state = self.decode_layer(dec_state, z) h_dec.get_shape() self.cs[t] = c_prev + self.write_basic(h_dec) h_dec_prev = h_dec self.share_parameters = self.generated_images = tf.nn.sigmoid(self.cs[ ]) self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log( + self.generated_images) + ( -self.images) * tf.log( + - self.generated_images), )) kl_terms = [ ]*self.sequence_length t xrange(self.sequence_length): mu2 = tf.square(self.mu[t]) sigma2 = tf.square(self.sigma[t]) logsigma = self.logsigma[t] kl_terms[t] = * tf.reduce_sum(mu2 + sigma2 - *logsigma, ) - self.sequence_length* self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms)) self.cost = self.generation_loss + self.latent_loss optimizer = tf.train.AdamOptimizer( , beta1= ) grads = optimizer.compute_gradients(self.cost) i,(g,v) enumerate(grads): g : grads[i] = (tf.clip_by_norm(g, ),v) self.train_op = optimizer.apply_gradients(grads) self.sess = tf.Session() self.sess.run(tf.initialize_all_variables()) i xrange( ): xtrain, _ = self.mnist.train.next_batch(self.batch_size) cs, gen_loss, lat_loss, _ = self.sess.run([self.cs, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: xtrain}) % (i, gen_loss, lat_loss) i % == : cs = /( +np.exp(-np.array(cs))) cs_iter xrange( ): results = cs[cs_iter] results_square = np.reshape(results, [ , , ]) results_square.shape ims( +str(i)+ +str(cs_iter)+ ,merge(results_square,[ , ])) tf.variable_scope(scope, reuse=self.share_parameters): parameters = dense(h_dec, self.n_hidden, ) gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split( , ,parameters) gx = (self.img_size+ )/ * (gx_ + ) gy = (self.img_size+ )/ * (gy_ + ) sigma2 = tf.exp(log_sigma2) delta = (self.img_size - ) / ((self.attention_n ) * tf.exp(log_delta)) self.filterbank(gx,gy,sigma2,delta) + (tf.exp(log_gamma),) grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32),[ , ]) mu_x = gx + (grid_i - self.attention_n/ - ) * delta mu_y = gy + (grid_i - self.attention_n/ - ) * delta mu_x = tf.reshape(mu_x, [ , self.attention_n, ]) mu_y = tf.reshape(mu_y, [ , self.attention_n, ]) im = tf.reshape(tf.cast(tf.range(self.img_size), tf.float32), [ , , ]) sigma2 = tf.reshape(sigma2, [ , , ]) Fx = tf.exp(-tf.square((im - mu_x) / ( *sigma2))) Fy = tf.exp(-tf.square((im - mu_x) / ( *sigma2))) Fx = Fx / tf.maximum(tf.reduce_sum(Fx, ,keep_dims= ), ) Fy = Fy / tf.maximum(tf.reduce_sum(Fy, ,keep_dims= ), ) Fx, Fy tf.concat( ,[x,x_hat]) Fx, Fy, gamma = self.attn_window( , h_dec_prev) Fxt = tf.transpose(Fx, perm=[ , , ]) img = tf.reshape(img, [ , self.img_size, self.img_size]) glimpse = tf.batch_matmul(Fy, tf.batch_matmul(img, Fxt)) glimpse = tf.reshape(glimpse, [ , self.attention_n** ]) glimpse * tf.reshape(gamma, [ , ]) x = filter_img(x, Fx, Fy, gamma) x_hat = filter_img(x_hat, Fx, Fy, gamma) tf.concat( , [x, x_hat]) tf.variable_scope( ,reuse=self.share_parameters): hidden_layer, next_state = self.lstm_enc(image, prev_state) tf.variable_scope( , reuse=self.share_parameters): mu = dense(hidden_layer, self.n_hidden, self.n_z) tf.variable_scope( , reuse=self.share_parameters): logsigma = dense(hidden_layer, self.n_hidden, self.n_z) sigma = tf.exp(logsigma) mu, logsigma, sigma, next_state mu + sigma*self.e tf.variable_scope( , reuse=self.share_parameters): hidden_layer, next_state = self.lstm_dec(latent, prev_state) hidden_layer, next_state tf.variable_scope( , reuse=self.share_parameters): decoded_image_portion = dense(hidden_layer, self.n_hidden, self.img_size** ) decoded_image_portion tf.variable_scope( , reuse=self.share_parameters): w = dense(hidden_layer, self.n_hidden, self.attention_n** ) w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n]) Fx, Fy, gamma = self.attn_window( , hidden_layer) Fyt = tf.transpose(Fy, perm=[ , , ]) wr = tf.batch_matmul(Fyt, tf.batch_matmul(w, Fx)) wr = tf.reshape(wr, [self.batch_size, self.img_size** ]) wr * tf.reshape( /gamma, [ , ]) model = draw_model() model.train() # DRAW implementation : class draw_model () : def __init__ (self) # First we download the MNIST dataset into our local machine. "data/" True print "------------------------------------" print "MNIST Dataset Succesufully Imported" print "------------------------------------" # We set up the model parameters # ------------------------------ # image width,height 28 # read glimpse grid width/height 5 # number of hidden units / output size in LSTM 256 # QSampler output size 10 # MNIST generation sequence length 10 # training minibatch size 64 # workaround for variable_scope(reuse=True) False # Build our model None 784 # input (batch_size * img_size) 0 1 # Qsampler noise True # encoder Op True # decoder Op # Define our state variables 0 # sequence of canvases 0 0 0 # Initial states # Construct the unrolled computational graph for in # error image + original image 2 if 0 else -1 # read the image #sanity check print # encode to guass distribution 1 # sample from the distribution to get z #sanity check print # retrieve the hidden layer of RNN #sanity check print # map from hidden layer True # from now on, share variables # Loss function -1 1e-10 1 1e-10 1 1 0 for in 0.5 2 1 0.5 # each kl term is (1xminibatch) # Optimization 1e-3 0.5 for in if is not None 5 # Our training function : def train (self) for in 20000 print "iter %d genloss %f latloss %f" if 500 0 1.0 1.0 # x_recons=sigmoid(canvas) for in 10 -1 28 28 print "results/" "-step-" ".jpg" 8 8 # Eric Jang's main functions # -------------------------- # locate where to put attention filters on hidden layers : def attn_window (self, scope, h_dec) with 5 # center of 2d gaussian on a scale of -1 to 1 1 5 # move gx/gy to be a scale of -imgsize to +imgsize 1 2 1 1 2 1 # distance between patches 1 -1 # returns [Fx, Fy, gamma] return # Construct patches of gaussian filters : def filterbank (self, gx, gy, sigma2, delta) # 1 x N, look like [[0,1,2,3,4]] 1 -1 # individual patches centers 2 0.5 2 0.5 -1 1 -1 1 # 1 x 1 x imgsize, looks like [[[0,1,2,3,4,...,27]]] 1 1 -1 # list of gaussian curves for x and y -1 1 1 2 2 # normalize area-under-curve 2 True 1e-8 2 True 1e-8 return # read operation without attention : def read_basic (self, x, x_hat, h_dec_prev) return 1 # read operation with attention : def read_attention (self, x, x_hat, h_dec_prev) "read" # apply parameters for patch of gaussian filters : def filter_img (img, Fx, Fy, gamma) 0 2 1 -1 # apply the gaussian patches -1 2 # scale using the gamma parameter return -1 1 return 1 # encoder function for attention patch : def encode (self, prev_state, image) # update the RNN with our image with "encoder" # map the RNN hidden state to latent variables with "mu" with "sigma" return : def sampleQ (self, mu, sigma) return # decoder function : def decode_layer (self, prev_state, latent) # update decoder RNN using our latent variable with "decoder" return # write operation without attention : def write_basic (self, hidden_layer) # map RNN hidden state to image with "write" 2 return # write operation with attention : def write_attention (self, hidden_layer) with "writeW" 2 "write" 0 2 1 2 return 1.0 -1 1 You can see the full notebook on my . github page About the author: is a graduate from the University of Barcelona. He is the head of Data Science at and city lead at . Recently co-founded Samuel Noriega Master of Data Science Shugert Analytics Saturdays.ai Roomies.es