update
UpdateOp
¶
Bases: TensorOp
This class performs updates to a model's weights based on the loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[tf.keras.Model, torch.nn.Module]
|
Model instance compiled by fe.build. |
required |
loss_name |
str
|
The name of loss. |
required |
mode |
Union[None, str, Iterable[str]]
|
What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train". |
'train'
|
defer |
bool
|
Whether to defer the actual application of the update until the end of the step. This can be necessary in PyTorch when trying to update multiple models which depend on one another (ex. certain GANs). By default, all UpdateOps which appear contiguously as the last ops of a Network will be deferred. We hope that you will never need to worry about this flag, but it's here for you if you need it. |
False
|