Skip to content

log_plot

plot_logs

A function which will plot experiment histories for comparison viewing / analysis.

Parameters:

Name Type Description Default
experiments List[Summary]

Experiment(s) to plot.

required
smooth_factor float

A non-negative float representing the magnitude of gaussian smoothing to apply (zero for none).

0
pretty_names bool

Whether to modify the metric names in graph titles (True) or leave them alone (False).

False
ignore_metrics Optional[Set[str]]

Any keys to ignore during plotting.

None
include_metrics Optional[Set[str]]

A whitelist of keys to include during plotting. If None then all will be included.

None

Returns:

Type Description
FigureFE

The handle of the pyplot figure.

Source code in fastestimator/fastestimator/summary/logs/log_plot.py
def plot_logs(experiments: List[Summary],
              smooth_factor: float = 0,
              ignore_metrics: Optional[Set[str]] = None,
              pretty_names: bool = False,
              include_metrics: Optional[Set[str]] = None) -> FigureFE:
    """A function which will plot experiment histories for comparison viewing / analysis.

    Args:
        experiments: Experiment(s) to plot.
        smooth_factor: A non-negative float representing the magnitude of gaussian smoothing to apply (zero for none).
        pretty_names: Whether to modify the metric names in graph titles (True) or leave them alone (False).
        ignore_metrics: Any keys to ignore during plotting.
        include_metrics: A whitelist of keys to include during plotting. If None then all will be included.

    Returns:
        The handle of the pyplot figure.
    """
    # Sort to keep same colors between multiple runs of visualization
    experiments = humansorted(to_list(experiments), lambda exp: exp.name)
    n_experiments = len(experiments)
    if n_experiments == 0:
        return FigureFE.from_figure(make_subplots())

    ignore_keys = ignore_metrics or set()
    ignore_keys = to_set(ignore_keys)
    ignore_keys |= {'epoch'}
    include_keys = to_set(include_metrics)
    # TODO: epoch should be indicated on the axis (top x axis?). Problem - different epochs per experiment.
    # TODO: figure out how ignore_metrics should interact with mode
    # TODO: when ds_id switches during training, prevent old id from connecting with new one (break every epoch?)
    ds_ids = set()
    metric_histories = defaultdict(_MetricGroup)  # metric: MetricGroup
    for idx, experiment in enumerate(experiments):
        history = experiment.history
        # Since python dicts remember insertion order, sort the history so that train mode is always plotted on bottom
        for mode, metrics in sorted(history.items(),
                                    key=lambda x: 0 if x[0] == 'train' else 1 if x[0] == 'eval' else 2 if x[0] == 'test'
                                    else 3 if x[0] == 'infer' else 4):
            for metric, step_val in metrics.items():
                base_metric, ds_id, *_ = f'{metric}|'.split('|')  # Plot acc|ds1 and acc|ds2 on same acc graph
                if len(step_val) == 0:
                    continue  # Ignore empty metrics
                if metric in ignore_keys or base_metric in ignore_keys:
                    continue
                # Here we intentionally check against metric and not base_metric. If user wants to display per-ds they
                #  can specify that in their include list: --include mcc 'mcc|usps'
                if include_keys and metric not in include_keys:
                    continue
                metric_histories[base_metric].add(idx, mode, ds_id, step_val)
                ds_ids.add(ds_id)

    metric_list = list(sorted(metric_histories.keys()))
    if len(metric_list) == 0:
        return FigureFE.from_figure(make_subplots())
    ds_ids = humansorted(ds_ids)  # Sort them to have consistent ordering (and thus symbols) between plot runs
    n_plots = len(metric_list)
    if len(ds_ids) > 9:  # 9 b/c None is included
        warn("Plotting more than 8 different datasets isn't well supported. Symbols will be reused.")

    # Non-Shared legends aren't supported yet. If they get supported then maybe can have that feature here too.
    #  https://github.com/plotly/plotly.js/issues/5099
    #  https://github.com/plotly/plotly.js/issues/5098

    # map the metrics into an n x n grid, then remove any extra columns. Final grid will be n x m with m <= n
    n_rows = math.ceil(math.sqrt(n_plots))
    n_cols = math.ceil(n_plots / n_rows)
    metric_grid_location = {}
    nd1_metrics = []
    idx = 0
    for metric in metric_list:
        if metric_histories[metric].ndim() == 1:
            # Delay placement of the 1D plots until the end
            nd1_metrics.append(metric)
        else:
            metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
            idx += 1
    for metric in nd1_metrics:
        metric_grid_location[metric] = (idx // n_cols, idx % n_cols)
        idx += 1
    titles = [k for k, v in sorted(list(metric_grid_location.items()), key=lambda e: e[1][0] * n_cols + e[1][1])]
    if pretty_names:
        titles = [prettify_metric_name(title) for title in titles]

    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=titles, shared_xaxes='all')
    fig.update_layout({
        'plot_bgcolor': '#FFF',
        'hovermode': 'closest',
        'margin': {
            't': 50
        },
        'modebar': {
            'add': ['hoverclosest', 'hovercompare'], 'remove': ['select2d', 'lasso2d']
        },
        'legend': {
            'tracegroupgap': 5, 'font': {
                'size': 11
            }
        }
    })

    # Set x-labels
    for idx, metric in enumerate(titles, start=1):
        plotly_idx = idx if idx > 1 else ""
        x_axis_name = f'xaxis{plotly_idx}'
        y_axis_name = f'yaxis{plotly_idx}'
        if metric_histories[metric].ndim() > 1:
            fig['layout'][x_axis_name]['title'] = 'Steps'
            fig['layout'][x_axis_name]['showticklabels'] = True
            fig['layout'][x_axis_name]['linecolor'] = "#BCCCDC"
            fig['layout'][y_axis_name]['linecolor'] = "#BCCCDC"
        else:
            # Put blank data onto the axis to instantiate the domain
            row, col = metric_grid_location[metric][0], metric_grid_location[metric][1]
            fig.add_annotation(text='', showarrow=False, row=row + 1, col=col + 1)
            # Hide the axis stuff
            fig['layout'][x_axis_name]['showgrid'] = False
            fig['layout'][x_axis_name]['zeroline'] = False
            fig['layout'][x_axis_name]['visible'] = False
            fig['layout'][y_axis_name]['showgrid'] = False
            fig['layout'][y_axis_name]['zeroline'] = False
            fig['layout'][y_axis_name]['visible'] = False

    # If there is only 1 experiment, we will use alternate colors based on mode
    color_offset = defaultdict(lambda: 0)
    n_colors = n_experiments
    if n_experiments == 1:
        n_colors = 4
        color_offset['eval'] = 1
        color_offset['test'] = 2
        color_offset['infer'] = 3
    colors = get_colors(n_colors=n_colors)
    alpha_colors = get_colors(n_colors=n_colors, alpha=0.3)

    # exp_id : {mode: {ds_id: {type: True}}}
    add_label = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: True))))
    # {row: {col: (x, y)}}
    ax_text = defaultdict(lambda: defaultdict(lambda: (0.0, 0.9)))  # Where to put the text on a given axis
    # Set up ds_id markers. The empty ds_id will have no extra marker. After that there are 4 configurations of 3-arm
    # marker, followed by 'x', '+', '*', and pound. After that it will just repeat the symbol set.
    ds_id_markers = [None, 37, 38, 39, 40, 34, 33, 35, 36]  # https://plotly.com/python/marker-style/
    ds_id_markers = {k: v for k, v in zip(ds_ids, cycle(ds_id_markers))}
    # Plotly doesn't support z-order, so delay insertion until all the plots are figured out:
    # https://github.com/plotly/plotly.py/issues/2345
    z_order = defaultdict(list)  # {order: [(plotly element, row, col), ...]}

    # Figure out the legend ordering
    legend_order = []
    for exp_idx, experiment in enumerate(experiments):
        for metric, group in metric_histories.items():
            for mode in group.modes(exp_idx):
                for ds_id in group.ds_ids(exp_idx, mode):
                    ds_title = f"{ds_id} " if ds_id else ''
                    title = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
                    legend_order.append(title)
    legend_order.sort()
    legend_order = {legend: order for order, legend in enumerate(legend_order)}

    # Actually do the plotting
    for exp_idx, experiment in enumerate(experiments):
        for metric, group in metric_histories.items():
            row, col = metric_grid_location[metric][0], metric_grid_location[metric][1]
            if group.ndim() == 1:
                # Single value
                for mode in group.modes(exp_idx):
                    for ds_id in group.ds_ids(exp_idx, mode):
                        ds_title = f"{ds_id} " if ds_id else ''
                        prefix = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
                        plotly_idx = row * n_cols + col + 1 if row * n_cols + col + 1 > 1 else ''
                        fig.add_annotation(text=f"{prefix}: {group.get_val(exp_idx, mode, ds_id)}",
                                           font={'color': colors[exp_idx + color_offset[mode]]},
                                           showarrow=False,
                                           xref=f'x{plotly_idx} domain',
                                           xanchor='left',
                                           x=ax_text[row][col][0],
                                           yref=f'y{plotly_idx} domain',
                                           yanchor='top',
                                           y=ax_text[row][col][1],
                                           exclude_empty_subplots=False)
                        ax_text[row][col] = (ax_text[row][col][0], ax_text[row][col][1] - 0.1)
                        if ax_text[row][col][1] < 0:
                            ax_text[row][col] = (ax_text[row][col][0] + 0.5, 0.9)
            elif group.ndim() == 2:
                for mode, dsv in group[exp_idx].items():
                    color = colors[exp_idx + color_offset[mode]]
                    for ds_id, data in dsv.items():
                        ds_title = f"{ds_id} " if ds_id else ''
                        title = f"{experiment.name} ({ds_title}{mode})" if n_experiments > 1 else f"{ds_title}{mode}"
                        if data.shape[0] < 2:
                            x = data[0][0]
                            y = data[0][1]
                            y_min = None
                            y_max = None
                            if isinstance(y, ValWithError):
                                y_min = y.y_min
                                y_max = y.y_max
                                y = y.y
                            marker_style = 'circle' if mode == 'train' else 'diamond' if mode == 'eval' \
                                else 'square' if mode == 'test' else 'hexagram'
                            limit_data = [(y_max, y_min)] if y_max is not None and y_min is not None else None
                            tip_text = "%{x}: (%{customdata[1]:.3f}, %{y:.3f}, %{customdata[0]:.3f})" if \
                                limit_data is not None else "%{x}: %{y:.3f}"
                            error_y = None if limit_data is None else {
                                'type': 'data', 'symmetric': False, 'array': [y_max - y], 'arrayminus': [y - y_min]
                            }
                            z_order[2].append((go.Scatter(
                                x=[x],
                                y=[y],
                                name=title,
                                legendgroup=title,
                                customdata=limit_data,
                                hovertemplate=tip_text,
                                mode='markers',
                                marker={
                                    'color': color,
                                    'size': 12,
                                    'symbol': _symbol_mash(marker_style, ds_id_markers[ds_id]),
                                    'line': {
                                        'width': 1.5, 'color': 'White'
                                    }
                                },
                                error_y=error_y,
                                showlegend=add_label[exp_idx][mode][ds_id]['patch'],
                                legendrank=legend_order[title]),
                                               row,
                                               col))
                            add_label[exp_idx][mode][ds_id]['patch'] = False
                        else:
                            # We can draw a line
                            y = data[:, 1]
                            y_min = None
                            y_max = None
                            if isinstance(y[0], ValWithError):
                                y = np.stack([e.as_tuple() for e in y])
                                y_min = y[:, 0]
                                y_max = y[:, 2]
                                y = y[:, 1]
                                if smooth_factor != 0:
                                    y_min = gaussian_filter1d(y_min, sigma=smooth_factor)
                                    y_max = gaussian_filter1d(y_max, sigma=smooth_factor)
                            # TODO - for smoothed lines, plot original data in background but greyed out
                            if smooth_factor != 0:
                                y = gaussian_filter1d(y, sigma=smooth_factor)
                            x = data[:, 0]
                            linestyle = 'solid' if mode == 'train' else 'dash' if mode == 'eval' else 'dot' if \
                                mode == 'test' else 'dashdot'
                            limit_data = [(mx, mn) for mx, mn in zip(y_max, y_min)] if y_max is not None and y_min is \
                                                                                       not None else None
                            tip_text = "%{x}: (%{customdata[1]:.3f}, %{y:.3f}, %{customdata[0]:.3f})" if \
                                limit_data is not None else "%{x}: %{y:.3f}"
                            z_order[1].append((go.Scatter(
                                x=x,
                                y=y,
                                name=title,
                                legendgroup=title,
                                mode="lines+markers" if ds_id_markers[ds_id] else 'lines',
                                marker={
                                    'color': color,
                                    'size': 8,
                                    'line': {
                                        'width': 2, 'color': 'DarkSlateGrey'
                                    },
                                    'maxdisplayed': 10,
                                    'symbol': ds_id_markers[ds_id]
                                },
                                line={
                                    'dash': linestyle, 'color': color
                                },
                                customdata=limit_data,
                                hovertemplate=tip_text,
                                showlegend=add_label[exp_idx][mode][ds_id]['line'],
                                legendrank=legend_order[title]),
                                               row,
                                               col))
                            add_label[exp_idx][mode][ds_id]['line'] = False
                            if limit_data is not None:
                                z_order[0].append((go.Scatter(x=x,
                                                              y=y_max,
                                                              mode='lines',
                                                              line={'width': 0},
                                                              legendgroup=title,
                                                              showlegend=False,
                                                              hoverinfo='skip'),
                                                   row,
                                                   col))
                                z_order[0].append((go.Scatter(x=x,
                                                              y=y_min,
                                                              mode='lines',
                                                              line={'width': 0},
                                                              fillcolor=alpha_colors[exp_idx + color_offset[mode]],
                                                              fill='tonexty',
                                                              legendgroup=title,
                                                              showlegend=False,
                                                              hoverinfo='skip'),
                                                   row,
                                                   col))
            else:
                # Some kind of image or matrix. Not implemented yet.
                pass
    for z in sorted(list(z_order.keys())):
        plts = z_order[z]
        for plt, row, col in plts:
            fig.add_trace(plt, row=row + 1, col=col + 1)

    # If inside a jupyter notebook then force the height based on number of rows
    if in_notebook():
        fig.update_layout(height=280 * n_rows)

    return FigureFE.from_figure(fig)

visualize_logs

A function which will save or display experiment histories for comparison viewing / analysis.

Parameters:

Name Type Description Default
experiments List[Summary]

Experiment(s) to plot.

required
save_path str

The path where the figure should be saved, or None to display the figure to the screen.

None
smooth_factor float

A non-negative float representing the magnitude of gaussian smoothing to apply (zero for none).

0
pretty_names bool

Whether to modify the metric names in graph titles (True) or leave them alone (False).

False
ignore_metrics Optional[Set[str]]

Any metrics to ignore during plotting.

None
include_metrics Optional[Set[str]]

A whitelist of metric keys (None whitelists all keys).

None
verbose bool

Whether to print out the save location.

True
Source code in fastestimator/fastestimator/summary/logs/log_plot.py
def visualize_logs(experiments: List[Summary],
                   save_path: str = None,
                   smooth_factor: float = 0,
                   pretty_names: bool = False,
                   ignore_metrics: Optional[Set[str]] = None,
                   include_metrics: Optional[Set[str]] = None,
                   verbose: bool = True):
    """A function which will save or display experiment histories for comparison viewing / analysis.

    Args:
        experiments: Experiment(s) to plot.
        save_path: The path where the figure should be saved, or None to display the figure to the screen.
        smooth_factor: A non-negative float representing the magnitude of gaussian smoothing to apply (zero for none).
        pretty_names: Whether to modify the metric names in graph titles (True) or leave them alone (False).
        ignore_metrics: Any metrics to ignore during plotting.
        include_metrics: A whitelist of metric keys (None whitelists all keys).
        verbose: Whether to print out the save location.
    """
    fig = plot_logs(experiments,
                    smooth_factor=smooth_factor,
                    pretty_names=pretty_names,
                    ignore_metrics=ignore_metrics,
                    include_metrics=include_metrics)
    fig.show(save_path=save_path, verbose=verbose, scale=5)