Get the tensor
height, width and channels.
This method can be used with Numpy data:
n = np.random.random((2, 12, 12, 3))
b = fe.backend.get_image_dims(n) # (3, 12, 12)
This method can be used with TensorFlow tensors:
t = tf.random.uniform((2, 12, 12, 3))
b = fe.backend.get_image_dims(t) # (3, 12, 12)
This method can be used with PyTorch tensors:
p = torch.rand((2, 3, 12, 12))
b = fe.backend.get_image_dims(p) # (3, 12, 12)
Parameters:
Name |
Type |
Description |
Default |
tensor |
Tensor
|
The input tensor. |
required
|
Returns:
Type |
Description |
Tensor
|
Channels, height and width of the tensor . |
Raises:
Type |
Description |
ValueError
|
If tensor is an unacceptable data type. |
Source code in fastestimator\fastestimator\backend\get_image_dims.py
| def get_image_dims(tensor: Tensor) -> Tensor:
"""Get the `tensor` height, width and channels.
This method can be used with Numpy data:
```python
n = np.random.random((2, 12, 12, 3))
b = fe.backend.get_image_dims(n) # (3, 12, 12)
```
This method can be used with TensorFlow tensors:
```python
t = tf.random.uniform((2, 12, 12, 3))
b = fe.backend.get_image_dims(t) # (3, 12, 12)
```
This method can be used with PyTorch tensors:
```python
p = torch.rand((2, 3, 12, 12))
b = fe.backend.get_image_dims(p) # (3, 12, 12)
```
Args:
tensor: The input tensor.
Returns:
Channels, height and width of the `tensor`.
Raises:
ValueError: If `tensor` is an unacceptable data type.
"""
assert len(tensor.shape) == 3 or len(tensor.shape) == 4, "Number of dimensions of input must be either 3 or 4"
shape_length = len(tensor.shape)
if tf.is_tensor(tensor) or isinstance(tensor, np.ndarray):
return tensor.shape[-1], tensor.shape[-3], tensor.shape[-2]
elif isinstance(tensor, torch.Tensor):
return tensor.shape[-3], tensor.shape[-2], tensor.shape[-1]
else:
raise ValueError("Unrecognized tensor type {}".format(type(tensor)))
|