How to Use Class and Sample Weights in Keras Training

Written by tensorflow | Published 2025/09/29
Tech Story Tags: keras | keras-training | keras-class-weights | keras-sample-weights | tensorflow-sample_weight | keras-model.fit | deep-learning-data-imbalance | tensorflow-class-imbalance

TLDRWhen training deep learning models, imbalanced datasets can bias predictions. Keras lets you adjust importance using class weights (per-class importance) or sample weights (per-instance control). This prevents rare classes from being ignored and improves model fairness without resampling. The guide walks through practical NumPy and tf.data examples, showing how to apply both techniques in single- and multi-output models.via the TL;DR App

Content Overview

  • Using sample weighting and class weighting
  • Class weights
  • Sample weights
  • Passing data to multi-input, multi-output models
  • Using callbacks
  • Many built-in callbacks are available
  • Writing your own callback’
  • Checkpointing models
  • Using learning rate schedules
  • Passing a schedule to an optimizer
  • Using callbacks to implement a dynamic learning rate schedule
  • Visualizing loss and metrics during training
  • Using the TensorBoard callbacks

Using sample weighting and class weighting

With the default settings, the weight of a sample is decided by its frequency in the dataset. There are two methods to weight the data, independent of sample frequency:

  • Class weights
  • Sample weights

Class weights

This is set by passing a dictionary to the class_weight argument to Model.fit(). This dictionary maps class indices to the weight that should be used for samples belonging to this class.

This can be used to balance classes without resampling, or to train a model that gives more importance to a particular class.

For instance, if class "0" is half as represented as class "1" in your data, you could use Model.fit(..., class_weight={0: 1., 1: 0.5}).

Here's a NumPy example where we use class weights or sample weights to give more importance to the correct classification of class #5 (which is the digit "5" in the MNIST dataset).

import numpy as np

class_weight = {
    0: 1.0,
    1: 1.0,
    2: 1.0,
    3: 1.0,
    4: 1.0,
    # Set weight "2" for class "5",
    # making this class 2x more important
    5: 2.0,
    6: 1.0,
    7: 1.0,
    8: 1.0,
    9: 1.0,
}

print("Fit with class weight")
model = get_compiled_model()
model.fit(x_train, y_train, class_weight=class_weight, batch_size=64, epochs=1)

Fit with class weight
782/782 [==============================] - 3s 2ms/step - loss: 0.3721 - sparse_categorical_accuracy: 0.9007
<keras.src.callbacks.History at 0x7fd5a032de80>

Sample weights

For fine grained control, or if you are not building a classifier, you can use "sample weights".

  • When training from NumPy data: Pass the sample_weight argument to Model.fit().
  • When training from tf.data or any other sort of iterator: Yield (input_batch, label_batch, sample_weight_batch) tuples.

A "sample weights" array is an array of numbers that specify how much weight each sample in a batch should have in computing the total loss. It is commonly used in imbalanced classification problems (the idea being to give more weight to rarely-seen classes).

When the weights used are ones and zeros, the array can be used as a mask for the loss function (entirely discarding the contribution of certain samples to the total loss).

sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0

print("Fit with sample weight")
model = get_compiled_model()
model.fit(x_train, y_train, sample_weight=sample_weight, batch_size=64, epochs=1)

Fit with sample weight
782/782 [==============================] - 2s 2ms/step - loss: 0.3753 - sparse_categorical_accuracy: 0.9019
<keras.src.callbacks.History at 0x7fd5a01eafa0>

Here's a matching Dataset example:

sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0

# Create a Dataset that includes sample weights
# (3rd element in the return tuple).
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))

# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model = get_compiled_model()
model.fit(train_dataset, epochs=1)

782/782 [==============================] - 2s 2ms/step - loss: 0.3794 - sparse_categorical_accuracy: 0.9023
<keras.src.callbacks.History at 0x7fd5a00a0f40>

Passing data to multi-input, multi-output models

In the previous examples, we were considering a model with a single input (a tensor of shape (764,)) and a single output (a prediction tensor of shape (10,)). But what about models that have multiple inputs or outputs?

Consider the following model, which has an image input of shape (32, 32, 3) (that's (height, width, channels)) and a time series input of shape (None, 10) (that's (timesteps, features)). Our model will have two outputs computed from the combination of these inputs: a "score" (of shape (1,)) and a probability distribution over five classes (of shape (5,)).

image_input = keras.Input(shape=(32, 32, 3), name="img_input")
timeseries_input = keras.Input(shape=(None, 10), name="ts_input")

x1 = layers.Conv2D(3, 3)(image_input)
x1 = layers.GlobalMaxPooling2D()(x1)

x2 = layers.Conv1D(3, 3)(timeseries_input)
x2 = layers.GlobalMaxPooling1D()(x2)

x = layers.concatenate([x1, x2])

score_output = layers.Dense(1, name="score_output")(x)
class_output = layers.Dense(5, name="class_output")(x)

model = keras.Model(
    inputs=[image_input, timeseries_input], outputs=[score_output, class_output]
)

Let's plot this model, so you can clearly see what we're doing here (note that the shapes shown in the plot are batch shapes, rather than per-sample shapes).

keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)

At compilation time, we can specify different losses to different outputs, by passing the loss functions as a list:

model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss=[keras.losses.MeanSquaredError(), keras.losses.CategoricalCrossentropy()],
)

Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.


Written by tensorflow | TensorFlow is an open-source machine learning framework developed by Google for numerical computation and building mach
Published by HackerNoon on 2025/09/29