Skip to content

model

ModelOp

Bases: TensorOp

This class performs forward passes of a neural network over batch data to generate predictions.

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A model compiled by fe.build.

required
inputs Union[None, str, Iterable[str]]

String key of input training data.

None
outputs Union[None, str, Iterable[str]]

String key under which to store predictions.

None
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

None
trainable bool

Indicates whether the model should have its weights tracked for update.

True
Source code in fastestimator\fastestimator\op\tensorop\model\model.py
@traceable()
class ModelOp(TensorOp):
    """This class performs forward passes of a neural network over batch data to generate predictions.

    Args:
        model: A model compiled by fe.build.
        inputs: String key of input training data.
        outputs: String key under which to store predictions.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        trainable: Indicates whether the model should have its weights tracked for update.
    """
    def __init__(self,
                 model: Union[tf.keras.Model, torch.nn.Module],
                 inputs: Union[None, str, Iterable[str]] = None,
                 outputs: Union[None, str, Iterable[str]] = None,
                 mode: Union[None, str, Iterable[str]] = None,
                 trainable: bool = True):
        super().__init__(inputs=inputs, outputs=outputs, mode=mode)
        assert hasattr(model, "fe_compiled"), "must use fe.build to compile the model before use"
        self.model = model
        self.trainable = trainable
        self.epoch_spec = None

    def get_fe_models(self) -> Set[Model]:
        return {self.model}

    def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> Union[Tensor, List[Tensor]]:
        training = state['mode'] == "train" and self.trainable
        if isinstance(self.model, torch.nn.Module) and self.epoch_spec != state['epoch']:
            # Gather model input specs for the sake of TensorBoard and Traceability
            self.model.fe_input_spec = FeInputSpec(data, self.model)
            self.epoch_spec = state['epoch']
        data = feed_forward(self.model, data, training=training)
        return data