Repeat takes an Op and runs it multiple times in a row. It can be set to repeat for a fixed (static) number of
times, or to repeat until a given input function evaluates to False (dynamic).
ops=[
LambdaOp(fn=lambda: 0, outputs="z"),
Repeat(AddOne(inputs="z", outputs="z"), repeat=lambda z: z < 6.5)
]
Note : Here the argument ('z') of the lambda function used as repeat callable function is the key used by the
ops passed to the Repeat Op.
Parameters:
Name
Type
Description
Default
op
TensorOp
A TensorOp to be run one or more times in a row.
required
repeat
Union[int, Callable[..., bool]]
How many times to repeat the op. This can also be a function return, in which case the function input
names will be matched to keys in the data dictionary, and the op will be repeated until the function
evaluates to False. The function evaluation will happen at the end of a forward call, so the op will
always be evaluated at least once.
1
max_iter
Optional[int]
A limit to how many iterations will be run (or None for no limit).
None
Raises:
Type
Description
ValueError
If repeat, op, or max_iter are invalid.
Source code in fastestimator/fastestimator/op/tensorop/meta/repeat.py
@traceable()classRepeat(TensorOp):"""Repeat a TensorOp several times in a row. Repeat takes an Op and runs it multiple times in a row. It can be set to repeat for a fixed (static) number of times, or to repeat until a given input function evaluates to False (dynamic). Static example: ops=[ LambdaOp(fn=lambda: 0, outputs="z"), Repeat(AddOne(inputs="z", outputs="z"), repeat=5) ] Dynamic example: ops=[ LambdaOp(fn=lambda: 0, outputs="z"), Repeat(AddOne(inputs="z", outputs="z"), repeat=lambda z: z < 6.5) ] Note : Here the argument ('z') of the lambda function used as repeat callable function is the key used by the ops passed to the Repeat Op. Args: op: A TensorOp to be run one or more times in a row. repeat: How many times to repeat the `op`. This can also be a function return, in which case the function input names will be matched to keys in the data dictionary, and the `op` will be repeated until the function evaluates to False. The function evaluation will happen at the end of a forward call, so the `op` will always be evaluated at least once. max_iter: A limit to how many iterations will be run (or None for no limit). Raises: ValueError: If `repeat`, `op`, or max_iter are invalid. """def__init__(self,op:TensorOp,repeat:Union[int,Callable[...,bool]]=1,max_iter:Optional[int]=None)->None:self.repeat_inputs=[]extra_reqs=[]ifmax_iterisNone:self.max_iter=max_iterelse:ifmax_iter<1:raiseValueError(f"Repeat requires max_iter to be >=1, but got {max_iter}")self.max_iter=max_iter-1# -1 b/c the first invocation happens outside the while loopifisinstance(repeat,int):ifrepeat<1:raiseValueError(f"Repeat requires repeat to be >= 1, but got {repeat}")ifmax_iter:raiseValueError("Do not set max_iter when repeat is an integer")else:self.repeat_inputs.extend(inspect.signature(repeat).parameters.keys())extra_reqs=list(set(self.repeat_inputs)-set(op.outputs))self.repeat=repeatsuper().__init__(inputs=op.inputs+extra_reqs,outputs=op.outputs,mode=op.mode,ds_id=op.ds_id)self.ops=[op]self.retain_graph=Noneself.while_fn=None@propertydefop(self)->TensorOp:returnself.ops[0]defbuild(self,framework:str,device:Optional[torch.device]=None)->None:self.op.build(framework,device)# Below the while function is chosen based on frameworkifframework=='tf':# For tensorflow the while function is decided based of object type of 'self.repeat'.ifisinstance(self.repeat,int):self.while_fn=self._tf_while_intelse:self.while_fn=self._tf_whileelse:self.while_fn=self._torch_whiledefget_fe_models(self)->Set[Model]:returnself.op.get_fe_models()defget_fe_loss_keys(self)->Set[str]:returnself.op.get_fe_loss_keys()deffe_retain_graph(self,retain:Optional[bool]=None)->Optional[bool]:ifretainisnotNone:self.retain_graph=retainreturnself.op.fe_retain_graph(retain)def__getstate__(self)->Dict[str,List[Dict[Any,Any]]]:return{'ops':[elem.__getstate__()ifhasattr(elem,'__getstate__')else{}foreleminself.ops]}defforward(self,data:List[Tensor],state:Dict[str,Any])->List[Tensor]:# Set retain to true since might loop over a gradient aware opself.op.fe_retain_graph(True)data={key:elemforkey,eleminzip(self.inputs,data)}ifisinstance(self.repeat,int):data=self.while_fn(data,state)else:BaseNetwork._forward_batch(data,state,self.ops)data=self.while_fn(data,state)# TODO - Find some magic way to invoke this at the right momentself.op.fe_retain_graph(self.retain_graph)return[data[key]forkeyinself.outputs]def_torch_while(self,data:Dict[str,Tensor],state:Dict[str,Any])->Dict[str,Tensor]:"""A helper function to invoke a loop. Args: data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: A reference to the updated data dictionary. """ifisinstance(self.repeat,int):for_inrange(self.repeat-1):# Perform n-1 rounds with all ops having retain_graph == TrueBaseNetwork._forward_batch(data,state,self.ops)# Let retain be whatever it was meant to be for the final sequenceself.op.fe_retain_graph(self.retain_graph)# Final round of ops to ensure accurate graph building in case we dont retain the graphBaseNetwork._forward_batch(data,state,self.ops)else:i=0whileself.repeat(*[data[var_name]forvar_nameinself.repeat_inputs]):ifself.max_iterandi>=self.max_iter:breakBaseNetwork._forward_batch(data,state,self.ops)i+=1returndatadef_tf_while_int(self,data:Dict[str,Tensor],state:Dict[str,Any])->Dict[str,Tensor]:"""A helper function to invoke a while loop in case self.repeat is an integer. Experiment were conducted to compare performance of tf.while_loop() with tf.range(), where tf.range outperformed tf.while_loop() in most scenarios. But it was found that tensors cannot be overwritten inside the scope of tf.range() and hence the RepeatOp failed on few Ops (eg: Ops which were updating the inputs). Creating a copy of tensor in every iteration of tf.range() resolved this issue, but also dissolved all the advantages of tf.range(). Args: data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: A reference to the updated data dictionary. """ifself.repeat==1:# Let retain be whatever it was meant to be for the final sequence# This is done right before the only forward pass to ensure accurate graph building in case# we dont retain the graphself.op.fe_retain_graph(self.retain_graph)# Final round of opsBaseNetwork._forward_batch(data,state,self.ops)elifself.repeat==2:BaseNetwork._forward_batch(data,state,self.ops)# Let retain be whatever it was meant to be for the final sequence# This is done right before the last forward pass to ensure accurate graph building in case# we dont retain the graphself.op.fe_retain_graph(self.retain_graph)# Final round of opsBaseNetwork._forward_batch(data,state,self.ops)else:# Run a forward pass to ensure that data dictionary structure doesn't change during while loop executionBaseNetwork._forward_batch(data,state,self.ops)args=(tf.constant(1),data)# Use functools.partial since state may contain objects which cannot be cast to tensors (ex. gradient tape)args=tf.while_loop(self._tf_cond,functools.partial(self._tf_body,state=state),args,maximum_iterations=self.max_iter)# Let retain be whatever it was meant to be for the final sequence# This is done right before the last forward pass to ensure accurate graph building in case# we dont retain the graphself.op.fe_retain_graph(self.retain_graph)data=args[1]# Final round of opsBaseNetwork._forward_batch(data,state,self.ops)returndatadef_tf_while(self,data:Dict[str,Tensor],state:Dict[str,Any])->Dict[str,Tensor]:"""A helper function to invoke a while loop in case self.repeat is a callable function. Args: data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: A reference to the updated data dictionary. """args=([data[var_name]forvar_nameinself.repeat_inputs],data)# Use functools.partial since state may contain objects which cannot be cast to tensors (ex. gradient tape)args=tf.while_loop(self._tf_cond,functools.partial(self._tf_body,state=state),args,maximum_iterations=self.max_iter,parallel_iterations=1)returnargs[1]def_tf_cond(self,cnd:Union[List[Tensor],Tensor],data:Dict[str,Tensor])->bool:"""A helper function determine whether to keep invoking the while method. Note that `data` and `state` are unused here, but required since tf.while_loop needs the cond and body to have the same input argument signatures. Args: cnd: A list of arguments to be passed to the condition function. data: A data dictionary to be used during looping. Returns: Whether to continue looping. """ifisinstance(self.repeat,int):# In this case we have 2 Forward calls for tf (one before and one after the while loop# (For accurate Tf while loop functioning))returntf.less(cnd,self.repeat-1)returnself.repeat(*cnd)def_tf_body(self,cnd:Union[List[Tensor],Tensor],data:Dict[str,Tensor],state:Dict[str,Any])->Tuple[Union[List[Tensor],Tensor],Dict[str,Tensor]]:"""A helper function to execute the body of a while method. Note that `cnd` is unused here, but required since tf.while_loop needs the cond and body to have the same input argument signatures. Args: cnd: A list of arguments to be passed to the condition function. data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: The updated `cnd` values, along with the modified data and state dictionaries. """# Run a round of opsBaseNetwork._forward_batch(data,state,self.ops)ifisinstance(self.repeat,int):# Updating the while conditionreturntf.add(cnd,1),datareturn[data[var_name]forvar_nameinself.repeat_inputs],data