@traceable()
class ImageSaver(Trace):
"""A trace that saves images to the disk.
Args:
inputs: Key(s) of images to be saved.
save_dir: The directory into which to write the images.
mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
like "!infer" or "!train".
"""
def __init__(self,
inputs: Union[str, Sequence[str]],
save_dir: str = os.getcwd(),
mode: Union[None, str, Iterable[str]] = ("eval", "test")) -> None:
super().__init__(inputs=inputs, mode=mode)
self.save_dir = save_dir
def on_epoch_end(self, data: Data) -> None:
self._save_images(data)
def on_end(self, data: Data) -> None:
self._save_images(data)
def _save_images(self, data: Data):
for key in self.inputs:
if key in data:
imgs = data[key]
im_path = os.path.join(self.save_dir,
"{}_{}_epoch_{}.png".format(key, self.system.mode, self.system.epoch_idx))
if isinstance(imgs, Display):
imgs.show(save_path=im_path, verbose=False)
print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
elif isinstance(imgs, Summary):
visualize_logs([imgs], save_path=im_path, verbose=False)
print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
elif isinstance(imgs, (list, tuple)) and all([isinstance(img, Summary) for img in imgs]):
visualize_logs(imgs, save_path=im_path, verbose=False)
print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
else:
for idx, img in enumerate(imgs):
f = ImageDisplay(image=img, title=key)
im_path = os.path.join(
self.save_dir,
"{}_{}_epoch_{}_elem_{}.png".format(key, self.system.mode, self.system.epoch_idx, idx))
f.show(save_path=im_path, verbose=False)
print("FastEstimator-ImageSaver: saved image to {}".format(im_path))