Skip to content

op_dataset

OpDataset

Bases: Dataset

A wrapper for datasets which allows operators to be applied to them in a pipeline.

This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets within an Op dataset as needed.

Parameters:

Name Type Description Default
dataset Dataset

The base dataset to wrap.

required
ops List[NumpyOp]

A list of ops to be applied after the base dataset __getitem__ is invoked.

required
mode str

What mode the system is currently running in ('train', 'eval', 'test', or 'infer').

required
Source code in fastestimator\fastestimator\dataset\op_dataset.py
class OpDataset(Dataset):
    """A wrapper for datasets which allows operators to be applied to them in a pipeline.

    This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets
    within an Op dataset as needed.

    Args:
        dataset: The base dataset to wrap.
        ops: A list of ops to be applied after the base `dataset` `__getitem__` is invoked.
        mode: What mode the system is currently running in ('train', 'eval', 'test', or 'infer').
    """
    def __init__(self, dataset: Dataset, ops: List[NumpyOp], mode: str) -> None:
        self.dataset = dataset
        if isinstance(self.dataset, BatchDataset):
            self.dataset.reset_index_maps()
        self.ops = ops
        self.mode = mode

    def __getitem__(self, index: int) -> Mapping[str, Any]:
        """Fetch a data instance at a specified index, and apply transformations to it.

        Args:
            index: Which datapoint to retrieve.

        Returns:
            The data dictionary from the specified index, with transformations applied.
        """
        items = deepcopy(self.dataset[index])  # Deepcopy to prevent ops from overwriting values in datasets
        if isinstance(self.dataset, BatchDataset):
            unique_list = []
            for item in items:
                if id(item) not in unique_list:
                    forward_numpyop(self.ops, item, self.mode)
                    unique_list.append(id(item))
            if self.dataset.pad_value is not None:
                pad_batch(items, self.dataset.pad_value)
            items = {key: np.array([item[key] for item in items]) for key in items[0]}
        else:
            forward_numpyop(self.ops, items, self.mode)
        return items

    def __len__(self):
        return len(self.dataset)