Skip to content

update_model

update_model

Update model weights based on a given loss.

This method can be used with TensorFlow models:

m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
x = tf.ones((3,28,28,1))  # (batch, height, width, channels)
y = tf.constant((1, 0, 1))
with tf.GradientTape(persistent=True) as tape:
    pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
    loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
    fe.backend.update_model(m, loss=loss, tape=tape)

This method can be used with PyTorch models:

m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
x = torch.ones((3,1,28,28))  # (batch, channels, height, width)
y = torch.tensor((1, 0, 1))
pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
fe.backend.update_model(m, loss=loss)

Parameters:

Name Type Description Default
model Union[tf.keras.Model, torch.nn.Module]

A neural network instance to update.

required
loss Union[tf.Tensor, torch.Tensor]

A loss value to compute gradients from.

required
tape Optional[tf.GradientTape]

A TensorFlow GradientTape which was recording when the loss was computed (iff using TensorFlow).

None
retain_graph bool

Whether to keep the model graph in memory (applicable only for PyTorch).

True

Raises:

Type Description
ValueError

If model is an unacceptable data type.

Source code in fastestimator\fastestimator\backend\update_model.py
def update_model(model: Union[tf.keras.Model, torch.nn.Module],
                 loss: Union[tf.Tensor, torch.Tensor],
                 tape: Optional[tf.GradientTape] = None,
                 retain_graph: bool = True):
    """Update `model` weights based on a given `loss`.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
    x = tf.ones((3,28,28,1))  # (batch, height, width, channels)
    y = tf.constant((1, 0, 1))
    with tf.GradientTape(persistent=True) as tape:
        pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
        loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
        fe.backend.update_model(m, loss=loss, tape=tape)
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
    x = torch.ones((3,1,28,28))  # (batch, channels, height, width)
    y = torch.tensor((1, 0, 1))
    pred = fe.backend.feed_forward(m, x)  # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
    loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y)  # ~2.3
    fe.backend.update_model(m, loss=loss)
    ```

    Args:
        model: A neural network instance to update.
        loss: A loss value to compute gradients from.
        tape: A TensorFlow GradientTape which was recording when the `loss` was computed (iff using TensorFlow).
        retain_graph: Whether to keep the model graph in memory (applicable only for PyTorch).

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    loss = reduce_mean(loss)
    if isinstance(model, tf.keras.Model):
        strategy = tf.distribute.get_strategy()
        if isinstance(strategy, tf.distribute.MirroredStrategy):
            loss = loss / strategy.num_replicas_in_sync
        gradients = get_gradient(loss, model.trainable_variables, tape=tape)
        with tape.stop_recording():
            model.current_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    elif isinstance(model, torch.nn.Module):
        gradients = get_gradient(loss, model.parameters(), retain_graph=retain_graph)
        for gradient, parameter in zip(gradients, model.parameters()):
            parameter.grad = gradient
        model.current_optimizer.step()
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))