Skip to content

Commit

Permalink
docstring for sinter plot group_func dict api, add curve sorting, fix…
Browse files Browse the repository at this point in the history
… fill_between color
  • Loading branch information
m-mcewen committed Aug 13, 2024
1 parent 98b47e8 commit 7390e70
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions glue/sample/src/sinter/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,21 @@ def plot_discard_rate(
group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
The statistics are grouped into curves based on whether or not they get the same result
out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
If the result of the function is a dictionary, then optional keys in the dictionary will
also control the plotting of each curve. Available keys are:
'label': the label added to the legend for the curve
'color': the color used for plotting the curve
'marker': the marker used for the curve
'linestyle': the linestyle used for the curve
'sort': the order in which the curves will be plotted and added to the legend
e.g. if two curves (with different resulting dictionaries from group_func) share the same
value for key 'marker', they will be plotted with the same marker.
Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
filter_func: Optional. When specified, some curves will not be plotted.
The statistics are filtered and only plotted if filter_func(stat) returns True.
For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
where the saved metadata indicates the basis was 'x'.
plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to
plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
`plot` and `fill_between` used to do the actual plotting. For example, this can be used
to specify markers and colors. Takes the index of the curve in sorted order and also a
curve_id (these will be 0 and None respectively if group_func is not specified). For example,
Expand Down Expand Up @@ -337,11 +347,21 @@ def plot_error_rate(
group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
The statistics are grouped into curves based on whether or not they get the same result
out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
If the result of the function is a dictionary, then optional keys in the dictionary will
also control the plotting of each curve. Available keys are:
'label': the label added to the legend for the curve
'color': the color used for plotting the curve
'marker': the marker used for the curve
'linestyle': the linestyle used for the curve
'sort': the order in which the curves will be plotted and added to the legend
e.g. if two curves (with different resulting dictionaries from group_func) share the same
value for key 'marker', they will be plotted with the same marker.
Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
filter_func: Optional. When specified, some curves will not be plotted.
The statistics are filtered and only plotted if filter_func(stat) returns True.
For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
where the saved metadata indicates the basis was 'x'.
plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to
plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
`plot` and `fill_between` used to do the actual plotting. For example, this can be used
to specify markers and colors. Takes the index of the curve in sorted order and also a
curve_id (these will be 0 and None respectively if group_func is not specified). For example,
Expand Down Expand Up @@ -435,12 +455,22 @@ def plot_custom(
group_func: Optional. When specified, multiple curves will be plotted instead of one curve.
The statistics are grouped into curves based on whether or not they get the same result
out of this function. For example, this could be `group_func=lambda stat: stat.decoder`.
If the result of the function is a dictionary, then optional keys in the dictionary will
also control the plotting of each curve. Available keys are:
'label': the label added to the legend for the curve
'color': the color used for plotting the curve
'marker': the marker used for the curve
'linestyle': the linestyle used for the curve
'sort': the order in which the curves will be plotted and added to the legend
e.g. if two curves (with different resulting dictionaries from group_func) share the same
value for key 'marker', they will be plotted with the same marker.
Colors, markers and linestyles are assigned in order, sorted by the values for those keys.
point_label_func: Optional. Specifies text to draw next to data points.
filter_func: Optional. When specified, some curves will not be plotted.
The statistics are filtered and only plotted if filter_func(stat) returns True.
For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats
where the saved metadata indicates the basis was 'x'.
plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to
plot_args_func: Optional. Specifies additional arguments to give the underlying calls to
`plot` and `fill_between` used to do the actual plotting. For example, this can be used
to specify markers and colors. Takes the index of the curve in sorted order and also a
curve_id (these will be 0 and None respectively if group_func is not specified). For example,
Expand Down Expand Up @@ -488,7 +518,12 @@ def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict:
for i, k in enumerate(sorted({g.get('linestyle', None) for g in curve_groups.keys()}, key=better_sorted_str_terms))
}

for k, group_key in enumerate(sorted(curve_groups.keys(), key=better_sorted_str_terms)):
def sort_key(a: Any) -> Any:
if isinstance(a, _FrozenDict):
return a.get('sort', better_sorted_str_terms(a))
return better_sorted_str_terms(a)

for k, group_key in enumerate(sorted(curve_groups.keys(), key=sort_key)):
group = curve_groups[group_key]
group = sorted(group, key=x_func)
color = colors[group_key.get('color', group_key)]
Expand Down Expand Up @@ -547,7 +582,7 @@ def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict:
if lbl:
ax.annotate(lbl, (x, y))
if len(xs_low_high) > 1:
ax.fill_between(xs_low_high, ys_low, ys_high, color=color, alpha=0.2, zorder=-100)
ax.fill_between(xs_low_high, ys_low, ys_high, color=args['color'], alpha=0.2, zorder=-100)
elif len(xs_low_high) == 1:
l, = ys_low
h, = ys_high
Expand Down

0 comments on commit 7390e70

Please sign in to comment.