update_model
update_model
¶
Update model
weights based on a given loss
.
This method can be used with TensorFlow models:
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
x = tf.ones((3,28,28,1)) # (batch, height, width, channels)
y = tf.constant((1, 0, 1))
with tf.GradientTape(persistent=True) as tape:
pred = fe.backend.feed_forward(m, x) # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y) # ~2.3
fe.backend.update_model(m, loss=loss, tape=tape)
This method can be used with PyTorch models:
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
x = torch.ones((3,1,28,28)) # (batch, channels, height, width)
y = torch.tensor((1, 0, 1))
pred = fe.backend.feed_forward(m, x) # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y) # ~2.3
fe.backend.update_model(m, loss=loss)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[tf.keras.Model, torch.nn.Module]
|
A neural network instance to update. |
required |
loss |
Union[tf.Tensor, torch.Tensor]
|
A loss value to compute gradients from. |
required |
tape |
Optional[tf.GradientTape]
|
A TensorFlow GradientTape which was recording when the |
None
|
retain_graph |
bool
|
Whether to keep the model graph in memory (applicable only for PyTorch). |
True
|
Raises:
Type | Description |
---|---|
ValueError
|
If |