Stephen Rimac

Image Classification with Convolutional Neural Networks

“Close up of modern metal sculpture of human face at Canary Wharf” by Clem Onojeghuo on Unsplash

Train and evaluate a world-class deep learning model in 4 lines of code, and under 21 seconds

by Stephen Rimac

The following contains notes and python code, compiled from a lecture given by Jeremy Howard, co-founder of Many thanks to Jeremy and Rachel Thomas for building and the library, a high-level wrapper for PyTorch. The following code is based on the libary. For more information, watch the first lesson (of seven) in Practical Deep Learning For Coders, Part 1, which is publically available free of charge. If you are keen to learning deep learning, you won’t regret it!

Table of Contents

  1. Introduction to our first task: ‘Dogs vs Cats’
  2. First look at cat pictures
  3. Model
  4. Analyzing results: looking at pictures

1. Introduction to our first task: Dogs vs Cats

We’re going to use convolutional neural networks (CNNs) to allow our computer to see — something that is only possible thanks to deep learning.

We’re going to try to create a deep learning CNN model based on data from a previous Kaggle competion called Dogs vs Cats. There are 25,000 labelled dog and cat photos available for training, and 12,500 in the test set that we have to try to label for this competition. According to the Kaggle web-site, when this competition was launched (end of 2013):

“State of the art: The current literature suggests machine classifiers can score above 80% accuracy on this task”.

So if we can beat 80%, then we will be at the cutting edge as of 2013! Ok, let’s going.

Put these at the top of every notebook, to get automatic reloading and inline plotting:

%reload_ext autoreload
%autoreload 2
%matplotlib inline

Here we import the libraries we need:

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
PATH = "data/dogscats/"

We set the size below to 224 because resnet uses 224 x 224 image sizes. More on this later:


Data download

The dataset is available at You can download it directly on your server by running the following line in your terminal. wget You should put the data in a subdirectory of your Jupyter notebook's directory, called data/.

2. First look at cat pictures

The library will assume that you have train and valid directories. It also assumes that each dir will have subdirs for each class you wish to recognize (in this case, ‘cats’ and ‘dogs’).

Below will show the contents of ‘PATH’ folder; !means run in bash.

!ls {PATH}
!ls {PATH}valid

The following code shows what’s inside the validation cats folder. This is a standard way to share or provide image classification files.

files = !ls {PATH}valid/cats | head
# Example: show first cat image in the cats folder
img = plt.imread(f'{PATH}valid/cats/{files[0]}') # This is formatting string

I am cute

Below code shows what the raw data looks like. This is called a rank 3 tensor aka a 3 x 3 matrix. Each cell shows red, green, and blue pixel values btwn 0 and 255.

(198, 179, 3)
array([[[ 29, 20, 23],
[ 31, 22, 25],
[ 34, 25, 28],
[ 37, 28, 31]],

[[ 60, 51, 54],
[ 58, 49, 52],
[ 56, 47, 50],
[ 55, 46, 49]],

[[ 93, 84, 87],
[ 89, 80, 83],
[ 85, 76, 79],
[ 81, 72, 75]],

[[104, 95, 98],
[103, 94, 97],
[102, 93, 96],
[102, 93, 96]]], dtype=uint8)

3. Model

We’re going to use a pre-trained model, that is, a model created by some one else to solve a different problem. Instead of building a model from scratch to solve a similar problem, we’ll use a model trained on ImageNet (1.2 million images and 1000 classes) as a starting point. The model is a Convolutional Neural Network (CNN), a type of Neural Network that builds state-of-the-art models for computer vision.

We will be using the resnet34 as our pre-trained model. resnet34 is a version of the model that won the 2015 ImageNet competition. Here is more info on resnet models.

Here’s how to train and evalulate a dogs vs cats model in 4 lines of code, and under 21 seconds. Under the syntax hood below is code/wrapper written by The library is updated regularly and keeps up with cuttting-edge deep leaerning research. So makes sure that best practices are always used. In turn, this works supper fast (e.g., 10–60 seconds depending on GPU) because it sits on top of Pytorch, which is a very flexible library written by facebook.

data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz))
learn = ConvLearner.pretrained(arch, data, precompute=True), 3)
100%|██████████| 360/360 [00:57<00:00,  6.24it/s]
100%|██████████| 32/32 [00:05<00:00, 6.04it/s]
epoch      trn_loss   val_loss   accuracy                     
0 0.045726 0.028603 0.989258
1 0.039685 0.026488 0.990234
2 0.041631 0.03259 0.990234
[0.032590486, 0.990234375]

data object contains the training and validation data.

ImageClassifierData.from_paths reads in images and their labels given as sub-folder names:

  • path: a root path of the data (used for storing trained models, precomputed values, etc)
  • bs: batch size. Default 64.
  • tfms: transformations (for data augmentations). e.g. output of tfms_from_model. Default 'None'
  • trn_name: a name of the folder that contains training images. Default 'train'
  • val_name: a name of the folder that contains validation images. Default 'valid'
  • test_name: a name of the folder that contains test images. Default 'None'
  • num_workers: number of workers. Default '8'

learn object contains the model.


  • f: arch. E.g., resnet34
  • data: previously defined data object
  • precompute: include/exclude precomputed activations. Default 'False' trains/fits the model through a given learning rate and epochs. In this instance, it is going to do 3 epochs with a 0.01 learning rate, meaning it is going to look at each image three times in total.

trn_loss and val_loss are the values of the cross-entropy loss function.

How good is this model? Well, prior to this competition, the state of the art was 80% accuracy. But the competition resulted in a huge jump to 99.0% accuracy, with the author of a popular deep learning library winning the competition. Extraordinarily, less than 4 years later, we can now beat that result in seconds!

Above model, can be used on any kind of pictures, as long as it is of things that people normally take photos of. However, things like pathology pictures or CT scans won’t do well using this model. There are some minor things we need to do to make those work. This will be covered in a subsequent notebook. Stay tuned!

4. Analyzing results: looking at pictures

As well as looking at the overall metrics, it’s also a good idea to look at examples of some of the predictions:

  1. A few correct labels at random
  2. A few incorrect labels at random
  3. The most correct labels of each class (i.e., those with highest probability that are correct)
  4. The most incorrect labels of each class (i.e., those with highest probability that are incorrect)
  5. The most uncertain labels (i.e., those with probability closest to 0.5).

We will look at all of this shortly. But first, if we ever want to know about the data, we can look inside with a few of the following methods:

# Pull the label (dependent variable)
# Pull the data classes
# Pull the y log predictions
log_preds = learn.predict()
# Pull first 10 predictions
# from log probabilities to 0 or 1
preds = np.argmax(log_preds, axis=1)
# pr(dog); i.e., anti log
probs = np.exp(log_preds[:,1])

Plotting functions

def rand_by_mask(mask): return np.random.choice(np.where(mask)[0], 4, replace=False)
def rand_by_correct(is_correct): return rand_by_mask((preds == data.val_y)==is_correct)
def plot_val_with_title(idxs, title):
imgs = np.stack([data.val_ds[x][0] for x in idxs])
title_probs = [probs[x] for x in idxs]
return plots(data.val_ds.denorm(imgs), rows=1, titles=title_probs)
def plots(ims, figsize=(12,6), rows=1, titles=None):
f = plt.figure(figsize=figsize)
for i in range(len(ims)):
sp = f.add_subplot(rows, len(ims)//rows, i+1)
if titles is not None: sp.set_title(titles[i], fontsize=16)
def load_img_id(ds, idx): return np.array([idx]))
def plot_val_with_title(idxs, title):
imgs = [load_img_id(data.val_ds,x) for x in idxs]
title_probs = [probs[x] for x in idxs]
return plots(imgs, rows=1, titles=title_probs, figsize=(16,8))
def most_by_mask(mask, mult):
idxs = np.where(mask)[0]
return idxs[np.argsort(mult * probs[idxs])[:4]]

def most_by_correct(y, is_correct):
mult = -1 if (y==1)==is_correct else 1
return most_by_mask(((preds == data.val_y)==is_correct) & (data.val_y == y), mult)

A few correct labels at random

Anything greater than 0.5 is dog; anything less than 0.5 is cat:

plot_val_with_title(rand_by_correct(True), "Correctly classified")

A few incorrect labels at random

Anything greater than 0.5 is dog; anything less than 0.5 is cat:

plot_val_with_title(rand_by_correct(True), "Correctly classified")

The most correct labels of each class

(i.e., those with highest probability that are correct)

plot_val_with_title(most_by_correct(0, True), "Most correct cats")
plot_val_with_title(most_by_correct(1, True), "Most correct dogs")
I like the grey ones

The most incorrect labels of each class

(i.e., those with highest probability that are incorrect)

plot_val_with_title(most_by_correct(0, False), "Most incorrect cats")
plot_val_with_title(most_by_correct(1, False), "Most incorrect dogs")

The most uncertain labels

(i.e., those with probability closest to 0.5)

# probabilites are closest to 0.5
most_uncertain = np.argsort(np.abs(probs -0.5))[:4]
plot_val_with_title(most_uncertain, "Most uncertain predictions")

nb: The images above that are dimensionally wrong (e.g., the rectangular ones) are skewing the results. We take care of this using a technique called data augmentation. More on that in a later post.

Pro tip: if you want to make the model better, you might want to take advantage of why it is doing well and fix the things that it is doing badly. E.g., in another Jupyter notebook, try removing images that are just skewing the data, like cartoons, etc. If you figure out how to do this, let me know :)

See GitHub link below for all of above work.

Thanks for reading!

Topics of interest

More Related Stories