Deploying machine learning models outside of a Python environment used to be difficult. When the target platform is the browser, the defacto standard for serving predictions has been an API call to a server-side inference engine. For many reasons, server-side inference APIs are a non-optimal solution and machine learning models are more often being deployed natively. TensorFlow has done a good job at supporting this movement by providing cross-platform APIs, however many of us do not want to be married to a single ecosystem.
In comes the Open Neural Network Exchange (ONNX) project which, since being picked up by Microsoft, has been seeing massive development efforts and is approaching a stable state. It's now easier than ever to deploy machine-learning models; trained using your machine-learning framework of choice, on your platform of choice, with hardware acceleration out of the box.
In April this year, onnxruntime-web
was introduced (see this Pull Request). onnxruntime-web
uses WebAssembly to compile the onnxruntime
inference engine to wasm
format - it's about time WebAssembly started to flex its muscles. Especially when paired with WebGL, we suddenly have GPU-powered machine learning in the browser, pretty cool.
In this tutorial we will dive into onnxruntime-web
by deploying a pre-trained PyTorch model to the browser. We will be using AlexNet as our deployment target. AlexNet has been trained as an image classifier on the ImageNet dataset, so we will be building an image classifier - nothing better than re-inventing the wheel. At the end of this tutorial, we will have built a bundled web app that can be run as a stand alone static web page, or integrated into your JavaScript framework of choice.
Jump to code → onnxruntime-web-tutorial
You will need a trained machine-learning model exported as an ONNX binary protobuf file. There's many ways to achieve this using a number of different deep-learning frameworks. For the sake of this tutorial, I will be using the exported model from the AlexNet example in the PyTorch documentation, the python code snippet below will help you generate your own model. You can also follow the documentation to export your own PyTorch model. If you're coming from Tensorflow, this tutorial will help you with exporting your model to ONNX. Lastly, ONNX doesn't just pride itself on cross-platform deployment, but also in allowing exports from all major deep-learning frameworks. Those of you using another deep learning framework should be able to find support for exporting to ONNX in the docs of your framework.
import torch
import torchvision
dummy_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.alexnet(pretrained=True)
input_names = ["input1"]
output_names = ["output1"]
torch.onnx.export(
model,
dummy_input,
"alexnet.onnx",
verbose=True,
input_names=input_names,
output_names=output_names
)
Running this script creates a file, alexnet.onnx
, a binary protobuf file which contains both the network structure and parameters of the model you exported (in this case, AlexNet).
ONNX Runtime Web is a JavaScript library for running ONNX models on the browser and on Node.js. ONNX Runtime Web has adopted WebAssembly and WebGL technologies for providing an optimized ONNX model inference runtime for both CPUs and GPUs.
The official package is hosted on npm under the name onnxruntime-web
. When using a bundler or working server-side, this package can be installed using npm install
. However, it's also possible to deliver the code via a CDN using a script tag. The bundling process is a bit involved so we will start with the script tag approach and come back to using the npm package later.
Let's start with the core application logic: model inference. onnxruntime
exposes a runtime object called an InferenceSession
with a method .run()
which is used to initiate the forward pass with the desired inputs. Both the InferenceSessesion
constructor and the accompanying .run()
method return a Promise
so we will run the entire process inside an async
context. Before implementing any browser elements, we will check that our model runs with a dummy input tensor, remembering the input and output names and sizes that we defined earlier when exporting the model.
async function run() {
try {
// create a new session and load the AlexNet model.
const session = await ort.InferenceSession.create('./alexnet.onnx');
// prepare dummy input data
const dims = [1, 3, 224, 224];
const size = dims[0] * dims[1] * dims[2] * dims[3];
const inputData = Float32Array.from({ length: size }, () => Math.random());
// prepare feeds. use model input names as keys.
const feeds = { input1: new ort.Tensor('float32', inputData, dims) };
// feed inputs and run
const results = await session.run(feeds);
console.log(results.output1.data);
} catch (e) {
console.log(e);
}
}
run();
We then implement a simple HTML template, index.html
, which should load both the pre-compiled onnxruntime-web
package and main.js
, containing our code.
<!DOCTYPE html>
<html>
<header>
<title>ONNX Runtime Web - Tutorial</title>
</header>
<body>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js">
</script>
<script src="main.js"></script>
</body>
</html>
To run this, we can use light-server
. If you haven't started an npm
project by now, please do so by running npm init
in your current working directory. Once you've completed the setup, install live-server (npm install light-server
) and serve the static HTML page using npx light-server -s . -p 8080
.
You’re now running a machine learning model natively in the browser! To check that everything is running fine go to your web console and make sure that the output tensor is logged (AlexNet is bulky so it's normal that inference takes a few seconds).
Next we will use webpack
to bundle our dependencies as would be the case if we want to deploy the model in a Javascript app powered by frameworks like React or Vue. Usually bundling is a relatively simple procedure, however onnxruntime-web
requires a slightly more involved webpack
configuration - this is because WebAssembly is used to provide the natively assembled runtime.
Browser support, the classic pitfall, especially when working with cutting-edge web technology. If your intended users are not using one of the four major browsers (Chrome, Edge, Firefox, Safari) you might want to hold off on integrating WebAssembly components. More information on the WebAssembly support and roadmap can be found here.
The following steps are based on the examples provided by the official ONNX documentation. We’re assuming you've already started an npm project.
npm install onnxruntime-web && npm install -D webpack webpack-cli copy-webpack
onnxruntime-web
module via a CDN, we should update main.js
to require
the package at the top of the script.const ort = require('onnxruntime-web');
webpack.config.js
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
const path = require('path');
const CopyPlugin = require("copy-webpack-plugin");
module.exports = () => {
return {
target: ['web'],
entry: path.resolve(__dirname, 'main.js'),
output: {
path: path.resolve(__dirname, 'dist'),
filename: 'bundle.min.js',
library: {
type: 'umd'
}
},
plugins: [new CopyPlugin({
// Use copy plugin to copy *.wasm to output folder.
patterns: [{ from: 'node_modules/onnxruntime-web/dist/*.wasm', to: '[name][ext]' }]
})],
mode: 'production'
}
};
npx webpack
to compile the bundle.5. Finally, before reloading the server, we need to update index.html
.
ort.min.js
script tag to stop loading the compiled package from the CDN.bundle.min.js
(which contains all our dependencies bundled and minified by webpack
) instead of main.js
index.html
should now look something like this.<!DOCTYPE html>
<html>
<header>
<title>ONNX Runtime Web Tutorial</title>
</header>
<body>
<script src="bundle.min.js.js"></script>
</body>
</html>
To make building and launching the live server easier, you could define build
and a serve
scripts in package.json
"scripts": {
"build": "npx webpack",
"serve": "npm run build && npx light-server -s . -p 8080"
}
Let's put this model to work and implement the image classification pipeline.
We will need some utility functions to load, resize, and display the image - the canvas
object is perfect for this. In addition, image classification systems typically have lots of magic built into the pre-processing pipeline, this is quite trivial to implement in Python using frameworks like numpy
, unfortunately this is not the case with JavaScript. It follows that we will have to implement our pre-processing from scratch to transform the image data into the correct input format.
We will need some HTML elements to interact with and display the data.
<label for="fileIn"><h2>What am I?</h2></label>
<input type="file" id="file-in" name="file-in">
<img id="input-image" class="input-image"></img>
<img id="scaled-image" class="scaled-image"></img>
<h3 id="target"></h3>
We want to load an image from file and display it. Back in main.js
, we will get the file input element from the DOM and use FileReader
to read the data into memory. Following this, the image data will be passed to handleImage
which will draw the image using the 2D canvas
context.
const canvas = document.createElement("canvas"),
ctx = canvas.getContext("2d");
document.getElementById("file-in").onchange = function (evt) {
let target = evt.target || window.event.src,
files = target.files;
if (FileReader && files && files.length) {
var fileReader = new FileReader();
fileReader.onload = () => onLoadImage(fileReader);
fileReader.readAsDataURL(files[0]);
}
}
function onLoadImage(fileReader) {
var img = document.getElementById("input-image");
img.onload = () => handleImage(img);
img.src = fileReader.result;
}
function handleImage(img) {
ctx.drawImage(img, 0, 0)
}
Now that we can load and display an image, we want to move to extracting and processing the data. Remember that our model takes in a matrix of shape [1, 3, 224, 224]
, this means we will have to resize the image to support any input image and perhaps also transpose the dimensions depending on how we extract the image data.
To resize and extract image data, we will use the canvas
context again. Let's define a function processImage
that does this. processImage
has the necessary elements in scope to immediately draw the scaled image so we will also do that here.
function processImage(img, width) {
const canvas = document.createElement("canvas"),
ctx = canvas.getContext("2d")
// resize image
canvas.width = width;
canvas.height = canvas.width * (img.height / img.width);
// draw scaled image
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
document.getElementById("scaled-image").src = canvas.toDataURL();
// return data
return ctx.getImageData(0, 0, width, width).data;
}
We can now add a line to the function handleImage
which calls processImage
.
const resizedImageData = processImage(img, targetWidth);
Finally, let's implement a function called imageDataToTensor
which applies the transforms needed to get the image data ready to be used as input to the model. imageDataToTensor
should apply three transforms:
Filter out the alpha channel, our input tensor should contain 3 channels corresponding to the RGB channels.
ctx.getImageData
returns data in the shape [224, 224, 3]
so we need to transpose the data to the shape [3, 224, 224]
ctx.getImageData
returns a UInt8ClampedArray
with int
values ranging 0 to 255, we need to convert the values to float32
and store them in a Float32Array
to construct our tensor input.
function imageDataToTensor(data, dims) {
// 1a. Extract the R, G, and B channels from the data
const [R, G, B] = [[], [], []]
for (let i = 0; i < data.length; i += 4) {
R.push(data[i]);
G.push(data[i + 1]);
B.push(data[i + 2]);
// 2. skip data[i + 3] thus filtering out the alpha channel
}
// 1b. concatenate RGB ~= transpose [224, 224, 3] -> [3, 224, 224]
const transposedData = R.concat(G).concat(B);
// 3. convert to float32
let i, l = transposedData.length; // length, we need this for the loop
const float32Data = new Float32Array(3 * 224 * 224); // create the Float32Array for output
for (i = 0; i < l; i++) {
float32Data[i] = transposedData[i] / 255.0; // convert to float
}
const inputTensor = new ort.Tensor("float32", float32Data, dims);
return inputTensor;
}
Almost there, let’s wrap up some loose ends to get the full inference pipeline up and running.
handleImageData
.function handleImage(img, targetWidth) {
ctx.drawImage(img, 0, 0);
const resizedImageData = processImage(img, targetWidth);
const inputTensor = imageDataToTensor(resizedImageData, DIMS);
run(inputTensor);
}
argMax
function.function argMax(arr) {
let max = arr[0];
let maxIndex = 0;
for (var i = 1; i < arr.length; i++) {
if (arr[i] > max) {
maxIndex = i;
max = arr[i];
}
}
return [max, maxIndex];
}
run()
needs to be re-factored to accept a tensor input. We also need to use the max index to actually retrieve the results from a list of ImageNet classes. I've pre-converted this list to JSON and we will load it into our script using require
- you can find the JSON file in the code repository linked at the start and end of the tutorial.const classes = require("./imagenet_classes.json").data;
async function run(inputTensor) {
try {
const session = await ort.InferenceSession.create('./alexnet.onnx');
const feeds = { input1: inputTensor };
const results = await session.run(feeds);
const [maxValue, maxIndex] = argMax(results.output1.data);
target.innerHTML = `${classes[maxIndex]}`;
} catch (e) {
console.error(e); // non-fatal error handling
}
}
That’s it! All that’s left is to re-build our bundle, serve the app, and start classifying some images.
As you test the app, you will notice that prediction quality is not as good as it could be. This is primarily because the current image processing pipeline is still rather rudimentary and can be improved in a number of ways, for example we could implement improved resizing, center-cropping, and/or normalization. Maybe food for a next tutorial, or I’ll just leave it up to you to explore!
That’s it, we’ve built a web app with a machine-learning model running natively in the browser! You can find the full code (including styles and layout) in this code repository on GitHub. I appreciate any and all feedback so feel free to share any Issues or Stars.
Thank you for reading!
Also published on: https://rekoil.io/blog/onnxruntime-web-tutorial