@traceable()
class OpDataset(Dataset):
"""A wrapper for datasets which allows operators to be applied to them in a pipeline.
This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets
within an Op dataset as needed.
Args:
dataset: The base dataset to wrap.
ops: A list of ops to be applied after the base `dataset` `__getitem__` is invoked.
mode: What mode the system is currently running in ('train', 'eval', 'test', or 'infer').
output_keys: What keys can be produced from pipeline. If None or empty, all keys will be considered.
deep_remainder: Whether data which is not modified by Ops should be deep copied or not. This argument is used to
help with RAM management, but end users can almost certainly ignore it.
"""
def __init__(self,
dataset: Dataset,
ops: List[NumpyOp],
mode: str,
output_keys: Optional[Set[str]] = None,
deep_remainder: bool = True) -> None:
# Track whether this dataset returns batches or not (useful for pipeline and traceability)
if not hasattr(dataset, "fe_batch"):
sample_item = dataset[0]
dataset.fe_batch = len(sample_item) if isinstance(sample_item, list) else 0
self.dataset = dataset
self.fe_batch = dataset.fe_batch
if hasattr(dataset, "fe_reset_ds"):
self.fe_reset_ds = dataset.fe_reset_ds
if hasattr(dataset, "fe_batch_indices"):
self.fe_batch_indices = dataset.fe_batch_indices
self.ops = ops
self.mode = mode
self.output_keys = output_keys
self.deep_remainder = deep_remainder
def __getitem__(self, index: int) -> Union[Mapping[str, Any], List[Mapping[str, Any]], FilteredData]:
"""Fetch a data instance at a specified index, and apply transformations to it.
Args:
index: Which datapoint to retrieve.
Returns:
The data dictionary from the specified index, with transformations applied OR an indication that this index
should be thrown out.
"""
item = self.dataset[index]
if isinstance(item, list):
# BatchDataset may randomly sample the same elements multiple times, so need to avoid reprocessing
unique_samples = {} # id: idx
results = []
for idx, data in enumerate(item):
data_id = id(data)
if data_id not in unique_samples:
data = _DelayedDeepDict(data)
filter_data = forward_numpyop(self.ops, data, {'mode': self.mode})
if filter_data:
results.append(filter_data)
else:
data.finalize(retain=self.output_keys, deep_remainder=self.deep_remainder)
results.append(data)
unique_samples[data_id] = idx
else:
results.append(results[unique_samples[data_id]])
else:
results = _DelayedDeepDict(item)
filter_data = forward_numpyop(self.ops, results, {'mode': self.mode})
if filter_data:
return filter_data
results.finalize(retain=self.output_keys, deep_remainder=self.deep_remainder)
return results
def __len__(self):
return len(self.dataset)