def plot_logs(experiments: List[Summary],
smooth_factor: float = 0,
share_legend: bool = True,
ignore_metrics: Optional[Set[str]] = None,
pretty_names: bool = False,
include_metrics: Optional[Set[str]] = None) -> plt.Figure:
"""A function which will plot experiment histories for comparison viewing / analysis.
Args:
experiments: Experiment(s) to plot.
smooth_factor: A non-negative float representing the magnitude of gaussian smoothing to apply (zero for none).
share_legend: Whether to have one legend across all graphs (True) or one legend per graph (False).
pretty_names: Whether to modify the metric names in graph titles (True) or leave them alone (False).
ignore_metrics: Any keys to ignore during plotting.
include_metrics: A whitelist of keys to include during plotting. If None then all will be included.
Returns:
The handle of the pyplot figure.
"""
# Sort to keep same colors between multiple runs of visualization
experiments = humansorted(to_list(experiments), lambda exp: exp.name)
n_experiments = len(experiments)
if n_experiments == 0:
return plt.subplots(111)[0]
ignore_keys = ignore_metrics or set()
ignore_keys = to_set(ignore_keys)
ignore_keys |= {'epoch'}
include_keys = to_set(include_metrics)
# TODO: epoch should be indicated on the axis (top x axis?). Problem - different epochs per experiment.
# TODO: figure out how ignore_metrics should interact with mode
ds_ids = set()
metric_histories = defaultdict(_MetricGroup) # metric: MetricGroup
for idx, experiment in enumerate(experiments):
history = experiment.history
# Since python dicts remember insertion order, sort the history so that train mode is always plotted on bottom
for mode, metrics in sorted(history.items(),
key=lambda x: 0 if x[0] == 'train' else 1 if x[0] == 'eval' else 2 if x[0] == 'test'
else 3 if x[0] == 'infer' else 4):
for metric, step_val in metrics.items():
base_metric, ds_id, *_ = f'{metric}|'.split('|') # Plot acc|ds1 and acc|ds2 on same acc graph
if len(step_val) == 0:
continue # Ignore empty metrics
if metric in ignore_keys or base_metric in ignore_keys:
continue
if include_keys and metric not in include_keys:
continue
metric_histories[base_metric].add(idx, mode, ds_id, step_val)
ds_ids.add(ds_id)
metric_list = list(sorted(metric_histories.keys()))
if len(metric_list) == 0:
return plt.subplots(111)[0]
ds_ids = humansorted(ds_ids) # Sort them to have consistent ordering (and thus symbols) between plot runs
n_ds_ids = len(ds_ids) # Each ds_id will have its own set of legend entries, so need to count them
# If sharing legend and there is more than 1 plot, then dedicate subplot(s) for the legend
share_legend = share_legend and (len(metric_list) > 1)
n_legends = math.ceil(n_experiments * n_ds_ids / 4)
n_plots = len(metric_list) + (share_legend * n_legends)
# map the metrics into an n x n grid, then remove any extra columns. Final grid will be n x m with m <= n
n_rows = math.ceil(math.sqrt(n_plots))
n_cols = math.ceil(n_plots / n_rows)
metric_grid_location = {}
nd1_metrics = []
idx = 0
for metric in metric_list:
if metric_histories[metric].ndim() == 1:
# Delay placement of the 1D plots until the end
nd1_metrics.append(metric)
else:
metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
idx += 1
for metric in nd1_metrics:
metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
idx += 1
sns.set_context('paper')
fig, axs = plt.subplots(n_rows, n_cols, sharex='all', figsize=(4 * n_cols, 2.8 * n_rows))
# If only one row, need to re-format the axs object for consistency. Likewise for columns
if n_rows == 1:
axs = [axs]
if n_cols == 1:
axs = [axs]
for metric in metric_grid_location.keys():
axis = axs[metric_grid_location[metric][0]][metric_grid_location[metric][1]]
if metric_histories[metric].ndim() == 1:
axis.grid(linestyle='')
else:
axis.grid(linestyle='--')
axis.ticklabel_format(axis='y', style='sci', scilimits=(-2, 3))
axis.set_title(metric if not pretty_names else prettify_metric_name(metric), fontweight='bold')
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)
axis.tick_params(bottom=False, left=False)
axis.xaxis.set_major_formatter(EngFormatter(sep="")) # Convert 10000 steps to 10k steps
# some of the later rows/columns might be unused or reserved for legends, so disable them
last_row_idx = math.ceil(len(metric_list) / n_cols) - 1
last_column_idx = len(metric_list) - last_row_idx * n_cols - 1
for c in range(n_cols):
if c <= last_column_idx:
axs[last_row_idx][c].set_xlabel('Steps')
axs[last_row_idx][c].xaxis.set_tick_params(which='both', labelbottom=True)
else:
axs[last_row_idx][c].axis('off')
axs[last_row_idx - 1][c].set_xlabel('Steps')
axs[last_row_idx - 1][c].xaxis.set_tick_params(which='both', labelbottom=True)
for r in range(last_row_idx + 1, n_rows):
axs[r][c].axis('off')
# the 1D metrics don't need x axis, so move them up, starting with the last in case multiple rows of them
for metric in reversed(nd1_metrics):
row = metric_grid_location[metric][0]
col = metric_grid_location[metric][1]
axs[row][col].axis('off')
if row > 0:
axs[row - 1][col].set_xlabel('Steps')
axs[row - 1][col].xaxis.set_tick_params(which='both', labelbottom=True)
colors = sns.hls_palette(n_colors=n_experiments, s=0.95) if n_experiments > 10 else sns.color_palette("colorblind")
color_offset = defaultdict(lambda: 0)
# If there is only 1 experiment, we will use alternate colors based on mode
if n_experiments == 1:
color_offset['eval'] = 1
color_offset['test'] = 2
color_offset['infer'] = 3
handles = []
labels = []
# exp_id : {mode: {ds_id: {type: True}}}
has_label = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: False))))
ax_text = defaultdict(lambda: (0.0, 0.9)) # Where to put the text on a given axis
# Set up ds_id markers. The empty ds_id will have no extra marker. After that there are 4 configurations of 3-arm
# marker, followed by asterisks with growing numbers of arms (starting at 4).
ds_id_markers = ['', "1", "2", "3", "4"] + [(ticks, 2, 0) for ticks in range(4, n_ds_ids - 1)]
ds_id_markers = {k: v for k, v in zip(ds_ids, ds_id_markers)}
# Actually do the plotting
for exp_idx, experiment in enumerate(experiments):
for metric, group in metric_histories.items():
axis = axs[metric_grid_location[metric][0]][metric_grid_location[metric][1]]
if group.ndim() == 1:
# Single value
for mode in group.modes(exp_idx):
for ds_id in group.ds_ids(exp_idx, mode):
ds_title = f"{ds_id} " if ds_id else ''
ax_id = id(axis)
prefix = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
axis.text(ax_text[ax_id][0],
ax_text[ax_id][1],
f"{prefix}: {group.get_val(exp_idx, mode, ds_id)}",
color=colors[exp_idx + color_offset[mode]],
transform=axis.transAxes)
ax_text[ax_id] = (ax_text[ax_id][0], ax_text[ax_id][1] - 0.1)
if ax_text[ax_id][1] < 0:
ax_text[ax_id] = (ax_text[ax_id][0] + 0.5, 0.9)
elif group.ndim() == 2:
for mode, dsv in group[exp_idx].items():
for ds_id, data in dsv.items():
ds_title = f"{ds_id} " if ds_id else ''
title = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
if data.shape[0] < 2:
# This particular mode only has a single data point, so draw a shape instead of a line
xy = [data[0][0], data[0][1]]
if mode == 'train':
style = MarkerStyle(marker='o', fillstyle='full')
elif mode == 'eval':
style = MarkerStyle(marker='D', fillstyle='full')
elif mode == 'test':
style = MarkerStyle(marker='s', fillstyle='full')
else:
style = MarkerStyle(marker='d', fillstyle='full')
if isinstance(xy[1], ValWithError):
# We've got error bars
x = xy[0]
y = xy[1]
# Plotting requires positive values for error
y_err = [[max(1e-9, y.y - y.y_min)], [max(1e-9, y.y_max - y.y)]]
axis.errorbar(x=x,
y=y.y,
yerr=y_err,
ecolor=colors[exp_idx + color_offset[mode]],
elinewidth=1.5,
capsize=4.0,
capthick=1.5,
zorder=3) # zorder to put markers on top of line segments
xy[1] = y.y
s = axis.scatter(xy[0],
xy[1],
s=45,
c=[colors[exp_idx + color_offset[mode]]],
label=title,
marker=style,
linewidth=1.0,
edgecolors='black',
zorder=4) # zorder to put markers on top of line segments
if ds_id:
# Overlay the dataset id marker on top of the normal scatter plot marker
s2 = axis.scatter(xy[0],
xy[1],
s=45,
c='white',
label=title,
marker=ds_id_markers[ds_id],
linewidth=1.1,
zorder=5) # zorder to put markers on top of line segments
s = (s, s2)
if not has_label[exp_idx][mode][ds_id]['patch']:
labels.append(title)
handles.append(s)
has_label[exp_idx][mode][ds_id]['patch'] = True
else:
# We can draw a line
y = data[:, 1]
y_min = None
y_max = None
if isinstance(y[0], ValWithError):
y = np.stack(y)
y_min = y[:, 0]
y_max = y[:, 2]
y = y[:, 1]
if smooth_factor != 0:
y_min = gaussian_filter1d(y_min, sigma=smooth_factor)
y_max = gaussian_filter1d(y_max, sigma=smooth_factor)
if smooth_factor != 0:
y = gaussian_filter1d(y, sigma=smooth_factor)
x = data[:, 0]
ln = axis.plot(
x,
y,
color=colors[exp_idx + color_offset[mode]],
label=title,
linewidth=1.5,
linestyle='solid' if mode == 'train' else
'dashed' if mode == 'eval' else 'dotted' if mode == 'test' else 'dashdot',
marker=ds_id_markers[ds_id],
markersize=7,
markeredgewidth=1.5,
markeredgecolor='black',
markevery=0.1)
if not has_label[exp_idx][mode][ds_id]['line']:
labels.append(title)
handles.append(ln[0])
has_label[exp_idx][mode][ds_id]['line'] = True
if y_max is not None and y_min is not None:
axis.fill_between(x.astype(np.float32),
y_max,
y_min,
facecolor=colors[exp_idx + color_offset[mode]],
alpha=0.3,
zorder=-1)
else:
# Some kind of image or matrix. Not implemented yet.
pass
plt.tight_layout()
if labels:
if share_legend:
# Sort the labels
handles = [h for _, h in sorted(zip(labels, handles), key=lambda pair: pair[0])]
labels = sorted(labels)
# Split the labels over multiple legends if there are too many to fit in one axis
elems_per_legend = math.ceil(len(labels) / n_legends)
i = 0
for r in range(last_row_idx, n_rows):
for c in range(last_column_idx + 1 if r == last_row_idx else 0, n_cols):
if len(handles) <= i:
break
axs[r][c].legend(
handles[i:i + elems_per_legend],
labels[i:i + elems_per_legend],
loc='center',
fontsize='large' if elems_per_legend <= 6 else 'medium' if elems_per_legend <= 8 else 'small')
i += elems_per_legend
else:
for i in range(n_rows):
for j in range(n_cols):
if i == last_row_idx and j > last_column_idx:
break
axs[i][j].legend(loc='best', fontsize='small')
return fig