This class performs updates to a model's weights based on the loss.
Parameters:
Name
Type
Description
Default
model
Union[tf.keras.Model, torch.nn.Module]
Model instance compiled by fe.build.
required
loss_name
str
The input loss key.
required
gradients
Optional[str]
An optional key containing model gradients. These will be directly applied to the model weights
during an update. If not provided, gradients will be computed based on the specified loss_name, which will
automatically handle any desired mixed-precision scaling. This argument shouldn't be used if mixed-precision
training is enabled.
None
mode
Union[None, str, Iterable[str]]
What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
like "!infer" or "!train".
'train'
ds_id
Union[None, str, Iterable[str]]
What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all
ds_ids except for a particular one, you can pass an argument like "!ds1".
None
merge_grad
int
The gradient accumulation times before model update. Ex: if merge_grad = 3, for every three Op
calls only the third one updates the model. The first two calls only accumulate its gradients. This default
value is 1 and it will update the model at every step.
1
defer
bool
Whether to defer the actual application of the update until the end of the step. This can be necessary
in PyTorch when trying to update multiple models which depend on one another (ex. certain GANs). By default,
all UpdateOps which appear contiguously as the last ops of a Network will be deferred. We hope that you will
never need to worry about this flag, but it's here for you if you need it.
False
Raise
ValueError: When model is mixed-precision and gradients is provided.
ValueError: Network framework is not one of "tf" or "torch".
ValueError: merge_grad is larger than 1 in multi-GPU configuration.
RuntimeError: If attempting to modify a PyTorch model which relied on gradients within a different PyTorch model
which has in turn already undergone a non-deferred update.
Source code in fastestimator\fastestimator\op\tensorop\model\update.py
@traceable()classUpdateOp(TensorOp):"""This class performs updates to a model's weights based on the loss. Args: model: Model instance compiled by fe.build. loss_name: The input loss key. gradients: An optional key containing model gradients. These will be directly applied to the model weights during an update. If not provided, gradients will be computed based on the specified loss_name, which will automatically handle any desired mixed-precision scaling. This argument shouldn't be used if mixed-precision training is enabled. mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train". ds_id: What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1". merge_grad: The gradient accumulation times before model update. Ex: if `merge_grad` = 3, for every three Op calls only the third one updates the model. The first two calls only accumulate its gradients. This default value is 1 and it will update the model at every step. defer: Whether to defer the actual application of the update until the end of the step. This can be necessary in PyTorch when trying to update multiple models which depend on one another (ex. certain GANs). By default, all UpdateOps which appear contiguously as the last ops of a Network will be deferred. We hope that you will never need to worry about this flag, but it's here for you if you need it. Raise: ValueError: When model is mixed-precision and `gradients` is provided. ValueError: Network framework is not one of "tf" or "torch". ValueError: `merge_grad` is larger than 1 in multi-GPU configuration. RuntimeError: If attempting to modify a PyTorch model which relied on gradients within a different PyTorch model which has in turn already undergone a non-deferred update. """def__init__(self,model:Union[tf.keras.Model,torch.nn.Module],loss_name:str,gradients:Optional[str]=None,mode:Union[None,str,Iterable[str]]="train",ds_id:Union[None,str,Iterable[str]]=None,merge_grad:int=1,defer:bool=False):self.extra_loss=isinstance(model,tf.keras.Model)andmodel.lossesifgradientsisNone:super().__init__(inputs=loss_name,outputs=None,mode=mode,ds_id=ds_id)else:ifmodel.mixed_precision:raiseValueError("Mixed precision training cannot take input gradients, because the gradients need to ""be computed in this module")ifself.extra_loss:print("FastEstimator-Warn: Extra model losses are detected and they will be ignored since the gradients"" are not computed in UpdateOp class.")super().__init__(inputs=gradients,outputs=None,mode=mode,ds_id=ds_id)iftorch.cuda.device_count()>1andmerge_grad>1:raiseValueError("Currently FastEstimator doesn't support merge_grad feature in multi-GPU configuration ""and thus 'merge_grad' cannot be larger than 1")ifnothasattr(model,"loss_name"):model.loss_name={loss_name}else:model.loss_name.add(loss_name)self.model=modelself.retain_graph=Falseself.defer=deferself.gradients=gradientsself.loss_name=loss_nameself.merge_grad=merge_gradself.framework=Nonedefbuild(self,framework:str,device:Optional[torch.device]=None)->None:ifframeworknotin["tf","torch"]:raiseValueError(f"Unrecognized framework {framework}")self.framework=frameworkifself.merge_grad>1:ifframework=="tf":self.step=tf.Variable(0,trainable=False,dtype=tf.int64)self.grad_sum=[tf.Variable(tf.zeros_like(x),trainable=False)forxinself.model.trainable_variables]else:# framework == "torch"self.step=torch.tensor(0,dtype=torch.int64).to(device)self.grad_sum=[torch.zeros_like(x).to(device)forxinself.model.parameters()ifx.requires_grad]defget_fe_models(self)->Set[Model]:return{self.model}defget_fe_loss_keys(self)->Set[str]:returnto_set(self.loss_name)deffe_retain_graph(self,retain:Optional[bool]=None)->Optional[bool]:ifretainisnotNone:self.retain_graph=retainreturnself.retain_graphdefforward(self,data:Union[Tensor,List[Tensor]],state:Dict[str,Any])->None:ifstate["warmup"]:returnifself.gradientsisNone:# data is lossloss=self._loss_preprocess(data)gradients=self._get_gradient(loss,state["tape"])else:# data is gradientsgradients=datagradients=self._gradient_postprocess(gradients)ifself.merge_grad>1:self._merge_grad_update(gradients,deferred=state["deferred"])else:update_model(model=self.model,gradients=gradients,defer=self.defer,deferred=state["deferred"])def_loss_preprocess(self,loss:Union[Tensor,List[Tensor]])->Union[Tensor,List[Tensor]]:"""Loss preprocess for multi-GPU and mixed-precision training. Args: loss: Unprocessed loss. Returns: Processed loss. """ifself.extra_loss:loss=loss+tf.reduce_sum(self.model.losses)loss=reduce_mean(loss)ifself.framework=="tf":# scale up loss for mixed precision training to avoid underflowifself.model.mixed_precision:loss=self.model.current_optimizer.get_scaled_loss(loss)# for multi-gpu training, the gradient will be combined by sum, normalize the lossstrategy=tf.distribute.get_strategy()ifisinstance(strategy,tf.distribute.MirroredStrategy):loss=loss/strategy.num_replicas_in_syncelse:# self.framework == "torch"ifself.model.current_optimizer.scalerisnotNone:# scale up loss for mixed precision training to avoid underflowloss=self.model.current_optimizer.scaler.scale(loss)returnlossdef_get_gradient(self,loss:Union[Tensor,List[Tensor]],tape:Optional[tf.GradientTape]=None)->Union[Tensor,List[Tensor]]:"""Get gradient from loss with repect to self.model. Args: loss: Input loss. tape: A TensorFlow GradientTape which was recording when the `loss` was computed (iff using TensorFlow). Returns: Computed gradients. """ifself.framework=="tf":gradients=get_gradient(loss,self.model.trainable_variables,tape=tape)else:# self.framework == "torch"trainable_params=[pforpinself.model.parameters()ifp.requires_grad]try:gradients=get_gradient(loss,trainable_params,retain_graph=self.retain_graph)exceptRuntimeErroraserr:iferr.argsandisinstance(err.args[0],str)anderr.args[0].startswith('one of the variables needed for gradient computation has been modified by an inplace operation'):raiseRuntimeError("When computing gradients for '{}', some variables it relied on during the forward pass had"" been updated. Consider setting defer=True in earlier UpdateOps related to models which ""interact with this one.".format(self.model.model_name))raiseerrreturngradientsdef_gradient_postprocess(self,gradients:Union[Tensor,List[Tensor]])->Union[Tensor,List[Tensor]]:"""Gradient postprocess for multi-GPU and mixed-precision training. Args: gradients: Unprocessed gradients. Returns: Processed gradients. """ifself.framework=="tf":ifself.gradientsisnotNone:# when user provide gradientsstrategy=tf.distribute.get_strategy()# for multi-gpu training, the gradient will be combined by sum, normalize the gradientifisinstance(strategy,tf.distribute.MirroredStrategy):gradients=[gs/strategy.num_replicas_in_syncforgsingradients]ifself.model.mixed_precision:# scale down gradient to balance scale-up lossgradients=self.model.current_optimizer.get_unscaled_gradients(gradients)returngradientsdef_merge_grad_update(self,gradients:Union[Tensor,List[Tensor]],deferred:Optional[Dict[str,List[Callable[[],None]]]]=None)->None:"""Accumulate gradients and update the model at certain frequency of invocation. Args: gradients: Input gradients. deferred: A dictionary in which model update functions are stored. """# add current gradient to the cumulative gradientforgs,ginzip(self.grad_sum,gradients):self._assign_add(gs,g)self._assign_add(self.step,1)ifself.step%self.merge_grad==0:average_grad=[gs/self.merge_gradforgsinself.grad_sum]update_model(model=self.model,gradients=average_grad,defer=self.defer,deferred=deferred)forgsinself.grad_sum:self._assign_add(gs,-gs)# zero the gradient in placedef_assign_add(self,a:Tensor,b:Tensor)->None:"""In-place addition for both Tensorflow and PyTorch. `a` = `a` + `b` Args: a: A tensor where in-place addition happens. b: Amount to be added. """ifself.framework=="tf":a.assign_add(b)else:# self.framework == "torch"a+=b