@traceable()classOneOf(TensorOp):"""Perform one of several possible TensorOps. Args: *tensor_ops: A list of ops to choose between with uniform probability. """def__init__(self,*tensor_ops:TensorOp)->None:inputs=tensor_ops[0].inputsoutputs=tensor_ops[0].outputsmode=tensor_ops[0].modesuper().__init__(inputs=inputs,outputs=outputs,mode=mode)self.in_list=tensor_ops[0].in_listself.out_list=tensor_ops[0].out_listforopintensor_ops[1:]:assertinputs==op.inputs,"All ops within a OneOf must share the same inputs"assertself.in_list==op.in_list,"All ops within OneOf must share the same input configuration"assertoutputs==op.outputs,"All ops within a OneOf must share the same outputs"assertself.out_list==op.out_list,"All ops within OneOf must share the same output configuration"assertmode==op.mode,"All ops within a OneOf must share the same mode"self.ops=tensor_opsself.prob_fn=Noneself.invoke_fn=Nonedefbuild(self,framework:str)->None:ifframework=='tf':self.prob_fn=tfp.distributions.Uniform(low=0,high=len(self.ops))self.invoke_fn=lambdaidx,data,state:tf.switch_case(idx,[lambda:op.forward(data,state)foropinself.ops])elifframework=='torch':self.prob_fn=torch.distributions.uniform.Uniform(low=0,high=len(self.ops))self.invoke_fn=lambdaidx,data,state:self.ops[idx].forward(data,state)else:raiseValueError("unrecognized framework: {}".format(framework))defget_fe_loss_keys(self)->Set[str]:returnset.union(*[op.get_fe_loss_keys()foropinself.ops])defget_fe_models(self)->Set[Model]:returnset.union(*[op.get_fe_models()foropinself.ops])deffe_retain_graph(self,retain:Optional[bool]=None)->Optional[bool]:resp=Noneforopinself.ops:resp=resporop.fe_retain_graph(retain)returnrespdef__getstate__(self)->Dict[str,List[Dict[Any,Any]]]:return{'ops':[elem.__getstate__()ifhasattr(elem,'__getstate__')else{}foreleminself.ops]}defforward(self,data:Union[Tensor,List[Tensor]],state:Dict[str,Any])->Union[Tensor,List[Tensor]]:"""Execute a randomly selected op from the list of `numpy_ops`. Args: data: The information to be passed to one of the wrapped operators. state: Information about the current execution context, for example {"mode": "train"}. Returns: The `data` after application of one of the available numpyOps. """idx=cast(self.prob_fn.sample(),dtype='int32')returnself.invoke_fn(idx,data,state)
defforward(self,data:Union[Tensor,List[Tensor]],state:Dict[str,Any])->Union[Tensor,List[Tensor]]:"""Execute a randomly selected op from the list of `numpy_ops`. Args: data: The information to be passed to one of the wrapped operators. state: Information about the current execution context, for example {"mode": "train"}. Returns: The `data` after application of one of the available numpyOps. """idx=cast(self.prob_fn.sample(),dtype='int32')returnself.invoke_fn(idx,data,state)