From 7390e7014a4569b8bb76fd472f841ae61a7fb6bb Mon Sep 17 00:00:00 2001 From: Matt McEwen Date: Tue, 13 Aug 2024 21:13:41 +0000 Subject: [PATCH] docstring for sinter plot group_func dict api, add curve sorting, fix fill_between color --- glue/sample/src/sinter/_plotting.py | 45 +++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/glue/sample/src/sinter/_plotting.py b/glue/sample/src/sinter/_plotting.py index fda78459..b0336ccc 100644 --- a/glue/sample/src/sinter/_plotting.py +++ b/glue/sample/src/sinter/_plotting.py @@ -252,11 +252,21 @@ def plot_discard_rate( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -337,11 +347,21 @@ def plot_error_rate( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -435,12 +455,22 @@ def plot_custom( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. point_label_func: Optional. Specifies text to draw next to data points. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -488,7 +518,12 @@ def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict: for i, k in enumerate(sorted({g.get('linestyle', None) for g in curve_groups.keys()}, key=better_sorted_str_terms)) } - for k, group_key in enumerate(sorted(curve_groups.keys(), key=better_sorted_str_terms)): + def sort_key(a: Any) -> Any: + if isinstance(a, _FrozenDict): + return a.get('sort', better_sorted_str_terms(a)) + return better_sorted_str_terms(a) + + for k, group_key in enumerate(sorted(curve_groups.keys(), key=sort_key)): group = curve_groups[group_key] group = sorted(group, key=x_func) color = colors[group_key.get('color', group_key)] @@ -547,7 +582,7 @@ def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict: if lbl: ax.annotate(lbl, (x, y)) if len(xs_low_high) > 1: - ax.fill_between(xs_low_high, ys_low, ys_high, color=color, alpha=0.2, zorder=-100) + ax.fill_between(xs_low_high, ys_low, ys_high, color=args['color'], alpha=0.2, zorder=-100) elif len(xs_low_high) == 1: l, = ys_low h, = ys_high