Bases: InMemoryDataset
A dataset constructed from a dictionary of Numpy data or list of data.
Parameters:
Name |
Type |
Description |
Default |
data |
Dict[str, Union[ndarray, List]]
|
A dictionary of data like {"key1": , "key2": [list]}.
|
required
|
Raises:
AssertionError: If any of the Numpy arrays or lists have differing numbers of elements.
ValueError: If any dictionary value is not instance of Numpy array or list.
Source code in fastestimator/fastestimator/dataset/numpy_dataset.py
| @traceable()
class NumpyDataset(InMemoryDataset):
"""A dataset constructed from a dictionary of Numpy data or list of data.
Args:
data: A dictionary of data like {"key1": <numpy array>, "key2": [list]}.
Raises:
AssertionError: If any of the Numpy arrays or lists have differing numbers of elements.
ValueError: If any dictionary value is not instance of Numpy array or list.
"""
def __init__(self, data: Dict[str, Union[np.ndarray, List]]) -> None:
size = None
for val in data.values():
if isinstance(val, np.ndarray):
current_size = val.shape[0]
elif isinstance(val, list):
current_size = len(val)
else:
raise ValueError("Please ensure you are passing numpy array or list in the data dictionary.")
if size is not None:
assert size == current_size, "All data arrays must have the same number of elements"
else:
size = current_size
super().__init__({i: {k: v[i] for k, v in data.items()} for i in range(size)} if size else {})
|