@traceable()
class ModelOp(TensorOp):
"""This class performs forward passes of a neural network over batch data to generate predictions.
Args:
model: A model compiled by fe.build.
inputs: String key of input training data.
outputs: String key under which to store predictions.
mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
like "!infer" or "!train".
ds_id: What dataset id(s) to execute this Op in. To execute regardless of ds_id, pass None. To execute in all
ds_ids except for a particular one, you can pass an argument like "!ds1".
trainable: Indicates whether the model should have its weights tracked for update.
intermediate_layers: One or more layers inside of the model from which you would also like to extract output.
This can be useful, for example, for visualizing feature extractor embeddings in conjunction with the
TensorBoard trace. Layers can be selected by name (str) or index (int). If you are using pytorch, you can
look up this information for your model by calling `list(model.named_modules())`. For TensorFlow you can use
`model.layers`. Tensorflow users should note that if you do not manually assign a name to a model layer,
a name will be autogenerated for you (ex. conv2d_2). This autogenerated name will change if you build a new
model within the same python session (for example, if you re-run a Jupyter notebook cell, the name could now
be conv2d_5). Any `intermediate_layers` you request will be appended in order to the end of the Op output,
so you must provide output key names for them within the `outputs` argument. Note that layer names may be
different between single-gpu and multi-gpu environments, though we attempt to prevent this.
"""
def __init__(self,
model: Model,
inputs: Union[None, str, Iterable[str]] = None,
outputs: Union[None, str, Iterable[str]] = None,
mode: Union[None, str, Iterable[str]] = None,
ds_id: Union[None, str, Iterable[str]] = None,
trainable: bool = True,
intermediate_layers: Union[None, str, int, List[Union[str, int]]] = None):
super().__init__(inputs=inputs, outputs=outputs, mode=mode, ds_id=ds_id)
assert hasattr(model, "fe_compiled"), "must use fe.build to compile the model before use"
self.intermediate_outputs = [] # [{device: Tensor}]
intermediate_layers = to_list(intermediate_layers)
if intermediate_layers and get_num_devices() > 1:
warn("Layer names / ids may be different between single-gpu and multi-gpu environments")
for intermediate_layer in intermediate_layers:
storage = {}
if isinstance(model, tf.keras.Model):
layers = list(model._flatten_layers(include_self=False, recursive=True))
if isinstance(intermediate_layer, int):
intermediate_layer = layers[intermediate_layer]
else:
layers = {layer.name: layer for layer in layers}
intermediate_layer = layers[intermediate_layer]
if not hasattr(intermediate_layer, 'fe_original_call'):
intermediate_layer.fe_original_call = intermediate_layer.call
intermediate_layer.call = partial(_capture_call_tf, fe_storage=storage, fe_layer=intermediate_layer)
elif isinstance(model, torch.nn.Module):
layers = model.named_modules()
if get_num_devices() > 1:
# Try to automatically adjust parameters for multi-gpu so that user doesn't need to change code
layers2 = list(model.named_modules()) # It's a generator, so don't corrupt the other copy
if isinstance(layers2[0][1], torch.nn.parallel.DataParallel):
parallel_prefix = "module."
if isinstance(intermediate_layer, str) and not intermediate_layer.startswith(parallel_prefix):
intermediate_layer = parallel_prefix + intermediate_layer
elif isinstance(intermediate_layer, int):
layers = layers2[1:]
if isinstance(intermediate_layer, int):
intermediate_layer = list(layers)[intermediate_layer][1]
else:
intermediate_layer = dict(layers)[intermediate_layer]
intermediate_layer.register_forward_hook(partial(_capture_call_torch, fe_storage=storage))
self.intermediate_outputs.append(storage)
self.model = model
self.trainable = trainable
self.epoch_spec = None
self.multi_inputs = False
self.device = ''
def build(self, framework: str, device: Optional[torch.device] = None) -> None:
self.device = device or '' # TF will just use empty string for device
if framework == "torch" and len(self.inputs) > 1:
if hasattr(self.model, "module"):
# multi-gpu models have module attribute
self.multi_inputs = len(inspect.signature(self.model.module.forward).parameters.keys()) > 1
else:
self.multi_inputs = len(inspect.signature(self.model.forward).parameters.keys()) > 1
elif framework == "tf" and "keras.engine" not in str(type(self.model)):
model_call_args = {x for x in inspect.signature(self.model.call).parameters.keys()}
assert len({"training", "mask"} & model_call_args) == 0, "Cannot use 'training' nor 'mask' as input args"
self.multi_inputs = len(model_call_args) > 1
def get_fe_models(self) -> Set[Model]:
return {self.model}
def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> Union[Tensor, List[Tensor]]:
training = state['mode'] == "train" and self.trainable
if isinstance(self.model, torch.nn.Module) and self.epoch_spec != state['epoch']:
# Gather model input specs for the sake of TensorBoard and Traceability
self.model.fe_input_spec = FeInputSpec(data, self.model)
self.epoch_spec = state['epoch']
if self.multi_inputs:
data = feed_forward(self.model, *data, training=training)
else:
data = feed_forward(self.model, data, training=training)
intermediate_outputs = []
for output in self.intermediate_outputs:
intermediate_outputs.append(_unpack_output(output, self.device))
output.clear() # This will only help with pytorch memory, tf tensors will remain until next forward
if intermediate_outputs:
data = to_list(data) + intermediate_outputs
return data