Skip to content

dice

Dice

Bases: Trace

Dice score for binary classification between y_true and y_predicted.

Parameters:

Name Type Description Default
true_key str

The key of the ground truth mask.

required
pred_key str

The key of the prediction values.

required
threshold float

The threshold for binarizing the prediction.

0.5
channel_average bool

Whether the average channel-wise dice loss.

False
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

('eval', 'test')
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
output_name str

What to call the output from this trace (for example in the logger output).

'Dice'
per_ds bool

Whether to automatically compute this metric individually for every ds_id it runs on, in addition to computing an aggregate across all ds_ids on which it runs. This is automatically False if output_name contains a "|" character.

True
Source code in fastestimator\fastestimator\trace\metric\dice.py
@per_ds
@traceable()
class Dice(Trace):
    """Dice score for binary classification between y_true and y_predicted.

    Args:
        true_key: The key of the ground truth mask.
        pred_key: The key of the prediction values.
        threshold: The threshold for binarizing the prediction.
        channel_average: Whether the average channel-wise dice loss.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
        output_name: What to call the output from this trace (for example in the logger output).
        per_ds: Whether to automatically compute this metric individually for every ds_id it runs on, in addition to
            computing an aggregate across all ds_ids on which it runs. This is automatically False if `output_name`
            contains a "|" character.
    """

    def __init__(self,
                 true_key: str,
                 pred_key: str,
                 threshold: float = 0.5,
                 channel_average: bool = False,
                 mode: Union[None, str, Iterable[str]] = ("eval", "test"),
                 ds_id: Union[None, str, Iterable[str]] = None,
                 output_name: str = "Dice",
                 per_ds: bool = True) -> None:
        super().__init__(inputs=(true_key, pred_key),
                         mode=mode, outputs=output_name, ds_id=ds_id)
        self.threshold = threshold
        self.smooth = 1e-8
        self.channel_average = channel_average
        self.dice = []
        self.per_ds = per_ds

    @property
    def true_key(self) -> str:
        return self.inputs[0]

    @property
    def pred_key(self) -> str:
        return self.inputs[1]

    def on_epoch_begin(self, data: Data) -> None:
        self.dice = []

    def on_batch_end(self, data: Data) -> None:
        y_true, y_pred = to_number(
            data[self.true_key]), to_number(data[self.pred_key])

        y_pred = np.where(y_pred > self.threshold, 1.0,
                          0.0).astype(y_pred.dtype)

        dice = dice_score(y_pred=y_pred, y_true=y_true,
                          channel_average=self.channel_average, epsilon=self.smooth)

        data.write_per_instance_log(self.outputs[0], dice)
        self.dice.extend(list(dice))

    def on_epoch_end(self, data: Data) -> None:
        data.write_with_log(self.outputs[0], np.mean(self.dice))