_watch
watch
¶
Monitor the given tensor
for later gradient computations.
This method can be used with TensorFlow tensors:
x = tf.ones((3,28,28,1))
with tf.GradientTape(persistent=True) as tape:
x = fe.backend.watch(x, tape=tape)
This method can be used with PyTorch tensors:
x = torch.ones((3,1,28,28)) # x.requires_grad == False
x = fe.backend.watch(x) # x.requires_grad == True
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Tensor
|
The tensor to be monitored. |
required |
tape |
Optional[GradientTape]
|
A TensorFlow GradientTape which will be used to record gradients (iff using TensorFlow for the backend). |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
The |
Tensor
|
needed if using PyTorch as the backend. |
Raises:
Type | Description |
---|---|
ValueError
|
If |