How many times to repeat the op. This can also be a function return, in which case the function input
names will be matched to keys in the data dictionary, and the op will be repeated until the function
evaluates to False. The function evaluation will happen at the end of a forward call, so the op will
always be evaluated at least once.
1
Raises:
Type
Description
ValueError
If repeat or op are invalid.
Source code in fastestimator\fastestimator\op\tensorop\meta\repeat.py
@traceable()classRepeat(TensorOp):"""Repeat a TensorOp several times in a row. Args: op: A TensorOp to be run one or more times in a row. repeat: How many times to repeat the `op`. This can also be a function return, in which case the function input names will be matched to keys in the data dictionary, and the `op` will be repeated until the function evaluates to False. The function evaluation will happen at the end of a forward call, so the `op` will always be evaluated at least once. Raises: ValueError: If `repeat` or `op` are invalid. """def__init__(self,op:TensorOp,repeat:Union[int,Callable[...,bool]]=1)->None:self.repeat_inputs=[]extra_reqs=[]ifisinstance(repeat,int):ifrepeat<1:raiseValueError(f"Repeat requires repeat to be >= 1, but got {repeat}")else:self.repeat_inputs.extend(inspect.signature(repeat).parameters.keys())extra_reqs=list(set(self.repeat_inputs)-set(op.outputs))self.repeat=repeatsuper().__init__(inputs=op.inputs+extra_reqs,outputs=op.outputs,mode=op.mode)self.ops=[op]self.retain_graph=Noneself.while_fn=None@propertydefop(self)->TensorOp:returnself.ops[0]defbuild(self,framework:str)->None:self.op.build(framework)ifframework=='tf':self.while_fn=self._tf_whileelse:self.while_fn=self._torch_whiledefget_fe_models(self)->Set[Model]:returnself.op.get_fe_models()defget_fe_loss_keys(self)->Set[str]:returnself.op.get_fe_loss_keys()deffe_retain_graph(self,retain:Optional[bool]=None)->Optional[bool]:ifretainisnotNone:self.retain_graph=retainreturnself.op.fe_retain_graph(retain)def__getstate__(self)->Dict[str,List[Dict[Any,Any]]]:return{'ops':[elem.__getstate__()ifhasattr(elem,'__getstate__')else{}foreleminself.ops]}defforward(self,data:List[Tensor],state:Dict[str,Any])->List[Tensor]:# Set retain to true since might loop over a gradient aware opself.op.fe_retain_graph(True)data={key:elemforkey,eleminzip(self.inputs,data)}ifisinstance(self.repeat,int):foriinrange(self.repeat-1):# Perform n-1 rounds with all ops having retain_graph == TrueBaseNetwork._forward_batch(data,state,self.ops)# Let retain be whatever it was meant to be for the final sequenceself.op.fe_retain_graph(self.retain_graph)# Final round of opsBaseNetwork._forward_batch(data,state,self.ops)else:BaseNetwork._forward_batch(data,state,self.ops)data=self.while_fn(data,state)# TODO - Find some magic way to invoke this at the right momentself.op.fe_retain_graph(self.retain_graph)return[data[key]forkeyinself.outputs]def_torch_while(self,data:Dict[str,Tensor],state:Dict[str,Any])->Dict[str,Tensor]:"""A helper function to invoke a while loop. Args: data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: A reference to the updated data dictionary. """whileself.repeat(*[data[var_name]forvar_nameinself.repeat_inputs]):BaseNetwork._forward_batch(data,state,self.ops)returndatadef_tf_while(self,data:Dict[str,Tensor],state:Dict[str,Any])->Dict[str,Tensor]:"""A helper function to invoke a while loop. Args: data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: A reference to the updated data dictionary. """args=([data[var_name]forvar_nameinself.repeat_inputs],data,state)args=tf.while_loop(self._tf_cond,self._tf_body,args)returnargs[1]def_tf_cond(self,cnd:List[Tensor],data:Dict[str,Tensor],state:Dict[str,Any])->bool:"""A helper function determine whether to keep invoking the while method. Note that `data` and `state` are unused here, but required since tf.while_loop needs the cond and body to have the same input argument signatures. Args: cnd: A list of arguments to be passed to the condition function. data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: Whether to continue looping. """returnself.repeat(*cnd)def_tf_body(self,cnd:List[Tensor],data:Dict[str,Tensor],state:Dict[str,Any])->Tuple[List[Tensor],Dict[str,Tensor],Dict[str,Any]]:"""A helper function to execute the body of a while method. Note that `cnd` is unused here, but required since tf.while_loop needs the cond and body to have the same input argument signatures. Args: cnd: A list of arguments to be passed to the condition function. data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: The updated `cnd` values, along with the modified data and state dictionaries. """BaseNetwork._forward_batch(data,state,self.ops)return[data[var_name]forvar_nameinself.repeat_inputs],data,state