Transfer learning with MXNet Gluonby@weixu
5,081 reads
5,081 reads

Transfer learning with MXNet Gluon

by Jason XuJanuary 6th, 2018
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

A year ago, I started learning neural network with Tensorflow. The journey is not as smooth as I thought. Thanks to <a href="" data-anchor-type="2" data-user-id="592ce2a67248" data-action-value="592ce2a67248" data-action="show-user-card" data-action-type="hover" target="_blank">Andrew Ng</a>’s online course and several books, I have a basic understand of the theory, however, when I try to apply it in real-life projects, the syntax and api of Tensorflow sometimes confused me (Maybe I wasn’t spending enough time with it).
featured image - Transfer learning with MXNet Gluon
Jason Xu HackerNoon profile picture

A year ago, I started learning neural network with Tensorflow. The journey is not as smooth as I thought. Thanks to Andrew Ng’s online course and several books, I have a basic understand of the theory, however, when I try to apply it in real-life projects, the syntax and api of Tensorflow sometimes confused me (Maybe I wasn’t spending enough time with it).

To get early result for my project, I directly applied the Tensorflow computer vision api. However, when it came to customising models, it took me a long time get everything sorted.

Until one day, I found MXNet Gluon.

What is MXNet Gluon

Gluon is an interface for MXNet — the deep learning framework supported by Amazon. Gluon is similar to other higher level api like Keras, Pytorch and Chainer. But it has its unique points to be loved:

  1. Imperative, and symbolic: Gluon enables you to enjoy the good part of both imperative framework and symbolic framework with its [HybridBlock]( . So you can develop and debug easily with your jupyter notebook, and enjoy the performance optimisation when you convert it to symbolic.
  2. Simple API without compromising flexibility: Unlike Pytorch or Chainer, you don’t need to remember the output size of each layer. The API definition is very similar to Keras. You can extend and build your own block with ease.

What is transfer learning

Transfer learning is a technique to reuse the learned representation of an existing model and apply it to a different but related domain.

The reason to use transfer learning is that it takes a long time and a lot of resources to train a neural network from scratch. Usually we use transfer learning in 2 ways:

  1. Initialise parameter with pretrained model
  2. Use pretrained model as fixed feature extractor and build model based on feature

Transfer learning as a topic itself can involve a long discussion. In this article, we will mainly look at using it to initialise parameters.

Let’s start

Problem definition

husky or akita

Given a photo of our favourite pet dog, is it a husky or akita.

Why this problem? It’s better to work on problem that interests you and dataset should be easy to find. (They are cute and it’s easy to find their photo)


  1. install MXNet: pip install mxnet
  2. install Jupyter Notebook: pip install jupyter

Prepare Dataset

Data source: google images by keyword husky and akita.

The dataset will be divided into 4 parts:

  1. sample: 16 images for development and debug
  2. train: 70%, to train model
  3. validation: 15%, to diagnose model, detect overfitting/underfitting and choose models if we have multiple versions
  4. test: 15%, to evaluate the accuracy of the model

MXNet provides a []( script to generate rec file, it requires a lst file to specify file locations and labels. The notebook for generating the lst file and rec file can be found here .

After generating the dataset, your directory structure will look like this:

Build the model

We will use MobileNet for our transfer learning task. MobileNet is an efficient convolutional neural network architecture. It applies 3x3 depthwise conv and a 1x1 pointwise conv to replace regular convolution layer, which reduces computation complexity.

regular convolution vs depthwise separable convolution (reference)

Thanks to the community of Gluon, we have the MobileNets pretrained on ImageNet. Let’s explore a bit on it with Jupyter notebook.

from import mobilenet1_0pretrained_net = mobilenet1_0(pretrained=True)print(pretrained_net)

Once you run it, you will find that the model consists of two high level blocks: features and output . As the output shape for ImageNet version is 1000, we need to create a model with output shape as 2.

net = mobilenet1_0(classes=2)

To reuse the weights of a pretrained model, we can directly assign it to the features block of the new model.

from mxnet import initnet.features = pretrained_net.featuresnet.output.initialize(init.Xavier())

That’s it ~

Train the model

Before training the model, let’s create the data loader for our husky & akita dataset. We will use the image augmentation functions provided by MXNet to generate new samples and help avoid overfitting.

from mxnet.image import color_normalizefrom mxnet import image

train_augs = [image.ResizeAug(224),image.HorizontalFlipAug(0.5), # flip the image horizontallyimage.BrightnessJitterAug(.3), # randomly change the brightnessimage.HueJitterAug(.1) # randomly change hue]test_augs = [image.ResizeAug(224)]

def transform(data, label, augs):data = data.astype('float32')for aug in augs:data = aug(data)data = nd.transpose(data, (2,0,1))return data, nd.array([label]).asscalar().astype('float32')

Now we can create a data iterator with the augmentations we defined.

from import ImageRecordDataset

train_rec = './data/train/dog.rec'validation_rec = './data/validation/dog.rec'

trainIterator = ImageRecordDataset(filename=train_rec,transform=lambda X, y: transform(X, y, train_augs))validationIterator = ImageRecordDataset(filename=validation_rec,transform=lambda X, y: transform(X, y, test_augs))

Let’s define our training function.

def train(net, ctx,batch_size=64, epochs=10, learning_rate=0.01, wd=0.001):

train\_data =  
    trainIterator, batch\_size, shuffle=True)  
validation\_data =  
    validationIterator, batch\_size)

loss = gluon.loss.SoftmaxCrossEntropyLoss()  
trainer = gluon.Trainer(net.collect\_params(), 'sgd', {  
    'learning\_rate': learning\_rate, 'wd': wd})  
train\_util(net, train\_data, validation\_data,   
           loss, trainer, ctx, epochs, batch\_size)

train_util function can be found here. It’s usually reused, so I will not post it here.

There are some key points to highlight:

  1. hybridize : function that transforms your imperative model to symbolic model, and the computation graph is optimised for faster training.
  2. Trainer : where we define our learning rate, optimisation function and weight decay

With everything in place, we can kick start the training!

After 5 epochs, we can see that the training accuracy and validation accuracy both went up.

training process


So we achieved 0.97 accuracy in training set, however, we can’t use training accuracy to evaluate a model. The reason is that this model maybe just overfitting the distribution of training data (Once I purposely overfitted training set and it goes to 99.9%, but it did poorly in real dataset).

To evaluate the performance of our model, we need to use our test dataset that is never used in training, model selection or parameter fine-tuning.

test_data_loader =, 64)test_acc = evaluate_accuracy(test_data_loader, net)

The test accuracy is 0.93 , so our model is not as good as we thought.


The validation accuracy reached 0.88 in first epoch, feels unreal

Yes, the reason is that husky is an existing label in ImageNet and our dataset is not very big

there’s always a gap between training accuracy and validation accuracy

It shows that we are overfitting our model. We can applied more regularisation by increase weight_decay or we should try to get more training data.


This article shows you how simple it is to use MXNet Gluon for transfer learning. It is also a guide when you try to apply this technique to your own problem.

There are still many interesting topics I haven’t covered, such as customising block, building your own architecture and deploying models in different environments. I will try to cover them in future posts.

Source code:


  1. On the importance of democratizing Artificial Intelligence
  2. CS231n Convolutional Neural Networks for Visual Recognition
  3. MXNet Gluon documentation
  5. MobileNets
  6. Google’s MobileNets on the iPhone