Skip to content

mix_loss

MixLoss

Bases: LossOp

Loss class to compute mixiup and cutmix losses.

This class should be used in conjunction with MixUpBatch and CutMixBatch to perform mix-up training, which helps to reduce over-fitting, stabilize GAN training, and harden against adversarial attacks. See https://arxiv.org/abs/1710.09412 for details.

Parameters:

Name Type Description Default
loss LossOp

A loss object which we use to calculate the underlying loss of MixLoss. This should be an object of type fe.op.tensorop.loss.loss.LossOp.

required
lam str

The key of the lambda value generated by MixUpBatch or CutMixBatch.

required
average_loss bool

Whether the final loss should be averaged or not.

True

Raises:

Type Description
ValueError

If the provided loss has multiple outputs.

Source code in fastestimator\fastestimator\op\tensorop\loss\mix_loss.py
class MixLoss(LossOp):
    """Loss class to compute mixiup and cutmix losses.

    This class should be used in conjunction with MixUpBatch and CutMixBatch to perform mix-up training, which helps to
    reduce over-fitting, stabilize GAN training, and harden against adversarial attacks. See
    https://arxiv.org/abs/1710.09412 for details.

    Args:
        loss: A loss object which we use to calculate the underlying loss of MixLoss. This should be an object of type
            fe.op.tensorop.loss.loss.LossOp.
        lam: The key of the lambda value generated by MixUpBatch or CutMixBatch.
        average_loss: Whether the final loss should be averaged or not.

    Raises:
        ValueError: If the provided `loss` has multiple outputs.
    """
    def __init__(self, loss: LossOp, lam: str, average_loss: bool = True):
        self.loss = loss
        self.loss.average_loss = False
        if len(loss.outputs) != 1:
            raise ValueError("MixLoss only supports lossOps which have a single output.")
        super().__init__(inputs=[lam] + loss.inputs, outputs=loss.outputs, mode=loss.mode, average_loss=average_loss)
        self.out_list = False

    @property
    def pred_key_idx(self) -> int:
        return self.loss.pred_key_idx + 1

    @property
    def true_key_idx(self) -> int:
        return self.loss.true_key_idx + 1

    def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Tensor:
        lam, *args = data
        loss1 = self.loss.forward(args, state)

        args[self.loss.true_key_idx] = roll(args[self.loss.true_key_idx], shift=1, axis=0)
        loss2 = self.loss.forward(args, state)

        loss = lam * loss1 + (1.0 - lam) * loss2

        if self.average_loss:
            loss = fe.backend.reduce_mean(loss)

        return loss