Subscribe to Hacker Noon's best tech stories, delivered at noon
Visit Noonification https://noonification.compromoted
pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-linux_x86_64.whl jax
pip install -q git+https://github.com/google/flax.git@dev-setup
pip install tensorflow
pip install tensorflow_datasets
unzip simpsons_faces.zip
from jax.lib import xla_bridge
import jax
import flax
import numpy as onp
import jax.numpy as jnp
import csv
import tensorflow as tf
import tensorflow_datasets as tfds
print(xla_bridge.get_backend().platform)
def train():
train_ds = create_dataset(tf.estimator.ModeKeys.TRAIN)
test_ds = create_dataset(tf.estimator.ModeKeys.EVAL)
test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
#test_ds is one giant batch
test_ds = test_ds.batch(1000)
#test ds is a feature dictonary!
test_ds = tf.compat.v1.data.experimental.get_single_element(test_ds)
test_ds = tfds.as_numpy(test_ds)
test_ds = {'image': test_ds[0].astype(jnp.float32), 'label': test_ds[1].astype(jnp.int32)}
_, initial_params = CNN.init_by_shape(jax.random.PRNGKey(0), [((1, 160, 120, 3), jnp.float32)])
model = flax.nn.Model(CNN, initial_params)
optimizer = flax.optim.Momentum(learning_rate=0.01, beta=0.9, weight_decay=0.0005).create(model)
for epoch in range(50):
for batch in tfds.as_numpy(train_ds):
optimizer = train_step(optimizer, batch)
metrics = eval(optimizer.target, test_ds)
print('eval epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100))
class CNN(flax.nn.Module):
def apply(self, x):
x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=16, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = flax.nn.Dense(x, features=256)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=64)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=4)
x = flax.nn.softmax(x)
return x
def compute_metrics(logits, labels):
loss = jnp.mean(cross_entropy_loss(logits, labels))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return {'loss': loss, 'accuracy': accuracy}
as a function decorator for our loss function. This vectorizes our code for running on batches efficiently.
@jax.vmap
@jax.vmap
def cross_entropy_loss(logits, label):
return -jnp.log(logits[label])
work?
cross_entropy_loss
takes both arrays, logits and label, and performs our
@jax.vmap
on each pair, thus allowing the parallel calculation of a batch. The cross entropy formula for a single example is:
cross_entropy_loss
, for speeding up our function. This works very similar to Tensorflow. Please have in mind
@jax.jit
is our image data and
batch[0]
our label.
batch[1]
@jax.jit
def train_step(optimizer, batch):
def loss_fn(model):
logits = model(batch[0])
loss = jnp.mean(cross_entropy_loss(
logits, batch[1]))
return loss
grad = jax.grad(loss_fn)(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer
, and our
optimizer.target
calculates its gradient. After the calculation we apply the gradient like in Tensorflow.
jax.grad()
@jax.jit
def eval(model, eval_ds):
logits = model(eval_ds['image'])
return compute_metrics(logits, eval_ds['label'])
The current set of optimizers is unfortunately limited to ADAM and SGD with Momentum. I especially liked the very strict forward direction of how to use this framework and the high flexibility.