@traceable()
class TensorBoard(Trace):
"""Output data for use in TensorBoard.
Note that if you plan to run a tensorboard server simultaneous to training, you may want to consider using the
--reload_multifile=true flag until their multi-writer use case is finished:
https://github.com/tensorflow/tensorboard/issues/1063
Args:
log_dir: Path of the directory where the log files to be parsed by TensorBoard should be saved.
update_freq: 'batch', 'epoch', integer, or strings like '10s', '15e'. When using 'batch', writes the losses and
metrics to TensorBoard after each batch. The same applies for 'epoch'. If using an integer, let's say 1000,
the callback will write the metrics and losses to TensorBoard every 1000 samples. You can also use strings
like '8s' to indicate every 8 steps or '5e' to indicate every 5 epochs. Note that writing too frequently to
TensorBoard can slow down your training. You can use None to disable updating, but this will make the trace
mostly useless.
write_graph: Whether to visualize the graph in TensorBoard. The log file can become quite large when write_graph
is set to True.
write_images: If a string or list of strings is provided, the corresponding keys will be written to TensorBoard
images.
weight_histogram_freq: Frequency (in epochs) at which to compute activation and weight histograms for the layers
of the model. Same argument format as `update_freq`.
paint_weights: If True the system will attempt to visualize model weights as an image.
write_embeddings: If a string or list of strings is provided, the corresponding keys will be written to
TensorBoard embeddings.
embedding_labels: Keys corresponding to label information for the `write_embeddings`.
embedding_images: Keys corresponding to raw images to be associated with the `write_embeddings`.
"""
writer: _BaseWriter
# TODO - support for per-instance tracking
def __init__(self,
log_dir: str = 'logs',
update_freq: Union[None, int, str] = 100,
write_graph: bool = True,
write_images: Union[None, str, List[str]] = None,
weight_histogram_freq: Union[None, int, str] = None,
paint_weights: bool = False,
embedding_freq: Union[None, int, str] = 'epoch',
write_embeddings: Union[None, str, List[str]] = None,
embedding_labels: Union[None, str, List[str]] = None,
embedding_images: Union[None, str, List[str]] = None) -> None:
super().__init__(inputs=["*"] + to_list(write_images) + to_list(write_embeddings) + to_list(embedding_labels) +
to_list(embedding_images))
self.root_log_dir = log_dir
self.update_freq = parse_freq(update_freq)
self.write_graph = write_graph
self.painted_graphs = set()
self.write_images = to_set(write_images)
self.histogram_freq = parse_freq(weight_histogram_freq)
if paint_weights and self.histogram_freq.freq == 0:
self.histogram_freq.is_step = False
self.histogram_freq.freq = 1
self.paint_weights = paint_weights
if write_embeddings is None and embedding_labels is None and embedding_images is None:
# Speed up if-check short-circuiting later
embedding_freq = None
self.embedding_freq = parse_freq(embedding_freq)
write_embeddings = to_list(write_embeddings)
embedding_labels = to_list(embedding_labels)
if embedding_labels:
assert len(embedding_labels) == len(write_embeddings), \
f"Expected {len(write_embeddings)} embedding_labels keys, but recieved {len(embedding_labels)}. Use \
None to pad out the list if you have labels for only a subset of all embeddings."
else:
embedding_labels = [None for _ in range(len(write_embeddings))]
embedding_images = to_list(embedding_images)
if embedding_images:
assert len(embedding_images) == len(write_embeddings), \
f"Expected {len(write_embeddings)} embedding_images keys, but recieved {len(embedding_images)}. Use \
None to pad out the list if you have labels for only a subset of all embeddings."
else:
embedding_images = [None for _ in range(len(write_embeddings))]
self.write_embeddings = [(feature, label, img_label) for feature,
label,
img_label in zip(write_embeddings, embedding_labels, embedding_images)]
self.collected_embeddings = defaultdict(list)
def on_begin(self, data: Data) -> None:
print("FastEstimator-Tensorboard: writing logs to {}".format(
os.path.abspath(os.path.join(self.root_log_dir, self.system.experiment_time))))
self.writer = _TfWriter(self.root_log_dir, self.system.experiment_time, self.system.network) if isinstance(
self.system.network, TFNetwork) else _TorchWriter(
self.root_log_dir, self.system.experiment_time, self.system.network)
if self.write_graph and self.system.global_step == 1:
self.painted_graphs = set()
def on_batch_end(self, data: Data) -> None:
if self.write_graph and self.system.network.ctx_models.symmetric_difference(self.painted_graphs):
self.writer.write_epoch_models(mode=self.system.mode, epoch=self.system.epoch_idx)
self.painted_graphs = self.system.network.ctx_models
# Collect embeddings if present in batch but viewing per epoch. Don't aggregate during training though
if self.system.mode != 'train' and self.embedding_freq.freq and not self.embedding_freq.is_step and \
self.system.epoch_idx % self.embedding_freq.freq == 0:
for elem in self.write_embeddings:
name, lbl, img = elem
if name in data:
self.collected_embeddings[name].append((data.get(name), data.get(lbl), data.get(img)))
# Handle embeddings if viewing per step
if self.embedding_freq.freq and self.embedding_freq.is_step and \
self.system.global_step % self.embedding_freq.freq == 0:
self.writer.write_embeddings(
mode=self.system.mode,
step=self.system.global_step,
embeddings=filter(
lambda x: x[1] is not None,
map(lambda t: (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])), self.write_embeddings)))
if self.system.mode != 'train':
return
if self.histogram_freq.freq and self.histogram_freq.is_step and \
self.system.global_step % self.histogram_freq.freq == 0:
self.writer.write_weights(mode=self.system.mode,
models=self.system.network.models,
step=self.system.global_step,
visualize=self.paint_weights)
if self.update_freq.freq and self.update_freq.is_step and self.system.global_step % self.update_freq.freq == 0:
self.writer.write_scalars(mode=self.system.mode,
step=self.system.global_step,
scalars=filter(lambda x: is_number(x[1]), data.items()))
self.writer.write_images(
mode=self.system.mode,
step=self.system.global_step,
images=filter(lambda x: x[1] is not None, map(lambda y: (y, data.get(y)), self.write_images)))
def on_epoch_end(self, data: Data) -> None:
if self.system.mode == 'train' and self.histogram_freq.freq and not self.histogram_freq.is_step and \
self.system.epoch_idx % self.histogram_freq.freq == 0:
self.writer.write_weights(mode=self.system.mode,
models=self.system.network.models,
step=self.system.global_step,
visualize=self.paint_weights)
# Write out any embeddings which were aggregated over batches
for name, val_list in self.collected_embeddings.items():
embeddings = None if any(x[0] is None for x in val_list) else concat([x[0] for x in val_list])
labels = None if any(x[1] is None for x in val_list) else concat([x[1] for x in val_list])
imgs = None if any(x[2] is None for x in val_list) else concat([x[2] for x in val_list])
self.writer.write_embeddings(mode=self.system.mode,
step=self.system.global_step,
embeddings=[(name, embeddings, labels, imgs)])
self.collected_embeddings.clear()
# Get any embeddings which were generated externally on epoch end
if self.embedding_freq.freq and (self.embedding_freq.is_step
or self.system.epoch_idx % self.embedding_freq.freq == 0):
self.writer.write_embeddings(
mode=self.system.mode,
step=self.system.global_step,
embeddings=filter(
lambda x: x[1] is not None,
map(lambda t: (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])), self.write_embeddings)))
if self.update_freq.freq and (self.update_freq.is_step or self.system.epoch_idx % self.update_freq.freq == 0):
self.writer.write_scalars(mode=self.system.mode,
step=self.system.global_step,
scalars=filter(lambda x: is_number(x[1]), data.items()))
self.writer.write_images(
mode=self.system.mode,
step=self.system.global_step,
images=filter(lambda x: x[1] is not None, map(lambda y: (y, data.get(y)), self.write_images)))
def on_end(self, data: Data) -> None:
self.writer.close()