def plot_logs(experiments: List[Summary],
smooth_factor: float = 0,
ignore_metrics: Optional[Set[str]] = None,
pretty_names: bool = False,
include_metrics: Optional[Set[str]] = None) -> FigureFE:
"""A function which will plot experiment histories for comparison viewing / analysis.
Args:
experiments: Experiment(s) to plot.
smooth_factor: A non-negative float representing the magnitude of gaussian smoothing to apply (zero for none).
pretty_names: Whether to modify the metric names in graph titles (True) or leave them alone (False).
ignore_metrics: Any keys to ignore during plotting.
include_metrics: A whitelist of keys to include during plotting. If None then all will be included.
Returns:
The handle of the pyplot figure.
"""
# Sort to keep same colors between multiple runs of visualization
experiments = humansorted(to_list(experiments), lambda exp: exp.name)
n_experiments = len(experiments)
if n_experiments == 0:
return FigureFE.from_figure(make_subplots())
ignore_keys = ignore_metrics or set()
ignore_keys = to_set(ignore_keys)
ignore_keys |= {'epoch'}
include_keys = to_set(include_metrics)
# TODO: epoch should be indicated on the axis (top x axis?). Problem - different epochs per experiment.
# TODO: figure out how ignore_metrics should interact with mode
# TODO: when ds_id switches during training, prevent old id from connecting with new one (break every epoch?)
ds_ids = set()
metric_histories = defaultdict(_MetricGroup) # metric: MetricGroup
for idx, experiment in enumerate(experiments):
history = experiment.history
# Since python dicts remember insertion order, sort the history so that train mode is always plotted on bottom
for mode, metrics in sorted(history.items(),
key=lambda x: 0 if x[0] == 'train' else 1 if x[0] == 'eval' else 2 if x[0] == 'test'
else 3 if x[0] == 'infer' else 4):
for metric, step_val in metrics.items():
base_metric, ds_id, *_ = f'{metric}|'.split('|') # Plot acc|ds1 and acc|ds2 on same acc graph
if len(step_val) == 0:
continue # Ignore empty metrics
if metric in ignore_keys or base_metric in ignore_keys:
continue
# Here we intentionally check against metric and not base_metric. If user wants to display per-ds they
# can specify that in their include list: --include mcc 'mcc|usps'
if include_keys and metric not in include_keys:
continue
metric_histories[base_metric].add(idx, mode, ds_id, step_val)
ds_ids.add(ds_id)
metric_list = list(sorted(metric_histories.keys()))
if len(metric_list) == 0:
return FigureFE.from_figure(make_subplots())
ds_ids = humansorted(ds_ids) # Sort them to have consistent ordering (and thus symbols) between plot runs
n_plots = len(metric_list)
if len(ds_ids) > 9: # 9 b/c None is included
warn("Plotting more than 8 different datasets isn't well supported. Symbols will be reused.")
# Non-Shared legends aren't supported yet. If they get supported then maybe can have that feature here too.
# https://github.com/plotly/plotly.js/issues/5099
# https://github.com/plotly/plotly.js/issues/5098
# map the metrics into an n x n grid, then remove any extra columns. Final grid will be n x m with m <= n
n_rows = math.ceil(math.sqrt(n_plots))
n_cols = math.ceil(n_plots / n_rows)
metric_grid_location = {}
nd1_metrics = []
idx = 0
for metric in metric_list:
if metric_histories[metric].ndim() == 1:
# Delay placement of the 1D plots until the end
nd1_metrics.append(metric)
else:
metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
idx += 1
for metric in nd1_metrics:
metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
idx += 1
titles = [k for k, v in sorted(list(metric_grid_location.items()), key=lambda e: e[1][0] * n_cols + e[1][1])]
if pretty_names:
titles = [prettify_metric_name(title) for title in titles]
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=titles, shared_xaxes='all')
fig.update_layout({
'plot_bgcolor': '#FFF',
'hovermode': 'closest',
'margin': {
't': 50
},
'modebar': {
'add': ['hoverclosest', 'hovercompare'], 'remove': ['select2d', 'lasso2d']
},
'legend': {
'tracegroupgap': 5, 'font': {
'size': 11
}
}
})
# Set x-labels
for idx, metric in enumerate(titles, start=1):
plotly_idx = idx if idx > 1 else ""
x_axis_name = f'xaxis{plotly_idx}'
y_axis_name = f'yaxis{plotly_idx}'
if metric_histories[metric].ndim() > 1:
fig['layout'][x_axis_name]['title'] = 'Steps'
fig['layout'][x_axis_name]['showticklabels'] = True
fig['layout'][x_axis_name]['linecolor'] = "#BCCCDC"
fig['layout'][y_axis_name]['linecolor'] = "#BCCCDC"
else:
# Put blank data onto the axis to instantiate the domain
row, col = metric_grid_location[metric][0], metric_grid_location[metric][1]
fig.add_annotation(text='', showarrow=False, row=row + 1, col=col + 1)
# Hide the axis stuff
fig['layout'][x_axis_name]['showgrid'] = False
fig['layout'][x_axis_name]['zeroline'] = False
fig['layout'][x_axis_name]['visible'] = False
fig['layout'][y_axis_name]['showgrid'] = False
fig['layout'][y_axis_name]['zeroline'] = False
fig['layout'][y_axis_name]['visible'] = False
# If there is only 1 experiment, we will use alternate colors based on mode
color_offset = defaultdict(lambda: 0)
n_colors = n_experiments
if n_experiments == 1:
n_colors = 4
color_offset['eval'] = 1
color_offset['test'] = 2
color_offset['infer'] = 3
colors = get_colors(n_colors=n_colors)
alpha_colors = get_colors(n_colors=n_colors, alpha=0.3)
# exp_id : {mode: {ds_id: {type: True}}}
add_label = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: True))))
# {row: {col: (x, y)}}
ax_text = defaultdict(lambda: defaultdict(lambda: (0.0, 0.9))) # Where to put the text on a given axis
# Set up ds_id markers. The empty ds_id will have no extra marker. After that there are 4 configurations of 3-arm
# marker, followed by 'x', '+', '*', and pound. After that it will just repeat the symbol set.
ds_id_markers = [None, 37, 38, 39, 40, 34, 33, 35, 36] # https://plotly.com/python/marker-style/
ds_id_markers = {k: v for k, v in zip(ds_ids, cycle(ds_id_markers))}
# Plotly doesn't support z-order, so delay insertion until all the plots are figured out:
# https://github.com/plotly/plotly.py/issues/2345
z_order = defaultdict(list) # {order: [(plotly element, row, col), ...]}
# Figure out the legend ordering
legend_order = []
for exp_idx, experiment in enumerate(experiments):
for metric, group in metric_histories.items():
for mode in group.modes(exp_idx):
for ds_id in group.ds_ids(exp_idx, mode):
ds_title = f"{ds_id} " if ds_id else ''
title = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
legend_order.append(title)
legend_order.sort()
legend_order = {legend: order for order, legend in enumerate(legend_order)}
# Actually do the plotting
for exp_idx, experiment in enumerate(experiments):
for metric, group in metric_histories.items():
row, col = metric_grid_location[metric][0], metric_grid_location[metric][1]
if group.ndim() == 1:
# Single value
for mode in group.modes(exp_idx):
for ds_id in group.ds_ids(exp_idx, mode):
ds_title = f"{ds_id} " if ds_id else ''
prefix = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
plotly_idx = row * n_cols + col + 1 if row * n_cols + col + 1 > 1 else ''
fig.add_annotation(text=f"{prefix}: {group.get_val(exp_idx, mode, ds_id)}",
font={'color': colors[exp_idx + color_offset[mode]]},
showarrow=False,
xref=f'x{plotly_idx} domain',
xanchor='left',
x=ax_text[row][col][0],
yref=f'y{plotly_idx} domain',
yanchor='top',
y=ax_text[row][col][1],
exclude_empty_subplots=False)
ax_text[row][col] = (ax_text[row][col][0], ax_text[row][col][1] - 0.1)
if ax_text[row][col][1] < 0:
ax_text[row][col] = (ax_text[row][col][0] + 0.5, 0.9)
elif group.ndim() == 2:
for mode, dsv in group[exp_idx].items():
color = colors[exp_idx + color_offset[mode]]
for ds_id, data in dsv.items():
ds_title = f"{ds_id} " if ds_id else ''
title = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
if data.shape[0] < 2:
x = data[0][0]
y = data[0][1]
y_min = None
y_max = None
if isinstance(y, ValWithError):
y_min = y.y_min
y_max = y.y_max
y = y.y
marker_style = 'circle' if mode == 'train' else 'diamond' if mode == 'eval' \
else 'square' if mode == 'test' else 'hexagram'
limit_data = [(y_max, y_min)] if y_max is not None and y_min is not None else None
tip_text = "%{x}: (%{customdata[1]:.3f}, %{y:.3f}, %{customdata[0]:.3f})" if \
limit_data is not None else "%{x}: %{y:.3f}"
error_y = None if limit_data is None else {
'type': 'data', 'symmetric': False, 'array': [y_max - y], 'arrayminus': [y - y_min]
}
z_order[2].append((go.Scatter(
x=[x],
y=[y],
name=title,
legendgroup=title,
customdata=limit_data,
hovertemplate=tip_text,
mode='markers',
marker={
'color': color,
'size': 12,
'symbol': _symbol_mash(marker_style, ds_id_markers[ds_id]),
'line': {
'width': 1.5, 'color': 'White'
}
},
error_y=error_y,
showlegend=add_label[exp_idx][mode][ds_id]['patch'],
legendrank=legend_order[title]),
row,
col))
add_label[exp_idx][mode][ds_id]['patch'] = False
else:
# We can draw a line
y = data[:, 1]
y_min = None
y_max = None
if isinstance(y[0], ValWithError):
y = np.stack([e.as_tuple() for e in y])
y_min = y[:, 0]
y_max = y[:, 2]
y = y[:, 1]
if smooth_factor != 0:
y_min = gaussian_filter1d(y_min, sigma=smooth_factor)
y_max = gaussian_filter1d(y_max, sigma=smooth_factor)
# TODO - for smoothed lines, plot original data in background but greyed out
if smooth_factor != 0:
y = gaussian_filter1d(y, sigma=smooth_factor)
x = data[:, 0]
linestyle = 'solid' if mode == 'train' else 'dash' if mode == 'eval' else 'dot' if \
mode == 'test' else 'dashdot'
limit_data = [(mx, mn) for mx, mn in zip(y_max, y_min)] if y_max is not None and y_min is \
not None else None
tip_text = "%{x}: (%{customdata[1]:.3f}, %{y:.3f}, %{customdata[0]:.3f})" if \
limit_data is not None else "%{x}: %{y:.3f}"
z_order[1].append((go.Scatter(
x=x,
y=y,
name=title,
legendgroup=title,
mode="lines+markers" if ds_id_markers[ds_id] else 'lines',
marker={
'color': color,
'size': 8,
'line': {
'width': 2, 'color': 'DarkSlateGrey'
},
'maxdisplayed': 10,
'symbol': ds_id_markers[ds_id]
},
line={
'dash': linestyle, 'color': color
},
customdata=limit_data,
hovertemplate=tip_text,
showlegend=add_label[exp_idx][mode][ds_id]['line'],
legendrank=legend_order[title]),
row,
col))
add_label[exp_idx][mode][ds_id]['line'] = False
if limit_data is not None:
z_order[0].append((go.Scatter(x=x,
y=y_max,
mode='lines',
line={'width': 0},
legendgroup=title,
showlegend=False,
hoverinfo='skip'),
row,
col))
z_order[0].append((go.Scatter(x=x,
y=y_min,
mode='lines',
line={'width': 0},
fillcolor=alpha_colors[exp_idx + color_offset[mode]],
fill='tonexty',
legendgroup=title,
showlegend=False,
hoverinfo='skip'),
row,
col))
else:
# Some kind of image or matrix. Not implemented yet.
pass
for z in sorted(list(z_order.keys())):
plts = z_order[z]
for plt, row, col in plts:
fig.add_trace(plt, row=row + 1, col=col + 1)
# If inside a jupyter notebook then force the height based on number of rows
if in_notebook():
fig.update_layout(height=280 * n_rows)
return FigureFE.from_figure(fig)