def per_ds(clz: type(Trace)):
"""A class annotation which will convert regular traces into dataset-sensitive traces.
Args:
clz: The base class to be converted.
Returns:
A dataset aware version of the class. Note that if the annotated class instance has a 'per_ds' member variable
which is set to False, or has outputs containing the '|' character, then a normal (non-ds-aware) instance will
be returned instead.
"""
class PerDS(clz, PerDSTrace):
@functools.wraps(clz.__new__)
def __new__(cls, *args, **kwargs):
# We will dynamically determine whether to return a base object or a PerDS variant
# If any of the outputs already use the | character then we cannot make this a PerDS variant
base_obj = clz.__new__(clz)
base_obj.__init__(*args, **kwargs)
for output in base_obj.outputs:
if '|' in output:
return base_obj
# If the user set per_ds to False in the constructor then we will not make this a PerDS variant
if hasattr(base_obj, 'per_ds') and base_obj.per_ds is False:
return base_obj
# Otherwise we are good to go with the PerDS variant
return super().__new__(cls)
@functools.wraps(clz.__init__)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fe_per_ds_trace = clz.__new__(clz)
self.fe_per_ds_trace.__init__(*args, **kwargs)
def get_outputs(self, ds_ids: Union[None, str, List[str]]) -> List[str]:
ds_ids = to_list(ds_ids)
outputs = list(self.outputs)
for output in self.outputs:
for ds_id in ds_ids:
outputs.append(f"{output}|{ds_id}")
return outputs
def on_begin(self, data: Data) -> None:
super().on_begin(data)
self.fe_per_ds_trace.on_begin(data)
def on_ds_begin(self, data: Data) -> None:
if self.system.ds_id != '':
self.fe_per_ds_trace.on_epoch_begin(DSData(self.system.ds_id, data))
def on_batch_begin(self, data: Data) -> None:
super().on_batch_begin(data)
if self.system.ds_id != '':
self.fe_per_ds_trace.on_batch_begin(DSData(self.system.ds_id, data))
def on_batch_end(self, data: Data) -> None:
super().on_batch_end(data)
if self.system.ds_id != '':
self.fe_per_ds_trace.on_batch_end(DSData(self.system.ds_id, data))
def on_ds_end(self, data: Data) -> None:
if self.system.ds_id != '':
self.fe_per_ds_trace.on_epoch_end(DSData(self.system.ds_id, data))
def on_end(self, data: Data) -> None:
super().on_end(data)
self.fe_per_ds_trace.on_end(data)
PerDS.__name__ = clz.__name__
PerDS.__qualname__ = clz.__qualname__
PerDS.__module__ = clz.__module__
PerDS.__doc__ = clz.__doc__ # We want to preserve the docstring of the original class
return PerDS