The base for the calculations is JAX instead of NumPy, which is also a Google research project. One of the biggest advantages of JAX is the use of XLA, a special compiler for linear algebra, that enables execution on GPUs and TPUs as well.
For those who do not know, TPU (tensor processing unit) is a specific chip optimized for Machine Learning. JAX reimplements parts of NumPy to run your functions on a GPU/TPU.
Flax focuses on key points like:
Enough praises, now let’s start coding.
Because the MNIST-Example becomes boring I will build an Image Classification for the Simpsons Family, unfortunately, Maggie is missing in the dataset :-( .
Sample Images of the Dataset
First, we install the necessary libraries and unzip our dataset. Unfortunately you will still need Tensorflow at this point because Flax misses a good data input pipeline.
pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-linux_x86_64.whl jax
pip install -q git+https://github.com/google/flax.git@dev-setup
pip install tensorflow
pip install tensorflow_datasets
unzip simpsons_faces.zip
Now we import the libraries. You see we have two “versions” of numpy, the normal numpy lib and the one part of the API that JAX implements. The print statement prints CPU, GPU or TPU out according to the available hardware.
from jax.lib import xla_bridge
import jax
import flax
import numpy as onp
import jax.numpy as jnp
import csv
import tensorflow as tf
import tensorflow_datasets as tfds
print(xla_bridge.get_backend().platform)
For training and evaluation we first have to create two Tensorflow datasets and convert them to numpy/jax arrays, because FLAX doesn’t take TF data types. This is currently a bit hacky, because the evaluation method doesn’t take batches.
I had to create one large batch for the eval step and create a TF feature dictionary from it, which is now parsable and can be fed to our eval step after each epoch.
def train():
train_ds = create_dataset(tf.estimator.ModeKeys.TRAIN)
test_ds = create_dataset(tf.estimator.ModeKeys.EVAL)
test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
#test_ds is one giant batch
test_ds = test_ds.batch(1000)
#test ds is a feature dictonary!
test_ds = tf.compat.v1.data.experimental.get_single_element(test_ds)
test_ds = tfds.as_numpy(test_ds)
test_ds = {'image': test_ds[0].astype(jnp.float32), 'label': test_ds[1].astype(jnp.int32)}
_, initial_params = CNN.init_by_shape(jax.random.PRNGKey(0), [((1, 160, 120, 3), jnp.float32)])
model = flax.nn.Model(CNN, initial_params)
optimizer = flax.optim.Momentum(learning_rate=0.01, beta=0.9, weight_decay=0.0005).create(model)
for epoch in range(50):
for batch in tfds.as_numpy(train_ds):
optimizer = train_step(optimizer, batch)
metrics = eval(optimizer.target, test_ds)
print('eval epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100))
The CNN-class contains our convolutional neural network. When you are familiar with Tensorflow/Pytorch you see it's pretty straight forward. Every call of our flax.nn.Conv defines a learnable kernel.
I used the MNIST-Example and extended it with some additional layers. In the end, we have our Dense-Layer with four output neurons, because we have a four-class problem.
class CNN(flax.nn.Module):
def apply(self, x):
x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=16, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = flax.nn.Dense(x, features=256)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=64)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=4)
x = flax.nn.softmax(x)
return x
Unlike in Tensorflow, the activation function is called explicitly, this makes it very easy to test new and own written activation functions. FLAX is based on the module abstraction and both initiating and calling the network is done with the apply function.
Of course, we want to measure how good our network becomes. Therefore, we compute our metrics like loss and accuracy. Our accuracy is then computed with the JAX library, instead of NumPy because we can use JAX on TPU/GPU.
def compute_metrics(logits, labels):
loss = jnp.mean(cross_entropy_loss(logits, labels))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return {'loss': loss, 'accuracy': accuracy}
To measure our loss we use the Cross Entropy Loss, unlike in Tensorflow it is calculated by yourself, we do not have the possibility to use ready-made loss objects yet. As you can see we use
@jax.vmap
as a function decorator for our loss function. This vectorizes our code for running on batches efficiently.@jax.vmap
def cross_entropy_loss(logits, label):
return -jnp.log(logits[label])
How does the
cross_entropy_loss
work? @jax.vmap
takes both arrays, logits and label, and performs our cross_entropy_loss
on each pair, thus allowing the parallel calculation of a batch. The cross entropy formula for a single example is:Our ground truth y is 0 or 1 for one of the four output neurons, therefore we do not need the sum formula in our code, because we just calculate the log(y_hat) of the correct label. The mean in our loss calculation is used because we have batches.
In our train step, we use again a function decorator,
@jax.jit
, for speeding up our function. This works very similar to Tensorflow. Please have in mind batch[0]
is our image data and batch[1]
our label.@jax.jit
def train_step(optimizer, batch):
def loss_fn(model):
logits = model(batch[0])
loss = jnp.mean(cross_entropy_loss(
logits, batch[1]))
return loss
grad = jax.grad(loss_fn)(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer
The loss function loss_fn returns the loss for our current model,
optimizer.target
, and our jax.grad()
calculates its gradient. After the calculation we apply the gradient like in Tensorflow.The eval step is very simple and minimalistic in Flax. Please note that the complete evaluation dataset is passed to this function.
@jax.jit
def eval(model, eval_ds):
logits = model(eval_ds['image'])
return compute_metrics(logits, eval_ds['label'])
After 50 epochs we have a very high accuracy. Of course, we can continue to tweak the model and optimize hyperparameter.
For this experiment, I used Google Colab, so if you want to test it yourself create a new environment with a GPU/TPU and import my notebook from Github. Please note that FLAX is not working under Windows at the moment.
It is important to note that FLAX is currently still in alpha and is not an official Google product.
The work so far gives hope for a fast, lightweight and highly customizable ML framework. What is completely missing so far is a data-input pipeline, so Tensorflow still has to be used.
The current set of optimizers is unfortunately limited to ADAM and SGD with Momentum. I especially liked the very strict forward direction of how to use this framework and the high flexibility.
My next plans are to develop some activation features that are not yet available. Also a speed comparison between Tensorflow, PyTorch and FLAX would also be very interesting.
If you want to try a little bit with FLAX, check out the documentation and their Github page.
And if you want to download my example with dataset just clone SimpsonsFaceRecognitionFlax.