Skip to content

cross_entropy

CrossEntropy

Bases: LossOp

Calculate Element-Wise CrossEntropy (binary, categorical or sparse categorical).

Parameters:

Name Type Description Default
inputs Union[None, str, Iterable[str]]

A tuple or list like: [, ].

None
outputs Union[None, str, Iterable[str]]

String key under which to store the computed loss value.

None
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'!infer'
from_logits bool

Whether y_pred is logits (without softmax).

False
average_loss bool

Whether to average the element-wise loss after the Loss Op.

True
form Optional[str]

What form of cross entropy should be performed ('binary', 'categorical', 'sparse', or None). None will automatically infer the correct form based on tensor shape.

None
Source code in fastestimator\fastestimator\op\tensorop\loss\cross_entropy.py
@traceable()
class CrossEntropy(LossOp):
    """Calculate Element-Wise CrossEntropy (binary, categorical or sparse categorical).

    Args:
        inputs: A tuple or list like: [<y_pred>, <y_true>].
        outputs: String key under which to store the computed loss value.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        from_logits: Whether y_pred is logits (without softmax).
        average_loss: Whether to average the element-wise loss after the Loss Op.
        form: What form of cross entropy should be performed ('binary', 'categorical', 'sparse', or None). None will
            automatically infer the correct form based on tensor shape.
    """
    def __init__(self,
                 inputs: Union[None, str, Iterable[str]] = None,
                 outputs: Union[None, str, Iterable[str]] = None,
                 mode: Union[None, str, Iterable[str]] = "!infer",
                 from_logits: bool = False,
                 average_loss: bool = True,
                 form: Optional[str] = None):
        super().__init__(inputs=inputs, outputs=outputs, mode=mode, average_loss=average_loss)
        self.from_logits = from_logits
        self.form = form
        self.cross_entropy_fn = {
            "binary": binary_crossentropy,
            "categorical": categorical_crossentropy,
            "sparse": sparse_categorical_crossentropy
        }

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Tensor:
        y_pred, y_true = data
        form = self.form
        if form is None:
            if len(y_pred.shape) == 2 and y_pred.shape[-1] > 1:
                if len(y_true.shape) == 2 and y_true.shape[-1] > 1:
                    form = "categorical"
                else:
                    form = "sparse"
            else:
                form = "binary"
        loss = self.cross_entropy_fn[form](y_pred, y_true, from_logits=self.from_logits, average_loss=self.average_loss)
        return loss