Note that Sometimes should not be used to wrap an op whose output key(s) do not already exist in the data
dictionary. This would result in a problem when future ops / traces attempt to reference the output key, but
Sometimes declined to generate it. If you want to create a default value for a new key, simply use a LambdaOp before
invoking the Sometimes.
Parameters:
Name
Type
Description
Default
tensor_op
TensorOp
The operator to be performed.
required
prob
float
The probability of execution, which should be in the range: [0-1).
0.5
Source code in fastestimator/fastestimator/op/tensorop/meta/sometimes.py
@traceable()classSometimes(TensorOp):"""Perform a NumpyOp with a given probability. Note that Sometimes should not be used to wrap an op whose output key(s) do not already exist in the data dictionary. This would result in a problem when future ops / traces attempt to reference the output key, but Sometimes declined to generate it. If you want to create a default value for a new key, simply use a LambdaOp before invoking the Sometimes. Args: tensor_op: The operator to be performed. prob: The probability of execution, which should be in the range: [0-1). """def__init__(self,tensor_op:TensorOp,prob:float=0.5)->None:# We're going to try to collect any missing output keys from the data dictionary so that they don't get# overridden when Sometimes chooses not to execute.inps=set(tensor_op.inputs)outs=set(tensor_op.outputs)self.extra_inputs=list(outs-inps)# Used by traceabilityself.inp_idx=len(tensor_op.inputs)super().__init__(inputs=tensor_op.inputs+self.extra_inputs,outputs=tensor_op.outputs,mode=tensor_op.mode,ds_id=tensor_op.ds_id)# Note that in_list and out_list will always be trueself.op=tensor_opself.prob=probself.prob_fn=Nonedefbuild(self,framework:str,device:Optional[torch.device]=None)->None:self.op.build(framework,device)ifframework=='tf':self.prob_fn=tfp.distributions.Uniform()elifframework=='torch':self.prob_fn=torch.distributions.uniform.Uniform(low=0,high=1)else:raiseValueError("unrecognized framework: {}".format(framework))defget_fe_loss_keys(self)->Set[str]:returnself.op.get_fe_loss_keys()defget_fe_models(self)->Set[Model]:returnself.op.get_fe_models()deffe_retain_graph(self,retain:Optional[bool]=None)->Optional[bool]:returnself.op.fe_retain_graph(retain)def__getstate__(self)->Dict[str,Dict[Any,Any]]:return{'op':self.op.__getstate__()ifhasattr(self.op,'__getstate__')else{}}defforward(self,data:List[Tensor],state:Dict[str,Any])->List[Tensor]:"""Execute the wrapped operator a certain fraction of the time. Args: data: The information to be passed to the wrapped operator. state: Information about the current execution context, for example {"mode": "train"}. Returns: The original `data`, or the `data` after running it through the wrapped operator. """ifself.prob>self.prob_fn.sample():data=data[:self.inp_idx]# Cut off the unnecessary inputsifnotself.op.in_list:data=data[0]data=self.op.forward(data,state)ifnotself.op.out_list:data=[data]else:data=[data[self.inputs.index(out)]foroutinself.outputs]returndata
defforward(self,data:List[Tensor],state:Dict[str,Any])->List[Tensor]:"""Execute the wrapped operator a certain fraction of the time. Args: data: The information to be passed to the wrapped operator. state: Information about the current execution context, for example {"mode": "train"}. Returns: The original `data`, or the `data` after running it through the wrapped operator. """ifself.prob>self.prob_fn.sample():data=data[:self.inp_idx]# Cut off the unnecessary inputsifnotself.op.in_list:data=data[0]data=self.op.forward(data,state)ifnotself.op.out_list:data=[data]else:data=[data[self.inputs.index(out)]foroutinself.outputs]returndata