Bases: Trace
Save model weights based on epoch frequency during training.
Parameters:
Name |
Type |
Description |
Default |
model |
Union[tf.keras.Model, torch.nn.Module]
|
A model instance compiled with fe.build. |
required
|
save_dir |
str
|
Folder path into which to save the model . |
required
|
frequency |
int
|
Model saving frequency in epoch(s). |
1
|
Source code in fastestimator\fastestimator\trace\io\model_saver.py
| class ModelSaver(Trace):
"""Save model weights based on epoch frequency during training.
Args:
model: A model instance compiled with fe.build.
save_dir: Folder path into which to save the `model`.
frequency: Model saving frequency in epoch(s).
"""
def __init__(self, model: Union[tf.keras.Model, torch.nn.Module], save_dir: str, frequency: int = 1) -> None:
super().__init__(mode="train")
self.model = model
self.save_dir = save_dir
self.frequency = frequency
def on_epoch_end(self, data: Data) -> None:
# No model will be saved when save_dir is None, which makes smoke test easier.
if self.save_dir and self.system.epoch_idx % self.frequency == 0:
model_name = "{}_epoch_{}".format(self.model.model_name, self.system.epoch_idx)
model_path = save_model(self.model, self.save_dir, model_name)
print("FastEstimator-ModelSaver: Saved model to {}".format(model_path))
|