_update_model
update_model
¶
Update model
weights based on a given gradients
.
This method can be used with TensorFlow models:
m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
x = tf.ones((3, 28, 28, 1)) # (batch, height, width, channels)
y = tf.constant((1, 0, 1))
with tf.GradientTape(persistent=True) as tape:
pred = fe.backend.feed_forward(m, x) # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y) # ~2.3
gradients = fe.backend.get_gradient(target=loss, sources=m.trainable_variables, tape=tape)
fe.backend.update_model(m, gradients=gradients)
This method can be used with PyTorch models:
m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
x = torch.ones((3, 1, 28, 28)) # (batch, channels, height, width)
y = torch.tensor((1, 0, 1))
pred = fe.backend.feed_forward(m, x) # [[~0.5, ~0.5], [~0.5, ~0.5], [~0.5, ~0.5]]
loss = fe.backend.sparse_categorical_crossentropy(y_pred=pred, y_true=y) # ~2.3
gradients = fe.backend.get_gradient(target=loss,
sources=[x for x in m.parameters() if x.requires_grad])
fe.backend.update_model(m, gradients=gradients)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[Model, Module]
|
A neural network instance to update. |
required |
gradients |
List[Union[Tensor, Tensor]]
|
A list of tensors to update the models. |
required |
defer |
bool
|
If True, then the model update function will be stored into the |
False
|
deferred |
Optional[Dict[str, List[Callable[[], None]]]]
|
A dictionary in which model update functions are stored. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
If |
AssertionError
|
If |
AssertionError
|
If Pytorch |