Skip to content



Load and return the ciFAIR10 dataset.

This is the cifar10 dataset but with test set duplicates removed and replaced. See or for details. Cite the paper if you use the dataset.


Name Type Description Default
root_dir str

The path to store the downloaded data. When path is not provided, the data will be saved into fastestimator_data under the user's home directory.

image_key str

The key for image.

label_key str

The key for label.



Type Description
Tuple[NumpyDataset, NumpyDataset]

(train_data, test_data)

Source code in fastestimator/fastestimator/dataset/data/
def load_data(root_dir: str = None, image_key: str = "x", label_key: str = "y") -> Tuple[NumpyDataset, NumpyDataset]:
    """Load and return the ciFAIR10 dataset.

    This is the cifar10 dataset but with test set duplicates removed and replaced. See or for details. Cite the paper if you use the

        root_dir: The path to store the downloaded data. When `path` is not provided, the data will be saved into
            `fastestimator_data` under the user's home directory.
        image_key: The key for image.
        label_key: The key for label.

        (train_data, test_data)
    home = str(Path.home())

    if root_dir is None:
        root_dir = os.path.join(home, 'fastestimator_data', 'ciFAIR10')
        root_dir = os.path.join(os.path.abspath(root_dir), 'ciFAIR10')
    os.makedirs(root_dir, exist_ok=True)

    image_compressed_path = os.path.join(root_dir, '')
    image_extracted_path = os.path.join(root_dir, 'ciFAIR-10')

    if not os.path.exists(image_extracted_path):
        print("Downloading data to {}".format(root_dir))
        download_file_from_google_drive('1dqTgqMVvgx_FZNAC7TqzoA0hYX1ttOUq', image_compressed_path)

        print("Extracting data to {}".format(root_dir))
        shutil.unpack_archive(image_compressed_path, root_dir)

    num_train_samples = 50000

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
    y_train = np.empty((num_train_samples, ), dtype='uint8')

    for i in range(1, 6):
        fpath = os.path.join(image_extracted_path, f'data_batch_{i}')
        (x_train[(i - 1) * 10000:i * 10000, :, :, :], y_train[(i - 1) * 10000:i * 10000]) = _load_batch(fpath)

    fpath = os.path.join(image_extracted_path, 'test_batch')
    x_test, y_test = _load_batch(fpath)

    y_train = np.array(y_train, dtype=np.uint8)
    y_test = np.array(y_test, dtype=np.uint8)

    x_train = x_train.transpose((0, 2, 3, 1))
    x_test = x_test.transpose((0, 2, 3, 1))

    x_test = x_test.astype(x_train.dtype)

    train_data = NumpyDataset({image_key: x_train, label_key: y_train})
    test_data = NumpyDataset({image_key: x_test, label_key: y_test})
    return train_data, test_data