Site Color

Text Color

Ad Color

Text Color





Sign Up to Save Your Colors


From TF to TFLite: Deploying ML Models on Mobile [Part 1]โ€‚by@oxymoron_31

From TF to TFLite: Deploying ML Models on Mobile [Part 1]

Nagarakshitha B R Hacker Noon profile picture

Nagarakshitha B R

tl;dr - Link to code: TensorFlow GAN model.

So the other day I was talking to my rubber ducky about how G-Board predicts my next word, even when those words are entirely made up by me, in that how it actually learns on-device. How amazingly Netflix, Amazon, Google Maps make use of machine learning in their apps. How does machine learning on apps even work? Does the model learn even after being deployed? Can I deploy a GAN model on mobile?


Baby steps, ducky, baby steps. In this two-article series, we are gonna learn about deploying two basic models on your mobile phone, in a single app.It involves deploying a GAN model to generate image if handwritten digits and a classifier model to detect the digit. I have split this into two articles for the sake of clarity and modularity . These are the two stories I will talk about -

  • Part 1 : The TensorFlow Story: Intricacies involved in TensorFlow model conversion and information extraction
  • Part 2 : The Android Story: Android(Java) code-snippet walk-through, to use the converted .tflite model for our example application.

Before jumping right into it, let us see the whys and whats of deploying a model and what is GAN, in order to understand the stories better.

What does it mean to have an ML model deployed in an app?

It means that you can get the power of prediction, recommendation and personalization on your app, which is a key factor for major tech companies all over the world - Netflix, Spotify, Tinder, Snapchat ,the list is endless.

So what are the technical advantages of this?

  1. Data privacy is improved since data doesn't have to leave your device at all !
  2. No internet connection is needed if your model is fully deployed on phone.Hence the power consumption can be considerably reduced.
  3. Quicker results since there is no need for communication with the server at all times.

Well, those amazing apps mentioned above use much complex architectures and state-of the art recommender systems to make use of ML on app. But the basic crux is to know how to deploy a model on an app.

I was also fascinated by GANs. So for our use case, we are gonna use the generator model of GAN and a classifier model to detect handwritten digits generated by the GAN. Basically the generator is pitted against the classifier to detect the images it generates.
Cool stuff :p To understand better, let us look at an overview of what a GAN is.

What is GAN?

In ELI5 language, Generative Adversarial Network (GAN) is a neural network that generates never-seen-before data based on what it is trained upon. It is like a mocking bird that trains to perfection to mock a certain set of features, only, the new sound is actually completely new but very convincing to seem familiar. It was first proposed in 2014 by Ian Goodfellow and his fellow researchers in this paper. It was a really refreshing development in the Deep Learning field after a pretty long time.

โ€œthe most interesting idea in the last 10 years in ML.โ€
-Yann LeCun (The guy who invented CNN)

There are two networks in it -

  • Discriminator network - a classifier that distinguishes between real and fake image.
  • Generator network - tries to generate as images that seem as real as possible by hoodwinking the discriminator.

The diagram shows that the discriminator is fed with training set and generator output, while generator is given random noise as input. On generating a new image, it is sent to the discriminator for verdict. If the image is very close to the features of training data, it is classified as real and the generator has manged to fool it and has generated a new image altogether, else, it is classified as fake and the errors are back-propagated to the generator for improvements. The Math behind GAN is pretty interesting and I strongly encourage you to go ahead and learn about it.
Watch this for more info.

Pertaining to our example, I have trained the GAN on MNIST handwriting data-set. So basically the generator mimics handwritten digits.

There is no fundamental difference in deploying a GAN model from a any other ML model

Let us get back to the Part 1 of the story of how to convert your model to TensorFlow Lite model and how to extract the right information from it for deploying.

Part 1 : The TensorFlow story

So you have trained and saved your TensorFlow model file (extension .h5 or .pb). Find out more about TensorFlow saved models here. TensorFlow has provided a set of tools to help deploy TensorFlow models on mobile, IoT and embedded devices - lo and behold, TensorFlow Lite. To use our model on an android devices, we have to use a TensorFlow Lite model.
How does it work? Look at the following conversion diagram.


    The TfLite converter takes a TensorFlow model and converts it into a TfLite FlatBuffer file (.tflite). Thisย file can be then deployed to a client device (e.g. mobile) and run locally using the TensorFlow Lite interpreter.

    I have trained a GAN to generate images of handwritten digits. Here is the link to my GoogleColab notebook : GAN-MNIST example.
    I have saved the model as a .h5 file (Refer to the GoogleColab Notebook above).

    So the next step is to convert your regular TensorFlow model to a .tflite model using this little piece of code :

    import tensorflow as tf
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.experimental_new_converter = True
    tflite_model = converter.convert()
    open("/path_to_save_tflite_model/converted_model.tflite", "wb").write(tflite_model)

    But your model is a Pytorch saved model? We got you covered. Refer to this article for converting it into a TfLite model - Pytorch to TensorFlow model with ONNX. Yay! You now have a TfLite model ready to be added to your android app!

    Next we need to get the proper format of the input and output to be provided to the model. We can use this really amazing neural network visualiser,Netron, to get the model architecture diagram and model properties. The following are the properties of the models used in the app we are building (GAN (left) and Classifier(right) ) -


    All we need are these pieces of information.

    GAN model (left):

    Input : A float32 array of size [1x100] forming the latent points as input to the generator.

    Output : A float32 array of [1x28x28x1] which is to be rendered as an image on the app.

    Classifier model(right):

    Input: A float32 array of [1x28x28x1] (so convenient)

    Output: A float32 array of [1x10] size. This represents the probability calculated for each digit.

    The same information about the model can also be found using TensorFlow Lite Interpreter in python with this snippet of code.

    import numpy as np
    import tensorflow as tf
    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path="/path/converted_model1.tflite")
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    input_shape = input_details[0]['shape']
    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    output_data = interpreter.get_tensor(output_details[0]['index'])

    The Interpreter loads the model, allocates tensors and we can easily print out the input and output details. The above code generates the following output -



    The above gives us the shape of the input and outputs. Try printing input_details and output_details to get all the parameters of the model which will look something like this -

    [{'name': 'dense_2_input', 'index': 0, 'shape': array([  1, 100], dtype=int32), 'shape_signature': array([  1, 100], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
    [{'name': 'Identity', 'index': 22, 'shape': array([ 1, 28, 28,  1], dtype=int32), 'shape_signature': array([ 1, 28, 28,  1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

    We will dive into android (Java) in the next article where we go through code snippets for loading the models and using them in the app discussed above. See you soon in the next part!