Get the tensor
channels, height, and width.
This method can be used with Numpy data:
n = np.random.random((2, 12, 12, 3))
b = fe.backend.get_image_dims(n) # (3, 12, 12)
This method can be used with TensorFlow tensors:
t = tf.random.uniform((2, 12, 12, 3))
b = fe.backend.get_image_dims(t) # (3, 12, 12)
This method can be used with PyTorch tensors:
p = torch.rand((2, 3, 12, 12))
b = fe.backend.get_image_dims(p) # (3, 12, 12)
Parameters:
Name |
Type |
Description |
Default |
tensor |
Tensor
|
|
required
|
Returns:
Type |
Description |
Tuple[int, int, int]
|
Channels, height and width of the tensor .
|
Raises:
Type |
Description |
ValueError
|
If tensor is an unacceptable data type.
|
Source code in fastestimator/fastestimator/backend/_get_image_dims.py
| def get_image_dims(tensor: Tensor) -> Tuple[int, int, int]:
"""Get the `tensor` channels, height, and width.
This method can be used with Numpy data:
```python
n = np.random.random((2, 12, 12, 3))
b = fe.backend.get_image_dims(n) # (3, 12, 12)
```
This method can be used with TensorFlow tensors:
```python
t = tf.random.uniform((2, 12, 12, 3))
b = fe.backend.get_image_dims(t) # (3, 12, 12)
```
This method can be used with PyTorch tensors:
```python
p = torch.rand((2, 3, 12, 12))
b = fe.backend.get_image_dims(p) # (3, 12, 12)
```
Args:
tensor: The input tensor.
Returns:
Channels, height and width of the `tensor`.
Raises:
ValueError: If `tensor` is an unacceptable data type.
"""
assert len(tensor.shape) == 3 or len(tensor.shape) == 4, \
f"Number of dimensions of input must be either 3 or 4, but found {len(tensor.shape)} (shape: {tensor.shape})"
if tf.is_tensor(tensor):
shape = tf.shape(tensor)
channels, height, width = shape[-1], shape[-3], shape[-2]
if hasattr(channels, 'numpy'):
# Running in eager mode, so can convert to integer
channels, height, width = channels.numpy().item(), height.numpy().item(), width.numpy().item()
return channels, height, width
elif isinstance(tensor, np.ndarray):
return tensor.shape[-1], tensor.shape[-3], tensor.shape[-2]
elif isinstance(tensor, torch.Tensor):
return tensor.shape[-3], tensor.shape[-2], tensor.shape[-1]
else:
raise ValueError("Unrecognized tensor type {}".format(type(tensor)))
|