@traceable()
class Traceability(Trace):
"""Automatically generate summary reports of the training.
Args:
save_path: Where to save the output files. Note that this will generate a new folder with the given name, into
which the report and corresponding graphics assets will be written.
extra_objects: Any extra objects which are not part of the Estimator, but which you want to capture in the
summary report. One example could be an extra pipeline which performs pre-processing.
Raises:
OSError: If graphviz is not installed.
"""
def __init__(self, save_path: str, extra_objects: Any = None):
# Verify that graphviz is available on this machine
try:
pydot.Dot.create(pydot.Dot())
except OSError:
raise OSError(
"Traceability requires that graphviz be installed. See www.graphviz.org/download for more information.")
# Verify that the system locale is functioning correctly
try:
locale.getlocale()
except ValueError:
raise OSError("Your system locale is not configured correctly. On mac this can be resolved by adding \
'export LC_ALL=en_US.UTF-8' and 'export LANG=en_US.UTF-8' to your ~/.bash_profile")
super().__init__(inputs="*", mode="!infer") # Claim wildcard inputs to get this trace sorted last
# Report assets will get saved into a folder for portability
path = os.path.normpath(save_path)
path = os.path.abspath(path)
root_dir = os.path.dirname(path)
report = os.path.basename(path) or 'report'
report = report.split('.')[0]
self.save_dir = os.path.join(root_dir, report)
self.resource_dir = os.path.join(self.save_dir, 'resources')
self.report_name = None # This will be set later by the experiment name
os.makedirs(self.save_dir, exist_ok=True)
os.makedirs(self.resource_dir, exist_ok=True)
# Other member variables
self.config_tables = []
# Extra objects will automatically get included in the report since this Trace is @traceable, so we don't need
# to do anything with them. Referencing here to stop IDEs from flagging the argument as unused and removing it.
to_list(extra_objects)
self.doc = Document()
self.log_splicer = None
def on_begin(self, data: Data) -> None:
exp_name = self.system.summary.name
if not exp_name:
raise RuntimeError("Traceability reports require an experiment name to be provided in estimator.fit()")
# Convert the experiment name to a report name (useful for saving multiple experiments into same directory)
report_name = "".join('_' if c == ' ' else c for c in exp_name
if c.isalnum() or c in (' ', '_')).rstrip().lower()
report_name = re.sub('_{2,}', '_', report_name)
self.report_name = report_name or 'report'
# Send experiment logs into a file
log_path = os.path.join(self.resource_dir, f"{report_name}.txt")
if self.system.mode != 'test':
# See if there's a RestoreWizard
restore = False
for trace in self.system.traces:
if isinstance(trace, RestoreWizard):
restore = trace.should_restore()
if not restore:
# If not running in test mode, we need to remove any old log file since it would get appended to
with contextlib.suppress(FileNotFoundError):
os.remove(log_path)
self.log_splicer = LogSplicer(log_path)
self.log_splicer.__enter__()
# Get the initialization summary information for the experiment
self.config_tables = self.system.summary.system_config
models = self.system.network.models
n_floats = len(self.config_tables) + len(models)
self.doc = self._init_document_geometry()
# Keep tables/figures in their sections
self.doc.packages.append(Package(name='placeins', options=['section']))
self.doc.preamble.append(NoEscape(r'\usetikzlibrary{positioning}'))
# Fix an issue with too many tables for LaTeX to render
self.doc.preamble.append(NoEscape(r'\maxdeadcycles=' + str(2 * n_floats + 10) + ''))
self.doc.preamble.append(NoEscape(r'\extrafloats{' + str(n_floats + 10) + '}'))
# Manipulate booktab tables so that their horizontal lines don't break
self.doc.preamble.append(NoEscape(r'\aboverulesep=0ex'))
self.doc.preamble.append(NoEscape(r'\belowrulesep=0ex'))
self.doc.preamble.append(NoEscape(r'\renewcommand{\arraystretch}{1.2}'))
self._write_title()
self._write_toc()
def on_end(self, data: Data) -> None:
self._write_body_content()
# Need to move the tikz dependency after the xcolor package
self.doc.dumps_packages()
packages = self.doc.packages
tikz = Package(name='tikz')
packages.discard(tikz)
packages.add(tikz)
if shutil.which("latexmk") is None and shutil.which("pdflatex") is None:
# No LaTeX Compiler is available
self.doc.generate_tex(os.path.join(self.save_dir, self.report_name))
suffix = '.tex'
else:
# Force a double-compile since some compilers will struggle with TOC generation
self.doc.generate_pdf(os.path.join(self.save_dir, self.report_name), clean_tex=False, clean=False)
self.doc.generate_pdf(os.path.join(self.save_dir, self.report_name), clean_tex=False)
suffix = '.pdf'
print("FastEstimator-Traceability: Report written to {}{}".format(os.path.join(self.save_dir, self.report_name),
suffix))
self.log_splicer.__exit__()
def _write_title(self) -> None:
"""Write the title content of the file. Override if you want to build on top of base traceability report.
"""
self.doc.preamble.append(Command('title', self.system.summary.name))
self.doc.preamble.append(Command('author', f"FastEstimator {fe.__version__}"))
self.doc.preamble.append(Command('date', NoEscape(r'\today')))
self.doc.append(NoEscape(r'\maketitle'))
def _write_toc(self) -> None:
"""Write the table of contents. Override if you want to build on top of base traceability report.
"""
self.doc.append(NoEscape(r'\tableofcontents'))
self.doc.append(NoEscape(r'\newpage'))
def _write_body_content(self) -> None:
"""Write the main content of the file. Override if you want to build on top of base traceability report.
"""
self._document_training_graphs()
self.doc.append(NoEscape(r'\newpage'))
self._document_fe_graph()
self.doc.append(NoEscape(r'\newpage'))
self._document_init_params()
self._document_models()
self._document_sys_config()
self.doc.append(NoEscape(r'\newpage'))
def _document_training_graphs(self) -> None:
"""Add training graphs to the traceability document.
"""
with self.doc.create(Section("Training Graphs")):
log_path = os.path.join(self.resource_dir, f'{self.report_name}_logs.png')
visualize_logs(experiments=[self.system.summary],
save_path=log_path,
verbose=False,
ignore_metrics={'num_device', 'logging_interval'})
with self.doc.create(Figure(position='h!')) as plot:
plot.add_image(os.path.relpath(log_path, start=self.save_dir),
width=NoEscape(r'1.0\textwidth,height=0.95\textheight,keepaspectratio'))
for idx, graph in enumerate(self.system.custom_graphs.values()):
graph_path = os.path.join(self.resource_dir, f'{self.report_name}_custom_graph_{idx}.png')
visualize_logs(experiments=graph, save_path=graph_path, verbose=False)
with self.doc.create(Figure(position='h!')) as plot:
plot.add_image(os.path.relpath(graph_path, start=self.save_dir),
width=NoEscape(r'1.0\textwidth,height=0.95\textheight,keepaspectratio'))
def _document_fe_graph(self) -> None:
"""Add FE execution graphs into the traceability document.
"""
with self.doc.create(Section("FastEstimator Architecture")):
for mode in self.system.pipeline.data.keys():
scheduled_items = self.system.pipeline.get_scheduled_items(
mode) + self.system.network.get_scheduled_items(mode) + self.system.traces
signature_epochs = get_signature_epochs(scheduled_items, total_epochs=self.system.epoch_idx, mode=mode)
epochs_with_data = self.system.pipeline.get_epochs_with_data(total_epochs=self.system.epoch_idx,
mode=mode)
if set(signature_epochs) & epochs_with_data:
self.doc.append(NoEscape(r'\FloatBarrier'))
with self.doc.create(Subsection(mode.capitalize())):
for epoch in signature_epochs:
if epoch not in epochs_with_data:
continue
self.doc.append(NoEscape(r'\FloatBarrier'))
with self.doc.create(
Subsubsection(f"Epoch {epoch}",
label=Label(Marker(name=f"{mode}{epoch}", prefix="ssubsec")))):
ds_ids = self.system.pipeline.get_ds_ids(epoch=epoch, mode=mode)
for ds_id in ds_ids:
with NonContext() if ds_id == '' else self.doc.create(
Paragraph(f"Dataset {ds_id}",
label=Label(Marker(name=f"{mode}{epoch}{ds_id}",
prefix="para")))):
diagram = self._draw_diagram(mode, epoch, ds_id)
ltx = d2t.dot2tex(diagram.to_string(), figonly=True)
args = Arguments(**{'max width': r'\textwidth, max height=0.9\textheight'})
args.escape = False
with self.doc.create(Center()):
with self.doc.create(AdjustBox(arguments=args)) as box:
box.append(NoEscape(ltx))
def _document_init_params(self) -> None:
"""Add initialization parameters to the traceability document.
"""
from fastestimator.estimator import Estimator # Avoid circular import
with self.doc.create(Section("Parameters")):
model_ids = {
FEID(id(model))
for model in self.system.network.models if isinstance(model, (tf.keras.Model, torch.nn.Module))
}
# Locate the datasets in order to provide extra details about them later in the summary
datasets = {}
for mode in ['train', 'eval', 'test']:
objs = to_list(self.system.pipeline.data.get(mode, None))
idx = 0
while idx < len(objs):
obj = objs[idx]
if obj:
feid = FEID(id(obj))
if feid not in datasets:
datasets[feid] = ({mode}, obj)
else:
datasets[feid][0].add(mode)
if isinstance(obj, Scheduler):
objs.extend(obj.get_all_values())
idx += 1
# Parse the config tables
start = 0
start = self._loop_tables(start,
classes=(Estimator, BaseNetwork, Pipeline),
name="Base Classes",
model_ids=model_ids,
datasets=datasets)
start = self._loop_tables(start,
classes=Scheduler,
name="Schedulers",
model_ids=model_ids,
datasets=datasets)
start = self._loop_tables(start, classes=Trace, name="Traces", model_ids=model_ids, datasets=datasets)
start = self._loop_tables(start, classes=Op, name="Operators", model_ids=model_ids, datasets=datasets)
start = self._loop_tables(start, classes=Slicer, name="Slicers", model_ids=model_ids, datasets=datasets)
start = self._loop_tables(start,
classes=(Dataset, tf.data.Dataset),
name="Datasets",
model_ids=model_ids,
datasets=datasets)
start = self._loop_tables(start,
classes=(tf.keras.Model, torch.nn.Module),
name="Models",
model_ids=model_ids,
datasets=datasets)
start = self._loop_tables(start,
classes=types.FunctionType,
name="Functions",
model_ids=model_ids,
datasets=datasets)
start = self._loop_tables(start,
classes=(np.ndarray, tf.Tensor, tf.Variable, torch.Tensor),
name="Tensors",
model_ids=model_ids,
datasets=datasets)
self._loop_tables(start, classes=Any, name="Miscellaneous", model_ids=model_ids, datasets=datasets)
def _loop_tables(self,
start: int,
classes: Union[type, Tuple[type, ...]],
name: str,
model_ids: Set[FEID],
datasets: Dict[FEID, Tuple[Set[str], Any]]) -> int:
"""Iterate through tables grouping them into subsections.
Args:
start: What index to start searching from.
classes: What classes are acceptable for this subsection.
name: What to call this subsection.
model_ids: The ids of any known models.
datasets: A mapping like {ID: ({modes}, dataset)}. Useful for augmenting the displayed information.
Returns:
The new start index after traversing as many spaces as possible along the list of tables.
"""
stop = start
while stop < len(self.config_tables):
if classes == Any or issubclass(self.config_tables[stop].type, classes):
stop += 1
else:
break
if stop > start:
self.doc.append(NoEscape(r'\FloatBarrier'))
with self.doc.create(Subsection(name)):
self._write_tables(self.config_tables[start:stop], model_ids, datasets)
return stop
def _write_tables(self,
tables: List[FeSummaryTable],
model_ids: Set[FEID],
datasets: Dict[FEID, Tuple[Set[str], Any]]) -> None:
"""Insert a LaTeX representation of a list of tables into the current doc.
Args:
tables: The tables to write into the doc.
model_ids: The ids of any known models.
datasets: A mapping like {ID: ({modes}, dataset)}. Useful for augmenting the displayed information.
"""
for tbl in tables:
name_override = None
toc_ref = None
extra_rows = None
if tbl.fe_id in model_ids:
# Link to a later detailed model description
name_override = Hyperref(Marker(name=str(tbl.name), prefix="subsec"),
text=NoEscape(r'\textcolor{blue}{') + bold(tbl.name) + NoEscape('}'))
if tbl.fe_id in datasets:
modes, dataset = datasets[tbl.fe_id]
title = ", ".join([s.capitalize() for s in modes])
name_override = bold(f'{tbl.name} ({title})')
# Enhance the dataset summary
if isinstance(dataset, FEDataset):
extra_rows = list(dataset.summary().__getstate__().items())
for idx, (key, val) in enumerate(extra_rows):
key = f"{prettify_metric_name(key)}:"
if isinstance(val, dict) and val:
if isinstance(list(val.values())[0], (int, float, str, bool, type(None))):
val = jsonpickle.dumps(val, unpicklable=False)
else:
subtable = Tabularx('l|X', width_argument=NoEscape(r'\linewidth'))
for k, v in val.items():
if hasattr(v, '__getstate__'):
v = jsonpickle.dumps(v, unpicklable=False)
subtable.add_row((k, v))
# To nest TabularX, have to wrap it in brackets
subtable = ContainerList(data=[NoEscape("{"), subtable, NoEscape("}")])
val = subtable
extra_rows[idx] = (key, val)
tbl.render_table(self.doc, name_override=name_override, toc_ref=toc_ref, extra_rows=extra_rows)
def _document_models(self) -> None:
"""Add model summaries to the traceability document.
"""
with self.doc.create(Section("Models")):
for model in humansorted(self.system.network.models, key=lambda m: m.model_name):
if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
continue
self.doc.append(NoEscape(r'\FloatBarrier'))
with self.doc.create(Subsection(f"{model.model_name.capitalize()}", label=model.model_name)):
if isinstance(model, tf.keras.Model):
# Text Summary
summary = []
model.summary(line_length=92, print_fn=lambda x: summary.append(x))
summary = "\n".join(summary)
self.doc.append(Verbatim(summary))
with self.doc.create(Center()):
self.doc.append(HrefFEID(FEID(id(model)), model.model_name))
# Visual Summary
# noinspection PyBroadException
try:
file_path = os.path.join(self.resource_dir,
"{}_{}.pdf".format(self.report_name, model.model_name))
dot = tf.keras.utils.model_to_dot(model, show_shapes=True, expand_nested=True)
# LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than
# 226 inches. However, the 'size' parameter doesn't account for the whole node height, so
# set the limit lower (100 inches) to leave some wiggle room.
dot.set('size', '100')
dot.write(file_path, format='pdf')
except Exception:
file_path = None
warn(f"Model {model.model_name} could not be visualized by Traceability")
elif isinstance(model, torch.nn.Module):
if hasattr(model, 'fe_input_spec'):
# Text Summary
# noinspection PyUnresolvedReferences
inputs = model.fe_input_spec.get_dummy_input()
with Suppressor():
self.doc.append(
Verbatim(
str(
pms(model.module if self.system.num_devices > 1 else model,
input_data=inputs,
col_names=("output_size", "num_params", "trainable"),
col_width=20,
row_settings=["ascii_only"],
verbose=0))))
with self.doc.create(Center()):
self.doc.append(HrefFEID(FEID(id(model)), model.model_name))
# Visual Summary
# noinspection PyBroadException
try:
model.to(inputs.device)
graph = draw_graph(
model.module if isinstance(model, torch.nn.parallel.DataParallel) else model,
input_data=inputs,
device=inputs.device,
graph_dir='TB',
expand_nested=True,
depth=7).visual_graph
# LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less
# than 226 inches. However, the 'size' parameter doesn't account for the whole node
# height, so set the limit lower (100 inches) to leave some wiggle room.
graph.attr(size="100,100")
graph.attr(margin='0')
file_path = graph.render(filename="{}_{}".format(self.report_name, model.model_name),
directory=self.resource_dir,
format='pdf',
cleanup=True)
except Exception:
file_path = None
warn("Model {} could not be visualized by Traceability".format(model.model_name))
else:
file_path = None
self.doc.append("This model was not used by the Network during training.")
else:
file_path = None
self.doc.append(f"Model format: {type(model)} not recognized.")
if file_path:
with self.doc.create(Figure(position='ht!')) as fig:
fig.append(Label(Marker(name=str(FEID(id(model))), prefix="model")))
fig.add_image(os.path.relpath(file_path, start=self.save_dir),
width=NoEscape(r'1.0\textwidth,height=0.95\textheight,keepaspectratio'))
fig.add_caption(NoEscape(HrefFEID(FEID(id(model)), model.model_name).dumps()))
def _document_sys_config(self) -> None:
"""Add a system config summary to the traceability document.
"""
with self.doc.create(Section("System Configuration")):
with self.doc.create(Itemize()) as itemize:
itemize.add_item(escape_latex(f"FastEstimator {fe.__version__}"))
itemize.add_item(escape_latex(f"Python {platform.python_version()}"))
itemize.add_item(escape_latex(f"OS: {sys.platform}"))
cpu = get_cpu_info()
itemize.add_item(f"CPU Used: {cpu_count()} Threads")
with self.doc.create(Itemize()) as subitem:
subitem.add_item(f"{cpu['brand_raw']} ({cpu['count']} Threads)")
itemize.add_item(f"GPU(s) Used: {get_num_gpus()}")
gpus = get_gpu_info()
if gpus:
with self.doc.create(Itemize()) as subitem:
for gpu in gpus:
subitem.add_item(gpu)
if fe.fe_deterministic_seed is not None:
itemize.add_item(escape_latex(f"Deterministic Seed: {fe.fe_deterministic_seed}"))
with self.doc.create(LongTable('|lr|', pos=['h!'], booktabs=True)) as tabular:
tabular.add_row((bold("Module"), bold("Version")))
tabular.add_hline()
tabular.end_table_header()
tabular.add_hline()
tabular.add_row((MultiColumn(2, align='r', data='Continued on Next Page'), ))
tabular.add_hline()
tabular.end_table_footer()
tabular.end_table_last_footer()
color = True
for name, module in humansorted(sys.modules.items(), key=lambda x: x[0]):
if "." in name:
continue # Skip sub-packages
if name.startswith("_"):
continue # Skip private packages
if isinstance(module, Base):
continue # Skip fake packages we mocked
if hasattr(module, '__version__'):
tabular.add_row((escape_latex(name), escape_latex(str(module.__version__))),
color='black!5' if color else 'white')
color = not color
elif hasattr(module, 'VERSION'):
tabular.add_row((escape_latex(name), escape_latex(str(module.VERSION))),
color='black!5' if color else 'white')
color = not color
def _draw_diagram(self, mode: str, epoch: int, ds_id: str) -> pydot.Dot:
"""Draw a summary diagram of the FastEstimator Ops / Traces.
Args:
mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer').
epoch: The epoch to summarize.
ds_id: The ds_id to summarize.
Returns:
A pydot digraph representing the execution flow.
"""
ds = self.system.pipeline.data[mode][ds_id]
if isinstance(ds, Scheduler):
ds = ds.get_current_value(epoch)
pipe_ops = get_current_items(self.system.pipeline.ops, run_modes=mode, epoch=epoch, ds_id=ds_id) if isinstance(
ds, Dataset) else []
net_ops = get_current_items(self.system.network.ops, run_modes=mode, epoch=epoch, ds_id=ds_id)
net_slicers = get_current_items(self.system.network.slicers, run_modes=mode, epoch=epoch, ds_id=ds_id)
net_post = get_current_items(self.system.network.postprocessing, run_modes=mode, epoch=epoch, ds_id=ds_id)
traces = sort_traces(get_current_items(self.system.traces, run_modes=mode, epoch=epoch, ds_id=ds_id),
ds_ids=self.system.pipeline.get_ds_ids(epoch=epoch, mode=mode))
diagram = pydot.Dot(compound='true') # Compound lets you draw edges which terminate at sub-graphs
diagram.set('rankdir', 'TB')
diagram.set('dpi', 300)
diagram.set_node_defaults(shape='box')
# Make the dataset the first of the pipeline ops
pipe_ops.insert(0, ds)
label_last_seen = DefaultKeyDict(lambda k: str(id(ds))) # Where was this key last generated
batch_size = ""
if isinstance(ds, Dataset):
if hasattr(ds, "fe_batch") and ds.fe_batch:
batch_size = ds.fe_batch
else:
batch_size = self.system.pipeline.batch_size
if isinstance(batch_size, Scheduler):
batch_size = batch_size.get_current_value(epoch)
if isinstance(batch_size, dict):
batch_size = batch_size[mode]
if batch_size is not None:
batch_size = f" (Batch Size: {batch_size})"
self._draw_subgraph(diagram, diagram, label_last_seen, f'Pipeline{batch_size}', pipe_ops, ds_id)
self._draw_subgraph(diagram,
diagram,
label_last_seen,
'Network',
net_slicers + net_ops + [_UnslicerWrapper(slicer) for slicer in net_slicers] + net_post,
ds_id)
self._draw_subgraph(diagram, diagram, label_last_seen, 'Traces', traces, ds_id)
return diagram
def _draw_subgraph(self,
progenitor: pydot.Dot,
diagram: Union[pydot.Dot, pydot.Cluster],
label_last_seen: DefaultKeyDict[str, str],
subgraph_name: str,
subgraph_ops: List[Union[Op, Trace, Any]],
ds_id: Optional[str]) -> None:
"""Draw a subgraph of ops into an existing `diagram`.
Args:
progenitor: The very top level diagram onto which Edges should be written.
diagram: The diagram into which to add new Nodes.
label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
subgraph_name: The name to be associated with this subgraph.
subgraph_ops: The ops to be wrapped in this subgraph.
ds_id: The ds_id to be associated with this subgraph.
"""
subgraph = pydot.Cluster(style='dashed', graph_name=subgraph_name, color='black')
subgraph.set('label', subgraph_name)
subgraph.set('labeljust', 'l')
for idx, op in enumerate(subgraph_ops):
node_id = str(id(op))
self._add_node(progenitor, subgraph, op, label_last_seen, ds_id)
if isinstance(op, Trace) and idx > 0:
# Invisibly connect traces in order so that they aren't all just squashed horizontally into the image
progenitor.add_edge(pydot.Edge(src=str(id(subgraph_ops[idx - 1])), dst=node_id, style='invis'))
diagram.add_subgraph(subgraph)
def _add_node(self,
progenitor: pydot.Dot,
diagram: Union[pydot.Dot, pydot.Cluster],
op: Union[Op, Trace, Any],
label_last_seen: DefaultKeyDict[str, str],
ds_id: Optional[str],
edges: bool = True) -> None:
"""Draw a node onto a diagram based on a given op.
Args:
progenitor: The very top level diagram onto which Edges should be written.
diagram: The diagram to be appended to.
op: The op (or trace) to be visualized.
label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
ds_id: The ds_id under which the node is currently running.
edges: Whether to write Edges to/from this Node.
"""
node_id = str(id(op))
if isinstance(op, (Sometimes, SometimesT)) and op.op:
wrapper = pydot.Cluster(style='dotted', color='red', graph_name=str(id(op)))
wrapper.set('label', f'Sometimes ({op.prob}):')
wrapper.set('labeljust', 'l')
edge_srcs = defaultdict(lambda: [])
if op.extra_inputs:
for inp in op.extra_inputs:
if inp == '*':
continue
edge_srcs[label_last_seen[inp]].append(inp)
self._add_node(progenitor, wrapper, op.op, label_last_seen, ds_id)
diagram.add_subgraph(wrapper)
dst_id = self._get_all_nodes(wrapper)[0].get_name()
for src, labels in edge_srcs.items():
progenitor.add_edge(
pydot.Edge(src=src, dst=dst_id, lhead=wrapper.get_name(), label=f" {', '.join(labels)} "))
elif isinstance(op, (OneOf, OneOfT)) and op.ops:
wrapper = pydot.Cluster(style='dotted', color='darkorchid4', graph_name=str(id(op)))
wrapper.set('label', 'One Of:')
wrapper.set('labeljust', 'l')
self._add_node(progenitor, wrapper, op.ops[0], label_last_seen, ds_id, edges=True)
for sub_op in op.ops[1:]:
self._add_node(progenitor, wrapper, sub_op, label_last_seen, ds_id, edges=False)
diagram.add_subgraph(wrapper)
elif isinstance(op, (Fuse, FuseT)) and op.ops:
self._draw_subgraph(progenitor, diagram, label_last_seen, 'Fuse:', op.ops, ds_id)
elif isinstance(op, (Repeat, RepeatT)) and op.op:
wrapper = pydot.Cluster(style='dotted', color='darkgreen', graph_name=str(id(op)))
wrapper.set('label', f'Repeat:')
wrapper.set('labeljust', 'l')
wrapper.add_node(
pydot.Node(node_id,
label=f'{op.repeat if isinstance(op.repeat, int) else "?"}',
shape='doublecircle',
width=0.1))
# dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex color
progenitor.add_edge(pydot.Edge(src=node_id + ":ne", dst=node_id + ":w", color='#006300'))
self._add_node(progenitor, wrapper, op.op, label_last_seen, ds_id)
# Add repeat edges
edge_srcs = defaultdict(lambda: [])
for out in op.outputs:
if out in op.inputs and out not in op.repeat_inputs:
edge_srcs[label_last_seen[out]].append(out)
for inp in op.repeat_inputs:
edge_srcs[label_last_seen[inp]].append(inp)
for src, labels in edge_srcs.items():
progenitor.add_edge(pydot.Edge(src=src, dst=node_id, constraint=False, label=f" {', '.join(labels)} "))
diagram.add_subgraph(wrapper)
else:
if isinstance(op, ModelOp):
label = f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"
model_ref = Hyperref(Marker(name=str(op.model.model_name), prefix='subsec'),
text=NoEscape(r'\textcolor{blue}{') + bold(op.model.model_name) +
NoEscape('}')).dumps()
texlbl = f"{HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()}: {model_ref}"
elif isinstance(op, Batch):
label = f"{op.__class__.__name__} ({FEID(id(op))})"
texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__, color='purple').dumps()
if op.batch_size is not None:
diagram.set_label(f"Pipeline (Batch Size: {op.batch_size})")
label_last_seen.factory = functools.partial(self._delayed_edge,
progenitor=progenitor,
old_source=label_last_seen.factory(''),
new_source=str(id(op)))
elif isinstance(op, Slicer):
label = f"{op.__class__.__name__} ({FEID(id(op))})"
texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__, color='purple').dumps()
if op.minibatch_size:
diagram.set_label(f"Network (Slices Per Step: {op.minibatch_size})")
elif isinstance(op, _UnslicerWrapper):
# The corresponding Slicer is already in the graph earlier
label = None
texlbl = None
else:
label = f"{op.__class__.__name__} ({FEID(id(op))})"
texlbl = HrefFEID(FEID(id(op)), name=op.__class__.__name__).dumps()
if label is not None:
diagram.add_node(pydot.Node(node_id, label=label, texlbl=texlbl))
if isinstance(op, (Op, Trace, Slicer, _UnslicerWrapper)) and edges:
# Need the instance check since subgraph_ops might contain a tf dataset or torch data loader
self._add_edge(progenitor, op, label_last_seen, ds_id)
@staticmethod
def _delayed_edge(key: str, progenitor: pydot.Dot, old_source: str, new_source: str) -> str:
"""Draw a specific edge between two nodes, modifying the old label if applicable.
Args:
key: The key associated with the edge.
progenitor: The parent cluster.
old_source: The edge source.
new_source: The edge sync.
Returns:
The `new_source`.
"""
edge = progenitor.get_edge(old_source, new_source)
if edge:
edge = edge[0]
label = f"{edge.get_label()}, {key}"
edge.set_label(label)
else:
progenitor.add_edge(pydot.Edge(src=old_source, dst=new_source, label=f" {key}"))
return new_source
def _add_edge(self,
progenitor: pydot.Dot,
op: Union[Trace, Op, Slicer, _UnslicerWrapper],
label_last_seen: Dict[str, str],
ds_id: Optional[str]):
"""Draw edges into a given Node.
Args:
progenitor: The very top level diagram onto which Edges should be written.
op: The op (or trace) to be visualized.
label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key.
ds_id: The ds_id under which the node is currently running.
"""
node_id = str(id(op.slicer)) if isinstance(op, _UnslicerWrapper) else str(id(op))
edge_srcs = defaultdict(lambda: [])
global_ds_ids = {key for vals in self.system.pipeline.data.values() for key in vals.keys() if key is not None}
for inp in label_last_seen.keys() if isinstance(op, Batch) else op.slice_inputs if isinstance(
op, Slicer) else op.slicer.unslice_inputs if isinstance(op, _UnslicerWrapper) else op.inputs:
if inp == '*':
continue
_, candidate_id, *_ = f"{inp}|".split('|')
if candidate_id in global_ds_ids and candidate_id != ds_id:
continue # Skip inputs which will be provided in other ds_id plots
edge_srcs[label_last_seen[inp]].append(inp)
for src, labels in edge_srcs.items():
progenitor.add_edge(pydot.Edge(src=src, dst=node_id, label=f" {', '.join(labels)} "))
outputs = op.get_outputs(ds_ids=ds_id) if isinstance(op, Trace) else op.slice_inputs if isinstance(
op, Slicer) else op.slicer.unslice_inputs if isinstance(op, _UnslicerWrapper) else op.outputs
for out in label_last_seen.keys() if isinstance(op, Batch) else outputs:
label_last_seen[out] = node_id
@staticmethod
def _get_all_nodes(diagram: Union[pydot.Dot, pydot.Cluster]) -> List[pydot.Node]:
"""Recursively search through a `diagram` looking for Nodes.
Args:
diagram: The diagram to be inspected.
Returns:
All of the Nodes available within this diagram and its child diagrams.
"""
nodes = diagram.get_nodes()
for subgraph in diagram.get_subgraphs():
nodes.extend(Traceability._get_all_nodes(subgraph))
return nodes
@staticmethod
def _init_document_geometry() -> Document:
"""Init geometry setting of the document.
Return:
Initialized Document object.
"""
return Document(geometry_options=['lmargin=2cm', 'rmargin=2cm', 'bmargin=2cm'])