Automatically generate summary reports of the training.
Parameters:
Name
Type
Description
Default
save_path
str
Where to save the output files. Note that this will generate a new folder with the given name, into
which the report and corresponding graphics assets will be written.
required
extra_objects
Any
Any extra objects which are not part of the Estimator, but which you want to capture in the
summary report. One example could be an extra pipeline which performs pre-processing.
None
Raises:
Type
Description
OSError
If graphviz is not installed.
Source code in fastestimator/fastestimator/trace/io/traceability.py
@traceable()classTraceability(Trace):"""Automatically generate summary reports of the training. Args: save_path: Where to save the output files. Note that this will generate a new folder with the given name, into which the report and corresponding graphics assets will be written. extra_objects: Any extra objects which are not part of the Estimator, but which you want to capture in the summary report. One example could be an extra pipeline which performs pre-processing. Raises: OSError: If graphviz is not installed. """def__init__(self,save_path:str,extra_objects:Any=None):# Verify that graphviz is available on this machinetry:pydot.Dot.create(pydot.Dot())exceptOSError:raiseOSError("Traceability requires that graphviz be installed. See www.graphviz.org/download for more information.")# Verify that the system locale is functioning correctlytry:locale.getlocale()exceptValueError:raiseOSError("Your system locale is not configured correctly. On mac this can be resolved by adding \ 'export LC_ALL=en_US.UTF-8' and 'export LANG=en_US.UTF-8' to your ~/.bash_profile")super().__init__(inputs="*",mode="!infer")# Claim wildcard inputs to get this trace sorted last# Report assets will get saved into a folder for portabilitypath=os.path.normpath(save_path)path=os.path.abspath(path)root_dir=os.path.dirname(path)report=os.path.basename(path)or'report'report=report.split('.')[0]self.save_dir=os.path.join(root_dir,report)self.resource_dir=os.path.join(self.save_dir,'resources')self.report_name=None# This will be set later by the experiment nameos.makedirs(self.save_dir,exist_ok=True)os.makedirs(self.resource_dir,exist_ok=True)# Other member variablesself.config_tables=[]# Extra objects will automatically get included in the report since this Trace is @traceable, so we don't need# to do anything with them. Referencing here to stop IDEs from flagging the argument as unused and removing it.to_list(extra_objects)self.doc=Document()self.log_splicer=Nonedefon_begin(self,data:Data)->None:exp_name=self.system.summary.nameifnotexp_name:raiseRuntimeError("Traceability reports require an experiment name to be provided in estimator.fit()")# Convert the experiment name to a report name (useful for saving multiple experiments into same directory)report_name="".join('_'ifc==' 'elsecforcinexp_nameifc.isalnum()orcin(' ','_')).rstrip().lower()report_name=re.sub('_{2,}','_',report_name)self.report_name=report_nameor'report'# Send experiment logs into a filelog_path=os.path.join(self.resource_dir,f"{report_name}.txt")ifself.system.mode!='test':# See if there's a RestoreWizardrestore=Falsefortraceinself.system.traces:ifisinstance(trace,RestoreWizard):restore=trace.should_restore()ifnotrestore:# If not running in test mode, we need to remove any old log file since it would get appended towithcontextlib.suppress(FileNotFoundError):os.remove(log_path)self.log_splicer=LogSplicer(log_path)self.log_splicer.__enter__()# Get the initialization summary information for the experimentself.config_tables=self.system.summary.system_configmodels=self.system.network.modelsn_floats=len(self.config_tables)+len(models)self.doc=self._init_document_geometry()# Keep tables/figures in their sectionsself.doc.packages.append(Package(name='placeins',options=['section']))self.doc.preamble.append(NoEscape(r'\usetikzlibrary{positioning}'))# Fix an issue with too many tables for LaTeX to renderself.doc.preamble.append(NoEscape(r'\maxdeadcycles='+str(2*n_floats+10)+''))self.doc.preamble.append(NoEscape(r'\extrafloats{'+str(n_floats+10)+'}'))# Manipulate booktab tables so that their horizontal lines don't breakself.doc.preamble.append(NoEscape(r'\aboverulesep=0ex'))self.doc.preamble.append(NoEscape(r'\belowrulesep=0ex'))self.doc.preamble.append(NoEscape(r'\renewcommand{\arraystretch}{1.2}'))self._write_title()self._write_toc()defon_end(self,data:Data)->None:self._write_body_content()# Need to move the tikz dependency after the xcolor packageself.doc.dumps_packages()packages=self.doc.packagestikz=Package(name='tikz')packages.discard(tikz)packages.add(tikz)ifshutil.which("latexmk")isNoneandshutil.which("pdflatex")isNone:# No LaTeX Compiler is availableself.doc.generate_tex(os.path.join(self.save_dir,self.report_name))suffix='.tex'else:# Force a double-compile since some compilers will struggle with TOC generationself.doc.generate_pdf(os.path.join(self.save_dir,self.report_name),clean_tex=False,clean=False)self.doc.generate_pdf(os.path.join(self.save_dir,self.report_name),clean_tex=False)suffix='.pdf'print("FastEstimator-Traceability: Report written to {}{}".format(os.path.join(self.save_dir,self.report_name),suffix))self.log_splicer.__exit__()def_write_title(self)->None:"""Write the title content of the file. Override if you want to build on top of base traceability report. """self.doc.preamble.append(Command('title',self.system.summary.name))self.doc.preamble.append(Command('author',f"FastEstimator {fe.__version__}"))self.doc.preamble.append(Command('date',NoEscape(r'\today')))self.doc.append(NoEscape(r'\maketitle'))def_write_toc(self)->None:"""Write the table of contents. Override if you want to build on top of base traceability report. """self.doc.append(NoEscape(r'\tableofcontents'))self.doc.append(NoEscape(r'\newpage'))def_write_body_content(self)->None:"""Write the main content of the file. Override if you want to build on top of base traceability report. """self._document_training_graphs()self.doc.append(NoEscape(r'\newpage'))self._document_fe_graph()self.doc.append(NoEscape(r'\newpage'))self._document_init_params()self._document_models()self._document_sys_config()self.doc.append(NoEscape(r'\newpage'))def_document_training_graphs(self)->None:"""Add training graphs to the traceability document. """withself.doc.create(Section("Training Graphs")):log_path=os.path.join(self.resource_dir,f'{self.report_name}_logs.png')visualize_logs(experiments=[self.system.summary],save_path=log_path,verbose=False,ignore_metrics={'num_device','logging_interval'})withself.doc.create(Figure(position='h!'))asplot:plot.add_image(os.path.relpath(log_path,start=self.save_dir),width=NoEscape(r'1.0\textwidth,height=0.95\textheight,keepaspectratio'))foridx,graphinenumerate(self.system.custom_graphs.values()):graph_path=os.path.join(self.resource_dir,f'{self.report_name}_custom_graph_{idx}.png')visualize_logs(experiments=graph,save_path=graph_path,verbose=False)withself.doc.create(Figure(position='h!'))asplot:plot.add_image(os.path.relpath(graph_path,start=self.save_dir),width=NoEscape(r'1.0\textwidth,height=0.95\textheight,keepaspectratio'))def_document_fe_graph(self)->None:"""Add FE execution graphs into the traceability document. """withself.doc.create(Section("FastEstimator Architecture")):formodeinself.system.pipeline.data.keys():scheduled_items=self.system.pipeline.get_scheduled_items(mode)+self.system.network.get_scheduled_items(mode)+self.system.tracessignature_epochs=get_signature_epochs(scheduled_items,total_epochs=self.system.epoch_idx,mode=mode)epochs_with_data=self.system.pipeline.get_epochs_with_data(total_epochs=self.system.epoch_idx,mode=mode)ifset(signature_epochs)&epochs_with_data:self.doc.append(NoEscape(r'\FloatBarrier'))withself.doc.create(Subsection(mode.capitalize())):forepochinsignature_epochs:ifepochnotinepochs_with_data:continueself.doc.append(NoEscape(r'\FloatBarrier'))withself.doc.create(Subsubsection(f"Epoch {epoch}",label=Label(Marker(name=f"{mode}{epoch}",prefix="ssubsec")))):ds_ids=self.system.pipeline.get_ds_ids(epoch=epoch,mode=mode)fords_idinds_ids:withNonContext()ifds_id==''elseself.doc.create(Paragraph(f"Dataset {ds_id}",label=Label(Marker(name=f"{mode}{epoch}{ds_id}",prefix="para")))):diagram=self._draw_diagram(mode,epoch,ds_id)ltx=d2t.dot2tex(diagram.to_string(),figonly=True)args=Arguments(**{'max width':r'\textwidth, max height=0.9\textheight'})args.escape=Falsewithself.doc.create(Center()):withself.doc.create(AdjustBox(arguments=args))asbox:box.append(NoEscape(ltx))# infer graphself.doc.append(NoEscape(r'\FloatBarrier'))withself.doc.create(Subsection('Infer')):withNonContext():diagram=self._draw_infer_diagram()ltx=d2t.dot2tex(diagram.to_string(),figonly=True)args=Arguments(**{'max width':r'\textwidth, max height=0.9\textheight'})args.escape=Falsewithself.doc.create(Center()):withself.doc.create(AdjustBox(arguments=args))asbox:box.append(NoEscape(ltx))def_document_init_params(self)->None:"""Add initialization parameters to the traceability document. """fromfastestimator.estimatorimportEstimator# Avoid circular importwithself.doc.create(Section("Parameters")):model_ids={FEID(id(model))formodelinself.system.network.modelsifisinstance(model,(tf.keras.Model,torch.nn.Module))}# Locate the datasets in order to provide extra details about them later in the summarydatasets={}formodein['train','eval','test']:objs=to_list(self.system.pipeline.data.get(mode,None))idx=0whileidx<len(objs):obj=objs[idx]ifobj:feid=FEID(id(obj))iffeidnotindatasets:datasets[feid]=({mode},obj)else:datasets[feid][0].add(mode)ifisinstance(obj,Scheduler):objs.extend(obj.get_all_values())idx+=1# Parse the config tablesstart=0start=self._loop_tables(start,classes=(Estimator,BaseNetwork,Pipeline),name="Base Classes",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=Scheduler,name="Schedulers",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=Trace,name="Traces",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=Op,name="Operators",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=Slicer,name="Slicers",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=(Dataset,tf.data.Dataset),name="Datasets",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=(tf.keras.Model,torch.nn.Module),name="Models",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=types.FunctionType,name="Functions",model_ids=model_ids,datasets=datasets)start=self._loop_tables(start,classes=(np.ndarray,tf.Tensor,tf.Variable,torch.Tensor),name="Tensors",model_ids=model_ids,datasets=datasets)self._loop_tables(start,classes=Any,name="Miscellaneous",model_ids=model_ids,datasets=datasets)self.get_parameter_summary()defget_parameter_summary(self):parameters={'batch_size':self.system.pipeline.batch_size,'epochs':self.system.total_epochs,'train_steps_per_epoch':self.system.train_steps_per_epoch,'eval_steps_per_epoch':self.system.eval_steps_per_epoch,}parameter_retrieval_errors=[]try:parameters["no_of_model_parameters"]={model.model_name.lower():get_model_parameters(model)formodelinself.system.network.modelsifisinstance(model,(tf.keras.Model,torch.nn.Module))}exceptExceptionase:print(e)parameter_retrieval_errors.append('no_of_model_parameters')try:parameters["lr"]={model.model_name.lower():fe.backend.get_lr(model=model)formodelinself.system.network.modelsifisinstance(model,(tf.keras.Model,torch.nn.Module))}exceptExceptionase:print(e)parameter_retrieval_errors.append('lr')try:parameters['optimizers']={model.model_name.lower():get_optimizer_name(model)formodelinself.system.network.models}exceptExceptionase:print(e)parameter_retrieval_errors.append('optimizers')try:parameters['lr_scheduler']={lr_schedule.model.model_name.lower():inspect.getsource(lr_schedule.lr_fn)forlr_scheduleinself.system.tracesifisinstance(lr_schedule,LRScheduler)},exceptExceptionase:print(e)parameter_retrieval_errors.append('lr_scheduler')iflen(parameter_retrieval_errors)>0:print("Couldn't retrieve the following parameters due to errors: "+', '.join(parameter_retrieval_errors))parameter_summary_table=SummaryTable('Parameter Summary',kwargs=parameters)self.doc.append(NoEscape(r'\FloatBarrier'))withself.doc.create(Subsection('Parameter Summary')):self._write_tables([parameter_summary_table],set(),set())def_loop_tables(self,start:int,classes:Union[type,Tuple[type,...]],name:str,model_ids:Set[FEID],datasets:Dict[FEID,Tuple[Set[str],Any]])->int:"""Iterate through tables grouping them into subsections. Args: start: What index to start searching from. classes: What classes are acceptable for this subsection. name: What to call this subsection. model_ids: The ids of any known models. datasets: A mapping like {ID: ({modes}, dataset)}. Useful for augmenting the displayed information. Returns: The new start index after traversing as many spaces as possible along the list of tables. """stop=startwhilestop<len(self.config_tables):ifclasses==Anyorissubclass(self.config_tables[stop].type,classes):stop+=1else:breakifstop>start:self.doc.append(NoEscape(r'\FloatBarrier'))withself.doc.create(Subsection(name)):self._write_tables(self.config_tables[start:stop],model_ids,datasets)returnstopdef_write_tables(self,tables:Union[List[FeSummaryTable],List[SummaryTable]],model_ids:Set[FEID],datasets:Dict[FEID,Tuple[Set[str],Any]])->None:"""Insert a LaTeX representation of a list of tables into the current doc. Args: tables: The tables to write into the doc. model_ids: The ids of any known models. datasets: A mapping like {ID: ({modes}, dataset)}. Useful for augmenting the displayed information. """fortblintables:name_override=Nonetoc_ref=Noneextra_rows=Noneifisinstance(tbl,FeSummaryTable):iftbl.fe_idinmodel_ids:# Link to a later detailed model descriptionname_override=Hyperref(Marker(name=str(tbl.name),prefix="subsec"),text=NoEscape(r'\textcolor{blue}{')+bold(tbl.name)+NoEscape('}'))iftbl.fe_idindatasets:modes,dataset=datasets[tbl.fe_id]title=", ".join([s.capitalize()forsinmodes])name_override=bold(f'{tbl.name} ({title})')# Enhance the dataset summaryifisinstance(dataset,FEDataset):extra_rows=list(dataset.summary().__getstate__().items())foridx,(key,val)inenumerate(extra_rows):key=f"{prettify_metric_name(key)}:"ifisinstance(val,dict)andval:ifisinstance(list(val.values())[0],(int,float,str,bool,type(None))):val=jsonpickle.dumps(val,unpicklable=False)else:subtable=Tabularx('l|X',width_argument=NoEscape(r'\linewidth'))fork,vinval.items():ifhasattr(v,'__getstate__'):v=jsonpickle.dumps(v,unpicklable=False)subtable.add_row((k,v))# To nest TabularX, have to wrap it in bracketssubtable=ContainerList(data=[NoEscape("{"),subtable,NoEscape("}")])val=subtableextra_rows[idx]=(key,val)tbl.render_table(self.doc,name_override=name_override,toc_ref=toc_ref,extra_rows=extra_rows)def_document_models(self)->None:"""Add model summaries to the traceability document. """withself.doc.create(Section("Models")):formodelinhumansorted(self.system.network.models,key=lambdam:m.model_name):ifnotisinstance(model,(tf.keras.Model,torch.nn.Module)):continueself.doc.append(NoEscape(r'\FloatBarrier'))withself.doc.create(Subsection(f"{model.model_name.capitalize()}",label=model.model_name)):ifisinstance(model,tf.keras.Model):# Text Summarysummary=[]model.summary(line_length=92,print_fn=lambdax:summary.append(x))summary="\n".join(summary)self.doc.append(Verbatim(summary))withself.doc.create(Center()):self.doc.append(HrefFEID(FEID(id(model)),model.model_name))# Visual Summary# noinspection PyBroadExceptiontry:file_path=os.path.join(self.resource_dir,"{}_{}.pdf".format(self.report_name,model.model_name))dot=tf.keras.utils.model_to_dot(model,show_shapes=True,expand_nested=True)# LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than# 226 inches. However, the 'size' parameter doesn't account for the whole node height, so# set the limit lower (100 inches) to leave some wiggle room.dot.set('size','100')dot.write(file_path,format='pdf')exceptException:file_path=Nonewarn(f"Model {model.model_name} could not be visualized by Traceability")elifisinstance(model,torch.nn.Module):ifhasattr(model,'fe_input_spec'):# Text Summary# noinspection PyUnresolvedReferencesinputs=model.fe_input_spec.get_dummy_input()try:model.to(inputs.device)exceptException:file_path=Nonewarn("Model {} could not be visualized by Traceability".format(model.model_name))withSuppressor():self.doc.append(Verbatim(str(pms(model.moduleifisinstance(model,torch.nn.parallel.DataParallel)elsemodel,input_data=inputs,col_names=("output_size","num_params","trainable"),col_width=20,row_settings=["ascii_only"],verbose=0))))withself.doc.create(Center()):self.doc.append(HrefFEID(FEID(id(model)),model.model_name))# Visual Summary# noinspection PyBroadExceptiontry:graph=draw_graph(model.moduleifisinstance(model,torch.nn.parallel.DataParallel)elsemodel,input_data=inputs,device=inputs.device,graph_dir='TB',expand_nested=True,depth=7).visual_graph# LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less# than 226 inches. However, the 'size' parameter doesn't account for the whole node# height, so set the limit lower (100 inches) to leave some wiggle room.graph.attr(size="100,100")graph.attr(margin='0')file_path=graph.render(filename="{}_{}".format(self.report_name,model.model_name),directory=self.resource_dir,format='pdf',cleanup=True)exceptException:file_path=Nonewarn("Model {} could not be visualized by Traceability".format(model.model_name))else:file_path=Noneself.doc.append("This model was not used by the Network during training.")else:file_path=Noneself.doc.append(f"Model format: {type(model)} not recognized.")iffile_path:withself.doc.create(Figure(position='ht!'))asfig:fig.append(Label(Marker(name=str(FEID(id(model))),prefix="model")))fig.add_image(os.path.relpath(file_path,start=self.save_dir),width=NoEscape(r'1.0\textwidth,height=0.95\textheight,keepaspectratio'))fig.add_caption(NoEscape(HrefFEID(FEID(id(model)),model.model_name).dumps()))def_document_sys_config(self)->None:"""Add a system config summary to the traceability document. """withself.doc.create(Section("System Configuration")):withself.doc.create(Itemize())asitemize:itemize.add_item(escape_latex(f"FastEstimator {fe.__version__}"))itemize.add_item(escape_latex(f"Python {platform.python_version()}"))itemize.add_item(escape_latex(f"OS: {sys.platform}"))cpu=get_cpu_info()itemize.add_item(f"CPU Used: {cpu_count()} Threads")withself.doc.create(Itemize())assubitem:subitem.add_item(f"{cpu['brand_raw']} ({cpu['count']} Threads)")itemize.add_item(f"GPU(s) Used: {get_num_gpus()}")gpus=get_gpu_info()ifgpus:withself.doc.create(Itemize())assubitem:forgpuingpus:subitem.add_item(gpu)iffe.fe_deterministic_seedisnotNone:itemize.add_item(escape_latex(f"Deterministic Seed: {fe.fe_deterministic_seed}"))withself.doc.create(LongTable('|lr|',pos=['h!'],booktabs=True))astabular:tabular.add_row((bold("Module"),bold("Version")))tabular.add_hline()tabular.end_table_header()tabular.add_hline()tabular.add_row((MultiColumn(2,align='r',data='Continued on Next Page'),))tabular.add_hline()tabular.end_table_footer()tabular.end_table_last_footer()color=Trueforname,moduleinhumansorted(sys.modules.items(),key=lambdax:x[0]):if"."inname:continue# Skip sub-packagesifname.startswith("_"):continue# Skip private packagesifisinstance(module,Base):continue# Skip fake packages we mockedifhasattr(module,'__version__'):tabular.add_row((escape_latex(name),escape_latex(str(module.__version__))),color='black!5'ifcolorelse'white')color=notcolorelifhasattr(module,'VERSION'):tabular.add_row((escape_latex(name),escape_latex(str(module.VERSION))),color='black!5'ifcolorelse'white')color=notcolordef_draw_infer_diagram(self)->pydot.Dot:"""Draw a summary diagram of the FastEstimator Ops Returns: A pydot digraph representing the execution flow. """pipe_ops=get_current_items(self.system.pipeline.ops,run_modes='infer')net_ops=get_current_items(self.system.network.ops,run_modes='infer')net_slicers=get_current_items(self.system.network.slicers,run_modes='infer')net_post=get_current_items(self.system.network.postprocessing,run_modes='infer')diagram=pydot.Dot(compound='true')# Compound lets you draw edges which terminate at sub-graphsdiagram.set('rankdir','TB')diagram.set('dpi',300)diagram.set_node_defaults(shape='box')# Make the dataset the first of the pipeline opsdefget_in_out(ops):inputs={}outputs={}foropinops:foriinop:op_input={op_in:Noneforop_inini.inputs}op_output={op_in:Noneforop_inini.outputs}forop_ininop_input:ifop_ininoutputs:deloutputs[op_in]else:inputs[op_in]=''forop_outinop_output:outputs[op_out]=''returnlist(inputs.keys()),list(outputs.keys())input_keys,_=get_in_out([pipe_ops,net_ops,net_slicers,net_post])dataop=DataOp(outputs=input_keys)label_last_seen=DefaultKeyDict(lambdak:str(id(dataop)))# Where was this key last generatedself._draw_data_node(diagram,dataop,label_last_seen)self._draw_subgraph(diagram,diagram,label_last_seen,'Pipeline',pipe_ops,None)self._draw_subgraph(diagram,diagram,label_last_seen,'Network',net_slicers+net_ops+[_UnslicerWrapper(slicer)forslicerinnet_slicers]+net_post,None)returndiagramdef_draw_data_node(self,diagram:pydot.Dot,dataop:Op,label_last_seen:DefaultKeyDict[str,str],):"""Draw a subgraph of ops into an existing `diagram`. Args: diagram: The diagram into which to add new node. dataop: The data op to be wrapped in this diagram. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. """diagram.add_node(pydot.Node(str(id(dataop)),label="Inference Data",texlbl="Inference Data"))self._add_edge(diagram,dataop,label_last_seen,None)def_draw_diagram(self,mode:str,epoch:int,ds_id:str)->pydot.Dot:"""Draw a summary diagram of the FastEstimator Ops / Traces. Args: mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer'). epoch: The epoch to summarize. ds_id: The ds_id to summarize. Returns: A pydot digraph representing the execution flow. """ds=self.system.pipeline.data[mode][ds_id]ifisinstance(ds,Scheduler):ds=ds.get_current_value(epoch)pipe_ops=get_current_items(self.system.pipeline.ops,run_modes=mode,epoch=epoch,ds_id=ds_id)ifisinstance(ds,Dataset)else[]net_ops=get_current_items(self.system.network.ops,run_modes=mode,epoch=epoch,ds_id=ds_id)net_slicers=get_current_items(self.system.network.slicers,run_modes=mode,epoch=epoch,ds_id=ds_id)net_post=get_current_items(self.system.network.postprocessing,run_modes=mode,epoch=epoch,ds_id=ds_id)traces=sort_traces(get_current_items(self.system.traces,run_modes=mode,epoch=epoch,ds_id=ds_id),ds_ids=self.system.pipeline.get_ds_ids(epoch=epoch,mode=mode))diagram=pydot.Dot(compound='true')# Compound lets you draw edges which terminate at sub-graphsdiagram.set('rankdir','TB')diagram.set('dpi',300)diagram.set_node_defaults(shape='box')# Make the dataset the first of the pipeline opspipe_ops.insert(0,ds)label_last_seen=DefaultKeyDict(lambdak:str(id(ds)))# Where was this key last generatedbatch_size=""ifisinstance(ds,Dataset):ifhasattr(ds,"fe_batch")andds.fe_batch:batch_size=ds.fe_batchelse:batch_size=self.system.pipeline.batch_sizeifisinstance(batch_size,Scheduler):batch_size=batch_size.get_current_value(epoch)ifisinstance(batch_size,dict):batch_size=batch_size[mode]ifbatch_sizeisnotNone:batch_size=f" (Batch Size: {batch_size})"self._draw_subgraph(diagram,diagram,label_last_seen,f'Pipeline{batch_size}',pipe_ops,ds_id)self._draw_subgraph(diagram,diagram,label_last_seen,'Network',net_slicers+net_ops+[_UnslicerWrapper(slicer)forslicerinnet_slicers]+net_post,ds_id)self._draw_subgraph(diagram,diagram,label_last_seen,'Traces',traces,ds_id)returndiagramdef_draw_subgraph(self,progenitor:pydot.Dot,diagram:Union[pydot.Dot,pydot.Cluster],label_last_seen:DefaultKeyDict[str,str],subgraph_name:str,subgraph_ops:List[Union[Op,Trace,Any]],ds_id:Optional[str])->None:"""Draw a subgraph of ops into an existing `diagram`. Args: progenitor: The very top level diagram onto which Edges should be written. diagram: The diagram into which to add new Nodes. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. subgraph_name: The name to be associated with this subgraph. subgraph_ops: The ops to be wrapped in this subgraph. ds_id: The ds_id to be associated with this subgraph. """subgraph=pydot.Cluster(style='dashed',graph_name=subgraph_name,color='black')subgraph.set('label',subgraph_name)subgraph.set('labeljust','l')foridx,opinenumerate(subgraph_ops):node_id=str(id(op))self._add_node(progenitor,subgraph,op,label_last_seen,ds_id)ifisinstance(op,Trace)andidx>0:# Invisibly connect traces in order so that they aren't all just squashed horizontally into the imageprogenitor.add_edge(pydot.Edge(src=str(id(subgraph_ops[idx-1])),dst=node_id,style='invis'))diagram.add_subgraph(subgraph)def_add_node(self,progenitor:pydot.Dot,diagram:Union[pydot.Dot,pydot.Cluster],op:Union[Op,Trace,Any],label_last_seen:DefaultKeyDict[str,str],ds_id:Optional[str],edges:bool=True)->None:"""Draw a node onto a diagram based on a given op. Args: progenitor: The very top level diagram onto which Edges should be written. diagram: The diagram to be appended to. op: The op (or trace) to be visualized. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. ds_id: The ds_id under which the node is currently running. edges: Whether to write Edges to/from this Node. """node_id=str(id(op))ifisinstance(op,(Sometimes,SometimesT))andop.op:wrapper=pydot.Cluster(style='dotted',color='red',graph_name=str(id(op)))wrapper.set('label',f'Sometimes ({op.prob}):')wrapper.set('labeljust','l')edge_srcs=defaultdict(lambda:[])ifop.extra_inputs:forinpinop.extra_inputs:ifinp=='*':continueedge_srcs[label_last_seen[inp]].append(inp)self._add_node(progenitor,wrapper,op.op,label_last_seen,ds_id)diagram.add_subgraph(wrapper)dst_id=self._get_all_nodes(wrapper)[0].get_name()forsrc,labelsinedge_srcs.items():progenitor.add_edge(pydot.Edge(src=src,dst=dst_id,lhead=wrapper.get_name(),label=f" {', '.join(labels)} "))elifisinstance(op,(OneOf,OneOfT))andop.ops:wrapper=pydot.Cluster(style='dotted',color='darkorchid4',graph_name=str(id(op)))wrapper.set('label','One Of:')wrapper.set('labeljust','l')self._add_node(progenitor,wrapper,op.ops[0],label_last_seen,ds_id,edges=True)forsub_opinop.ops[1:]:self._add_node(progenitor,wrapper,sub_op,label_last_seen,ds_id,edges=False)diagram.add_subgraph(wrapper)elifisinstance(op,(Fuse,FuseT))andop.ops:self._draw_subgraph(progenitor,diagram,label_last_seen,'Fuse:',op.ops,ds_id)elifisinstance(op,(Repeat,RepeatT))andop.op:wrapper=pydot.Cluster(style='dotted',color='darkgreen',graph_name=str(id(op)))wrapper.set('label',f'Repeat:')wrapper.set('labeljust','l')wrapper.add_node(pydot.Node(node_id,label=f'{op.repeatifisinstance(op.repeat,int)else"?"}',shape='doublecircle',width=0.1))# dot2tex doesn't seem to handle edge color conversion correctly, so have to set hex colorprogenitor.add_edge(pydot.Edge(src=node_id+":ne",dst=node_id+":w",color='#006300'))self._add_node(progenitor,wrapper,op.op,label_last_seen,ds_id)# Add repeat edgesedge_srcs=defaultdict(lambda:[])foroutinop.outputs:ifoutinop.inputsandoutnotinop.repeat_inputs:edge_srcs[label_last_seen[out]].append(out)forinpinop.repeat_inputs:edge_srcs[label_last_seen[inp]].append(inp)forsrc,labelsinedge_srcs.items():progenitor.add_edge(pydot.Edge(src=src,dst=node_id,constraint=False,label=f" {', '.join(labels)} "))diagram.add_subgraph(wrapper)else:ifisinstance(op,ModelOp):label=f"{op.__class__.__name__} ({FEID(id(op))}): {op.model.model_name}"model_ref=Hyperref(Marker(name=str(op.model.model_name),prefix='subsec'),text=NoEscape(r'\textcolor{blue}{')+bold(op.model.model_name)+NoEscape('}')).dumps()texlbl=f"{HrefFEID(FEID(id(op)),name=op.__class__.__name__).dumps()}: {model_ref}"elifisinstance(op,Batch):label=f"{op.__class__.__name__} ({FEID(id(op))})"texlbl=HrefFEID(FEID(id(op)),name=op.__class__.__name__,color='purple').dumps()ifop.batch_sizeisnotNone:diagram.set_label(f"Pipeline (Batch Size: {op.batch_size})")label_last_seen.factory=functools.partial(self._delayed_edge,progenitor=progenitor,old_source=label_last_seen.factory(''),new_source=str(id(op)))elifisinstance(op,Slicer):label=f"{op.__class__.__name__} ({FEID(id(op))})"texlbl=HrefFEID(FEID(id(op)),name=op.__class__.__name__,color='purple').dumps()ifop.minibatch_size:diagram.set_label(f"Network (Slices Per Step: {op.minibatch_size})")elifisinstance(op,_UnslicerWrapper):# The corresponding Slicer is already in the graph earlierlabel=Nonetexlbl=Noneelse:label=f"{op.__class__.__name__} ({FEID(id(op))})"texlbl=HrefFEID(FEID(id(op)),name=op.__class__.__name__).dumps()iflabelisnotNone:diagram.add_node(pydot.Node(node_id,label=label,texlbl=texlbl))ifisinstance(op,(Op,Trace,Slicer,_UnslicerWrapper))andedges:# Need the instance check since subgraph_ops might contain a tf dataset or torch data loaderself._add_edge(progenitor,op,label_last_seen,ds_id)@staticmethoddef_delayed_edge(key:str,progenitor:pydot.Dot,old_source:str,new_source:str)->str:"""Draw a specific edge between two nodes, modifying the old label if applicable. Args: key: The key associated with the edge. progenitor: The parent cluster. old_source: The edge source. new_source: The edge sync. Returns: The `new_source`. """edge=progenitor.get_edge(old_source,new_source)ifedge:edge=edge[0]label=f"{edge.get_label()}, {key}"edge.set_label(label)else:progenitor.add_edge(pydot.Edge(src=old_source,dst=new_source,label=f" {key}"))returnnew_sourcedef_add_edge(self,progenitor:pydot.Dot,op:Union[Trace,Op,Slicer,_UnslicerWrapper],label_last_seen:Dict[str,str],ds_id:Optional[str]):"""Draw edges into a given Node. Args: progenitor: The very top level diagram onto which Edges should be written. op: The op (or trace) to be visualized. label_last_seen: A mapping of {data_dict_key: node_id} indicating the last node which generated the key. ds_id: The ds_id under which the node is currently running. """node_id=str(id(op.slicer))ifisinstance(op,_UnslicerWrapper)elsestr(id(op))edge_srcs=defaultdict(lambda:[])global_ds_ids={keyforvalsinself.system.pipeline.data.values()forkeyinvals.keys()ifkeyisnotNone}forinpinlabel_last_seen.keys()ifisinstance(op,Batch)elseop.slice_inputsifisinstance(op,Slicer)elseop.slicer.unslice_inputsifisinstance(op,_UnslicerWrapper)elseop.inputs:ifinp=='*':continue_,candidate_id,*_=f"{inp}|".split('|')ifcandidate_idinglobal_ds_idsandcandidate_id!=ds_idandds_idisnotNone:continue# Skip inputs which will be provided in other ds_id plotsedge_srcs[label_last_seen[inp]].append(inp)forsrc,labelsinedge_srcs.items():progenitor.add_edge(pydot.Edge(src=src,dst=node_id,label=f" {', '.join(labels)} "))outputs=op.get_outputs(ds_ids=ds_id)ifisinstance(op,Trace)elseop.slice_inputsifisinstance(op,Slicer)elseop.slicer.unslice_inputsifisinstance(op,_UnslicerWrapper)elseop.outputsforoutinlabel_last_seen.keys()ifisinstance(op,Batch)elseoutputs:label_last_seen[out]=node_id@staticmethoddef_get_all_nodes(diagram:Union[pydot.Dot,pydot.Cluster])->List[pydot.Node]:"""Recursively search through a `diagram` looking for Nodes. Args: diagram: The diagram to be inspected. Returns: All of the Nodes available within this diagram and its child diagrams. """nodes=diagram.get_nodes()forsubgraphindiagram.get_subgraphs():nodes.extend(Traceability._get_all_nodes(subgraph))returnnodes@staticmethoddef_init_document_geometry()->Document:"""Init geometry setting of the document. Return: Initialized Document object. """returnDocument(geometry_options=['lmargin=2cm','rmargin=2cm','bmargin=2cm'])