Skip to content



Compute categorical crossentropy.

Note that if any of the y_pred values are exactly 0, this will result in a NaN output. If from_logits is False, then each entry of y_pred should sum to 1. If they don't sum to 1 then tf and torch backends will result in different numerical values.

This method can be used with TensorFlow tensors:

true = tf.constant([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
pred = tf.constant([[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.1, 0.2, 0.7]])
weights = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(tf.constant([1, 2]), tf.constant([2.0, 3.0])), default_value=1.0)
b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true)  # 0.228
b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False)  # [0.223, 0.105, 0.356]
b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False, class_weights=weights)
# [0.446, 0.105, 1.068]

This method can be used with PyTorch tensors:

true = torch.tensor([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
pred = torch.tensor([[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.1, 0.2, 0.7]])
weights = {1: 2.0, 2: 3.0}
b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true)  # 0.228
b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False)  # [0.223, 0.105, 0.356]
b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False, class_weights=weights)
# [0.446, 0.105, 1.068]


Name Type Description Default
y_pred Tensor

Prediction with a shape like (Batch, ..., C) for tensorflow and (Batch, C, ...) for PyTorch. dtype: float32 or float16.

y_true Tensor

Ground truth class labels with a shape like y_pred. dtype: int or float32 or float16.

from_logits bool

Whether y_pred is from logits. If True, a softmax will be applied to the prediction.

average_loss bool

Whether to average the element-wise loss.

class_weights Optional[Weight_Dict]

Mapping of class indices to a weight for weighting the loss function. Useful when you need to pay more attention to samples from an under-represented class.



Type Description

The categorical crossentropy between y_pred and y_true. A scalar if average_loss is True, else a


tensor with the shape (Batch).


Type Description

If y_true or y_pred are unacceptable data types.

Source code in fastestimator/fastestimator/backend/
def categorical_crossentropy(y_pred: Tensor,
                             y_true: Tensor,
                             from_logits: bool = False,
                             average_loss: bool = True,
                             class_weights: Optional[Weight_Dict] = None) -> Tensor:
    """Compute categorical crossentropy.

    Note that if any of the `y_pred` values are exactly 0, this will result in a NaN output. If `from_logits` is
    False, then each entry of `y_pred` should sum to 1. If they don't sum to 1 then tf and torch backends will
    result in different numerical values.

    This method can be used with TensorFlow tensors:
    true = tf.constant([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
    pred = tf.constant([[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.1, 0.2, 0.7]])
    weights = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(tf.constant([1, 2]), tf.constant([2.0, 3.0])), default_value=1.0)
    b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true)  # 0.228
    b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False)  # [0.223, 0.105, 0.356]
    b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False, class_weights=weights)
    # [0.446, 0.105, 1.068]

    This method can be used with PyTorch tensors:
    true = torch.tensor([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
    pred = torch.tensor([[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.1, 0.2, 0.7]])
    weights = {1: 2.0, 2: 3.0}
    b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true)  # 0.228
    b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False)  # [0.223, 0.105, 0.356]
    b = fe.backend.categorical_crossentropy(y_pred=pred, y_true=true, average_loss=False, class_weights=weights)
    # [0.446, 0.105, 1.068]

        y_pred: Prediction with a shape like (Batch, ..., C) for tensorflow and (Batch, C, ...) for PyTorch. dtype:
            float32 or float16.
        y_true: Ground truth class labels with a shape like `y_pred`. dtype: int or float32 or float16.
        from_logits: Whether y_pred is from logits. If True, a softmax will be applied to the prediction.
        average_loss: Whether to average the element-wise loss.
        class_weights: Mapping of class indices to a weight for weighting the loss function. Useful when you need to pay
            more attention to samples from an under-represented class.

        The categorical crossentropy between `y_pred` and `y_true`. A scalar if `average_loss` is True, else a
        tensor with the shape (Batch).

        AssertionError: If `y_true` or `y_pred` are unacceptable data types.
    assert isinstance(y_pred, (tf.Tensor, torch.Tensor)), "only support tf.Tensor or torch.Tensor as y_pred"
    assert isinstance(y_true, (tf.Tensor, torch.Tensor)), "only support tf.Tensor or torch.Tensor as y_true"
    if tf.is_tensor(y_pred):
        ce = tf.losses.categorical_crossentropy(y_pred=y_pred, y_true=y_true, from_logits=from_logits)
        if class_weights is not None:
            sample_weights = class_weights.lookup(tf.math.argmax(y_true, axis=-1, output_type=class_weights.key_dtype))
            ce = ce * sample_weights
        y_true =
        ce = _categorical_crossentropy_torch(y_pred=y_pred, y_true=y_true, from_logits=from_logits)
        if class_weights is not None:
            y_class = torch.argmax(y_true, dim=1)
            sample_weights = torch.ones_like(y_class, dtype=torch.float)
            for key in class_weights.keys():
                sample_weights[y_class == key] = class_weights[key]
            ce = ce * sample_weights.reshape(ce.shape)

    if average_loss:
        ce = reduce_mean(ce)
    return ce