As simple as possible in TensorFlow When I first came across DeepMind’s paper “ ”, my reaction was “Wow, how on earth does that work?”. Unfortunately, my first read-through of the paper was not that illuminating, and looking at the was quite daunting. Learning to learn by gradient descent by gradient descent code Thankfully, I was intrigued enough to force myself to re-read the paper in depth and it actually turned out to be surprisingly simple in the end. Personally, what really helps me when I’m trying to understand something is to create the simplest non-trivial version of the problem and then scale up from there. So here it is, the simplest version of the idea I could create, also I hope you find it illuminating. no maths equations! I suggest skimming the paper first but this should all be understandable without. The version of this article with syntax highlighting can be found . Code was ran on TF v1.0.0-rc1. IPython Notebook here So let’s get started. Time to learn about learning to learn by gradient descent by gradient descent by reading my article! import tensorflow as tf First of all we need a problem for our meta-learning optimizer to solve. Let’s take the simplest experiment from the paper; finding the minimum of a multi-dimensional quadratic function. We are going to randomly scale the parabolas and start at random locations, the solution is always at (0, 0). DIMS = 10 # Dimensions of the parabola scale = tf.random_uniform([DIMS], 0.5, 1.5) # This represents the network/function we are trying to optimize,# the `optimizee' as it's called in the paper.# Actually, it's more accurate to think of this as the error# landscape.def f(x):x = scale*xreturn tf.reduce_sum(x*x) We can’t easily use TensorFlow’s built in optimizers here since the technique requires us to , as we’ll see in a bit. So let’s define a couple of simple hand crafted optimizers to test against our learned optimizer. As discussed in the paper, an optimizer is a function that takes the gradient of a parameter at a given step and returns back the step you should take in parameter space for that parameter. Here’s vanilla gradient descent: (Some optimizers need to keep track of state, here I just pass the param through) unroll the training loop inside the computation graph g def g_sgd(gradients, state, learning_rate=0.1):return -learning_rate*gradients, state For a stronger baseline let’s use RMSProp: def g_rms(gradients, state, learning_rate=0.1, decay_rate=0.99):if state is None:state = tf.zeros(DIMS)state = decay_rate*state + (1-decay_rate)*tf.pow(gradients, 2)update = -learning_rate*gradients / (tf.sqrt(state)+1e-5)return update, state Great, now let’s unroll all the training steps, here is a function which takes one of these optimizers and applies it in a loop for number of steps and collects the value of the function at each point, which we can think of as our loss. learn f TRAINING_STEPS = 20 # This is 100 in the paper initial_pos = tf.random_uniform([DIMS], -1., 1.) def learn(optimizer):losses = []x = initial_posstate = Nonefor _ in range(TRAINING_STEPS):loss = f(x)losses.append(loss)grads, = tf.gradients(loss, x) update, state = optimizer(grads, state) x += update return losses OK, now let’s test it out. sgd_losses = learn(g_sgd)rms_losses = learn(g_rms) And see what the losses look like. sess = tf.InteractiveSession()sess.run(tf.global_variables_initializer()) import matplotlibimport matplotlib.pyplot as plt%matplotlib inlineimport numpy as np x = np.arange(TRAINING_STEPS)for _ in range(3):sgd_l, rms_l = sess.run([sgd_losses, rms_losses])p1, = plt.plot(x, sgd_l, label='SGD')p2, = plt.plot(x, rms_l, label='RMS')plt.legend(handles=[p1, p2])plt.title('Losses')plt.show() RMS Prop outperforms vanilla gradient descent here as expected. Note that nothing out of the ordinary has happened so far, I’ve simply hand rolled my own optimizers and unrolled the entire training into a single computational graph, which is generally not recommended since you would very quickly run out of memory! Meta-Learning Time to put together our meta-learning optimizer, we are going to use the same architecture as in the paper: an LSTM with 2 layers and 20 hidden units. LAYERS = 2STATE_SIZE = 20 cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(STATE_SIZE) for _ in range(LAYERS)])cell = tf.contrib.rnn.InputProjectionWrapper(cell, STATE_SIZE)cell = tf.contrib.rnn.OutputProjectionWrapper(cell, 1)cell = tf.make_template('cell', cell) def g_rnn(gradients, state):# Make a `batch' of single gradients to create a# "coordinate-wise" RNN as the paper describes.gradients = tf.expand_dims(gradients, axis=1) if state is None: state = \[\[tf.zeros(\[DIMS, STATE\_SIZE\])\] \* 2\] \* LAYERS update, state = cell(gradients, state) # Squeeze to make it a single batch again. return tf.squeeze(update, axis=\[1\]), state And , that’s our meta learner. We can use it in exactly the same way: that’s it rnn_losses = learn(g_rnn)sum_losses = tf.reduce_sum(rnn_losses) Now here’s the magic bit, we want to be low, since the lower the losses, the better the optimizer right? And because the entire training loop is in the graph we can use Back-Propagation Through Time (BPTT) and a meta-optimizer to minimize this value! sum_losses And this is the main point: , gradients flow through the graph we’ve defined just fine! TensorFlow is able to work out the gradients of the parameters in our LSTM with respect to this sum of losses. So let’s optimize: sum_losses is differentiable def optimize(loss):optimizer = tf.train.AdamOptimizer(0.0001)gradients, v = zip(*optimizer.compute_gradients(loss))gradients, _ = tf.clip_by_global_norm(gradients, 1.)return optimizer.apply_gradients(zip(gradients, v)) apply_update = optimize(sum_losses) I found gradient clipping to be here since the values that go into our meta-learner can be very large at the start of training. critical sess.run(tf.global_variables_initializer()) ave = 0for i in range(3000):err, _ = sess.run([sum_losses, apply_update])ave += errif i % 1000 == 0:print(ave / 1000 if i!=0 else ave)ave = 0print(ave / 1000) > 223.577606201> 15.9170453466> 4.06150362206> 3.94412120444 And see how it does: for _ in range(3):sgd_l, rms_l, rnn_l = sess.run([sgd_losses, rms_losses, rnn_losses])p1, = plt.plot(x, sgd_l, label='SGD')p2, = plt.plot(x, rms_l, label='RMS')p3, = plt.plot(x, rnn_l, label='RNN')plt.legend(handles=[p1, p2, p3])plt.title('Losses')plt.show() Success! Looks like it’s doing even better than RMS on this problem. Actually, these graphs are slightly misleading, log scale shows something slightly different: for _ in range(3):sgd_l, rms_l, rnn_l = sess.run([sgd_losses, rms_losses, rnn_losses])p1, = plt.semilogy(x, sgd_l, label='SGD')p2, = plt.semilogy(x, rms_l, label='RMS')p3, = plt.semilogy(x, rnn_l, label='RNN')plt.legend(handles=[p1, p2, p3])plt.title('Losses')plt.show() I think the reason for this, again as discussed in the paper, is that the magnitude of the values being fed into the LSTM can vary wildly and neural networks generally do not perform well when that happens. Here the gradients get so small that it isn’t able to compute sensible updates. The paper uses a solution to this for the bigger experiments; feed in the log gradient and the direction instead. See the paper for details. Hopefully, now that you understand how learn to learn by gradient descent by gradient descent you can see the limitations. It doesn’t seem very scalable. I think it is quite telling that the experiments in the paper are very small. It takes 4000 steps for even our toy problem to converge, we had to train a network completely just for one step of optimization for the meta-learner. We would have to optimize a large problem many more times and would take a very long time. Also unrolling the entire training loop in the graph is not feasible for larger problems, although in the paper they only unroll the BPTT to 20 steps. There is also evidence in the paper that the RNN optimizer can generalize from smaller problem to larger problems. I have left a lot of details out for simplicity, so reading the paper is worth it for sure. I expect meta-learning to become increasingly more important, for more inspiration I suggest watching . these NIPS presentations