Skip to content

image_viewer

ImageViewer

Bases: Trace

A trace that interrupts your training in order to display images on the screen.

This class is useful primarily for Jupyter Notebook, or for debugging purposes.

Parameters:

Name Type Description Default
inputs Union[str, Sequence[str]]

Key(s) of images to be displayed.

required
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

('eval', 'test')
width int

The width in inches of the figure.

12
height int

The height in inches of the figure.

6
Source code in fastestimator\fastestimator\trace\io\image_viewer.py
@traceable()
class ImageViewer(Trace):
    """A trace that interrupts your training in order to display images on the screen.

    This class is useful primarily for Jupyter Notebook, or for debugging purposes.

    Args:
        inputs: Key(s) of images to be displayed.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        width: The width in inches of the figure.
        height: The height in inches of the figure.
    """
    def __init__(
        self,
        inputs: Union[str, Sequence[str]],
        mode: Union[None, str, Iterable[str]] = ("eval", "test"),
        width: int = 12,
        height: int = 6) -> None:
        super().__init__(inputs=inputs, mode=mode)
        plt.rcParams['figure.figsize'] = [width, height]

    def on_epoch_end(self, data: Data) -> None:
        self._display_images(data)

    def on_end(self, data: Data) -> None:
        self._display_images(data)

    def _display_images(self, data: Data) -> None:
        """A method to render images to the screen.

        Args:
            data: Data possibly containing images to render.
        """
        for key in self.inputs:
            if key in data:
                imgs = data[key]
                if isinstance(imgs, ImgData):
                    fig = imgs.paint_numpy(dpi=96)
                    plt.imshow(fig[0])
                    plt.axis('off')
                    plt.tight_layout()
                    plt.show()
                elif isinstance(imgs, Summary):
                    visualize_logs([imgs])
                elif isinstance(imgs, (list, tuple)) and all([isinstance(img, Summary) for img in imgs]):
                    visualize_logs(imgs)
                else:
                    for idx, img in enumerate(imgs):
                        show_image(img, title="{}_{}".format(key, idx))
                        plt.show()