This class defines the test case that the TestReport trace will take to perform auto-testing.
Parameters:
Name
Type
Description
Default
description
str
A test description.
required
criteria
Callable[..., Union[bool, np.ndarray]]
A function to perform the test. For an aggregate test, criteria needs to return True when the test
passes and False when it fails. For a per-instance test, criteria needs to return a boolean np.ndarray,
where entries show corresponding test results (True if the test of that data instance passes; False if it
fails).
required
aggregate
bool
If True, this test is aggregate type and its criteria function will be examined at epoch_end. If
False, this test is per-instance type and its criteria function will be examined at batch_end.
True
fail_threshold
int
Threshold of failure instance number to judge the per-instance test as failed or passed. If
the failure number is above this value, then the test fails; otherwise it passes. It can only be set when
aggregate is equal to False.
0
Raises:
Type
Description
ValueError
If user set fail_threshold for an aggregate test.
Source code in fastestimator\fastestimator\trace\io\test_report.py
@traceable()classTestCase:"""This class defines the test case that the TestReport trace will take to perform auto-testing. Args: description: A test description. criteria: A function to perform the test. For an aggregate test, `criteria` needs to return True when the test passes and False when it fails. For a per-instance test, `criteria` needs to return a boolean np.ndarray, where entries show corresponding test results (True if the test of that data instance passes; False if it fails). aggregate: If True, this test is aggregate type and its `criteria` function will be examined at epoch_end. If False, this test is per-instance type and its `criteria` function will be examined at batch_end. fail_threshold: Threshold of failure instance number to judge the per-instance test as failed or passed. If the failure number is above this value, then the test fails; otherwise it passes. It can only be set when `aggregate` is equal to False. Raises: ValueError: If user set `fail_threshold` for an aggregate test. """def__init__(self,description:str,criteria:Callable[...,Union[bool,np.ndarray]],aggregate:bool=True,fail_threshold:int=0)->None:self.description=descriptionself.criteria=criteriaself.criteria_inputs=inspect.signature(criteria).parameters.keys()self.aggregate=aggregateifself.aggregate:iffail_threshold:raiseValueError("fail_threshold cannot be set in a aggregate test")else:self.fail_threshold=fail_thresholdself.result=Noneself.input_val=Noneself.fail_id=[]self.init_result()definit_result(self)->None:"""Reset the test result. """ifself.aggregate:self.result=Noneself.input_val=Noneelse:self.result=[]self.fail_id=[]
@traceable()classTestReport(Trace):"""Automate testing and report generation. This trace will evaluate all its `test_cases` during test mode and generate a PDF report and a JSON test result. Args: test_cases: The test(s) to be run. save_path: Where to save the outputs. test_title: The title of the test, or None to use the experiment name. data_id: Data instance ID key. If provided, then per-instances test will include failing instance IDs. """def__init__(self,test_cases:Union[TestCase,List[TestCase]],save_path:str,test_title:Optional[str]=None,data_id:str=None)->None:self.check_pdf_dependency()self.test_title=test_titleself.report_name=Noneself.instance_cases=[]self.aggregate_cases=[]self.data_id=data_idall_inputs=to_set(self.data_id)forcaseinto_list(test_cases):all_inputs.update(case.criteria_inputs)ifcase.aggregate:self.aggregate_cases.append(case)else:self.instance_cases.append(case)path=os.path.normpath(save_path)path=os.path.abspath(path)root_dir=os.path.dirname(path)report=os.path.basename(path)or'report'report=report.split('.')[0]self.save_dir=os.path.join(root_dir,report)self.resource_dir=os.path.join(self.save_dir,"resources")os.makedirs(self.save_dir,exist_ok=True)os.makedirs(self.resource_dir,exist_ok=True)self.json_summary={}# PDF document relatedself.doc=Noneself.test_id=Nonesuper().__init__(inputs=all_inputs,mode="test")defon_begin(self,data:Data)->None:self._sanitize_report_name()self._initialize_json_summary()forcaseinself.instance_cases+self.aggregate_cases:case.init_result()defon_batch_end(self,data:Data)->None:forcaseinself.instance_cases:result=case.criteria(*[data[var_name]forvar_nameincase.criteria_inputs])ifnotisinstance(result,np.ndarray):raiseTypeError(f"In test with description '{case.description}': ""Criteria return of per-instance test needs to be ndarray with dtype bool.")elifresult.dtype!=np.dtype("bool"):raiseTypeError(f"In test with description '{case.description}': ""Criteria return of per-instance test needs to be ndarray with dtype bool.")result=result.reshape(-1)case.result.append(result)ifself.data_id:data_id=to_number(data[self.data_id]).reshape((-1,))ifdata_id.size!=result.size:raiseValueError(f"In test with description '{case.description}': ""Array size of criteria return doesn't match ID array size. Size of criteria""return should be equal to the batch_size such that each entry represents the test""result of its corresponding data instance.")case.fail_id.append(data_id[result==False])defon_epoch_end(self,data:Data)->None:forcaseinself.aggregate_cases:result=case.criteria(*[data[var_name]forvar_nameincase.criteria_inputs])ifnotisinstance(result,(bool,np.bool_)):raiseTypeError(f"In test with description '{case.description}': ""Criteria return of aggregate-case test needs to be a bool.")case.result=case.criteria(*[data[var_name]forvar_nameincase.criteria_inputs])case.input_val={var_name:self._to_serializable(data[var_name])forvar_nameincase.criteria_inputs}defon_end(self,data:Data)->None:forcaseinself.instance_cases:case_dict={"test_type":"per-instance","description":case.description}result=np.hstack(case.result)fail_num=np.sum(result==False)case_dict["passed"]=self._to_serializable(fail_num<=case.fail_threshold)case_dict["fail_threshold"]=case.fail_thresholdcase_dict["fail_number"]=self._to_serializable(fail_num)ifself.data_id:fail_id=np.hstack(case.fail_id)case_dict["fail_id"]=self._to_serializable(fail_id)self.json_summary["tests"].append(case_dict)forcaseinself.aggregate_cases:case_dict={"test_type":"aggregate","description":case.description,"passed":self._to_serializable(case.result),"inputs":case.input_val}self.json_summary["tests"].append(case_dict)self.json_summary["execution_time(s)"]=time()-self.json_summary["execution_time(s)"]self._dump_json()self._init_document()self._write_body_content()self._dump_pdf()def_initialize_json_summary(self)->None:"""Initialize json summary. """self.json_summary={"title":self.test_title,"timestamp":str(datetime.now()),"execution_time(s)":time(),"tests":[]}def_sanitize_report_name(self)->None:"""Sanitize report name and make it class attribute. Raises: RuntimeError: If a test title was not provided and the user did not set an experiment name. """exp_name=self.system.summary.nameorself.test_titleifnotexp_name:raiseRuntimeError("TestReport requires an experiment name to be provided in estimator.fit(), or a title")# Convert the experiment name to a report name (useful for saving multiple experiments into same directory)report_name="".join('_'ifc==' 'elsecforcinexp_nameifc.isalnum()orcin(' ','_')).rstrip("_").lower()self.report_name=re.sub('_{2,}','_',report_name)+"_TestReport"ifself.test_titleisNone:self.test_title=exp_namedef_init_document(self)->None:"""Initialize latex document. """self.doc=self._init_document_geometry()self.doc.packages.append(Package(name='placeins',options=['section']))self.doc.packages.append(Package(name='float'))self.doc.packages.append(Package(name='hyperref',options='hidelinks'))self.doc.preamble.append(NoEscape(r'\aboverulesep=0ex'))self.doc.preamble.append(NoEscape(r'\belowrulesep=0ex'))self.doc.preamble.append(NoEscape(r'\renewcommand{\arraystretch}{1.2}'))# new column type for tabularxself.doc.preamble.append(NoEscape(r'\newcolumntype{Y}{>{\centering\arraybackslash}X}'))self._write_title()self._write_toc()def_write_title(self)->None:"""Write the title content of the file. Override if you want to build on top of base traceability report. """self.doc.preamble.append(Command('title',self.json_summary["title"]))self.doc.preamble.append(Command('author',f"FastEstimator {fe.__version__}"))self.doc.preamble.append(Command('date',NoEscape(r'\today')))self.doc.append(NoEscape(r'\maketitle'))def_write_toc(self)->None:"""Write the table of contents. Override if you want to build on top of base traceability report. """self.doc.append(NoEscape(r'\tableofcontents'))self.doc.append(NoEscape(r'\newpage'))def_write_body_content(self)->None:"""Write the main content of the file. Override if you want to build on top of base traceability report. """self._document_test_result()def_document_test_result(self)->None:"""Document test results including test summary, passed tests, and failed tests. """self.test_id=1instance_pass_tests,aggregate_pass_tests,instance_fail_tests,aggregate_fail_tests=[],[],[],[]fortestinself.json_summary["tests"]:iftest["test_type"]=="per-instance"andtest["passed"]:instance_pass_tests.append(test)eliftest["test_type"]=="per-instance"andnottest["passed"]:instance_fail_tests.append(test)eliftest["test_type"]=="aggregate"andtest["passed"]:aggregate_pass_tests.append(test)eliftest["test_type"]=="aggregate"andnottest["passed"]:aggregate_fail_tests.append(test)withself.doc.create(Section("Test Summary")):withself.doc.create(Itemize())asitemize:itemize.add_item(escape_latex("Execution time: {:.2f} seconds".format(self.json_summary['execution_time(s)'])))withself.doc.create(Table(position='H'))astable:table.append(NoEscape(r'\refstepcounter{table}'))self._document_summary_table(pass_num=len(instance_pass_tests)+len(aggregate_pass_tests),fail_num=len(instance_fail_tests)+len(aggregate_fail_tests))ifinstance_fail_testsoraggregate_fail_tests:withself.doc.create(Section("Failed Tests")):iflen(aggregate_fail_tests)>0:withself.doc.create(Subsection("Failed Aggregate Tests")):self._document_aggregate_table(tests=aggregate_fail_tests)iflen(instance_fail_tests)>0:withself.doc.create(Subsection("Failed Per-Instance Tests")):self._document_instance_table(tests=instance_fail_tests,with_id=bool(self.data_id))ifinstance_pass_testsoraggregate_pass_tests:withself.doc.create(Section("Passed Tests")):ifaggregate_pass_tests:withself.doc.create(Subsection("Passed Aggregate Tests")):self._document_aggregate_table(tests=aggregate_pass_tests)ifinstance_pass_tests:withself.doc.create(Subsection("Passed Per-Instance Tests")):self._document_instance_table(tests=instance_pass_tests,with_id=bool(self.data_id))self.doc.append(NoEscape(r'\newpage'))# For QMS reportdef_document_summary_table(self,pass_num:int,fail_num:int)->None:"""Document a summary table. Args: pass_num: Total number of passed tests. fail_num: Total number of failed tests. """withself.doc.create(Tabularx('|Y|Y|Y|',booktabs=True))astabular:package=Package('seqsplit')ifpackagenotintabular.packages:tabular.packages.append(package)# add table headingtabular.add_row(("Total Tests","Total Passed ","Total Failed"),strict=False)tabular.add_hline()tabular.add_row((pass_num+fail_num,pass_num,fail_num),strict=False)def_document_instance_table(self,tests:List[Dict[str,Any]],with_id:bool):"""Document a result table of per-instance tests. Args: tests: List of corresponding test dictionary to make a table. with_id: Whether the test information includes data ID. """ifwith_id:table_spec='|c|p{5cm}|c|c|p{5cm}|'column_num=5else:table_spec='|c|p{10cm}|c|c|'column_num=4withself.doc.create(LongTable(table_spec,pos=['h!'],booktabs=True))astabular:package=Package('seqsplit')ifpackagenotintabular.packages:tabular.packages.append(package)# add table headingrow_cells=[MultiColumn(size=1,align='|c|',data="Test ID"),MultiColumn(size=1,align='c|',data="Test Description"),MultiColumn(size=1,align='c|',data="Pass Threshold"),MultiColumn(size=1,align='c|',data="Failure Count")]ifwith_id:row_cells.append(MultiColumn(size=1,align='c|',data="Failure Data Instance ID"))tabular.add_row(row_cells)# add table header and footertabular.add_hline()tabular.end_table_header()tabular.add_hline()tabular.add_row((MultiColumn(column_num,align='r',data='Continued on Next Page'),))tabular.add_hline()tabular.end_table_footer()tabular.end_table_last_footer()foridx,testinenumerate(tests):ifidx>0:tabular.add_hline()des_data=[WrapText(data=x,threshold=27)forxintest["description"].split(" ")]row_cells=[self.test_id,IterJoin(data=des_data,token=" "),NoEscape(r'$\le $'+str(test["fail_threshold"])),test["fail_number"]]ifwith_id:id_data=[WrapText(data=x,threshold=27)forxintest["fail_id"]]row_cells.append(IterJoin(data=id_data,token=", "))tabular.add_row(row_cells)self.test_id+=1def_document_aggregate_table(self,tests:List[Dict[str,Any]])->None:"""Document a result table of aggregate tests. Args: tests: List of corresponding test dictionary to make a table. """withself.doc.create(LongTable('|c|p{8cm}|p{7.3cm}|',booktabs=True))astabular:package=Package('seqsplit')ifpackagenotintabular.packages:tabular.packages.append(package)# add table headingtabular.add_row((MultiColumn(size=1,align='|c|',data="Test ID"),MultiColumn(size=1,align='c|',data="Test Description"),MultiColumn(size=1,align='c|',data="Input Value")))# add table header and footertabular.add_hline()tabular.end_table_header()tabular.add_hline()tabular.add_row((MultiColumn(3,align='r',data='Continued on Next Page'),))tabular.add_hline()tabular.end_table_footer()tabular.end_table_last_footer()foridx,testinenumerate(tests):ifidx>0:tabular.add_hline()inp_data=[f"{arg}={self.sanitize_value(value)}"forarg,valueintest["inputs"].items()]inp_data=[WrapText(data=x,threshold=27)forxininp_data]des_data=[WrapText(data=x,threshold=27)forxintest["description"].split(" ")]row_cells=[self.test_id,IterJoin(data=des_data,token=" "),IterJoin(data=inp_data,token=escape_latex(", \n")),]tabular.add_row(row_cells)self.test_id+=1def_dump_pdf(self)->None:"""Dump PDF summary report. """ifshutil.which("latexmk")isNoneandshutil.which("pdflatex")isNone:# No LaTeX Compiler is availableself.doc.generate_tex(os.path.join(self.save_dir,self.report_name))suffix='.tex'else:# Force a double-compile since some compilers will struggle with TOC generationself.doc.generate_pdf(os.path.join(self.save_dir,self.report_name),clean_tex=False,clean=False)self.doc.generate_pdf(os.path.join(self.save_dir,self.report_name),clean_tex=False)suffix='.pdf'print("FastEstimator-TestReport: Report written to {}{}".format(os.path.join(self.save_dir,self.report_name),suffix))def_dump_json(self)->None:"""Dump JSON file. """json_path=os.path.join(self.resource_dir,self.report_name+".json")withopen(json_path,'w')asfp:json.dump(self.json_summary,fp,indent=4)@staticmethoddef_to_serializable(obj:Any)->Union[float,int,list]:"""Convert to JSON serializable type. Args: obj: Any object that needs to be converted. Return: JSON serializable object that essentially is equivalent to input obj. """ifisinstance(obj,np.ndarray):ifobj.size>0:shape=obj.shapeobj=obj.reshape((-1,))obj=np.vectorize(TestReport._element_to_serializable)(obj)obj=obj.reshape(shape)obj=obj.tolist()else:obj=TestReport._element_to_serializable(obj)returnobj@staticmethoddef_element_to_serializable(obj:Any)->Any:"""Convert to JSON serializable type. This function can handle any object type except ndarray. Args: obj: Any object except ndarray that needs to be converted. Return: JSON serializable object that essentially is equivalent to input obj. """ifisinstance(obj,bytes):obj=obj.decode('utf-8')elifisinstance(obj,np.generic):obj=obj.item()returnobj@staticmethoddefcheck_pdf_dependency()->None:"""Check dependency of PDF-generating packages. Raises: OSError: Some required package has not been installed. """# Verify that the system locale is functioning correctlytry:locale.getlocale()exceptValueError:raiseOSError("Your system locale is not configured correctly. On mac this can be resolved by adding \ 'export LC_ALL=en_US.UTF-8' and 'export LANG=en_US.UTF-8' to your ~/.bash_profile")@staticmethoddefsanitize_value(value:Union[int,float])->str:"""Sanitize input value for a better report display. Args: value: Value to be sanitized. Returns: Sanitized string of `value`. """if1000>value>=0.001:returnf"{value:.3f}"else:returnf"{value:.3e}"@staticmethoddef_init_document_geometry()->Document:"""Init geometry setting of the document. Return: Initialized Document object. """returnDocument(geometry_options=['lmargin=2cm','rmargin=2cm','bmargin=2cm'])
@staticmethoddefcheck_pdf_dependency()->None:"""Check dependency of PDF-generating packages. Raises: OSError: Some required package has not been installed. """# Verify that the system locale is functioning correctlytry:locale.getlocale()exceptValueError:raiseOSError("Your system locale is not configured correctly. On mac this can be resolved by adding \ 'export LC_ALL=en_US.UTF-8' and 'export LANG=en_US.UTF-8' to your ~/.bash_profile")
@staticmethoddefsanitize_value(value:Union[int,float])->str:"""Sanitize input value for a better report display. Args: value: Value to be sanitized. Returns: Sanitized string of `value`. """if1000>value>=0.001:returnf"{value:.3f}"else:returnf"{value:.3e}"