Skip to content

op_dataset

OpDataset

Bases: Dataset

A wrapper for datasets which allows operators to be applied to them in a pipeline.

This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets within an Op dataset as needed.

Parameters:

Name Type Description Default
dataset Dataset

The base dataset to wrap.

required
ops List[NumpyOp]

A list of ops to be applied after the base dataset __getitem__ is invoked.

required
mode str

What mode the system is currently running in ('train', 'eval', 'test', or 'infer').

required
output_keys Optional[Set[str]]

What keys can be produced from pipeline. If None, all keys will be considered.

None
deep_remainder bool

Whether data which is not modified by Ops should be deep copied or not. This argument is used to help with RAM management, but end users can almost certainly ignore it.

True
shuffle bool

Whether to shuffle batched datasets every epoch.

True
Source code in fastestimator\fastestimator\dataset\op_dataset.py
@traceable()
class OpDataset(Dataset):
    """A wrapper for datasets which allows operators to be applied to them in a pipeline.

    This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets
    within an Op dataset as needed.

    Args:
        dataset: The base dataset to wrap.
        ops: A list of ops to be applied after the base `dataset` `__getitem__` is invoked.
        mode: What mode the system is currently running in ('train', 'eval', 'test', or 'infer').
        output_keys: What keys can be produced from pipeline. If None, all keys will be considered.
        deep_remainder: Whether data which is not modified by Ops should be deep copied or not. This argument is used to
            help with RAM management, but end users can almost certainly ignore it.
        shuffle: Whether to shuffle batched datasets every epoch.
    """
    def __init__(self,
                 dataset: Dataset,
                 ops: List[NumpyOp],
                 mode: str,
                 output_keys: Optional[Set[str]] = None,
                 deep_remainder: bool = True,
                 shuffle: bool = True) -> None:
        self.dataset = dataset
        if hasattr(self.dataset, "reset_index_maps") and shuffle:
            self.dataset.reset_index_maps()
        self.ops = ops
        self.mode = mode
        self.output_keys = output_keys
        self.deep_remainder = deep_remainder

    def __getitem__(self, index: int) -> Mapping[str, Any]:
        """Fetch a data instance at a specified index, and apply transformations to it.

        Args:
            index: Which datapoint to retrieve.

        Returns:
            The data dictionary from the specified index, with transformations applied.
        """
        item = self.dataset[index]
        if isinstance(item, list):
            # BatchDataset may randomly sample the same elements multiple times, so need to avoid reprocessing
            unique_samples = {}  # id: idx
            results = []
            for idx, data in enumerate(item):
                data_id = id(data)
                if data_id not in unique_samples:
                    data = _DelayedDeepDict(data)
                    forward_numpyop(self.ops, data, {'mode': self.mode})
                    data.finalize(retain=self.output_keys, deep_remainder=self.deep_remainder)
                    results.append(data)
                    unique_samples[data_id] = idx
                else:
                    results.append(results[unique_samples[data_id]])
            if hasattr(self.dataset, "pad_value") and self.dataset.pad_value is not None:
                pad_batch(results, self.dataset.pad_value)
            results = {key: np.array([result[key] for result in results]) for key in results[0]}
        else:
            results = _DelayedDeepDict(item)
            forward_numpyop(self.ops, results, {'mode': self.mode})
            results.finalize(retain=self.output_keys, deep_remainder=self.deep_remainder)
        return results

    def __len__(self):
        return len(self.dataset)