Skip to content

restore_wizard

RestoreWizard

Bases: Trace

A trace that can backup and load your entire training status.

System includes model weights, optimizer state, global step and epoch index.

Parameters:

Name Type Description Default
directory str

Directory to save and load system.

required
frequency int

Saving frequency in epoch(s).

1
Source code in fastestimator\fastestimator\trace\io\restore_wizard.py
class RestoreWizard(Trace):
    """A trace that can backup and load your entire training status.

    System includes model weights, optimizer state, global step and epoch index.

    Args:
        directory: Directory to save and load system.
        frequency: Saving frequency in epoch(s).
    """
    def __init__(self, directory: str, frequency: int = 1) -> None:
        super().__init__(mode="train")
        self.directory = directory
        self.frequency = frequency
        self.model_extension = {"tf": "h5", "torch": "pt"}
        self.optimizer_extension = {"tf": "pkl", "torch": "pt"}
        self.system_file = "system.json"

    def on_begin(self, data: Data) -> None:
        if not os.path.exists(self.directory) or not os.listdir(self.directory):
            print("FastEstimator-RestoreWizard: Backing up in {}".format(self.directory))
        else:
            self._scan_files()
            self._load_files()
            data.write_with_log("epoch", self.system.epoch_idx)
            print("FastEstimator-RestoreWizard: Restoring from {}, resume training".format(self.directory))

    def _load_files(self) -> None:
        """Restore from files.
        """
        system_path = os.path.join(self.directory, self.system_file)
        self.system.load_state(json_path=system_path)
        for model in self.system.network.models:
            if isinstance(model, tf.keras.Model):
                framework = "tf"
            elif isinstance(model, torch.nn.Module):
                framework = "torch"
            else:
                raise ValueError("Unknown model type {}".format(type(model)))
            weights_path = os.path.join(self.directory,
                                        "{}.{}".format(model.model_name, self.model_extension[framework]))
            load_model(model, weights_path=weights_path, load_optimizer=True)

    def _scan_files(self) -> None:
        """Scan necessary files to load.
        """
        system_path = os.path.join(self.directory, self.system_file)
        assert os.path.exists(system_path), "cannot find system file at {}".format(system_path)
        for model in self.system.network.models:
            if isinstance(model, tf.keras.Model):
                framework = "tf"
            elif isinstance(model, torch.nn.Module):
                framework = "torch"
            else:
                raise ValueError("Unknown model type {}".format(type(model)))
            weights_path = os.path.join(self.directory,
                                        "{}.{}".format(model.model_name, self.model_extension[framework]))
            assert os.path.exists(weights_path), "cannot find model weights file at {}".format(weights_path)
            optimizer_path = os.path.join(self.directory,
                                          "{}_opt.{}".format(model.model_name, self.optimizer_extension[framework]))
            assert os.path.exists(optimizer_path), "cannot find model optimizer file at {}".format(optimizer_path)

    def on_epoch_end(self, data: Data) -> None:
        if self.system.epoch_idx % self.frequency == 0:
            # Save all models and optimizer state
            for model in self.system.network.models:
                save_model(model, save_dir=self.directory, save_optimizer=True)
            # Save system state
            self.system.save_state(json_path=os.path.join(self.directory, self.system_file))
            print("FastEstimator-RestoreWizard: Saved milestones to {}".format(self.directory))