Skip to content

get_image_dims

get_image_dims

Get the tensor height, width and channels.

This method can be used with Numpy data:

n = np.random.random((2, 12, 12, 3))
b = fe.backend.get_image_dims(n)  # (3, 12, 12)

This method can be used with TensorFlow tensors:

t = tf.random.uniform((2, 12, 12, 3))
b = fe.backend.get_image_dims(t)  # (3, 12, 12)

This method can be used with PyTorch tensors:

p = torch.rand((2, 3, 12, 12))
b = fe.backend.get_image_dims(p)  # (3, 12, 12)

Parameters:

Name Type Description Default
tensor Tensor

The input tensor.

required

Returns:

Type Description
Tensor

Channels, height and width of the tensor.

Raises:

Type Description
ValueError

If tensor is an unacceptable data type.

Source code in fastestimator\fastestimator\backend\get_image_dims.py
def get_image_dims(tensor: Tensor) -> Tensor:
    """Get the `tensor` height, width and channels.

    This method can be used with Numpy data:
    ```python
    n = np.random.random((2, 12, 12, 3))
    b = fe.backend.get_image_dims(n)  # (3, 12, 12)
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = tf.random.uniform((2, 12, 12, 3))
    b = fe.backend.get_image_dims(t)  # (3, 12, 12)
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = torch.rand((2, 3, 12, 12))
    b = fe.backend.get_image_dims(p)  # (3, 12, 12)
    ```

    Args:
        tensor: The input tensor.

    Returns:
        Channels, height and width of the `tensor`.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    assert len(tensor.shape) == 3 or len(tensor.shape) == 4, "Number of dimensions of input must be either 3 or 4"
    shape_length = len(tensor.shape)
    if tf.is_tensor(tensor) or isinstance(tensor, np.ndarray):
        return tensor.shape[-1], tensor.shape[-3], tensor.shape[-2]
    elif isinstance(tensor, torch.Tensor):
        return tensor.shape[-3], tensor.shape[-2], tensor.shape[-1]
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))