Skip to content

tensorboard

TensorBoard

Bases: Trace

Output data for use in TensorBoard.

Note that if you plan to run a tensorboard server simultaneous to training, you may want to consider using the --reload_multifile=true flag until their multi-writer use case is finished: https://github.com/tensorflow/tensorboard/issues/1063

Parameters:

Name Type Description Default
log_dir str

Path of the directory where the log files to be parsed by TensorBoard should be saved.

'logs'
update_freq Union[None, int, str]

'batch', 'epoch', integer, or strings like '10s', '15e'. When using 'batch', writes the losses and metrics to TensorBoard after each batch. The same applies for 'epoch'. If using an integer, let's say 1000, the callback will write the metrics and losses to TensorBoard every 1000 samples. You can also use strings like '8s' to indicate every 8 steps or '5e' to indicate every 5 epochs. Note that writing too frequently to TensorBoard can slow down your training. You can use None to disable updating, but this will make the trace mostly useless.

100
write_graph bool

Whether to visualize the graph in TensorBoard. The log file can become quite large when write_graph is set to True.

True
write_images Union[None, str, List[str]]

If a string or list of strings is provided, the corresponding keys will be written to TensorBoard images.

None
weight_histogram_freq Union[None, int, str]

Frequency (in epochs) at which to compute activation and weight histograms for the layers of the model. Same argument format as update_freq.

None
paint_weights bool

If True the system will attempt to visualize model weights as an image.

False
write_embeddings Union[None, str, List[str]]

If a string or list of strings is provided, the corresponding keys will be written to TensorBoard embeddings.

None
embedding_labels Union[None, str, List[str]]

Keys corresponding to label information for the write_embeddings.

None
embedding_images Union[None, str, List[str]]

Keys corresponding to raw images to be associated with the write_embeddings.

None
Source code in fastestimator\fastestimator\trace\io\tensorboard.py
class TensorBoard(Trace):
    """Output data for use in TensorBoard.

    Note that if you plan to run a tensorboard server simultaneous to training, you may want to consider using the
    --reload_multifile=true flag until their multi-writer use case is finished:
    https://github.com/tensorflow/tensorboard/issues/1063

    Args:
        log_dir: Path of the directory where the log files to be parsed by TensorBoard should be saved.
        update_freq: 'batch', 'epoch', integer, or strings like '10s', '15e'. When using 'batch', writes the losses and
            metrics to TensorBoard after each batch. The same applies for 'epoch'. If using an integer, let's say 1000,
            the callback will write the metrics and losses to TensorBoard every 1000 samples. You can also use strings
            like '8s' to indicate every 8 steps or '5e' to indicate every 5 epochs. Note that writing too frequently to
            TensorBoard can slow down your training. You can use None to disable updating, but this will make the trace
            mostly useless.
        write_graph: Whether to visualize the graph in TensorBoard. The log file can become quite large when write_graph
            is set to True.
        write_images: If a string or list of strings is provided, the corresponding keys will be written to TensorBoard
            images.
        weight_histogram_freq: Frequency (in epochs) at which to compute activation and weight histograms for the layers
            of the model. Same argument format as `update_freq`.
        paint_weights: If True the system will attempt to visualize model weights as an image.
        write_embeddings: If a string or list of strings is provided, the corresponding keys will be written to
            TensorBoard embeddings.
        embedding_labels: Keys corresponding to label information for the `write_embeddings`.
        embedding_images: Keys corresponding to raw images to be associated with the `write_embeddings`.
    """
    Freq = namedtuple('Freq', ['is_step', 'freq'])
    writer: _BaseWriter

    def __init__(self,
                 log_dir: str = 'logs',
                 update_freq: Union[None, int, str] = 100,
                 write_graph: bool = True,
                 write_images: Union[None, str, List[str]] = None,
                 weight_histogram_freq: Union[None, int, str] = None,
                 paint_weights: bool = False,
                 write_embeddings: Union[None, str, List[str]] = None,
                 embedding_labels: Union[None, str, List[str]] = None,
                 embedding_images: Union[None, str, List[str]] = None) -> None:
        super().__init__(inputs="*")
        self.root_log_dir = log_dir
        self.update_freq = self._parse_freq(update_freq)
        self.write_graph = write_graph
        self.painted_graphs = set()
        self.write_images = to_set(write_images)
        self.histogram_freq = self._parse_freq(weight_histogram_freq)
        if paint_weights and self.histogram_freq.freq == 0:
            self.histogram_freq.is_step = False
            self.histogram_freq.freq = 1
        self.paint_weights = paint_weights
        write_embeddings = to_list(write_embeddings)
        embedding_labels = to_list(embedding_labels)
        if embedding_labels:
            assert len(embedding_labels) == len(write_embeddings), \
                f"Expected {len(write_embeddings)} embedding_labels keys, but recieved {len(embedding_labels)}. Use \
                None to pad out the list if you have labels for only a subset of all embeddings."

        else:
            embedding_labels = [None for _ in range(len(write_embeddings))]
        embedding_images = to_list(embedding_images)
        if embedding_images:
            assert len(embedding_images) == len(write_embeddings), \
                f"Expected {len(write_embeddings)} embedding_images keys, but recieved {len(embedding_images)}. Use \
                None to pad out the list if you have labels for only a subset of all embeddings."

        else:
            embedding_images = [None for _ in range(len(write_embeddings))]
        self.write_embeddings = [(feature, label, img_label) for feature,
                                 label,
                                 img_label in zip(write_embeddings, embedding_labels, embedding_images)]

    def _parse_freq(self, freq: Union[None, str, int]) -> Freq:
        """A helper function to convert string based frequency inputs into epochs or steps

        Args:
            freq: One of either None, "step", "epoch", "#s", "#e", or #, where # is an integer.

        Returns:
            A `Freq` object recording whether the trace should run on an epoch basis or a step basis, as well as the
            frequency with which it should run.
        """
        if freq is None:
            return self.Freq(False, 0)
        if isinstance(freq, int):
            if freq < 1:
                raise ValueError(f"Tensorboard frequency argument must be a positive integer but got {freq}")
            return self.Freq(True, freq)
        if isinstance(freq, str):
            if freq in {'step', 's'}:
                return self.Freq(True, 1)
            if freq in {'epoch', 'e'}:
                return self.Freq(False, 1)
            parts = re.match(r"^([0-9]+)([se])$", freq)
            if parts is None:
                raise ValueError(f"Tensorboard frequency argument must be formatted like <int><s|e> but got {freq}")
            freq = int(parts[1])
            if freq < 1:
                raise ValueError(f"Tensorboard frequency argument must be a positive integer but got {freq}")
            return self.Freq(parts[2] == 's', freq)
        else:
            raise ValueError(f"Unrecognized type passed as Tensorboard frequency: {type(freq)}")

    def on_begin(self, data: Data) -> None:
        print("FastEstimator-Tensorboard: writing logs to {}".format(
            os.path.abspath(os.path.join(self.root_log_dir, self.system.experiment_time))))
        self.writer = _TfWriter(self.root_log_dir, self.system.experiment_time, self.system.network) if isinstance(
            self.system.network, TFNetwork) else _TorchWriter(
                self.root_log_dir, self.system.experiment_time, self.system.network)
        if self.write_graph and self.system.global_step == 1:
            self.painted_graphs = set()

    def on_batch_end(self, data: Data) -> None:
        if self.write_graph and self.system.network.epoch_models.symmetric_difference(self.painted_graphs):
            self.writer.write_epoch_models(mode=self.system.mode, data=data)
            self.painted_graphs = self.system.network.epoch_models
        if self.system.mode != 'train':
            return
        if self.histogram_freq.freq and self.histogram_freq.is_step and \
                self.system.global_step % self.histogram_freq.freq == 0:
            self.writer.write_weights(mode=self.system.mode,
                                      models=self.system.network.models,
                                      step=self.system.global_step,
                                      visualize=self.paint_weights)
        if self.update_freq.freq and self.update_freq.is_step and self.system.global_step % self.update_freq.freq == 0:
            self.writer.write_scalars(mode=self.system.mode,
                                      step=self.system.global_step,
                                      scalars=filter(lambda x: is_number(x[1]), data.items()))
            self.writer.write_images(
                mode=self.system.mode,
                step=self.system.global_step,
                images=filter(lambda x: x[1] is not None, map(lambda y: (y, data.get(y)), self.write_images)))
            self.writer.write_embeddings(
                mode=self.system.mode,
                step=self.system.global_step,
                embeddings=filter(
                    lambda x: x[1] is not None,
                    map(lambda t: (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])), self.write_embeddings)))

    def on_epoch_end(self, data: Data) -> None:
        if self.system.mode == 'train' and self.histogram_freq.freq and not self.histogram_freq.is_step and \
                self.system.epoch_idx % self.histogram_freq.freq == 0:
            self.writer.write_weights(mode=self.system.mode,
                                      models=self.system.network.models,
                                      step=self.system.global_step,
                                      visualize=self.paint_weights)
        if self.update_freq.freq and (self.update_freq.is_step or self.system.epoch_idx % self.update_freq.freq == 0):
            self.writer.write_scalars(mode=self.system.mode,
                                      step=self.system.global_step,
                                      scalars=filter(lambda x: is_number(x[1]), data.items()))
            self.writer.write_images(
                mode=self.system.mode,
                step=self.system.global_step,
                images=filter(lambda x: x[1] is not None, map(lambda y: (y, data.get(y)), self.write_images)))
            self.writer.write_embeddings(
                mode=self.system.mode,
                step=self.system.global_step,
                embeddings=filter(
                    lambda x: x[1] is not None,
                    map(lambda t: (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])), self.write_embeddings)))

    def on_end(self, data: Data) -> None:
        self.writer.close()