Save model
weights to a specific directory.
This method can be used with TensorFlow models:
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="/tmp", model_name="test") # Generates 'test.h5' file inside /tmp directory
This method can be used with PyTorch models:
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="/tmp", model_name="test") # Generates 'test.pt' file inside /tmp directory
Parameters:
Name |
Type |
Description |
Default |
model |
Union[tf.keras.Model, torch.nn.Module]
|
A neural network instance to save. |
required
|
save_dir |
str
|
Directory into which to write the model weights. |
required
|
model_name |
Optional[str]
|
The name of the model (used for naming the weights file). If None, model.model_name will be used. |
None
|
save_optimizer |
bool
|
Whether to save optimizer. If True, optimizer will be saved in a separate file at same folder. |
False
|
Returns:
Type |
Description |
|
The saved model path. |
Raises:
Type |
Description |
ValueError
|
If model is an unacceptable data type. |
Source code in fastestimator\fastestimator\backend\save_model.py
| def save_model(model: Union[tf.keras.Model, torch.nn.Module],
save_dir: str,
model_name: Optional[str] = None,
save_optimizer: bool = False):
"""Save `model` weights to a specific directory.
This method can be used with TensorFlow models:
```python
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="/tmp", model_name="test") # Generates 'test.h5' file inside /tmp directory
```
This method can be used with PyTorch models:
```python
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
fe.backend.save_model(m, save_dir="/tmp", model_name="test") # Generates 'test.pt' file inside /tmp directory
```
Args:
model: A neural network instance to save.
save_dir: Directory into which to write the `model` weights.
model_name: The name of the model (used for naming the weights file). If None, model.model_name will be used.
save_optimizer: Whether to save optimizer. If True, optimizer will be saved in a separate file at same folder.
Returns:
The saved model path.
Raises:
ValueError: If `model` is an unacceptable data type.
"""
assert hasattr(model, "fe_compiled") and model.fe_compiled, "model must be built by fe.build"
if model_name is None:
model_name = model.model_name
save_dir = os.path.normpath(save_dir)
os.makedirs(save_dir, exist_ok=True)
if isinstance(model, tf.keras.Model):
model_path = os.path.join(save_dir, "{}.h5".format(model_name))
model.save_weights(model_path)
if save_optimizer:
assert model.current_optimizer, "optimizer does not exist"
optimizer_path = os.path.join(save_dir, "{}_opt.pkl".format(model_name))
with open(optimizer_path, 'wb') as f:
pickle.dump(model.current_optimizer.get_weights(), f)
return model_path
elif isinstance(model, torch.nn.Module):
model_path = os.path.join(save_dir, "{}.pt".format(model_name))
torch.save(model.state_dict(), model_path)
if save_optimizer:
assert model.current_optimizer, "optimizer does not exist"
optimizer_path = os.path.join(save_dir, "{}_opt.pt".format(model_name))
torch.save(model.current_optimizer.state_dict(), optimizer_path)
return model_path
else:
raise ValueError("Unrecognized model instance {}".format(type(model)))
|