Whether you are new to deep learning or a seasoned veteran, setting up an environment for training a neural network can be painful sometimes. Would it be awesome to make training a neural network as simple as loading a webpage followed by a few clicks and you are ready to make inference with it right away?
In this tutorial, I will show you how to build a model with the on-browser framework TensorFlow.js with data collected from your webcam and train on your browser. To make the model useful, we will turn a webcam into a controller for the legendary game — Pong.
Let’s play the game first
Instruction to get the web application served locally on your computer,
- Download the dist.zip and extract it to your local machine.
- Install an HTTP server, my recommendation is to install http-server globally by npm,
npm install -g http-server
You ask what is npm? It is a package installer for Node.js like pip for Python and can be acquired here.
- Run the following command in a command line where the dist folder is located to serve the web app on your local machine at a port, say 1234.
http-server dist --cors -p 1234 -s
- Point a browser window to http://localhost:1234, I have tested on Chrome and Firefox.
- When the page finishes loading, start by collect training images for three moves, left, middle and right. One tip here, balance the training samples, maybe around 20 samples for each case.
- Click the “TRAIN”, and it will start the training process with the loss shows.
- The loss is not changing means the training is over, now click “PLAY” to start the game.
- If you want to start over, click “RESET”.
After you get tired of beating the computer with head or hand with your choice, let’s have a look at how the game is built. There are two models used in this tutorial, the first one is a pre-trained convolutional network exported from Keras, it is responsible for extracting image features from webcam images. The second model builds and trains on your browser to make predictions for game control with the image features. It is a regression model predicts values range between -1~1 to control the speed of the player’s paddle. It is essentially a transfer learning task. More on the topic of transfer learning, refer to my previous series — Gentle guide to setup Keras deep learning framework and build a travel recommendation engine. Without further due, let’s dive in. You can download the source code now from my GitHub repo webcam-pong.
Export pre-trained model to tfjs
You can skip this section if you just want to learn the web application part.
Let’s start by export a pre-trained convolutional network to TensorFlow.js(tfjs) format. I pick DenseNet trained with ImageNet datasets in this tutorial but you can go with other models like MobileNet. Try to avoid large deep convolutional networks like the ResNets and VGGs even though they might provide slightly higher accuracy but not fit edge devices like our case running on a browser.
The first step is to save the pre-trained DenseNet Keras model to a .h5 file in Python script.
Then we run the conversion script to convert the .h5 file to tfjs files optimized for browser caching. Before continuing, install the tensorflowjs conversion script python package through pip3.
pip3 install tensorflowjs
We can now generate tfjs files by running,
cd ./tfjs-densenettensorflowjs_converter --input_format keras ./model.h5 ./model
You will see a folder named model with several files inside. The model.json file defines the model structure and the path to weights files. And the pre-trained model is ready to be served for the web app. For example, you can rename the model folder to serveDenseNet and copy to your web app served folder, then the model can be loaded like this.
window.location.origin is the web app URL or if you serve it locally on port 1234 it will be localhost:1234. The await statement simply allows the web app to load the model in the background without freezing the main user interface.
Also, realize that since the model we loaded is an image classification model with layers at the top we don’t need, we only want the feature extraction part of the model. The solution is to locate the topmost convolutional layer and truncate the model shown in the previous code snippet.
Generate training data from a webcam
To prepare the training data for the regression model, we will grab some images from a webcam and extract their feature with the pre-trained model within the web app. To simplify the user interface for acquiring training data, we only label images with one of the three values [-1, 0, 1].
For each image acquired through the webcam, it will be feed into the pre-trained DenseNet to extract features and saved as a training sample. After passing an image through the feature extractor model, a 224 by 224 color image will have its dimension reduced to image features tensor of shape [7 ,7, 1024]. The shape depends on the pre-trained model you choose and can be obtained by calling outputShape on the layer we picked in the previous section like this.
modelLayerShape = layer.outputShape.slice(1)
The reason to use the extracted image features as training data instead of the raw images is two folds. First it saves memory to store the training data, second, it reduces the training time by not running the feature extraction model.
This following snippet shows how an image is captured by a webcam, has its features extracted and aggregated. Note that all image features are saved in the form of tensors which means if your model runs with the browser’s WebGL backend, there is a limit on how many training samples it can safely contain at a time in GPU memory. So don’t expect training your model with thousands or even hundreds of image samples depends on your hardware.
Build and train the neural network
Build and train your neural network without uploading to any cloud service protects your privacy since the data never leave your device and watch it happen on your browser make it even cooler.
The regression model takes image features as input flatten it to a vector then followed by two fully connected layers and generates one floating number to control the game. The last fully connected layer takes no activation function since we want it to produce real numbers between -1 to 1. The loss function we pick is mean squared error during the training to minimize the loss. More on the choice, read my post — How to choose Last-layer activation and loss function.
The following code will build, compile and fit the model. Looks quite similar to Keras’ workflow right?
Turning webcam into a Pong controller
As you might expect to predict with an image resembles Keras syntax as well. The image is first converted to image features then pass on to the trained regression neural network which outputs a controller value between -1 to 1.
Once you got the model trained and the game is up and running, the predicted value passes down to control the speed of player’s paddle going left or right with variable speed through this call pong.updatePlayerSpeed(value). You can start and stop the game by calling,
- pong.startGameplay() it happens if you press the Play button
- pong.stopGameplay() will be called if you click the Reset button.
The aggressiveness of the paddle movement can be adjusted by calling pong.updateMultiplier(multiplier), the current multiplier value is set to 12 in the Pong's class constructor.
Conclusion and further thought
In this tutorial, you have learned how to train a neural network on a browser with TensorFlow.js and turn your webcam into a Pong controller recognizing your moves. Feel free to check out my source code and experiment with it, modified it and see how it turns out, like the activation functions, loss function and swap to another pre-trained model, etc. The beauty of training a neural network on a browser with instant feedback enables us trying new ideas and get results faster for our prototypes also make it more readily accessible to the general public. Check out the full source on my GitHub repo webcam-pong.
Tony607/webcam-pong_webcam-pong - Transfer Learning to play Pong via the Webcam_github.com
Originally published at www.dlology.com.
