Skip to content

update

UpdateOp

Bases: TensorOp

This class performs updates to a model's weights based on the loss.

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

Model instance compiled by fe.build.

required
loss_name str

The name of loss.

required
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'train'
defer bool

Whether to defer the actual application of the update until the end of the step. This can be necessary in PyTorch when trying to update multiple models which depend on one another (ex. certain GANs). By default, all UpdateOps which appear contiguously as the last ops of a Network will be deferred. We hope that you will never need to worry about this flag, but it's here for you if you need it.

False
Source code in fastestimator\fastestimator\op\tensorop\model\update.py
@traceable()
class UpdateOp(TensorOp):
    """This class performs updates to a model's weights based on the loss.

    Args:
        model: Model instance compiled by fe.build.
        loss_name: The name of loss.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        defer: Whether to defer the actual application of the update until the end of the step. This can be necessary
            in PyTorch when trying to update multiple models which depend on one another (ex. certain GANs). By default,
            all UpdateOps which appear contiguously as the last ops of a Network will be deferred. We hope that you will
            never need to worry about this flag, but it's here for you if you need it.
    """
    def __init__(self,
                 model: Union[tf.keras.Model, torch.nn.Module],
                 loss_name: str,
                 mode: Union[None, str, Iterable[str]] = "train",
                 defer: bool = False):
        super().__init__(inputs=loss_name, outputs=None, mode=mode)
        self.model = model
        self.retain_graph = False
        self.weight_decay = isinstance(self.model, tf.keras.Model) and self.model.losses
        self.defer = defer
        if not hasattr(self.model, "loss_name"):
            self.model.loss_name = {loss_name}
        else:
            self.model.loss_name.add(loss_name)

    def get_fe_models(self) -> Set[Model]:
        return {self.model}

    def get_fe_loss_keys(self) -> Set[str]:
        return set(self.inputs)

    def fe_retain_graph(self, retain: Optional[bool] = None) -> Optional[bool]:
        if retain is not None:
            self.retain_graph = retain
        return self.retain_graph

    def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> None:
        if not state["warmup"]:
            if self.weight_decay:
                data = data + tf.reduce_sum(self.model.losses)
            update_model(self.model,
                         data,
                         tape=state['tape'],
                         retain_graph=self.retain_graph,
                         scaler=state["scaler"],
                         defer=self.defer,
                         deferred=state["deferred"])