Skip to content

numpy_dataset

NumpyDataset

Bases: InMemoryDataset

A dataset constructed from a dictionary of Numpy data or list of data.

Parameters:

Name Type Description Default
data Dict[str, Union[np.ndarray, List]]

A dictionary of data like {"key1": , "key2": [list]}.

required

Raises:

Type Description
AssertionError

If any of the Numpy arrays or lists have differing numbers of elements.

ValueError

If any dictionary value is not instance of Numpy array or list.

Source code in fastestimator\fastestimator\dataset\numpy_dataset.py
@traceable()
class NumpyDataset(InMemoryDataset):
    """A dataset constructed from a dictionary of Numpy data or list of data.

    Args:
        data: A dictionary of data like {"key1": <numpy array>, "key2": [list]}.
    Raises:
        AssertionError: If any of the Numpy arrays or lists have differing numbers of elements.
        ValueError: If any dictionary value is not instance of Numpy array or list.
    """
    def __init__(self, data: Dict[str, Union[np.ndarray, List]]) -> None:
        size = None
        for val in data.values():
            if isinstance(val, np.ndarray):
                current_size = val.shape[0]
            elif isinstance(val, list):
                current_size = len(val)
            else:
                raise ValueError("Please ensure you are passing numpy array or list in the data dictionary.")
            if size is not None:
                assert size == current_size, "All data arrays must have the same number of elements"
            else:
                size = current_size
        super().__init__({i: {k: v[i] for k, v in data.items()} for i in range(size)})