Skip to content

repeat

Repeat

Bases: NumpyOp

Repeat a NumpyOp several times in a row.

Parameters:

Name Type Description Default
op NumpyOp

A NumpyOp to be run one or more times in a row.

required
repeat Union[int, Callable[..., bool]]

How many times to repeat the op. This can also be a function return, in which case the function input names will be matched to keys in the data dictionary, and the op will be repeated until the function evaluates to False. The function evaluation will happen at the end of a forward call, so the op will always be evaluated at least once.

1

Raises:

Type Description
ValueError

If repeat or op are invalid.

Source code in fastestimator\fastestimator\op\numpyop\meta\repeat.py
@traceable()
class Repeat(NumpyOp):
    """Repeat a NumpyOp several times in a row.

    Args:
        op: A NumpyOp to be run one or more times in a row.
        repeat: How many times to repeat the `op`. This can also be a function return, in which case the function input
            names will be matched to keys in the data dictionary, and the `op` will be repeated until the function
            evaluates to False. The function evaluation will happen at the end of a forward call, so the `op` will
            always be evaluated at least once.

    Raises:
        ValueError: If `repeat` or `op` are invalid.
    """
    def __init__(self, op: NumpyOp, repeat: Union[int, Callable[..., bool]] = 1) -> None:
        self.repeat_inputs = []
        extra_reqs = []
        if isinstance(repeat, int):
            if repeat < 1:
                raise ValueError(f"Repeat requires repeat to be >= 1, but got {repeat}")
        else:
            self.repeat_inputs.extend(inspect.signature(repeat).parameters.keys())
            extra_reqs = list(set(self.repeat_inputs) - set(op.outputs))
        self.repeat = repeat
        super().__init__(inputs=op.inputs + extra_reqs, outputs=op.outputs, mode=op.mode)
        self.ops = [op]

    @property
    def op(self) -> NumpyOp:
        return self.ops[0]

    def __getstate__(self) -> Dict[str, List[Dict[Any, Any]]]:
        return {'ops': [elem.__getstate__() if hasattr(elem, '__getstate__') else {} for elem in self.ops]}

    def forward(self, data: List[np.ndarray], state: Dict[str, Any]) -> List[np.ndarray]:
        data = {key: elem for key, elem in zip(self.inputs, data)}
        if isinstance(self.repeat, int):
            for i in range(self.repeat):
                forward_numpyop(self.ops, data, state["mode"])
        else:
            forward_numpyop(self.ops, data, state["mode"])
            while self.repeat(*[data[var_name] for var_name in self.repeat_inputs]):
                forward_numpyop(self.ops, data, state["mode"])
        return [data[key] for key in self.outputs]