_argmax
argmax
¶
Compute the index of the maximum value along a given axis of a tensor.
This method can be used with Numpy data:
n = np.array([[2,7,5],[9,1,3],[4,8,2]])
b = fe.backend.argmax(n, axis=0) # [1, 2, 0]
b = fe.backend.argmax(n, axis=1) # [1, 0, 1]
This method can be used with TensorFlow tensors:
t = tf.constant([[2,7,5],[9,1,3],[4,8,2]])
b = fe.backend.argmax(t, axis=0) # [1, 2, 0]
b = fe.backend.argmax(t, axis=1) # [1, 0, 1]
This method can be used with PyTorch tensors:
p = torch.tensor([[2,7,5],[9,1,3],[4,8,2]])
b = fe.backend.argmax(p, axis=0) # [1, 2, 0]
b = fe.backend.argmax(p, axis=1) # [1, 0, 1]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Tensor
|
The input value. |
required |
axis |
int
|
Which axis to compute the index along. |
0
|
Returns:
Type | Description |
---|---|
Tensor
|
The indices corresponding to the maximum values within |
Raises:
Type | Description |
---|---|
ValueError
|
If |