Skip to content

tednmt

load_data

Load and return the neural machine translation dataset from TED talks.

Parameters:

Name Type Description Default
root_dir Optional[str]

The path to store the downloaded data. When path is not provided, the data will be saved into fastestimator_data under the user's home directory.

None
translate_option str

Options for translation languages. Available options are: "az_to_en", "az_tr_to_en", "be_ru_to_en", "be_to_en", "es_to_pt", "fr_to_pt", "gl_pt_to_en", "gl_to_en", "he_to_pt", "it_to_pt", "pt_to_en", "ru_to_en", "ru_to_pt", and "tr_to_en".

'az_to_en'

Returns:

Type Description
Tuple[NumpyDataset, NumpyDataset, NumpyDataset]

(train_data, eval_data, test_data)

Source code in fastestimator/fastestimator/dataset/data/tednmt.py
def load_data(root_dir: Optional[str] = None,
              translate_option: str = "az_to_en") -> Tuple[NumpyDataset, NumpyDataset, NumpyDataset]:
    """Load and return the neural machine translation dataset from TED talks.

    Args:
        root_dir: The path to store the downloaded data. When `path` is not provided, the data will be saved into
            `fastestimator_data` under the user's home directory.
        translate_option: Options for translation languages. Available options are: "az_to_en", "az_tr_to_en",
            "be_ru_to_en", "be_to_en", "es_to_pt", "fr_to_pt", "gl_pt_to_en", "gl_to_en", "he_to_pt", "it_to_pt",
            "pt_to_en", "ru_to_en", "ru_to_pt", and "tr_to_en".

    Returns:
        (train_data, eval_data, test_data)
    """
    # Set up path
    home = str(Path.home())
    if root_dir is None:
        root_dir = os.path.join(home, 'fastestimator_data', 'tednmt')
    else:
        root_dir = os.path.join(os.path.abspath(root_dir), 'tednmt')
    os.makedirs(root_dir, exist_ok=True)
    compressed_path = os.path.join(root_dir, 'qi18naacl-dataset.tar.gz')
    extracted_path = os.path.join(root_dir, 'datasets')
    if not os.path.exists(extracted_path):
        # Download
        if not os.path.exists(compressed_path):
            print("Downloading data to {}".format(compressed_path))
            wget.download('http://www.phontron.com/data/qi18naacl-dataset.tar.gz', compressed_path, bar=bar_custom)
        # Extract
        print("\nExtracting files ...")
        with tarfile.open(compressed_path) as f:
            f.extractall(root_dir)
    # process data
    data_path = os.path.join(extracted_path, translate_option)
    assert os.path.exists(data_path), "folder {} does not exist, please verify translation options".format(data_path)
    train_ds = _create_dataset(data_path=data_path, translate_option=translate_option, extension="train")
    eval_ds = _create_dataset(data_path=data_path, translate_option=translate_option, extension="dev")
    test_ds = _create_dataset(data_path=data_path, translate_option=translate_option, extension="test")
    return train_ds, eval_ds, test_ds