@traceable()
class Repeat(NumpyOp):
"""Repeat a NumpyOp several times in a row.
Args:
op: A NumpyOp to be run one or more times in a row.
repeat: How many times to repeat the `op`. This can also be a function return, in which case the function input
names will be matched to keys in the data dictionary, and the `op` will be repeated until the function
evaluates to False. The function evaluation will happen at the end of a forward call, so the `op` will
always be evaluated at least once.
Raises:
ValueError: If `repeat` or `op` are invalid.
"""
def __init__(self, op: NumpyOp, repeat: Union[int, Callable[..., bool]] = 1) -> None:
self.repeat_inputs = []
extra_reqs = []
if isinstance(repeat, int):
if repeat < 1:
raise ValueError(f"Repeat requires repeat to be >= 1, but got {repeat}")
else:
self.repeat_inputs.extend(inspect.signature(repeat).parameters.keys())
extra_reqs = list(set(self.repeat_inputs) - set(op.outputs))
self.repeat = repeat
super().__init__(inputs=op.inputs + extra_reqs, outputs=op.outputs, mode=op.mode)
self.ops = [op]
@property
def op(self) -> NumpyOp:
return self.ops[0]
def __getstate__(self) -> Dict[str, List[Dict[Any, Any]]]:
return {'ops': [elem.__getstate__() if hasattr(elem, '__getstate__') else {} for elem in self.ops]}
def forward(self, data: List[np.ndarray], state: Dict[str, Any]) -> List[np.ndarray]:
data = {key: elem for key, elem in zip(self.inputs, data)}
if isinstance(self.repeat, int):
for i in range(self.repeat):
forward_numpyop(self.ops, data, state["mode"])
else:
forward_numpyop(self.ops, data, state["mode"])
while self.repeat(*[data[var_name] for var_name in self.repeat_inputs]):
forward_numpyop(self.ops, data, state["mode"])
return [data[key] for key in self.outputs]