Skip to content

Commit

Permalink
Merge branch 'main' into invgen
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc authored Nov 30, 2023
2 parents 17ee69c + 97a9cd6 commit b34e45c
Show file tree
Hide file tree
Showing 15 changed files with 503 additions and 111 deletions.
41 changes: 32 additions & 9 deletions doc/python_api_reference_vDev.md
Original file line number Diff line number Diff line change
Expand Up @@ -1889,9 +1889,10 @@ def get_final_qubit_coordinates(
# (in class stim.Circuit)
def has_flow(
self,
shorthand: Optional[str] = None,
*,
start: Optional[stim.PauliString] = None,
end: Optional[stim.PauliString] = None,
start: Union[None, str, stim.PauliString] = None,
end: Union[None, str, stim.PauliString] = None,
measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None,
unsigned: bool = False,
) -> bool:
Expand All @@ -1904,15 +1905,27 @@ def has_flow(
the CNOT flows implemented by the circuit involve these measurements.
A flow like P -> Q means that the circuit transforms P into Q.
A flow like IDENTITY -> P means that the circuit prepares P.
A flow like P -> IDENTITY means that the circuit measures P.
A flow like IDENTITY -> IDENTITY means that the circuit contains a detector.
Args:
A flow like 1 -> P means that the circuit prepares P.
A flow like P -> 1 means that the circuit measures P.
A flow like 1 -> 1 means that the circuit contains a detector.
Args:
shorthand: Specifies the flow as a short string like "IX -> -YZ xor rec[1]".
The text must contain "->" to separate the input pauli string from the
output pauli string. Each pauli string should be a sequence of
characters from "_IXYZ" (or else just "1" to indicate the empty Pauli
string) optionally prefixed by "+" or "-". Measurements are included
by appending " xor rec[k]" for each measurement index k. Indexing uses
the python convention where non-negative indices index from the start
and negative indices index from the end.
start: The input into the flow at the start of the circuit. Defaults to None
(the identity Pauli string).
(the identity Pauli string). When specified, this should be a
`stim.PauliString`, or a `str` (which will be parsed using
`stim.PauliString.__init__`).
end: The output from the flow at the end of the circuit. Defaults to None
(the identity Pauli string).
(the identity Pauli string). When specified, this should be a
`stim.PauliString`, or a `str` (which will be parsed using
`stim.PauliString.__init__`).
measurements: Defaults to None (empty). The indices of measurements to
include in the flow. This should be a collection of integers and/or
stim.GateTarget instances. Indexing uses the python convention where
Expand All @@ -1936,6 +1949,16 @@ def has_flow(
Examples:
>>> import stim
>>> m = stim.Circuit('M 0')
>>> m.has_flow('Z -> Z')
True
>>> m.has_flow('X -> X')
False
>>> m.has_flow('Z -> I')
False
>>> m.has_flow('Z -> I xor rec[-1]')
True
>>> stim.Circuit('''
... RY 0
... ''').has_flow(
Expand Down
41 changes: 32 additions & 9 deletions doc/stim.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1313,9 +1313,10 @@ class Circuit:
"""
def has_flow(
self,
shorthand: Optional[str] = None,
*,
start: Optional[stim.PauliString] = None,
end: Optional[stim.PauliString] = None,
start: Union[None, str, stim.PauliString] = None,
end: Union[None, str, stim.PauliString] = None,
measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None,
unsigned: bool = False,
) -> bool:
Expand All @@ -1328,15 +1329,27 @@ class Circuit:
the CNOT flows implemented by the circuit involve these measurements.
A flow like P -> Q means that the circuit transforms P into Q.
A flow like IDENTITY -> P means that the circuit prepares P.
A flow like P -> IDENTITY means that the circuit measures P.
A flow like IDENTITY -> IDENTITY means that the circuit contains a detector.
Args:
A flow like 1 -> P means that the circuit prepares P.
A flow like P -> 1 means that the circuit measures P.
A flow like 1 -> 1 means that the circuit contains a detector.
Args:
shorthand: Specifies the flow as a short string like "IX -> -YZ xor rec[1]".
The text must contain "->" to separate the input pauli string from the
output pauli string. Each pauli string should be a sequence of
characters from "_IXYZ" (or else just "1" to indicate the empty Pauli
string) optionally prefixed by "+" or "-". Measurements are included
by appending " xor rec[k]" for each measurement index k. Indexing uses
the python convention where non-negative indices index from the start
and negative indices index from the end.
start: The input into the flow at the start of the circuit. Defaults to None
(the identity Pauli string).
(the identity Pauli string). When specified, this should be a
`stim.PauliString`, or a `str` (which will be parsed using
`stim.PauliString.__init__`).
end: The output from the flow at the end of the circuit. Defaults to None
(the identity Pauli string).
(the identity Pauli string). When specified, this should be a
`stim.PauliString`, or a `str` (which will be parsed using
`stim.PauliString.__init__`).
measurements: Defaults to None (empty). The indices of measurements to
include in the flow. This should be a collection of integers and/or
stim.GateTarget instances. Indexing uses the python convention where
Expand All @@ -1360,6 +1373,16 @@ class Circuit:
Examples:
>>> import stim
>>> m = stim.Circuit('M 0')
>>> m.has_flow('Z -> Z')
True
>>> m.has_flow('X -> X')
False
>>> m.has_flow('Z -> I')
False
>>> m.has_flow('Z -> I xor rec[-1]')
True
>>> stim.Circuit('''
... RY 0
... ''').has_flow(
Expand Down
2 changes: 1 addition & 1 deletion glue/javascript/pauli_string.js.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ExposedPauliString::ExposedPauliString(const emscripten::val &arg) : pauli_strin
if (arg.isNumber()) {
pauli_string = PauliString<stim::MAX_BITWORD_WIDTH>(js_val_to_uint32_t(arg));
} else if (arg.isString()) {
pauli_string = PauliString<stim::MAX_BITWORD_WIDTH>::from_str(arg.as<std::string>().data());
pauli_string = PauliString<stim::MAX_BITWORD_WIDTH>::from_str(arg.as<std::string>());
} else {
throw std::invalid_argument("Expected an int or a string. Got " + t);
}
Expand Down
41 changes: 32 additions & 9 deletions glue/python/src/stim/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1313,9 +1313,10 @@ class Circuit:
"""
def has_flow(
self,
shorthand: Optional[str] = None,
*,
start: Optional[stim.PauliString] = None,
end: Optional[stim.PauliString] = None,
start: Union[None, str, stim.PauliString] = None,
end: Union[None, str, stim.PauliString] = None,
measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None,
unsigned: bool = False,
) -> bool:
Expand All @@ -1328,15 +1329,27 @@ class Circuit:
the CNOT flows implemented by the circuit involve these measurements.
A flow like P -> Q means that the circuit transforms P into Q.
A flow like IDENTITY -> P means that the circuit prepares P.
A flow like P -> IDENTITY means that the circuit measures P.
A flow like IDENTITY -> IDENTITY means that the circuit contains a detector.
Args:
A flow like 1 -> P means that the circuit prepares P.
A flow like P -> 1 means that the circuit measures P.
A flow like 1 -> 1 means that the circuit contains a detector.
Args:
shorthand: Specifies the flow as a short string like "IX -> -YZ xor rec[1]".
The text must contain "->" to separate the input pauli string from the
output pauli string. Each pauli string should be a sequence of
characters from "_IXYZ" (or else just "1" to indicate the empty Pauli
string) optionally prefixed by "+" or "-". Measurements are included
by appending " xor rec[k]" for each measurement index k. Indexing uses
the python convention where non-negative indices index from the start
and negative indices index from the end.
start: The input into the flow at the start of the circuit. Defaults to None
(the identity Pauli string).
(the identity Pauli string). When specified, this should be a
`stim.PauliString`, or a `str` (which will be parsed using
`stim.PauliString.__init__`).
end: The output from the flow at the end of the circuit. Defaults to None
(the identity Pauli string).
(the identity Pauli string). When specified, this should be a
`stim.PauliString`, or a `str` (which will be parsed using
`stim.PauliString.__init__`).
measurements: Defaults to None (empty). The indices of measurements to
include in the flow. This should be a collection of integers and/or
stim.GateTarget instances. Indexing uses the python convention where
Expand All @@ -1360,6 +1373,16 @@ class Circuit:
Examples:
>>> import stim
>>> m = stim.Circuit('M 0')
>>> m.has_flow('Z -> Z')
True
>>> m.has_flow('X -> X')
False
>>> m.has_flow('Z -> I')
False
>>> m.has_flow('Z -> I xor rec[-1]')
True
>>> stim.Circuit('''
... RY 0
... ''').has_flow(
Expand Down
46 changes: 42 additions & 4 deletions glue/sample/src/sinter/_decoding_fusion_blossom.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,58 @@
import math
import pathlib
from typing import Callable, List, TYPE_CHECKING
from typing import Tuple
from typing import Callable, List, TYPE_CHECKING, Tuple

import numpy as np
import stim

from sinter._decoding_decoder_class import Decoder
from sinter._decoding_decoder_class import Decoder, CompiledDecoder

if TYPE_CHECKING:
import fusion_blossom


class FusionBlossomCompiledDecoder(CompiledDecoder):
def __init__(self, solver: 'fusion_blossom.SolverSerial', fault_masks: 'np.ndarray', num_dets: int, num_obs: int):
self.solver = solver
self.fault_masks = fault_masks
self.num_dets = num_dets
self.num_obs = num_obs

def decode_shots_bit_packed(
self,
*,
bit_packed_detection_event_data: 'np.ndarray',
) -> 'np.ndarray':
num_shots = bit_packed_detection_event_data.shape[0]
predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8)
import fusion_blossom
for shot in range(num_shots):
dets_sparse = np.flatnonzero(np.unpackbits(bit_packed_detection_event_data[shot], count=self.num_dets, bitorder='little'))
syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
self.solver.solve(syndrome)
prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
predictions[shot] = np.packbits(prediction, bitorder='little')
self.solver.clear()
return predictions


class FusionBlossomDecoder(Decoder):
"""Use fusion blossom to predict observables from detection events."""

def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder:
try:
import fusion_blossom
except ImportError as ex:
raise ImportError(
"The decoder 'fusion_blossom' isn't installed\n"
"To fix this, install the python package 'fusion_blossom' into your environment.\n"
"For example, if you are using pip, run `pip install fusion_blossom`.\n"
) from ex

solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(dem)
return FusionBlossomCompiledDecoder(solver, fault_masks, dem.num_detectors, dem.num_observables)


def decode_via_files(self,
*,
num_shots: int,
Expand All @@ -36,7 +75,6 @@ def decode_via_files(self,
error_model = stim.DetectorErrorModel.from_file(dem_path)
solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(error_model)
num_det_bytes = math.ceil(num_dets / 8)

with open(dets_b8_in_path, 'rb') as dets_in_f:
with open(obs_predictions_b8_out_path, 'wb') as obs_out_f:
for _ in range(num_shots):
Expand Down
18 changes: 8 additions & 10 deletions src/stim/arg_parse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,20 @@ bool stim::find_bool_argument(const char *name, int argc, const char **argv) {
throw std::invalid_argument(msg.str());
}

bool parse_int64(const char *data, int64_t *out) {
char c = *data;
if (c == 0) {
bool stim::parse_int64(std::string_view data, int64_t *out) {
if (data.empty()) {
return false;
}
bool negate = false;
if (c == '-') {
if (data.starts_with("-")) {
negate = true;
data++;
c = *data;
data = data.substr(1);
} else if (data.starts_with("+")) {
data = data.substr(1);
}

uint64_t accumulator = 0;
while (c) {
for (char c : data) {
if (!(c >= '0' && c <= '9')) {
return false;
}
Expand All @@ -248,8 +248,6 @@ bool parse_int64(const char *data, int64_t *out) {
return false; // Overflow.
}
accumulator = next;
data++;
c = *data;
}

if (negate && accumulator == (uint64_t)INT64_MAX + uint64_t{1}) {
Expand Down Expand Up @@ -423,7 +421,7 @@ uint64_t stim::parse_exact_uint64_t_from_string(const std::string &text) {
if (end == c + text.size()) {
// strtoull silently accepts spaces and negative signs and overflowing
// values. The only guaranteed way I've found to ensure it actually
// worked is to recreate the string and check that it's the sam.e
// worked is to recreate the string and check that it's the same.
std::stringstream ss;
ss << v;
if (ss.str() == text) {
Expand Down
1 change: 1 addition & 0 deletions src/stim/arg_parse.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ std::vector<std::string> split(char splitter, const std::string &text);

double parse_exact_double_from_string(const std::string &text);
uint64_t parse_exact_uint64_t_from_string(const std::string &text);
bool parse_int64(std::string_view data, int64_t *out);

} // namespace stim

Expand Down
34 changes: 34 additions & 0 deletions src/stim/arg_parse.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,37 @@ TEST(arg_parse, parse_exact_uint64_t_from_string) {
ASSERT_EQ(parse_exact_uint64_t_from_string("2"), 2);
ASSERT_EQ(parse_exact_uint64_t_from_string("18446744073709551615"), UINT64_MAX);
}

TEST(arg_parse, parse_int64) {
int64_t x = 0;

ASSERT_TRUE(parse_int64("+0", &x));
ASSERT_EQ(x, 0);
ASSERT_TRUE(parse_int64("-0", &x));
ASSERT_EQ(x, 0);
ASSERT_TRUE(parse_int64("0", &x));
ASSERT_EQ(x, 0);
ASSERT_TRUE(parse_int64("1", &x));
ASSERT_EQ(x, 1);
ASSERT_TRUE(parse_int64("-1", &x));
ASSERT_EQ(x, -1);
ASSERT_FALSE(parse_int64("i", &x));
ASSERT_FALSE(parse_int64("1i", &x));
ASSERT_FALSE(parse_int64("i1", &x));
ASSERT_FALSE(parse_int64("1e2", &x));
ASSERT_FALSE(parse_int64("12i1", &x));
ASSERT_FALSE(parse_int64("12 ", &x));
ASSERT_FALSE(parse_int64(" 12", &x));

ASSERT_TRUE(parse_int64("0123", &x));
ASSERT_EQ(x, 123);
ASSERT_TRUE(parse_int64("-0123", &x));
ASSERT_EQ(x, -123);

ASSERT_FALSE(parse_int64("-9223372036854775809", &x));
ASSERT_TRUE(parse_int64("-9223372036854775808", &x));
ASSERT_EQ(x, INT64_MIN);
ASSERT_FALSE(parse_int64("9223372036854775808", &x));
ASSERT_TRUE(parse_int64("9223372036854775807", &x));
ASSERT_EQ(x, INT64_MAX);
}
Loading

0 comments on commit b34e45c

Please sign in to comment.