Skip to content

update_model

update_model

Update model weights based on a given loss.

This method can be used with TensorFlow models:

m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
x = tf.ones((3,28,28,1))  # (batch, height, width, channels)
y = tf.constant((1, 0, 1))
with tf.GradientTape(persistent=True) as tape:
    pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
    loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
    fe.backend.update_model(m, loss=loss, tape=tape)

This method can be used with PyTorch models:

m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
x = torch.ones((3,1,28,28))  # (batch, channels, height, width)
y = torch.tensor((1, 0, 1))
pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
fe.backend.update_model(m, loss=loss)

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A neural network instance to update.

required
loss Union[tf.Tensor, torch.Tensor]

A loss value to compute gradients from.

required
tape Optional[tf.GradientTape]

A TensorFlow GradientTape which was recording when the loss was computed (iff using TensorFlow).

None
retain_graph bool

Whether to keep the model graph in memory (applicable only for PyTorch).

True
scaler Optional[torch.cuda.amp.GradScaler]

A PyTorch loss scaler that scales loss when PyTorch mixed precision is used.

None
defer bool

If True, then the model update function will be stored into the deferred dictionary rather than applied immediately.

False
deferred Optional[Dict[str, List[Callable[[], None]]]]

A dictionary in which model update functions are stored.

None

Raises:

Type Description
ValueError

If model is an unacceptable data type.

RuntimeError

If attempting to modify a PyTorch model which relied on gradients within a different PyTorch model which has in turn already undergone a non-deferred update.

Source code in fastestimator\fastestimator\backend\update_model.py
def update_model(model: Union[tf.keras.Model, torch.nn.Module],
                 loss: Union[tf.Tensor, torch.Tensor],
                 tape: Optional[tf.GradientTape] = None,
                 retain_graph: bool = True,
                 scaler: Optional[torch.cuda.amp.GradScaler] = None,
                 defer: bool = False,
                 deferred: Optional[Dict[str, List[Callable[[], None]]]] = None) -> None:
    """Update `model` weights based on a given `loss`.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
    x = tf.ones((3,28,28,1))  # (batch, height, width, channels)
    y = tf.constant((1, 0, 1))
    with tf.GradientTape(persistent=True) as tape:
        pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
        loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
        fe.backend.update_model(m, loss=loss, tape=tape)
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
    x = torch.ones((3,1,28,28))  # (batch, channels, height, width)
    y = torch.tensor((1, 0, 1))
    pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
    loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
    fe.backend.update_model(m, loss=loss)
    ```

    Args:
        model: A neural network instance to update.
        loss: A loss value to compute gradients from.
        tape: A TensorFlow GradientTape which was recording when the `loss` was computed (iff using TensorFlow).
        retain_graph: Whether to keep the model graph in memory (applicable only for PyTorch).
        scaler: A PyTorch loss scaler that scales loss when PyTorch mixed precision is used.
        defer: If True, then the model update function will be stored into the `deferred` dictionary rather than
            applied immediately.
        deferred: A dictionary in which model update functions are stored.

    Raises:
        ValueError: If `model` is an unacceptable data type.
        RuntimeError: If attempting to modify a PyTorch model which relied on gradients within a different PyTorch model
            which has in turn already undergone a non-deferred update.
    """
    loss = reduce_mean(loss)
    if isinstance(model, tf.keras.Model):
        # scale up loss for mixed precision training to avoid underflow
        if isinstance(model.current_optimizer, mixed_precision.LossScaleOptimizer):
            loss = model.current_optimizer.get_scaled_loss(loss)
        # for multi-gpu training, the gradient will be combined by sum, normalize the loss
        strategy = tf.distribute.get_strategy()
        if isinstance(strategy, tf.distribute.MirroredStrategy):
            loss = loss / strategy.num_replicas_in_sync
        gradients = get_gradient(loss, model.trainable_variables, tape=tape)
        with tape.stop_recording():
            # scale down gradient to balance scale-up loss
            if isinstance(model.current_optimizer, mixed_precision.LossScaleOptimizer):
                gradients = model.current_optimizer.get_unscaled_gradients(gradients)
            if defer:
                deferred.setdefault(model.model_name, []).append(
                    lambda: model.current_optimizer.apply_gradients(zip(gradients, model.trainable_variables)))
            else:
                model.current_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    elif isinstance(model, torch.nn.Module):
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        # scale up loss for mixed precision training to avoid underflow
        if scaler is not None:
            loss = scaler.scale(loss)
        try:
            gradients = get_gradient(loss, trainable_params, retain_graph=retain_graph)
        except RuntimeError as err:
            if err.args and isinstance(err.args[0], str) and err.args[0].startswith(
                    'one of the variables needed for gradient computation has been modified by an inplace operation'):
                raise RuntimeError(
                    "When computing gradients for '{}', some variables it relied on during the forward pass had already"
                    " been updated. Consider setting defer=True in earlier UpdateOps related to models which interact "
                    "with this one.".format(model.model_name))
            raise err
        for gradient, parameter in zip(gradients, trainable_params):
            if parameter.grad is not None:
                parameter.grad += gradient
            else:
                parameter.grad = gradient.clone()
        if defer:
            # Only need to call once per model since gradients are getting accumulated
            deferred[model.model_name] = [lambda: _torch_step(model.current_optimizer, scaler)]
        else:
            _torch_step(model.current_optimizer, scaler)
            deferred.pop(model.model_name, None)  # Don't need those deferred steps anymore
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))