Load saved weights for a given model.
This method can be used with TensorFlow models:
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="tmp", model_name="test")
fe.backend.load_model(m, weights_path="tmp/test.h5")
This method can be used with PyTorch models:
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="tmp", model_name="test")
fe.backend.load_model(m, weights_path="tmp/test.pt")
Parameters:
Name |
Type |
Description |
Default |
model |
Union[Model, Module]
|
A neural network instance to load.
|
required
|
weights_path |
str
|
Path to the model weights.
|
required
|
load_optimizer |
bool
|
Whether to load optimizer. If True, then it will load file in the path.
|
False
|
Raises:
Type |
Description |
ValueError
|
If model is an unacceptable data type.
|
Source code in fastestimator/fastestimator/backend/_load_model.py
| def load_model(model: Union[tf.keras.Model, torch.nn.Module], weights_path: str, load_optimizer: bool = False):
"""Load saved weights for a given model.
This method can be used with TensorFlow models:
```python
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="tmp", model_name="test")
fe.backend.load_model(m, weights_path="tmp/test.h5")
```
This method can be used with PyTorch models:
```python
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="tmp", model_name="test")
fe.backend.load_model(m, weights_path="tmp/test.pt")
```
Args:
model: A neural network instance to load.
weights_path: Path to the `model` weights.
load_optimizer: Whether to load optimizer. If True, then it will load <weights_opt> file in the path.
Raises:
ValueError: If `model` is an unacceptable data type.
"""
assert hasattr(model, "fe_compiled") and model.fe_compiled, "model must be built by fe.build"
if os.path.exists(weights_path):
ValueError("Weights path doesn't exist: ", weights_path)
if isinstance(model, tf.keras.Model):
model.load_weights(weights_path)
if load_optimizer:
assert model.current_optimizer, "optimizer does not exist"
optimizer_path = "{}_opt.pkl".format(os.path.splitext(weights_path)[0])
assert os.path.exists(optimizer_path), "cannot find optimizer path: {}".format(optimizer_path)
with open(optimizer_path, 'rb') as f:
state_dict = pickle.load(f)
model.current_optimizer.set_weights(state_dict['weights'])
weight_decay = None
if isinstance(model.current_optimizer, tfa.optimizers.DecoupledWeightDecayExtension) or hasattr(
model.current_optimizer, "inner_optimizer") and isinstance(
model.current_optimizer.inner_optimizer, tfa.optimizers.DecoupledWeightDecayExtension):
weight_decay = state_dict['weight_decay']
set_lr(model, state_dict['lr'], weight_decay=weight_decay)
elif isinstance(model, torch.nn.Module):
if isinstance(model, torch.nn.DataParallel):
model.module.load_state_dict(preprocess_torch_weights(weights_path))
else:
model.load_state_dict(preprocess_torch_weights(weights_path))
if load_optimizer:
assert model.current_optimizer, "optimizer does not exist"
optimizer_path = "{}_opt.pt".format(os.path.splitext(weights_path)[0])
assert os.path.exists(optimizer_path), "cannot find optimizer path: {}".format(optimizer_path)
model.current_optimizer.load_state_dict(torch.load(optimizer_path))
else:
raise ValueError("Unrecognized model instance {}".format(type(model)))
|