Bases: LossOp
Loss class to compute mixiup and cutmix losses.
This class should be used in conjunction with MixUpBatch and CutMixBatch to perform mix-up training, which helps to
reduce over-fitting, stabilize GAN training, and harden against adversarial attacks. See
https://arxiv.org/abs/1710.09412 for details.
Parameters:
Name |
Type |
Description |
Default |
loss |
LossOp
|
A loss object which we use to calculate the underlying loss of MixLoss. This should be an object of type
fe.op.tensorop.loss.loss.LossOp. |
required
|
lam |
str
|
The key of the lambda value generated by MixUpBatch or CutMixBatch. |
required
|
average_loss |
bool
|
Whether the final loss should be averaged or not. |
True
|
Raises:
Type |
Description |
ValueError
|
If the provided loss has multiple outputs. |
Source code in fastestimator\fastestimator\op\tensorop\loss\mix_loss.py
| class MixLoss(LossOp):
"""Loss class to compute mixiup and cutmix losses.
This class should be used in conjunction with MixUpBatch and CutMixBatch to perform mix-up training, which helps to
reduce over-fitting, stabilize GAN training, and harden against adversarial attacks. See
https://arxiv.org/abs/1710.09412 for details.
Args:
loss: A loss object which we use to calculate the underlying loss of MixLoss. This should be an object of type
fe.op.tensorop.loss.loss.LossOp.
lam: The key of the lambda value generated by MixUpBatch or CutMixBatch.
average_loss: Whether the final loss should be averaged or not.
Raises:
ValueError: If the provided `loss` has multiple outputs.
"""
def __init__(self, loss: LossOp, lam: str, average_loss: bool = True):
self.loss = loss
self.loss.average_loss = False
if len(loss.outputs) != 1:
raise ValueError("MixLoss only supports lossOps which have a single output.")
super().__init__(inputs=[lam] + loss.inputs, outputs=loss.outputs, mode=loss.mode, average_loss=average_loss)
self.out_list = False
@property
def pred_key_idx(self) -> int:
return self.loss.pred_key_idx + 1
@property
def true_key_idx(self) -> int:
return self.loss.true_key_idx + 1
def forward(self, data: List[Tensor], state: Dict[str, Any]) -> Tensor:
lam, *args = data
loss1 = self.loss.forward(args, state)
args[self.loss.true_key_idx] = roll(args[self.loss.true_key_idx], shift=1, axis=0)
loss2 = self.loss.forward(args, state)
loss = lam * loss1 + (1.0 - lam) * loss2
if self.average_loss:
loss = fe.backend.reduce_mean(loss)
return loss
|