In this article, we’ll try to replicate the approach used by the FastAI team to win the Stanford DAWNBench competition by training a model that achieves 94% accuracy on the CIFAR-10 dataset in under 3 minutes.
The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. There are 50,000 training images (5,000 per class) and 10,000 test images. Here are 10 random images from each class:
You can download the data here or by running the following commands:
cd data wget http://files.fast.ai/data/cifar10.tgz tar -xf cifar10.tgz
Once the data is downloaded, start the Jupyter notebook server using the
command and create a new notebook called
Let’s define a helper function to create data loaders with data augmentation:
import torchvision.transforms as tt from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader from fastai.dataset import ModelData def get_data(bs, num_workers): PATH = "data/cifar10/" trn_dir, val_dir = PATH + 'train', PATH + 'test' stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Data transforms (normalization & data augmentation) tfms = [tt.ToTensor(), tt.Normalize(*stats)] aug_tfms = tt.Compose([tt.RandomCrop(32, padding=4), tt.RandomHorizontalFlip()] + tfms) # PyTorch datasets trn_ds = ImageFolder(trn_dir, aug_tfms) val_ds = ImageFolder(val_dir, tt.Compose(tfms)) aug_ds = ImageFolder(val_dir, aug_tfms) # PyTorch data loaders trn_dl = DataLoader(trn_ds, batch_size=bs, shuffle=True, num_workers=num_workers, pin_memory=True) val_dl = DataLoader(val_ds, batch_size=bs, shuffle=False, num_workers=num_workers, pin_memory=True) aug_dl = DataLoader(aug_ds, batch_size=bs, shuffle=False, num_workers=num_workers, pin_memory=True) # FastAI model data data = ModelData(PATH, trn_dl, val_dl) data.aug_dl = aug_dl data.sz = 32 return data
A few things to note about
as the validation dataset, to keep things simple.
contains channel-wise means and standard deviations for entire dataset, and is used to normalize the data.
applies data augmentation to the validation dataset. It is used for test time augmentation (TTA).
We’ll use a model called WideResNet-22, inspired from the family of architectures introduced in the paper Wide Residual Networks. It has the following architecture:
A few notable aspects of the architecture:
Conv(size, input_channels, output_channels, stride=1)
Let’s first implement a generic module class for creating the residual blocks:
import torch.nn as nn import torch.nn.functional as F def conv_2d(ni, nf, stride=1, ks=3): return nn.Conv2d(in_channels=ni, out_channels=nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False) def bn_relu_conv(ni, nf): return nn.Sequential(nn.BatchNorm2d(ni), nn.ReLU(inplace=True), conv_2d(ni, nf)) class BasicBlock(nn.Module): def __init__(self, ni, nf, stride=1): super().__init__() self.bn = nn.BatchNorm2d(ni) self.conv1 = conv_2d(ni, nf, stride) self.conv2 = bn_relu_conv(nf, nf) self.shortcut = lambda x: x if ni != nf: self.shortcut = conv_2d(ni, nf, stride, 1) def forward(self, x): x = F.relu(self.bn(x), inplace=True) r = self.shortcut(x) x = self.conv1(x) x = self.conv2(x) * 0.2 return x.add_(r)
Next, let’s define a generic
class which will allow us to create a network with
blocks per group and a factor
which can be used to adjust the width of the network i.e. the number of channels. It also adds the pooling and linear layers at the end.
def make_group(N, ni, nf, stride): start = BasicBlock(ni, nf, stride) rest = [BasicBlock(nf, nf) for j in range(1, N)] return [start] + rest class Flatten(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.view(x.size(0), -1) class WideResNet(nn.Module): def __init__(self, n_groups, N, n_classes, k=1, n_start=16): super().__init__() # Increase channels to n_start using conv layer layers = [conv_2d(3, n_start)] n_channels = [n_start] # Add groups of BasicBlock(increase channels & downsample) for i in range(n_groups): n_channels.append(n_start*(2**i)*k) stride = 2 if i>0 else 1 layers += make_group(N, n_channels[i], n_channels[i+1], stride) # Pool, flatten & add linear layer for classification layers += [nn.BatchNorm2d(n_channels), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(n_channels, n_classes)] self.features = nn.Sequential(*layers) def forward(self, x): return self.features(x) def wrn_22(): return WideResNet(n_groups=3, N=3, n_classes=10, k=6)
Finally, we can also create a helper function for WideResNet-22, which has 3 groups, 3 residual blocks per group and
. It’s always a good idea to define flexible and generic models, so that you can easily experiment with deeper or wider networks.
Let’s define a couple of helper functions for instantiating the model and evaluating the results:
from fastai.conv_learner import ConvLearner, num_cpus, accuracy def get_learner(arch, bs): """Create a FastAI learner using the given model""" data = get_data(bs, num_cpus()) learn = ConvLearner.from_model_data(arch.cuda(), data) learn.crit = nn.CrossEntropyLoss() learn.metrics = [accuracy] return learn def get_TTA_accuracy(learn): """Calculate accuracy with Test Time Agumentation(TTA)""" preds, targs = learn.TTA() preds = 0.6 * preds + 0.4 * preds[1:].sum(0) return accuracy_np(preds, targs)
Finally, let’s train the model using the 1 cycle policy, which involves gradually increasing the learning rate and decreasing the momentum till about halfway into the cycle, and then doing the opposite. Here’s what it looks like:
On a 6-core Intel i5 CPU and NVIDIA GTX 1080 Ti, the training takes about 15 minutes. You might see slightly different results depending on your hardware. Here’s a plot of the loss, learning rate and momentum over time:
And that’s it! Feel free to play around with the network architecture, learning rate, cycle length and other factors to try and get a better result in a shorter time. You can find the entire code for this post in .