Skip to content

save_model

save_model

Save model weights to a specific directory.

This method can be used with TensorFlow models:

m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="/tmp", model_name="test")  # Generates 'test.h5' file inside /tmp directory

This method can be used with PyTorch models:

m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="/tmp", model_name="test")  # Generates 'test.pt' file inside /tmp directory

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A neural network instance to save.

required
save_dir str

Directory into which to write the model weights.

required
model_name Optional[str]

The name of the model (used for naming the weights file). If None, model.model_name will be used.

None
save_optimizer bool

Whether to save optimizer. If True, optimizer will be saved in a separate file at same folder.

False

Returns:

Type Description

The saved model path.

Raises:

Type Description
ValueError

If model is an unacceptable data type.

Source code in fastestimator\fastestimator\backend\save_model.py
def save_model(model: Union[tf.keras.Model, torch.nn.Module],
               save_dir: str,
               model_name: Optional[str] = None,
               save_optimizer: bool = False):
    """Save `model` weights to a specific directory.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
    fe.backend.save_model(m, save_dir="/tmp", model_name="test")  # Generates 'test.h5' file inside /tmp directory
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
    fe.backend.save_model(m, save_dir="/tmp", model_name="test")  # Generates 'test.pt' file inside /tmp directory
    ```

    Args:
        model: A neural network instance to save.
        save_dir: Directory into which to write the `model` weights.
        model_name: The name of the model (used for naming the weights file). If None, model.model_name will be used.
        save_optimizer: Whether to save optimizer. If True, optimizer will be saved in a separate file at same folder.

    Returns:
        The saved model path.

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    assert hasattr(model, "fe_compiled") and model.fe_compiled, "model must be built by fe.build"
    if model_name is None:
        model_name = model.model_name
    save_dir = os.path.normpath(save_dir)
    os.makedirs(save_dir, exist_ok=True)
    if isinstance(model, tf.keras.Model):
        model_path = os.path.join(save_dir, "{}.h5".format(model_name))
        model.save_weights(model_path)
        if save_optimizer:
            assert model.current_optimizer, "optimizer does not exist"
            optimizer_path = os.path.join(save_dir, "{}_opt.pkl".format(model_name))
            with open(optimizer_path, 'wb') as f:
                pickle.dump(model.current_optimizer.get_weights(), f)
        return model_path
    elif isinstance(model, torch.nn.Module):
        model_path = os.path.join(save_dir, "{}.pt".format(model_name))
        torch.save(model.state_dict(), model_path)
        if save_optimizer:
            assert model.current_optimizer, "optimizer does not exist"
            optimizer_path = os.path.join(save_dir, "{}_opt.pt".format(model_name))
            torch.save(model.current_optimizer.state_dict(), optimizer_path)
        return model_path
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))