From 1325e6b2650089eae1d4a7f6f4adebec401e6ff5 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Mon, 29 Jul 2024 21:47:37 -0700 Subject: [PATCH] - Add `sinter plot --preprocess_stats_func` - Add `sinter.TaskStats.with_edits` --- glue/sample/src/sinter/_command/_main_plot.py | 26 +++++++++++++-- glue/sample/src/sinter/_data/_task_stats.py | 24 ++++++++++++++ .../src/sinter/_data/_task_stats_test.py | 33 +++++++++++++++++++ glue/sample/src/sinter/_plotting.py | 2 -- 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/glue/sample/src/sinter/_command/_main_plot.py b/glue/sample/src/sinter/_command/_main_plot.py index 0205ee831..b808b1dbf 100644 --- a/glue/sample/src/sinter/_command/_main_plot.py +++ b/glue/sample/src/sinter/_command/_main_plot.py @@ -32,6 +32,16 @@ def parse_args(args: List[str]) -> Any: 'Examples:\n' ''' --filter_func "decoder=='pymatching'"\n''' ''' --filter_func "0.001 < metadata['p'] < 0.005"\n''') + parser.add_argument('--preprocess_stats_func', + type=str, + default=None, + help='An expression that operates on a `stats` value, returning a new list of stats to plot.\n' + 'For example, this could double add a field to json_metadata or merge stats together.\n' + 'Examples:\n' + ''' --preprocess_stats_func "[stat for stat in stats if stat.errors > 0]\n''' + ''' --preprocess_stats_func "[stat.with_edits(errors=stat.custom_counts['severe_errors']) for stat in stats]\n''' + ''' --preprocess_stats_func "__import__('your_custom_module').your_custom_function(stats)"\n''' + ) parser.add_argument('--x_func', type=str, default="1", @@ -174,7 +184,7 @@ def parse_args(args: List[str]) -> Any: ) parser.add_argument('--plot_args_func', type=str, - default='''{'marker': 'ov*sp^<>8P+xXhHDd|'[index % 18]}''', + default='''{}''', help='A python expression used to customize the look of curves.\n' 'Values available to the python expression:\n' ' index: A unique integer identifying the curve.\n' @@ -284,6 +294,10 @@ def parse_args(args: List[str]) -> Any: a.failure_values_func = "1" if a.failure_unit_name is None: a.failure_unit_name = 'shot' + a.preprocess_stats_func = None if a.preprocess_stats_func is None else eval(compile( + f'lambda *, stats: {a.preprocess_stats_func}', + filename='preprocess_stats_func:command_line_arg', + mode='eval')) a.x_func = eval(compile( f'lambda *, stat, decoder, metadata, m, strong_id: {a.x_func}', filename='x_func:command_line_arg', @@ -491,6 +505,7 @@ def _plot_helper( samples: Union[Iterable['sinter.TaskStats'], ExistingData], group_func: Callable[['sinter.TaskStats'], Any], filter_func: Callable[['sinter.TaskStats'], Any], + preprocess_stats_func: Optional[Callable], failure_units_per_shot_func: Callable[['sinter.TaskStats'], Any], failure_values_func: Callable[['sinter.TaskStats'], Any], x_func: Callable[['sinter.TaskStats'], Any], @@ -521,6 +536,12 @@ def _plot_helper( for k, v in total.data.items() if bool(filter_func(v))} + if preprocess_stats_func is not None: + processed_stats = preprocess_stats_func(stats=list(total.data.values())) + total.data = {} + for stat in processed_stats: + total.add_sample(stat) + if not plot_types: if y_func is not None: plot_types = ['custom_y'] @@ -553,7 +574,6 @@ def _plot_helper( plotted_stats: List['sinter.TaskStats'] = [ stat for stat in total.data.values() - if filter_func(stat) ] def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: @@ -652,7 +672,6 @@ def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: x_func=x_func, y_func=y_func, group_func=group_func, - filter_func=filter_func, plot_args_func=plot_args_func, line_fits=None if not line_fits else (x_scale_name, y_scale_name), point_label_func=point_label_func, @@ -801,6 +820,7 @@ def main_plot(*, command_line_args: List[str]): title=args.title, subtitle=args.subtitle, line_fits=args.line_fits, + preprocess_stats_func=args.preprocess_stats_func, ) if args.out is not None: fig.savefig(args.out) diff --git a/glue/sample/src/sinter/_data/_task_stats.py b/glue/sample/src/sinter/_data/_task_stats.py index f86e9ba70..48cf681f4 100644 --- a/glue/sample/src/sinter/_data/_task_stats.py +++ b/glue/sample/src/sinter/_data/_task_stats.py @@ -1,6 +1,7 @@ import collections import dataclasses from typing import Counter, List, Any +from typing import Optional from typing import Union from typing import overload @@ -69,6 +70,29 @@ def __post_init__(self): assert self.shots >= self.errors + self.discards assert all(isinstance(k, str) and isinstance(v, int) for k, v in self.custom_counts.items()) + def with_edits( + self, + *, + strong_id: Optional[str] = None, + decoder: Optional[str] = None, + json_metadata: Optional[Any] = None, + shots: Optional[int] = None, + errors: Optional[int] = None, + discards: Optional[int] = None, + seconds: Optional[float] = None, + custom_counts: Optional[Counter[str]] = None, + ) -> 'TaskStats': + return TaskStats( + strong_id=self.strong_id if strong_id is None else strong_id, + decoder=self.decoder if decoder is None else decoder, + json_metadata=self.json_metadata if json_metadata is None else json_metadata, + shots=self.shots if shots is None else shots, + errors=self.errors if errors is None else errors, + discards=self.discards if discards is None else discards, + seconds=self.seconds if seconds is None else seconds, + custom_counts=self.custom_counts if custom_counts is None else custom_counts, + ) + @overload def __add__(self, other: AnonTaskStats) -> AnonTaskStats: pass diff --git a/glue/sample/src/sinter/_data/_task_stats_test.py b/glue/sample/src/sinter/_data/_task_stats_test.py index 0847b003f..d6bbe5c80 100644 --- a/glue/sample/src/sinter/_data/_task_stats_test.py +++ b/glue/sample/src/sinter/_data/_task_stats_test.py @@ -87,3 +87,36 @@ def test_add(): seconds=52, custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), ) + + +def test_with_edits(): + v = sinter.TaskStats( + decoder='pymatching', + json_metadata={'a': 2}, + strong_id='abcdefDIFFERENT', + shots=270, + errors=34, + discards=43, + seconds=52, + custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), + ) + assert v.with_edits(json_metadata={'b': 3}) == sinter.TaskStats( + decoder='pymatching', + json_metadata={'b': 3}, + strong_id='abcdefDIFFERENT', + shots=270, + errors=34, + discards=43, + seconds=52, + custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), + ) + assert v == sinter.TaskStats(strong_id='', json_metadata={}, decoder='').with_edits( + decoder='pymatching', + json_metadata={'a': 2}, + strong_id='abcdefDIFFERENT', + shots=270, + errors=34, + discards=43, + seconds=52, + custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), + ) diff --git a/glue/sample/src/sinter/_plotting.py b/glue/sample/src/sinter/_plotting.py index 1755f7e29..b23ae320c 100644 --- a/glue/sample/src/sinter/_plotting.py +++ b/glue/sample/src/sinter/_plotting.py @@ -478,8 +478,6 @@ def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict: } for k, group_key in enumerate(sorted(curve_groups.keys(), key=better_sorted_str_terms)): - this_group_stats = sorted(curve_groups[group_key], key=x_func) - group = curve_groups[group_key] group = sorted(group, key=x_func) color = colors[group_key.get('color', group_key)]