Skip to content

Commit

Permalink
- Add sinter plot --preprocess_stats_func
Browse files Browse the repository at this point in the history
- Add `sinter.TaskStats.with_edits`
  • Loading branch information
Strilanc committed Jul 30, 2024
1 parent 1802b72 commit 1325e6b
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
26 changes: 23 additions & 3 deletions glue/sample/src/sinter/_command/_main_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def parse_args(args: List[str]) -> Any:
'Examples:\n'
''' --filter_func "decoder=='pymatching'"\n'''
''' --filter_func "0.001 < metadata['p'] < 0.005"\n''')
parser.add_argument('--preprocess_stats_func',
type=str,
default=None,
help='An expression that operates on a `stats` value, returning a new list of stats to plot.\n'
'For example, this could double add a field to json_metadata or merge stats together.\n'
'Examples:\n'
''' --preprocess_stats_func "[stat for stat in stats if stat.errors > 0]\n'''
''' --preprocess_stats_func "[stat.with_edits(errors=stat.custom_counts['severe_errors']) for stat in stats]\n'''
''' --preprocess_stats_func "__import__('your_custom_module').your_custom_function(stats)"\n'''
)
parser.add_argument('--x_func',
type=str,
default="1",
Expand Down Expand Up @@ -174,7 +184,7 @@ def parse_args(args: List[str]) -> Any:
)
parser.add_argument('--plot_args_func',
type=str,
default='''{'marker': 'ov*sp^<>8P+xXhHDd|'[index % 18]}''',
default='''{}''',
help='A python expression used to customize the look of curves.\n'
'Values available to the python expression:\n'
' index: A unique integer identifying the curve.\n'
Expand Down Expand Up @@ -284,6 +294,10 @@ def parse_args(args: List[str]) -> Any:
a.failure_values_func = "1"
if a.failure_unit_name is None:
a.failure_unit_name = 'shot'
a.preprocess_stats_func = None if a.preprocess_stats_func is None else eval(compile(
f'lambda *, stats: {a.preprocess_stats_func}',
filename='preprocess_stats_func:command_line_arg',
mode='eval'))
a.x_func = eval(compile(
f'lambda *, stat, decoder, metadata, m, strong_id: {a.x_func}',
filename='x_func:command_line_arg',
Expand Down Expand Up @@ -491,6 +505,7 @@ def _plot_helper(
samples: Union[Iterable['sinter.TaskStats'], ExistingData],
group_func: Callable[['sinter.TaskStats'], Any],
filter_func: Callable[['sinter.TaskStats'], Any],
preprocess_stats_func: Optional[Callable],
failure_units_per_shot_func: Callable[['sinter.TaskStats'], Any],
failure_values_func: Callable[['sinter.TaskStats'], Any],
x_func: Callable[['sinter.TaskStats'], Any],
Expand Down Expand Up @@ -521,6 +536,12 @@ def _plot_helper(
for k, v in total.data.items()
if bool(filter_func(v))}

if preprocess_stats_func is not None:
processed_stats = preprocess_stats_func(stats=list(total.data.values()))
total.data = {}
for stat in processed_stats:
total.add_sample(stat)

if not plot_types:
if y_func is not None:
plot_types = ['custom_y']
Expand Down Expand Up @@ -553,7 +574,6 @@ def _plot_helper(
plotted_stats: List['sinter.TaskStats'] = [
stat
for stat in total.data.values()
if filter_func(stat)
]

def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]:
Expand Down Expand Up @@ -652,7 +672,6 @@ def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]:
x_func=x_func,
y_func=y_func,
group_func=group_func,
filter_func=filter_func,
plot_args_func=plot_args_func,
line_fits=None if not line_fits else (x_scale_name, y_scale_name),
point_label_func=point_label_func,
Expand Down Expand Up @@ -801,6 +820,7 @@ def main_plot(*, command_line_args: List[str]):
title=args.title,
subtitle=args.subtitle,
line_fits=args.line_fits,
preprocess_stats_func=args.preprocess_stats_func,
)
if args.out is not None:
fig.savefig(args.out)
Expand Down
24 changes: 24 additions & 0 deletions glue/sample/src/sinter/_data/_task_stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import dataclasses
from typing import Counter, List, Any
from typing import Optional
from typing import Union
from typing import overload

Expand Down Expand Up @@ -69,6 +70,29 @@ def __post_init__(self):
assert self.shots >= self.errors + self.discards
assert all(isinstance(k, str) and isinstance(v, int) for k, v in self.custom_counts.items())

def with_edits(
self,
*,
strong_id: Optional[str] = None,
decoder: Optional[str] = None,
json_metadata: Optional[Any] = None,
shots: Optional[int] = None,
errors: Optional[int] = None,
discards: Optional[int] = None,
seconds: Optional[float] = None,
custom_counts: Optional[Counter[str]] = None,
) -> 'TaskStats':
return TaskStats(
strong_id=self.strong_id if strong_id is None else strong_id,
decoder=self.decoder if decoder is None else decoder,
json_metadata=self.json_metadata if json_metadata is None else json_metadata,
shots=self.shots if shots is None else shots,
errors=self.errors if errors is None else errors,
discards=self.discards if discards is None else discards,
seconds=self.seconds if seconds is None else seconds,
custom_counts=self.custom_counts if custom_counts is None else custom_counts,
)

@overload
def __add__(self, other: AnonTaskStats) -> AnonTaskStats:
pass
Expand Down
33 changes: 33 additions & 0 deletions glue/sample/src/sinter/_data/_task_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,36 @@ def test_add():
seconds=52,
custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
)


def test_with_edits():
v = sinter.TaskStats(
decoder='pymatching',
json_metadata={'a': 2},
strong_id='abcdefDIFFERENT',
shots=270,
errors=34,
discards=43,
seconds=52,
custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
)
assert v.with_edits(json_metadata={'b': 3}) == sinter.TaskStats(
decoder='pymatching',
json_metadata={'b': 3},
strong_id='abcdefDIFFERENT',
shots=270,
errors=34,
discards=43,
seconds=52,
custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
)
assert v == sinter.TaskStats(strong_id='', json_metadata={}, decoder='').with_edits(
decoder='pymatching',
json_metadata={'a': 2},
strong_id='abcdefDIFFERENT',
shots=270,
errors=34,
discards=43,
seconds=52,
custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}),
)
2 changes: 0 additions & 2 deletions glue/sample/src/sinter/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,6 @@ def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict:
}

for k, group_key in enumerate(sorted(curve_groups.keys(), key=better_sorted_str_terms)):
this_group_stats = sorted(curve_groups[group_key], key=x_func)

group = curve_groups[group_key]
group = sorted(group, key=x_func)
color = colors[group_key.get('color', group_key)]
Expand Down

0 comments on commit 1325e6b

Please sign in to comment.