Skip to content

Commit

Permalink
regen api docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Jul 27, 2024
1 parent 44206e2 commit c2557a9
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 42 deletions.
23 changes: 22 additions & 1 deletion dev/gen_sinter_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def main():
```
'''.strip())

replace_rules = []
for package in ['stim', 'sinter']:
p = __import__(package)
for name in dir(p):
x = getattr(p, name)
if isinstance(x, type) and '_' in str(x) and 'class' in str(x):
desired_name = f'{package}.{name}'
bad_name = str(x).split("'")[1]
lonely_name = desired_name.split(".")[-1]
replace_rules.append((bad_name, desired_name))
for q in ['"', "'"]:
replace_rules.append(('ForwardRef(' + q + lonely_name + q + ')', desired_name))
replace_rules.append(('ForwardRef(' + q + desired_name + q + ')', desired_name))
replace_rules.append((q + desired_name + q, desired_name))
replace_rules.append((q + lonely_name + q, desired_name))
replace_rules.append(('ForwardRef(' + desired_name + ')', desired_name))
replace_rules.append(('ForwardRef(' + lonely_name + ')', desired_name))

for obj in objects:
print()
print(f'<a name="{obj.full_name}"></a>')
Expand All @@ -58,7 +76,10 @@ def main():
print(f'# (in class {".".join(obj.full_name.split(".")[:-1])})')
else:
print(f'# (at top-level in the sinter module)')
print('\n'.join(obj.lines))
for line in obj.lines:
for a, b in replace_rules:
line = line.replace(a, b)
print(line)
print("```")


Expand Down
13 changes: 1 addition & 12 deletions dev/util_gen_stub_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import sys
import types
from typing import Any
from typing import Optional, Iterator, List
Expand All @@ -9,6 +8,7 @@

keep = {
"__add__",
"__radd__",
"__eq__",
"__call__",
"__ge__",
Expand Down Expand Up @@ -224,17 +224,6 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Opt
text += '@abc.abstractmethod\n'
sig_name = f'{term_name}{inspect.signature(obj)}'
text += "\n".join(splay_signature(f"def {sig_name}:"))
text = text.replace('''ForwardRef('sinter.TaskStats')''', 'sinter.TaskStats')
text = text.replace('''ForwardRef('sinter.Task')''', 'sinter.Task')
text = text.replace('''ForwardRef('sinter.Progress')''', 'sinter.Progress')
text = text.replace('''ForwardRef('sinter.Decoder')''', 'sinter.Decoder')
text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats")
text = text.replace('sinter._decoding_decoder_class.CompiledDecoder', 'sinter.CompiledDecoder')
text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats")
text = text.replace("'stim.Circuit'", "stim.Circuit")
text = text.replace("'stim.DetectorErrorModel'", "stim.DetectorErrorModel")
text = text.replace("'sinter.CollectionOptions'", "sinter.CollectionOptions")
text = text.replace("'sinter.Fit'", 'sinter.Fit')

# Replace default value lambdas with their source.
if 'lambda' in str(text):
Expand Down
108 changes: 98 additions & 10 deletions doc/sinter_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@ API references for stable versions are kept on the [stim github wiki](https://gi
- [`sinter.CollectionOptions.combine`](#sinter.CollectionOptions.combine)
- [`sinter.CompiledDecoder`](#sinter.CompiledDecoder)
- [`sinter.CompiledDecoder.decode_shots_bit_packed`](#sinter.CompiledDecoder.decode_shots_bit_packed)
- [`sinter.CompiledSampler`](#sinter.CompiledSampler)
- [`sinter.CompiledSampler.handles_throttling`](#sinter.CompiledSampler.handles_throttling)
- [`sinter.CompiledSampler.sample`](#sinter.CompiledSampler.sample)
- [`sinter.Decoder`](#sinter.Decoder)
- [`sinter.Decoder.compile_decoder_for_dem`](#sinter.Decoder.compile_decoder_for_dem)
- [`sinter.Decoder.decode_via_files`](#sinter.Decoder.decode_via_files)
- [`sinter.Fit`](#sinter.Fit)
- [`sinter.Progress`](#sinter.Progress)
- [`sinter.Sampler`](#sinter.Sampler)
- [`sinter.Sampler.compiled_sampler_for_task`](#sinter.Sampler.compiled_sampler_for_task)
- [`sinter.Task`](#sinter.Task)
- [`sinter.Task.__init__`](#sinter.Task.__init__)
- [`sinter.Task.strong_id`](#sinter.Task.strong_id)
Expand Down Expand Up @@ -257,6 +262,50 @@ def decode_shots_bit_packed(
"""
```

<a name="sinter.CompiledSampler"></a>
```python
# sinter.CompiledSampler

# (at top-level in the sinter module)
class CompiledSampler(metaclass=abc.ABCMeta):
"""A sampler that has been configured for efficiently sampling some task.
"""
```

<a name="sinter.CompiledSampler.handles_throttling"></a>
```python
# sinter.CompiledSampler.handles_throttling

# (in class sinter.CompiledSampler)
def handles_throttling(
self,
) -> bool:
"""Return True to disable sinter wrapping samplers with throttling.
By default, sinter will wrap samplers so that they initially only do
a small number of shots then slowly ramp up. Sometimes this behavior
is not desired (e.g. in unit tests). Override this method to return True
to disable it.
"""
```

<a name="sinter.CompiledSampler.sample"></a>
```python
# sinter.CompiledSampler.sample

# (in class sinter.CompiledSampler)
@abc.abstractmethod
def sample(
self,
shots: int,
) -> sinter.AnonTaskStats:
"""Perform the given number of samples, and return statistics.
This method is permitted to perform fewer shots than specified, but must
indicate this in its returned statistics.
"""
```

<a name="sinter.Decoder"></a>
```python
# sinter.Decoder
Expand Down Expand Up @@ -385,9 +434,9 @@ class Fit:
of the best fit's square error, or whose likelihood was within some
maximum Bayes factor of the max likelihood hypothesis.
"""
low: float
best: float
high: float
low: Optional[float]
best: Optional[float]
high: Optional[float]
```

<a name="sinter.Progress"></a>
Expand All @@ -409,10 +458,45 @@ class Progress:
collection status, such as the number of tasks left and the
estimated time to completion for each task.
"""
new_stats: Tuple[sinter._task_stats.TaskStats, ...]
new_stats: Tuple[sinter.TaskStats, ...]
status_message: str
```

<a name="sinter.Sampler"></a>
```python
# sinter.Sampler

# (at top-level in the sinter module)
class Sampler(metaclass=abc.ABCMeta):
"""A strategy for producing stats from tasks.
Call `sampler.compiled_sampler_for_task(task)` to get a compiled sampler for
a task, then call `compiled_sampler.sample(shots)` to collect statistics.
A sampler differs from a `sinter.Decoder` because the sampler is responsible
for the full sampling process (e.g. simulating the circuit), whereas a
decoder can do nothing except predict observable flips from detection event
data. This prevents the decoders from cheating, but makes them less flexible
overall. A sampler can do things like use simulators other than stim, or
really anything at all as long as it ends with returning statistics about
shot counts, error counts, and etc.
"""
```

<a name="sinter.Sampler.compiled_sampler_for_task"></a>
```python
# sinter.Sampler.compiled_sampler_for_task

# (in class sinter.Sampler)
@abc.abstractmethod
def compiled_sampler_for_task(
self,
task: sinter.Task,
) -> sinter.CompiledSampler:
"""Creates, configures, and returns an object for sampling the task.
"""
```

<a name="sinter.Task"></a>
```python
# sinter.Task
Expand Down Expand Up @@ -475,9 +559,9 @@ class Task:
def __init__(
self,
*,
circuit: Optional[ForwardRef(stim.Circuit)] = None,
circuit: Optional[stim.Circuit] = None,
decoder: Optional[str] = None,
detector_error_model: Optional[ForwardRef(stim.DetectorErrorModel)] = None,
detector_error_model: Optional[stim.DetectorErrorModel] = None,
postselection_mask: Optional[np.ndarray] = None,
postselected_observables_mask: Optional[np.ndarray] = None,
json_metadata: Any = None,
Expand Down Expand Up @@ -699,7 +783,7 @@ class TaskStats:
# (in class sinter.TaskStats)
def to_anon_stats(
self,
) -> sinter._anon_task_stats.AnonTaskStats:
) -> sinter.AnonTaskStats:
"""Returns a `sinter.AnonTaskStats` with the same statistics.
Examples:
Expand Down Expand Up @@ -1124,7 +1208,7 @@ def iter_collect(
num_workers: int,
tasks: Union[Iterator[sinter.Task], Iterable[sinter.Task]],
hint_num_tasks: Optional[int] = None,
additional_existing_data: Optional[sinter._existing_data.ExistingData] = None,
additional_existing_data: Union[NoneType, Dict[str, sinter.TaskStats], Iterable[sinter.TaskStats]] = None,
max_shots: Optional[int] = None,
max_errors: Optional[int] = None,
decoders: Optional[Iterable[str]] = None,
Expand Down Expand Up @@ -1337,6 +1421,7 @@ def plot_discard_rate(
filter_func: Callable[[sinter.TaskStats], Any] = lambda _: True,
plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
highlight_max_likelihood_factor: Optional[float] = 1000.0,
point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None,
) -> None:
"""Plots discard rates in curves with uncertainty highlights.
Expand Down Expand Up @@ -1370,6 +1455,7 @@ def plot_discard_rate(
highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is.
Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood
hypothesis will be highlighted.
point_label_func: Optional. Specifies text to draw next to data points.
"""
```

Expand All @@ -1390,6 +1476,7 @@ def plot_error_rate(
plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(),
highlight_max_likelihood_factor: Optional[float] = 1000.0,
line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None,
point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None,
) -> None:
"""Plots error rates in curves with uncertainty highlights.
Expand Down Expand Up @@ -1430,6 +1517,7 @@ def plot_error_rate(
line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line
fit to every curve. The scales determine how to transform the coordinates before
performing the fit, and can be set to 'linear', 'sqrt', or 'log'.
point_label_func: Optional. Specifies text to draw next to data points.
"""
```

Expand Down Expand Up @@ -1712,11 +1800,11 @@ def read_stats_from_csv_files(

# (at top-level in the sinter module)
def shot_error_rate_to_piece_error_rate(
shot_error_rate: Union[float, ForwardRef(sinter.Fit)],
shot_error_rate: Union[float, sinter.Fit],
*,
pieces: float,
values: float = 1,
) -> Union[float, ForwardRef(sinter.Fit)]:
) -> Union[float, sinter.Fit]:
"""Convert from total error rate to per-piece error rate.
Args:
Expand Down
17 changes: 15 additions & 2 deletions glue/sample/src/sinter/_collection/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def iter_collect(*,
tasks: Union[Iterator['sinter.Task'],
Iterable['sinter.Task']],
hint_num_tasks: Optional[int] = None,
additional_existing_data: Optional[ExistingData] = None,
additional_existing_data: Union[None, dict[str, 'TaskStats'], Iterable['TaskStats']] = None,
max_shots: Optional[int] = None,
max_errors: Optional[int] = None,
decoders: Optional[Iterable[str]] = None,
Expand Down Expand Up @@ -152,6 +152,19 @@ def iter_collect(*,
>>> print(total_shots)
200
"""
existing_data: dict[str, TaskStats]
if isinstance(additional_existing_data, ExistingData):
existing_data = additional_existing_data.data
elif isinstance(additional_existing_data, dict):
existing_data = additional_existing_data
elif additional_existing_data is None:
existing_data = {}
else:
acc = ExistingData()
for stat in additional_existing_data:
acc.add_sample(stat)
existing_data = acc.data

if isinstance(decoders, str):
decoders = [decoders]

Expand Down Expand Up @@ -192,7 +205,7 @@ def log_progress(e: Optional[TaskStats]):
start_batch_size=start_batch_size,
max_batch_size=max_batch_size,
),
existing_data={} if additional_existing_data is None else additional_existing_data.data,
existing_data=existing_data,
count_observable_error_combos=count_observable_error_combos,
count_detection_events=count_detection_events,
custom_error_count_key=custom_error_count_key,
Expand Down
5 changes: 2 additions & 3 deletions glue/sample/src/sinter/_data/_anon_task_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __repr__(self) -> str:
terms.append(f'custom_counts={self.custom_counts!r}')
return f'sinter.AnonTaskStats({", ".join(terms)})'

def __add__(self, other: Union['AnonTaskStats', 'TaskStats']) -> 'AnonTaskStats':
def __add__(self, other: 'AnonTaskStats') -> 'AnonTaskStats':
"""Returns the sum of the statistics from both anonymous stats.
Adds the shots, the errors, the discards, and the seconds.
Expand All @@ -77,8 +77,7 @@ def __add__(self, other: Union['AnonTaskStats', 'TaskStats']) -> 'AnonTaskStats'
>>> a + b
sinter.AnonTaskStats(shots=1100, errors=220)
"""
from sinter._data._task_stats import TaskStats
if isinstance(other, (AnonTaskStats, TaskStats)):
if isinstance(other, AnonTaskStats):
return AnonTaskStats(
shots=self.shots + other.shots,
errors=self.errors + other.errors,
Expand Down
42 changes: 28 additions & 14 deletions glue/sample/src/sinter/_data/_task_stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import dataclasses
from typing import Counter, List, Any
from typing import Union
from typing import overload

from sinter._data._anon_task_stats import AnonTaskStats
from sinter._data._csv_out import csv_line
Expand Down Expand Up @@ -67,21 +69,33 @@ def __post_init__(self):
assert self.shots >= self.errors + self.discards
assert all(isinstance(k, str) and isinstance(v, int) for k, v in self.custom_counts.items())

@overload
def __add__(self, other: AnonTaskStats) -> AnonTaskStats:
pass
@overload
def __add__(self, other: 'TaskStats') -> 'TaskStats':
if self.strong_id != other.strong_id:
raise ValueError(f'{self.strong_id=} != {other.strong_id=}')
total = self.to_anon_stats() + other.to_anon_stats()

return TaskStats(
decoder=self.decoder,
strong_id=self.strong_id,
json_metadata=self.json_metadata,
shots=total.shots,
errors=total.errors,
discards=total.discards,
seconds=total.seconds,
custom_counts=total.custom_counts,
)
pass
def __add__(self, other: Union[AnonTaskStats, 'TaskStats']) -> Union[AnonTaskStats, 'TaskStats']:
if isinstance(other, AnonTaskStats):
return self.to_anon_stats() + other

if isinstance(other, TaskStats):
if self.strong_id != other.strong_id:
raise ValueError(f'{self.strong_id=} != {other.strong_id=}')
total = self.to_anon_stats() + other.to_anon_stats()
return TaskStats(
decoder=self.decoder,
strong_id=self.strong_id,
json_metadata=self.json_metadata,
shots=total.shots,
errors=total.errors,
discards=total.discards,
seconds=total.seconds,
custom_counts=total.custom_counts,
)

return NotImplemented
__radd__ = __add__

def to_anon_stats(self) -> AnonTaskStats:
"""Returns a `sinter.AnonTaskStats` with the same statistics.
Expand Down

0 comments on commit c2557a9

Please sign in to comment.