Skip to content

dir_dataset

DirDataset

Bases: InMemoryDataset

A dataset which reads files from a folder hierarchy like root/data.file.

Parameters:

Name Type Description Default
root_dir str

The path to the directory containing data.

required
data_key str

What key to assign to the data values in the data dictionary.

'x'
file_extension Optional[str]

If provided then only files ending with the file_extension will be included.

None
recursive_search bool

Whether to search within subdirectories for files.

True
Source code in fastestimator\fastestimator\dataset\dir_dataset.py
@traceable()
class DirDataset(InMemoryDataset):
    """A dataset which reads files from a folder hierarchy like root/data.file.

    Args:
        root_dir: The path to the directory containing data.
        data_key: What key to assign to the data values in the data dictionary.
        file_extension: If provided then only files ending with the file_extension will be included.
        recursive_search: Whether to search within subdirectories for files.
    """
    data: Dict[int, Dict[str, str]]

    def __init__(self,
                 root_dir: str,
                 data_key: str = "x",
                 file_extension: Optional[str] = None,
                 recursive_search: bool = True) -> None:
        data = []
        root_dir = os.path.normpath(root_dir)
        if not os.path.isdir(root_dir):
            raise AssertionError("Provided path is not a directory")
        try:
            for root, dirs, files in os.walk(root_dir):
                for file_name in files:
                    if file_name.startswith(".") or (file_extension is not None
                                                     and not file_name.endswith(file_extension)):
                        continue
                    data.append((os.path.join(root, file_name), os.path.basename(root)))
                if not recursive_search:
                    break
        except StopIteration:
            raise ValueError("Invalid directory structure for DirDataset at root: {}".format(root_dir))
        super().__init__({i: {data_key: data[i][0]} for i in range(len(data))})