Bases: Dataset
A wrapper for datasets which allows operators to be applied to them in a pipeline.
This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets
within an Op dataset as needed.
Parameters:
Name |
Type |
Description |
Default |
dataset |
Dataset
|
The base dataset to wrap. |
required
|
ops |
List[NumpyOp]
|
A list of ops to be applied after the base dataset __getitem__ is invoked. |
required
|
mode |
str
|
What mode the system is currently running in ('train', 'eval', 'test', or 'infer'). |
required
|
Source code in fastestimator\fastestimator\dataset\op_dataset.py
| class OpDataset(Dataset):
"""A wrapper for datasets which allows operators to be applied to them in a pipeline.
This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets
within an Op dataset as needed.
Args:
dataset: The base dataset to wrap.
ops: A list of ops to be applied after the base `dataset` `__getitem__` is invoked.
mode: What mode the system is currently running in ('train', 'eval', 'test', or 'infer').
"""
def __init__(self, dataset: Dataset, ops: List[NumpyOp], mode: str) -> None:
self.dataset = dataset
if isinstance(self.dataset, BatchDataset):
self.dataset.reset_index_maps()
self.ops = ops
self.mode = mode
def __getitem__(self, index: int) -> Mapping[str, Any]:
"""Fetch a data instance at a specified index, and apply transformations to it.
Args:
index: Which datapoint to retrieve.
Returns:
The data dictionary from the specified index, with transformations applied.
"""
items = deepcopy(self.dataset[index]) # Deepcopy to prevent ops from overwriting values in datasets
if isinstance(self.dataset, BatchDataset):
unique_list = []
for item in items:
if id(item) not in unique_list:
forward_numpyop(self.ops, item, self.mode)
unique_list.append(id(item))
if self.dataset.pad_value is not None:
pad_batch(items, self.dataset.pad_value)
items = {key: np.array([item[key] for item in items]) for key in items[0]}
else:
forward_numpyop(self.ops, items, self.mode)
return items
def __len__(self):
return len(self.dataset)
|