Skip to content

model_saver

ModelSaver

Bases: Trace

Save model weights based on epoch frequency during training.

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A model instance compiled with fe.build.

required
save_dir str

Folder path into which to save the model.

required
frequency int

Model saving frequency in epoch(s).

1
Source code in fastestimator\fastestimator\trace\io\model_saver.py
class ModelSaver(Trace):
    """Save model weights based on epoch frequency during training.

    Args:
        model: A model instance compiled with fe.build.
        save_dir: Folder path into which to save the `model`.
        frequency: Model saving frequency in epoch(s).
    """
    def __init__(self, model: Union[tf.keras.Model, torch.nn.Module], save_dir: str, frequency: int = 1) -> None:
        super().__init__(mode="train")
        self.model = model
        self.save_dir = save_dir
        self.frequency = frequency

    def on_epoch_end(self, data: Data) -> None:
        # No model will be saved when save_dir is None, which makes smoke test easier.
        if self.save_dir and self.system.epoch_idx % self.frequency == 0:
            model_name = "{}_epoch_{}".format(self.model.model_name, self.system.epoch_idx)
            model_path = save_model(self.model, self.save_dir, model_name)
            print("FastEstimator-ModelSaver: Saved model to {}".format(model_path))