Bases: Trace
End Training if a NaN value is detected.
By default (monitor_names=None) it will monitor all loss values at the end of each batch. If one or more inputs are
specified, it will only monitor those values. Inputs may be loss keys and/or the keys corresponding to the outputs
of other traces (ex. accuracy).
Parameters:
Name |
Type |
Description |
Default |
monitor_names |
Union[None, str, Iterable[str]]
|
key(s) to monitor for NaN values. If None, all loss values will be monitored. "*" will monitor
all trace output keys and losses. |
None
|
mode |
Union[None, str, Set[str]]
|
What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
like "!infer" or "!train". |
None
|
Source code in fastestimator\fastestimator\trace\adapt\terminate_on_nan.py
| @traceable()
class TerminateOnNaN(Trace):
"""End Training if a NaN value is detected.
By default (monitor_names=None) it will monitor all loss values at the end of each batch. If one or more inputs are
specified, it will only monitor those values. Inputs may be loss keys and/or the keys corresponding to the outputs
of other traces (ex. accuracy).
Args:
monitor_names: key(s) to monitor for NaN values. If None, all loss values will be monitored. "*" will monitor
all trace output keys and losses.
mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
like "!infer" or "!train".
"""
def __init__(self, monitor_names: Union[None, str, Iterable[str]] = None, mode: Union[None, str,
Set[str]] = None) -> None:
super().__init__(inputs=monitor_names, mode=mode)
self.monitor_keys = {}
self.in_list = True
def on_epoch_begin(self, data: Data) -> None:
if not self.inputs:
self.monitor_keys = self.system.network.get_loss_keys()
elif "*" in self.inputs:
self.monitor_keys = self.system.network.get_loss_keys()
for trace in get_current_items(self.system.traces, run_modes=self.system.mode, epoch=self.system.epoch_idx):
self.monitor_keys.update(trace.outputs)
else:
self.monitor_keys = self.inputs
def on_batch_end(self, data: Data) -> None:
for key in self.monitor_keys:
if key in data:
if check_nan(data[key]):
self.system.stop_training = True
print("FastEstimator-TerminateOnNaN: NaN Detected in: {}".format(key))
def on_epoch_end(self, data: Data) -> None:
for key in self.monitor_keys:
if key in data:
if check_nan(data[key]):
self.system.stop_training = True
print("FastEstimator-TerminateOnNaN: NaN Detected in: {}".format(key))
|