Thinking of Machine Learning, the first frameworks that come to mind are Tensorflow and PyTorch, which are currently the state-of-the-art frameworks if you want to work with Deep Neural Networks. Technology is changing rapidly and more flexibility is needed, so Google researchers are developing a new high performance framework for the open source community: Flax. The base for the calculations is JAX instead of NumPy, which is also a Google research project. One of the biggest advantages of JAX is the use of XLA, a special compiler for linear algebra, that . enables execution on GPUs and TPUs as well For those who do not know, TPU (tensor processing unit) is a specific chip optimized for Machine Learning. JAX reimplements parts of NumPy to run your functions on a GPU/TPU. Flax focuses on key points like: to read code easy , instead of bad abstraction or inflated functions prefers duplication , seems they learned from the Tensorflow error messages helpful error messages easy expandability of basic implementations Enough praises, now let’s start coding. Because the MNIST-Example becomes boring I will build an Image Classification for the Simpsons Family, unfortunately, Maggie is missing in the dataset :-( . Sample Images of the Dataset First, we install the necessary libraries and unzip our dataset. Unfortunately you will still need Tensorflow at this point because Flax misses a good data input pipeline. pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En `/jaxlib-0.1.42-`python3 -V | sed -En `-none-linux_x86_64.whl jax pip install -q git+https://github.com/google/flax.git@dev-setup pip install tensorflow pip install tensorflow_datasets unzip simpsons_faces.zip "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p" "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p" Now we import the libraries. You see we have two “versions” of numpy, the normal numpy lib and the one part of the API that JAX implements. The print statement prints CPU, GPU or TPU out according to the available hardware. jax.lib xla_bridge jax flax numpy onp jax.numpy jnp csv tensorflow tf tensorflow_datasets tfds print(xla_bridge.get_backend().platform) from import import import import as import as import import as import as For training and evaluation we first have to create two Tensorflow datasets and convert them to numpy/jax arrays, because FLAX doesn’t take TF data types. This is currently a bit hacky, because the evaluation method doesn’t take batches. I had to create one large batch for the eval step and create a TF feature dictionary from it, which is now parsable and can be fed to our eval step after each epoch. train_ds = create_dataset(tf.estimator.ModeKeys.TRAIN) test_ds = create_dataset(tf.estimator.ModeKeys.EVAL) test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE) test_ds = test_ds.batch( ) test_ds = tf.compat.v1.data.experimental.get_single_element(test_ds) test_ds = tfds.as_numpy(test_ds) test_ds = { : test_ds[ ].astype(jnp.float32), : test_ds[ ].astype(jnp.int32)} _, initial_params = CNN.init_by_shape(jax.random.PRNGKey( ), [(( , , , ), jnp.float32)]) model = flax.nn.Model(CNN, initial_params) optimizer = flax.optim.Momentum(learning_rate= , beta= , weight_decay= ).create(model) epoch range( ): batch tfds.as_numpy(train_ds): optimizer = train_step(optimizer, batch) metrics = eval(optimizer.target, test_ds) print( % (epoch+ ,metrics[ ], metrics[ ] * )) : def train () #test_ds is one giant batch 1000 #test ds is a feature dictonary! 'image' 0 'label' 1 0 1 160 120 3 0.01 0.9 0.0005 for in 50 for in 'eval epoch: %d, loss: %.4f, accuracy: %.2f' 1 'loss' 'accuracy' 100 The Model The CNN-class contains our convolutional neural network. When you are familiar with Tensorflow/Pytorch you see it's pretty straight forward. Every call of our flax.nn.Conv defines a learnable kernel. I used the MNIST-Example and extended it with some additional layers. In the end, we have our Dense-Layer with four output neurons, because we have a four-class problem. x = flax.nn.Conv(x, features= , kernel_size=( , )) x = flax.nn.relu(x) x = flax.nn.avg_pool(x, window_shape=( , ), strides=( , )) x = flax.nn.Conv(x, features= , kernel_size=( , )) x = flax.nn.relu(x) x = flax.nn.avg_pool(x, window_shape=( , ), strides=( , )) x = flax.nn.Conv(x, features= , kernel_size=( , )) x = flax.nn.relu(x) x = flax.nn.avg_pool(x, window_shape=( , ), strides=( , )) x = flax.nn.Conv(x, features= , kernel_size=( , )) x = flax.nn.relu(x) x = flax.nn.avg_pool(x, window_shape=( , ), strides=( , )) x = flax.nn.Conv(x, features= , kernel_size=( , )) x = flax.nn.relu(x) x = flax.nn.avg_pool(x, window_shape=( , ), strides=( , )) x = x.reshape((x.shape[ ], )) x = flax.nn.Dense(x, features= ) x = flax.nn.relu(x) x = flax.nn.Dense(x, features= ) x = flax.nn.relu(x) x = flax.nn.Dense(x, features= ) x = flax.nn.softmax(x) x : class CNN (flax.nn.Module) : def apply (self, x) 128 3 3 2 2 2 2 128 3 3 2 2 2 2 64 3 3 2 2 2 2 32 3 3 2 2 2 2 16 3 3 2 2 2 2 0 -1 256 64 4 return Unlike in Tensorflow, the activation function is called explicitly, this makes it very easy to test new and own written activation functions. FLAX is based on the module abstraction and both initiating and calling the network is done with the apply function. Metrics in FLAX Of course, we want to measure how good our network becomes. Therefore, we compute our metrics like loss and accuracy. Our accuracy is then computed with the JAX library, instead of NumPy because we can use JAX on TPU/GPU. loss = jnp.mean(cross_entropy_loss(logits, labels)) accuracy = jnp.mean(jnp.argmax(logits, ) == labels) { : loss, : accuracy} : def compute_metrics (logits, labels) -1 return 'loss' 'accuracy' To measure our loss we use the Cross Entropy Loss, unlike in Tensorflow it is calculated by yourself, we do not have the possibility to use ready-made loss objects yet. As you can see we use as a function decorator for our loss function. This vectorizes our code for running on batches efficiently. @jax.vmap -jnp.log(logits[label]) @jax.vmap : def cross_entropy_loss (logits, label) return How does the work? takes both arrays, logits and label, and performs our on each pair, thus allowing the parallel calculation of a batch. The cross entropy formula for a single example is: cross_entropy_loss @jax.vmap cross_entropy_loss Our ground truth y is 0 or 1 for one of the four output neurons, therefore we do not need the sum formula in our code, because we just calculate the log(y_hat) of the correct label. The mean in our loss calculation is used because we have batches. Training In our train step, we use again a function decorator, , for speeding up our function. This works very similar to Tensorflow. Please have in mind is our image data and our label. @jax.jit batch[0] batch[1] logits = model(batch[ ]) loss = jnp.mean(cross_entropy_loss( logits, batch[ ])) loss grad = jax.grad(loss_fn)(optimizer.target) optimizer = optimizer.apply_gradient(grad) optimizer @jax.jit : def train_step (optimizer, batch) : def loss_fn (model) 0 1 return return The loss function loss_fn returns the loss for our current model, , and our calculates its gradient. After the calculation we apply the gradient like in Tensorflow. optimizer.target jax.grad() The eval step is very simple and minimalistic in Flax. Please note that the complete evaluation dataset is passed to this function. logits = model(eval_ds[ ]) compute_metrics(logits, eval_ds[ ]) @jax.jit : def eval (model, eval_ds) 'image' return 'label' After 50 epochs we have a very high accuracy. Of course, we can continue to tweak the model and optimize hyperparameter. For this experiment, I used Google Colab, so if you want to test it yourself create a new environment with a GPU/TPU and import my notebook from . Please note that FLAX is not working under Windows at the moment. Github Conclusions It is important to note that and is not an official Google product. FLAX is currently still in alpha The work so far gives hope for a . What is completely missing so far is a data-input pipeline, so Tensorflow still has to be used. fast, lightweight and highly customizable ML framework The current set of optimizers is unfortunately limited to ADAM and SGD with Momentum. I especially liked the very strict forward direction of how to use this framework and the high flexibility. My next plans are to develop some activation features that are not yet available. Also a speed comparison between Tensorflow, PyTorch and FLAX would also be very interesting. If you want to try a little bit with FLAX, check out the and their . documentation Github page And if you want to download my example with dataset just clone . SimpsonsFaceRecognitionFlax