Before you go, check out these stories!

Hackernoon logoShakespeare Meets Google's Flax by@fabian

Shakespeare Meets Google's Flax

Author profile picture

@fabianFabian Deuser

Some are born great, some achieve greatness, and some have greatness thrust upon them.

William Shakespeare, Twelfth Night, or What You Will

Google Researcher introduced Flax, a new rising star in Machine Learning, a few months ago. A lot has happened since then and the pre-release has improved tremendous. My own experiments with CNNs on Flax are bearing fruit and I am still amazed about the flexibility compared to Tensorflow. Today I will show you an application of RNNs in Flax: Character-Level Language Model.

In many learning tasks, we do not have to consider temporal dependencies on the previous inputs. 

But what can we do if we do not have independent fixed size input and output vectors? What if we have sequences of vectors? The solution is the Recurrent Neural Networks. They allow us to operate on sequences of vectors as described below. 

Recurrent Neural Network

In the above picture you can see different types of in- and output architectures:

  • one-to-one is our typical CNN or Multilayer-Perceptron, one input vector is mapped to one output vector.
  • one-to-many is a good RNN-architecture for image captioning. The input is our image and the output are a sequence of words that describe our image.
  • many-to-many:
    The first architecture utilizes the input sequence to output sequence for machine translation e.g. German to English.The second is good for video captioning on frame level.

The main advantage of RNNs are that they do not only rely on current input, but also on the previous inputs

A RNN is a cell with an internal hidden state h initialized with zeros depending on the hidden size. In each timestep t we insert the input x_t into our RNN-cell and update also the hidden state. Now in the next timestep t+1 the hidden state is not initialized with zeros again, but with the previous hidden state. Therefore RNNs allow to keep information about several time steps and to generate sequences.

Character-Level Language Model

With our new knowledge we want now build a first application for our RNN. The Character-Level Language Model is the foundation for many tasks e.g. image captioning or text generation. The input to the RNN-cells are hugh chunks of text in the form of character sequence. Now the training task is to learn how to predict the next character, given a sequence of previous characters. So we generate one character at each timestep t and our previous characters are x_t-1, x_t-2 ,… .

As an example let’s take the word FUZZY as our training sequence, now the vocabulary is now {‘f’,’u’,’z’,’y’}. Because the RNN only works with vectors we convert all character to so called one-hot-vectors. A one-hot-vector consists out of zeros with a one based on the position in the vocabulary, for ‘Z’ the converted vector is [0,0,1,0]. In the following picture you can see an example for the given input “FUZZ” and we want to predict the end of the word “UZZY”. The hidden size of our neurons is four and we want the green numbers in the output layer to be high and the red ones low.

If you are interested in the math behind RNNs please follow the link.

Finally, we are coding

Please note that I explained some basic concepts of Flax in the previous article about CNNs. As dataset we use the tiny-shakespeare that consists out of conversations like this:

Tis even so; yet you are Warwick still.
Come, Warwick, take the time; kneel down, kneel down: Nay, when? strike now, or else the iron cools.

I used again Google Colab for the training, so we have to install the necessary PIP-Packages again:

pip install -q --upgrade`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-linux_x86_64.whl jax
pip install -q git+

You should use a runtime with GPU support, because the training task is extrem demanding. You can test the presence of the GPU support with:

from jax.lib import xla_bridge

Now we are ready to create our RNN from scratch:

class RNN(flax.nn.Module):
    def apply(self, carry, inputs):
        carry1, outputs = jax_utils.scan_in_dim(
            nn.LSTMCell.partial(name='lstm1'), carry[0], inputs, axis=1)
        carry2, outputs = jax_utils.scan_in_dim(
            nn.LSTMCell.partial(name='lstm2'), carry[1], outputs, axis=1)
        carry3, outputs = jax_utils.scan_in_dim(
            nn.LSTMCell.partial(name='lstm3'), carry[2], outputs, axis=1)
        x = nn.Dense(outputs, features=params['vocab_length'], name='dense')
        return [carry1, carry2, carry3], x

In a real training situation like this we do not use vanilla RNN cells, but LSTM cells. These are a further development which can deal better with the problem of the vanishing gradient. To achieve a higher accuracy I use three stacked LSTM-Cells. It is very important that we pass the ouput of the first cell to the next and also initialize each LSTM-cell with a own hidden state. Otherwise we lose track of the temporal dependencies. 

The output of the last LSTM-cell is given to our dense layer. The dense layer has the size of our vocabulary. In our previous example with 'FUZZY' the number of neurons would be four. If 'FUZZ' is set as input to our RNN, the neurons should at best produce an output like [1.7,0.1,-1.0,3.1], because this output indicates 'Y' as the most probable character.

Because we have two different modes we wrap our RNN in another module for the different cases.

class charRNN(flax.nn.Module):
    """Char Generator"""
    def apply(self, inputs, carry_pred=None, train=True):
        batch_size = params['batch_size']
        vocab_size = params['vocab_length']
        hidden_size = 512
        if train:
            carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
            carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
            carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),hidden_size)
            carry = [carry1, carry2, carry3]
            _, x = RNN(carry, inputs)
            return x
            carry, x = RNN(carry_pred, inputs)
            return carry, x

This cases are:

  • Training-Mode, where we want to learn how to predict
  • Predict-Mode, where we actually sample some text

Before we can train our model we need to create it with the following function:

def create_model(rng):
    """Creates a model."""
    vocab_size = params['vocab_length']
    _, initial_params = charRNN.init_by_shape(
        rng, [((1, params['seq_length'], vocab_size), jnp.float32)])
    model = nn.Model(charRNN, initial_params)
    return model

Each of our sequences has a length of 50 chars and we have a vocabulary of 65 different characters.

As optimizer for our RNN I choose the Adam optimizer with an initial learning rate of 0.002 and a weight decay to avoid too large weights.

def create_optimizer(model, learning_rate):
    """Creates an Adam optimizer for model."""
    optimizer_def = optim.Adam(learning_rate=learning_rate, weight_decay=1e-1)
    optimizer = optimizer_def.create(model)
    return optimizer

The Training Mode

In the training mode we feed a batch of 32 sequences into our RNN. Every sequence is taken from our dataset and contains two subsequences, one with the characters from 0 to 49 and one with the characters from 1 to 50. With this simple split, our network can learn the most likely next character. In each batch we initialize the hidden states and feed the sequences to our RNN.

def train_step(optimizer, batch):
    """Train one step."""
    def loss_fn(model):
        """Compute cross-entropy loss and predict logits of the current batch"""

        logits = model(batch[0])        
        loss = jnp.mean(cross_entropy_loss(logits, batch[1])) / params['batch_size']
        return loss, logits

    def exponential_decay(steps):
        """Decrease the learning rate every 5 epochs"""
        x_decay = (steps / params['step_decay']).astype('int32')
        ret = params['learning_rate']* jax.lax.pow((params['learning_rate_decay']), x_decay.astype('float32'))
        return jnp.asarray(ret, dtype=jnp.float32)

    current_step = optimizer.state.step
    new_lr = exponential_decay(current_step)
    # calculate and apply the gradient 
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grad = grad_fn(
    new_optimizer = optimizer.apply_gradient(grad, learning_rate=new_lr)

    metrics = compute_metrics(logits, batch[1])
    metrics['learning_rate'] = new_lr
    return new_optimizer, metrics

Within our training method we have two subfunctions. The loss_fn calculates the cross entropy loss, by comparing the output neurons interpreted as a vector with the desired one-hot vector. So again in our 'FUZZY' example we would have an output [1.7,0.1,-1.0,3.1] and a one-hot-vector [0,0,0,1]. We now calculate the loss with this formula:

I had to rewrite the code a bit from the CNN example because we now work with sequences not with simple classes:

def cross_entropy_loss(logits, labels):
      """Returns cross-entropy loss."""
      return -jnp.mean(jnp.sum(nn.log_softmax(logits) * labels))

The other method in the training step is exponential_decay. I use the Adam-Optimizer with an initial learning rate of 0.002. But with every five epochs I want to decrease the learning rate to avoid too strong oscillations. After each five epochs the factor 0.97ˣ is multiplicated with our initial learning rate, x is the how often we reached five epochs.

Again you can see the strength of Flax, the easy and flexibel way how you can integrate your own learning rate schedulers on the fly.

The Predict-Mode

Now we want to evaluate our learned model, therefore we pick one random character out of our vocabulary as an entry point. Like in training we initialize our hidden state, but this time only in the beginning of the sampling. The subfunction inference now takes one character as an input. For the hidden state we output them after every timestep and feed them into our RNN in the next timestep. Thus we do not loose our temporal dependencies.

def sample(inputs, optimizer):
    next_inputs = inputs
    output = []
    batch_size = 1 
    carry1 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
    carry2 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
    carry3 = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,),512)
    carry = [carry1, carry2, carry3]

    def inference(model, carry):
        carry, rnn_output = model(inputs=next_inputs, train=False, carry_pred=carry)
        return carry, rnn_output
    for i in range(200):
        carry, rnn_output = inference(, carry)
        output.append(jnp.argmax(rnn_output, axis=-1))
        # Select the argmax as the next input.
        next_inputs = jnp.expand_dims(common_utils.onehot(jnp.argmax(rnn_output), params['vocab_length']), axis=0)
    return output      

This method is called "greedy-sampeling", because we always take the character with the highest probability in our output vector. There are better methods of sampling, like Beam-Search, that I do not cover here.

The training and sample loop

At least we can call all our written functions in our training and sample loop.

def train_model():
    """Train and inference """
    rng = jax.random.PRNGKey(0)
    model = create_model(rng)
    optimizer = create_optimizer(model, params['learning_rate'])

    del model
    for epoch in range(100):

        for text in tfds.as_numpy(ds):
            optimizer, metrics = train_step(optimizer, text)

        print('epoch: %d, loss: %.4f, accuracy: %.2f, LR: %.8f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100, metrics['learning_rate']))
        test = test_ds(params['vocab_length'])
        sampled_text = ""

        if ((epoch+1)%10 == 0):
            for i in test:
                sampled_text += vocab[int(jnp.argmax(i.numpy(),-1))]
                start = np.expand_dims(i, axis=0)
                text = sample(start, optimizer)

            for i in text:
                sampled_text += vocab[int(i)]

After every 10 epochs we generate an example of our text, and at the beginning it looks very repetitive:

peak the mariners all the merchant of the meaning of the meaning of the meaning of the meaning of the meaning of the meaning…

But we become better and better and after 100 epochs of training the output looks like Shakespeare is still alive and is writing new texts!

This is a shift respected woman to the king's forth,
To this most dangerous soldier there and fortune.

If she would concount a sight on honour
Of the moon, why,...

The training accuracy after 100 epochs is 86.10% and our learning rate decayed to 0.00112123.


The Character-Level Language Model in its foundation is a powerful tool to complete texts and can be used as an auto completion. Also sentiment of a given text can be learned utilizing this concept. But generating complete new texts is a very hard task as you see. The output sentences of our model look like a Shakespeare text but it lacks of meaning. In a further article I will use this kind of model and create more meaningful sentences based on a meaningful input.

Flax despite its powerful and numerous tools is still in an early stage of development, but they are on a good way in developing a framework I like. What was really ingenious was that I only had to change my "old" CNN code a little bit to use RNN on the existing foundation.
But Flax is still missing its own input pipeline, thus I have write this with Tensorflow. You can find the code for the dataset creation and the complete RNN in the Github Repo.

If you want to try my code yourself, just have a look into my Github Repo. Otherwise I can recommend you the Flax Github Repo and their documentation.

Images are inspired from this blog.


Join Hacker Noon

Create your free account to unlock your custom reading experience.