When I first came across DeepMind’s paper “Learning to learn by gradient descent by gradient descent”, my reaction was “Wow, how on earth does that work?”. Unfortunately, my first read-through of the paper was not that illuminating, and looking at the code was quite daunting.
Thankfully, I was intrigued enough to force myself to re-read the paper in depth and it actually turned out to be surprisingly simple in the end. Personally, what really helps me when I’m trying to understand something is to create the simplest non-trivial version of the problem and then scale up from there. So here it is, the simplest version of the idea I could create, also no maths equations! I hope you find it illuminating.
I suggest skimming the paper first but this should all be understandable without.
The IPython Notebook version of this article with syntax highlighting can be found here. Code was ran on TF v1.0.0-rc1.
So let’s get started. Time to learn about learning to learn by gradient descent by gradient descent by reading my article!
import tensorflow as tf
First of all we need a problem for our meta-learning optimizer to solve. Let’s take the simplest experiment from the paper; finding the minimum of a multi-dimensional quadratic function. We are going to randomly scale the parabolas and start at random locations, the solution is always at (0, 0).
DIMS = 10 # Dimensions of the parabola
scale = tf.random_uniform([DIMS], 0.5, 1.5)
# This represents the network/function we are trying to optimize,# the `optimizee' as it's called in the paper.# Actually, it's more accurate to think of this as the error# landscape.def f(x):x = scale*xreturn tf.reduce_sum(x*x)
We can’t easily use TensorFlow’s built in optimizers here since the technique requires us to unroll the training loop inside the computation graph, as we’ll see in a bit. So let’s define a couple of simple hand crafted optimizers to test against our learned optimizer. As discussed in the paper, an optimizer is a function g that takes the gradient of a parameter at a given step and returns back the step you should take in parameter space for that parameter. Here’s vanilla gradient descent: (Some optimizers need to keep track of state, here I just pass the param through)
def g_sgd(gradients, state, learning_rate=0.1):return -learning_rate*gradients, state
For a stronger baseline let’s use RMSProp:
def g_rms(gradients, state, learning_rate=0.1, decay_rate=0.99):if state is None:state = tf.zeros(DIMS)state = decay_rate*state + (1-decay_rate)*tf.pow(gradients, 2)update = -learning_rate*gradients / (tf.sqrt(state)+1e-5)return update, state
Great, now let’s unroll all the training steps, here learn is a function which takes one of these optimizers and applies it in a loop for number of steps and collects the value of the function f at each point, which we can think of as our loss.
TRAINING_STEPS = 20 # This is 100 in the paper
initial_pos = tf.random_uniform([DIMS], -1., 1.)
def learn(optimizer):losses = []x = initial_posstate = Nonefor _ in range(TRAINING_STEPS):loss = f(x)losses.append(loss)grads, = tf.gradients(loss, x)
update, state = optimizer(grads, state)
x += update
return losses
OK, now let’s test it out.
sgd_losses = learn(g_sgd)rms_losses = learn(g_rms)
And see what the losses look like.
sess = tf.InteractiveSession()sess.run(tf.global_variables_initializer())
import matplotlibimport matplotlib.pyplot as plt%matplotlib inlineimport numpy as np
x = np.arange(TRAINING_STEPS)for _ in range(3):sgd_l, rms_l = sess.run([sgd_losses, rms_losses])p1, = plt.plot(x, sgd_l, label='SGD')p2, = plt.plot(x, rms_l, label='RMS')plt.legend(handles=[p1, p2])plt.title('Losses')plt.show()
RMS Prop outperforms vanilla gradient descent here as expected. Note that nothing out of the ordinary has happened so far, I’ve simply hand rolled my own optimizers and unrolled the entire training into a single computational graph, which is generally not recommended since you would very quickly run out of memory!
Time to put together our meta-learning optimizer, we are going to use the same architecture as in the paper: an LSTM with 2 layers and 20 hidden units.
LAYERS = 2STATE_SIZE = 20
cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(STATE_SIZE) for _ in range(LAYERS)])cell = tf.contrib.rnn.InputProjectionWrapper(cell, STATE_SIZE)cell = tf.contrib.rnn.OutputProjectionWrapper(cell, 1)cell = tf.make_template('cell', cell)
def g_rnn(gradients, state):# Make a `batch' of single gradients to create a# "coordinate-wise" RNN as the paper describes.gradients = tf.expand_dims(gradients, axis=1)
if state is None:
state = \[\[tf.zeros(\[DIMS, STATE\_SIZE\])\] \* 2\] \* LAYERS
update, state = cell(gradients, state)
# Squeeze to make it a single batch again.
return tf.squeeze(update, axis=\[1\]), state
And that’s it, that’s our meta learner. We can use it in exactly the same way:
rnn_losses = learn(g_rnn)sum_losses = tf.reduce_sum(rnn_losses)
Now here’s the magic bit, we want sum_losses to be low, since the lower the losses, the better the optimizer right? And because the entire training loop is in the graph we can use Back-Propagation Through Time (BPTT) and a meta-optimizer to minimize this value!
And this is the main point: sum_losses is differentiable, gradients flow through the graph we’ve defined just fine! TensorFlow is able to work out the gradients of the parameters in our LSTM with respect to this sum of losses. So let’s optimize:
def optimize(loss):optimizer = tf.train.AdamOptimizer(0.0001)gradients, v = zip(*optimizer.compute_gradients(loss))gradients, _ = tf.clip_by_global_norm(gradients, 1.)return optimizer.apply_gradients(zip(gradients, v))
apply_update = optimize(sum_losses)
I found gradient clipping to be critical here since the values that go into our meta-learner can be very large at the start of training.
sess.run(tf.global_variables_initializer())
ave = 0for i in range(3000):err, _ = sess.run([sum_losses, apply_update])ave += errif i % 1000 == 0:print(ave / 1000 if i!=0 else ave)ave = 0print(ave / 1000)
> 223.577606201> 15.9170453466> 4.06150362206> 3.94412120444
And see how it does:
for _ in range(3):sgd_l, rms_l, rnn_l = sess.run([sgd_losses, rms_losses, rnn_losses])p1, = plt.plot(x, sgd_l, label='SGD')p2, = plt.plot(x, rms_l, label='RMS')p3, = plt.plot(x, rnn_l, label='RNN')plt.legend(handles=[p1, p2, p3])plt.title('Losses')plt.show()
Success! Looks like it’s doing even better than RMS on this problem. Actually, these graphs are slightly misleading, log scale shows something slightly different:
for _ in range(3):sgd_l, rms_l, rnn_l = sess.run([sgd_losses, rms_losses, rnn_losses])p1, = plt.semilogy(x, sgd_l, label='SGD')p2, = plt.semilogy(x, rms_l, label='RMS')p3, = plt.semilogy(x, rnn_l, label='RNN')plt.legend(handles=[p1, p2, p3])plt.title('Losses')plt.show()
I think the reason for this, again as discussed in the paper, is that the magnitude of the values being fed into the LSTM can vary wildly and neural networks generally do not perform well when that happens. Here the gradients get so small that it isn’t able to compute sensible updates. The paper uses a solution to this for the bigger experiments; feed in the log gradient and the direction instead. See the paper for details.
Hopefully, now that you understand how learn to learn by gradient descent by gradient descent you can see the limitations. It doesn’t seem very scalable. I think it is quite telling that the experiments in the paper are very small. It takes 4000 steps for even our toy problem to converge, we had to train a network completely just for one step of optimization for the meta-learner. We would have to optimize a large problem many more times and would take a very long time. Also unrolling the entire training loop in the graph is not feasible for larger problems, although in the paper they only unroll the BPTT to 20 steps. There is also evidence in the paper that the RNN optimizer can generalize from smaller problem to larger problems.
I have left a lot of details out for simplicity, so reading the paper is worth it for sure. I expect meta-learning to become increasingly more important, for more inspiration I suggest watching these NIPS presentations.