Skip to content

load_model

load_model

Load saved weights for a given model.

This method can be used with TensorFlow models:

m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="tmp", model_name="test")
fe.backend.load_model(m, weights_path="tmp/test.h5")

This method can be used with PyTorch models:

m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="tmp", model_name="test")
fe.backend.load_model(m, weights_path="tmp/test.pt")

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A neural network instance to load.

required
weights_path str

Path to the model weights.

required
load_optimizer bool

Whether to load optimizer. If True, then it will load file in the path.

False

Raises:

Type Description
ValueError

If model is an unacceptable data type.

Source code in fastestimator\fastestimator\backend\load_model.py
def load_model(model: Union[tf.keras.Model, torch.nn.Module], weights_path: str, load_optimizer: bool = False):
    """Load saved weights for a given model.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
    fe.backend.save_model(m, save_dir="tmp", model_name="test")
    fe.backend.load_model(m, weights_path="tmp/test.h5")
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
    fe.backend.save_model(m, save_dir="tmp", model_name="test")
    fe.backend.load_model(m, weights_path="tmp/test.pt")
    ```

    Args:
        model: A neural network instance to load.
        weights_path: Path to the `model` weights.
        load_optimizer: Whether to load optimizer. If True, then it will load <weights_opt> file in the path.

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    assert hasattr(model, "fe_compiled") and model.fe_compiled, "model must be built by fe.build"
    if isinstance(model, tf.keras.Model):
        model.load_weights(weights_path)
        if load_optimizer:
            assert model.current_optimizer, "optimizer does not exist"
            optimizer_path = "{}_opt.pkl".format(os.path.splitext(weights_path)[0])
            assert os.path.exists(optimizer_path), "cannot find optimizer path: {}".format(optimizer_path)
            with open(optimizer_path, 'rb') as f:
                weight_values = pickle.load(f)
            model.current_optimizer.set_weights(weight_values)
    elif isinstance(model, torch.nn.Module):
        model.load_state_dict(torch.load(weights_path))
        if load_optimizer:
            assert model.current_optimizer, "optimizer does not exist"
            optimizer_path = "{}_opt.pt".format(os.path.splitext(weights_path)[0])
            assert os.path.exists(optimizer_path), "cannot find optimizer path: {}".format(optimizer_path)
            model.current_optimizer.load_state_dict(torch.load(optimizer_path))
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))