Estimator is the highest level class within FastEstimator. It is the class which is invoked to actually train
(estimator.fit) or test (estimator.test) models. It wraps Pipeline, Network, Trace objects together and
defines the whole optimization process.
Parameters:
Name
Type
Description
Default
pipeline
Pipeline
An fe.Pipeline object that defines the data processing workflow.
required
network
BaseNetwork
An fe.Network object that contains models and other training graph definitions.
required
epochs
int
The number of epochs to run.
required
max_train_steps_per_epoch
Optional[int]
Training will complete after n steps even if loader is not yet exhausted. If None,
all data will be used.
None
max_eval_steps_per_epoch
Optional[int]
Evaluation will complete after n steps even if loader is not yet exhausted. If None,
all data will be used.
What Traces to run during training. If None, only the system's default Traces will be included.
None
log_steps
Optional[int]
Frequency (in steps) for printing log messages. 0 to disable all step-based printing (though epoch
information will still print). None to completely disable printing.
100
monitor_names
Union[None, str, Iterable[str]]
Additional keys from the data dictionary to be written into the logs.
None
Source code in fastestimator\fastestimator\estimator.py
classEstimator:"""One class to rule them all. Estimator is the highest level class within FastEstimator. It is the class which is invoked to actually train (estimator.fit) or test (estimator.test) models. It wraps `Pipeline`, `Network`, `Trace` objects together and defines the whole optimization process. Args: pipeline: An fe.Pipeline object that defines the data processing workflow. network: An fe.Network object that contains models and other training graph definitions. epochs: The number of epochs to run. max_train_steps_per_epoch: Training will complete after n steps even if loader is not yet exhausted. If None, all data will be used. max_eval_steps_per_epoch: Evaluation will complete after n steps even if loader is not yet exhausted. If None, all data will be used. traces: What Traces to run during training. If None, only the system's default Traces will be included. log_steps: Frequency (in steps) for printing log messages. 0 to disable all step-based printing (though epoch information will still print). None to completely disable printing. monitor_names: Additional keys from the data dictionary to be written into the logs. """pipeline:Pipelinetraces:List[Union[Trace,Scheduler[Trace]]]monitor_names:Set[str]def__init__(self,pipeline:Pipeline,network:BaseNetwork,epochs:int,max_train_steps_per_epoch:Optional[int]=None,max_eval_steps_per_epoch:Optional[int]=None,traces:Union[None,Trace,Scheduler[Trace],Iterable[Union[Trace,Scheduler[Trace]]]]=None,log_steps:Optional[int]=100,monitor_names:Union[None,str,Iterable[str]]=None):self.pipeline=pipelineself.network=networkself.traces=to_list(traces)self.traces_in_use=Noneassertlog_stepsisNoneorlog_steps>=0, \
"log_steps must be None or positive (or 0 to disable only train logging)"self.monitor_names=to_set(monitor_names)|self.network.get_loss_keys()self.system=System(network=network,log_steps=log_steps,total_epochs=epochs,max_train_steps_per_epoch=max_train_steps_per_epoch,max_eval_steps_per_epoch=max_eval_steps_per_epoch)deffit(self,summary:Optional[str]=None,warmup:bool=True)->Optional[Summary]:"""Train the network for the number of epochs specified by the estimator's constructor. Args: summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as a summary object at the end of training. warmup: Whether to perform warmup before training begins. The warmup procedure will test one step at every epoch where schedulers cause the execution graph to change. This can take some time up front, but can also save significant heartache on epoch 300 when the training unexpectedly fails due to a tensor size mismatch. Returns: A summary object containing the training history for this session iff a `summary` name was provided. """draw()self.system.reset(summary)self._prepare_traces(run_modes={"train","eval"})ifwarmup:self._warmup()self._start(run_modes={"train","eval"})returnself.system.summaryorNonedef_prepare_traces(self,run_modes:Set[str])->None:"""Prepare information about the traces for training. Add default traces into the traces_in_use list, also prints a warning if no model saver trace is detected. Args: run_modes: The current execution modes. """self.traces_in_use=[tracefortraceinself.traces]ifself.system.log_stepsisnotNone:self.traces_in_use.append(Logger())if"train"inrun_modes:self.traces_in_use.insert(0,TrainEssential(monitor_names=self.monitor_names))no_save_warning=Truefortraceinget_current_items(self.traces_in_use,run_modes=run_modes):ifisinstance(trace,(ModelSaver,BestModelSaver)):no_save_warning=Falseifno_save_warning:print("FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.")if"eval"inrun_modesand"eval"inself.pipeline.get_modes():self.traces_in_use.insert(1,EvalEssential(monitor_names=self.monitor_names))# insert system instance to tracefortraceinget_current_items(self.traces_in_use,run_modes=run_modes):trace.system=self.systemdeftest(self,summary:Optional[str]=None)->Optional[Summary]:"""Run the pipeline / network in test mode for one epoch. Args: summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as a summary object at the end of training. If None, the default value will be whatever `summary` name was most recently provided to this Estimator's .fit() or .test() methods. Returns: A summary object containing the training history for this session iff the `summary` name is not None (after considering the default behavior above). """self.system.reset_for_test(summary)self._prepare_traces(run_modes={"test"})self._start(run_modes={"test"})returnself.system.summaryorNone@staticmethoddef_sort_traces(traces:List[Trace],available_outputs:Optional[Set[str]]=None)->List[Trace]:"""Sort traces to attempt to resolve any dependency issues. This is essentially a topological sort, but it doesn't seem worthwhile to convert the data into a graph representation in order to get the slightly better asymptotic runtime complexity. Args: traces: A list of traces (not inside schedulers) to be sorted. available_outputs: What output keys are already available for the traces to use. If None are provided, the sorting algorithm will assume that any keys not generated by traces are being provided by the system. This results in a less rigorous sorting. Returns: The sorted list of `traces`. Raises: AssertionError: If Traces have circular dependencies or require input keys which are not available. """sorted_traces=[]trace_outputs={outputfortraceintracesforoutputintrace.outputs}ifavailable_outputsisNone:# Assume that anything not generated by a Trace is provided by the systemavailable_outputs={inpfortraceintracesforinpintrace.inputs}-trace_outputsweak_sort=Trueelse:available_outputs=to_set(available_outputs)weak_sort=Falseend_traces=deque()intermediate_traces=deque()intermediate_outputs=set()trace_deque=deque(traces)whiletrace_deque:trace=trace_deque.popleft()ins=set(trace.inputs)outs=set(trace.outputs)ifnotinsorisinstance(trace,(TrainEssential,EvalEssential)):sorted_traces.append(trace)available_outputs|=outselif"*"inins:ifouts:end_traces.appendleft(trace)else:end_traces.append(trace)elifins<=available_outputsor(weak_sortand(ins-outs-available_outputs).isdisjoint(trace_outputs)):sorted_traces.append(trace)available_outputs|=outselse:intermediate_traces.append(trace)intermediate_outputs|=outsalready_seen=set()whileintermediate_traces:trace=intermediate_traces.popleft()ins=set(trace.inputs)outs=set(trace.outputs)already_seen.add(trace)ifins<=available_outputsor(weak_sortand(ins-outs-available_outputs).isdisjoint(trace_outputs)):sorted_traces.append(trace)available_outputs|=outsalready_seen.clear()elifins<=(available_outputs|intermediate_outputs):intermediate_traces.append(trace)else:raiseAssertionError("The {} trace has unsatisfiable inputs: {}".format(type(trace).__name__,", ".join(ins-(available_outputs|intermediate_outputs))))ifintermediate_tracesandlen(already_seen)==len(intermediate_traces):raiseAssertionError("Dependency cycle detected amongst traces: {}".format(", ".join([type(tr).__name__fortrinalready_seen])))sorted_traces.extend(list(end_traces))returnsorted_tracesdef_warmup(self)->None:"""Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later. Traces are not executed in the warmup since they are likely to contain state variables which could become corrupted by running extra steps. """all_traces=get_current_items(self.traces_in_use,run_modes={"train","eval"})self._sort_traces(all_traces)monitor_names=self.monitor_namesformodeinself.pipeline.get_modes()-{"test"}:scheduled_items=self.pipeline.get_scheduled_items(mode)+self.network.get_scheduled_items(mode)+self.get_scheduled_items(mode)signature_epochs=get_signature_epochs(scheduled_items,self.system.total_epochs,mode=mode)epochs_with_data=self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs,mode=mode)forepochinsignature_epochs:ifepochnotinepochs_with_data:continue# key checkingloader=self._configure_loader(self.pipeline.get_loader(mode,epoch))withSuppressor():ifisinstance(loader,tf.data.Dataset):batch=list(loader.take(1))[0]else:batch=next(iter(loader))batch=self._configure_tensor(loader,batch)assertisinstance(batch,dict),"please make sure data output format is dictionary"pipeline_output_keys=to_set(batch.keys())network_output_keys=self.network.get_all_output_keys(mode,epoch)trace_input_keys=set()trace_output_keys={"*"}traces=get_current_items(self.traces_in_use,run_modes=mode,epoch=epoch)foridx,traceinenumerate(traces):ifidx>0:# ignore TrainEssential and EvalEssential's inputs for unmet requirement checkingtrace_input_keys.update(trace.inputs)trace_output_keys.update(trace.outputs)monitor_names=monitor_names-(pipeline_output_keys|network_output_keys)unmet_requirements=trace_input_keys-(pipeline_output_keys|network_output_keys|trace_output_keys)assertnotunmet_requirements, \
"found missing key(s) during epoch {} mode {}: {}".format(epoch,mode,unmet_requirements)self._sort_traces(traces,available_outputs=pipeline_output_keys|network_output_keys)trace_input_keys.update(traces[0].inputs)self.network.load_epoch(mode,epoch,output_keys=trace_input_keys,warmup=True)self.network.run_step(batch)self.network.unload_epoch()assertnotmonitor_names,"found missing key(s): {}".format(monitor_names)defget_scheduled_items(self,mode:str)->List[Any]:"""Get a list of items considered for scheduling. Args: mode: Current execution mode. Returns: List of schedulable items in estimator. """returnself.traces_in_usedef_start(self,run_modes:Set[str])->None:"""The outer training loop. This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes the trace on_end method. Args: run_modes: The current execution modes. """all_traces=get_current_items(self.traces_in_use,run_modes=run_modes)self._sort_traces(all_traces)self._run_traces_on_begin(traces=all_traces)try:if"train"inrun_modesor"eval"inrun_modes:forself.system.epoch_idxinrange(self.system.epoch_idx+1,self.system.total_epochs+1):if"train"inself.pipeline.get_modes(epoch=self.system.epoch_idx):self.system.mode="train"self._run_epoch()if"eval"inself.pipeline.get_modes(epoch=self.system.epoch_idx):self.system.mode="eval"self._run_epoch()else:self._run_epoch()exceptEarlyStop:pass# On early stopping we still want to run the final traces and return resultsself._run_traces_on_end(traces=all_traces)def_run_epoch(self)->None:"""A method to perform an epoch of activity. This method requires that the current mode and epoch already be specified within the self.system object. """traces=get_current_items(self.traces_in_use,run_modes=self.system.mode,epoch=self.system.epoch_idx)trace_input_keys=set()fortraceintraces:trace_input_keys.update(trace.inputs)loader=self._configure_loader(self.pipeline.get_loader(self.system.mode,self.system.epoch_idx))iterator=iter(loader)self.network.load_epoch(mode=self.system.mode,epoch=self.system.epoch_idx,output_keys=trace_input_keys)self.system.batch_idx=NonewithSuppressor():batch=next(iterator)traces=self._sort_traces(traces,available_outputs=to_set(batch.keys())|self.network.get_all_output_keys(self.system.mode,self.system.epoch_idx))self._run_traces_on_epoch_begin(traces=traces)whileTrue:try:ifself.system.mode=="train":self.system.update_global_step()self.system.update_batch_idx()self._run_traces_on_batch_begin(traces=traces)batch=self._configure_tensor(loader,batch)batch,prediction=self.network.run_step(batch)self._run_traces_on_batch_end(batch,prediction,traces=traces)ifisinstance(loader,DataLoader)and((self.system.batch_idx==self.system.max_train_steps_per_epochandself.system.mode=="train")or(self.system.batch_idx==self.system.max_eval_steps_per_epochandself.system.mode=="eval")):raiseStopIterationwithSuppressor():batch=next(iterator)exceptStopIteration:breakself._run_traces_on_epoch_end(traces=traces)self.network.unload_epoch()def_configure_loader(self,loader:Union[DataLoader,tf.data.Dataset])->Union[DataLoader,tf.data.Dataset]:"""A method to configure a given dataloader for use with this Estimator's Network. This method will ensure that the `loader` returns the correct data type (tf.Tensor or torch.Tensor) depending on the requirements of the Network. It also handles issues with multi-gpu data sharding. Args: loader: A data loader to be modified. Returns: The potentially modified dataloader to be used for training. """new_loader=loaderifisinstance(new_loader,DataLoader)andisinstance(self.network,TFNetwork):add_batch=Trueifhasattr(loader.dataset,"dataset")andisinstance(loader.dataset.dataset,BatchDataset):add_batch=Falsebatch=to_tensor(loader.dataset[0],target_type="tf")data_type=to_type(batch)data_shape=to_shape(batch,add_batch=add_batch,exact_shape=False)new_loader=tf.data.Dataset.from_generator(lambda:loader,data_type,output_shapes=data_shape)new_loader=new_loader.prefetch(1)ifisinstance(new_loader,tf.data.Dataset):ifself.system.max_train_steps_per_epochandself.system.mode=="train":new_loader=new_loader.take(self.system.max_train_steps_per_epoch)ifself.system.max_eval_steps_per_epochandself.system.mode=="eval":new_loader=new_loader.take(self.system.max_eval_steps_per_epoch)ifisinstance(tf.distribute.get_strategy(),tf.distribute.MirroredStrategy)andnotisinstance(new_loader,DistributedDataset):new_loader=tf.distribute.get_strategy().experimental_distribute_dataset(new_loader)returnnew_loaderdef_configure_tensor(self,loader:Union[DataLoader,tf.data.Dataset],batch:Dict[str,Any])->Dict[str,Any]:"""A function to convert a batch of tf.Tensors to torch.Tensors if required. Returns: Either the original `batch`, or the `batch` converted to torch.Tensors if required. """ifisinstance(loader,tf.data.Dataset)andisinstance(self.network,TorchNetwork):batch=to_tensor(batch,target_type="torch")returnbatchdef_run_traces_on_begin(self,traces:Iterable[Trace])->None:"""Invoke the on_begin methods of given traces. Args: traces: List of traces. """data=Data()fortraceintraces:trace.on_begin(data)self._check_early_exit()def_run_traces_on_epoch_begin(self,traces:Iterable[Trace])->None:"""Invoke the on_epoch_begin methods of given traces. Args: traces: List of traces. """data=Data()fortraceintraces:trace.on_epoch_begin(data)self._check_early_exit()def_run_traces_on_batch_begin(self,traces:Iterable[Trace])->None:"""Invoke the on_batch_begin methods of given traces. Args: traces: List of traces. """data=Data()fortraceintraces:trace.on_batch_begin(data)self._check_early_exit()def_run_traces_on_batch_end(self,batch:Dict[str,Any],prediction:Dict[str,Any],traces:Iterable[Trace])->None:"""Invoke the on_batch_end methods of given traces. Args: batch: The batch data which was provided by the pipeline. prediction: The prediction data which was generated by the network. traces: List of traces. """data=Data(ChainMap(prediction,batch))fortraceintraces:trace.on_batch_end(data)self._check_early_exit()def_run_traces_on_epoch_end(self,traces:Iterable[Trace])->None:"""Invoke the on_epoch_end methods of of given traces. Args: traces: List of traces. """data=Data()fortraceintraces:trace.on_epoch_end(data)self._check_early_exit()@staticmethoddef_run_traces_on_end(traces:Iterable[Trace])->None:"""Invoke the on_end methods of given traces. Args: traces: List of traces. """data=Data()fortraceintraces:trace.on_end(data)def_check_early_exit(self)->None:"""Determine whether training should be prematurely aborted. Raises: EarlyStop: If the system.stop_training flag has been set to True. """ifself.system.stop_training:raiseEarlyStop
Train the network for the number of epochs specified by the estimator's constructor.
Parameters:
Name
Type
Description
Default
summary
Optional[str]
A name for the experiment. If provided, the log history will be recorded in-memory and returned as
a summary object at the end of training.
None
warmup
bool
Whether to perform warmup before training begins. The warmup procedure will test one step at every
epoch where schedulers cause the execution graph to change. This can take some time up front, but can
also save significant heartache on epoch 300 when the training unexpectedly fails due to a tensor size
mismatch.
True
Returns:
Type
Description
Optional[Summary]
A summary object containing the training history for this session iff a summary name was provided.
Source code in fastestimator\fastestimator\estimator.py
deffit(self,summary:Optional[str]=None,warmup:bool=True)->Optional[Summary]:"""Train the network for the number of epochs specified by the estimator's constructor. Args: summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as a summary object at the end of training. warmup: Whether to perform warmup before training begins. The warmup procedure will test one step at every epoch where schedulers cause the execution graph to change. This can take some time up front, but can also save significant heartache on epoch 300 when the training unexpectedly fails due to a tensor size mismatch. Returns: A summary object containing the training history for this session iff a `summary` name was provided. """draw()self.system.reset(summary)self._prepare_traces(run_modes={"train","eval"})ifwarmup:self._warmup()self._start(run_modes={"train","eval"})returnself.system.summaryorNone
defget_scheduled_items(self,mode:str)->List[Any]:"""Get a list of items considered for scheduling. Args: mode: Current execution mode. Returns: List of schedulable items in estimator. """returnself.traces_in_use
Run the pipeline / network in test mode for one epoch.
Parameters:
Name
Type
Description
Default
summary
Optional[str]
A name for the experiment. If provided, the log history will be recorded in-memory and returned as
a summary object at the end of training. If None, the default value will be whatever summary name was
most recently provided to this Estimator's .fit() or .test() methods.
None
Returns:
Type
Description
Optional[Summary]
A summary object containing the training history for this session iff the summary name is not None (after
Optional[Summary]
considering the default behavior above).
Source code in fastestimator\fastestimator\estimator.py
deftest(self,summary:Optional[str]=None)->Optional[Summary]:"""Run the pipeline / network in test mode for one epoch. Args: summary: A name for the experiment. If provided, the log history will be recorded in-memory and returned as a summary object at the end of training. If None, the default value will be whatever `summary` name was most recently provided to this Estimator's .fit() or .test() methods. Returns: A summary object containing the training history for this session iff the `summary` name is not None (after considering the default behavior above). """self.system.reset_for_test(summary)self._prepare_traces(run_modes={"test"})self._start(run_modes={"test"})returnself.system.summaryorNone