From fb69c2dbe11b56f746cc6ef2dcd5c43ba6395519 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Sun, 26 Nov 2023 16:41:25 -0800 Subject: [PATCH 1/2] Add `shorthand` argument to `stim.Circuit.has_flow` (#671) - Refactor a few initial C++ methods to use `std::string_view` instead of `const char *` or `const std::string &` --- doc/python_api_reference_vDev.md | 41 ++++-- doc/stim.pyi | 41 ++++-- glue/javascript/pauli_string.js.cc | 2 +- glue/python/src/stim/__init__.pyi | 41 ++++-- src/stim/arg_parse.cc | 18 ++- src/stim/arg_parse.h | 1 + src/stim/arg_parse.test.cc | 34 +++++ src/stim/circuit/circuit.pybind.cc | 160 ++++++++++++++++------- src/stim/circuit/circuit_pybind_test.py | 25 ++++ src/stim/circuit/stabilizer_flow.h | 2 +- src/stim/circuit/stabilizer_flow.inl | 93 ++++++++++--- src/stim/circuit/stabilizer_flow.test.cc | 95 ++++++++++++++ src/stim/stabilizers/pauli_string.h | 2 +- src/stim/stabilizers/pauli_string.inl | 13 +- 14 files changed, 461 insertions(+), 107 deletions(-) diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index e263f9d15..d210aea73 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -1889,9 +1889,10 @@ def get_final_qubit_coordinates( # (in class stim.Circuit) def has_flow( self, + shorthand: Optional[str] = None, *, - start: Optional[stim.PauliString] = None, - end: Optional[stim.PauliString] = None, + start: Union[None, str, stim.PauliString] = None, + end: Union[None, str, stim.PauliString] = None, measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None, unsigned: bool = False, ) -> bool: @@ -1904,15 +1905,27 @@ def has_flow( the CNOT flows implemented by the circuit involve these measurements. A flow like P -> Q means that the circuit transforms P into Q. - A flow like IDENTITY -> P means that the circuit prepares P. - A flow like P -> IDENTITY means that the circuit measures P. - A flow like IDENTITY -> IDENTITY means that the circuit contains a detector. - - Args: + A flow like 1 -> P means that the circuit prepares P. + A flow like P -> 1 means that the circuit measures P. + A flow like 1 -> 1 means that the circuit contains a detector. + + Args: + shorthand: Specifies the flow as a short string like "IX -> -YZ xor rec[1]". + The text must contain "->" to separate the input pauli string from the + output pauli string. Each pauli string should be a sequence of + characters from "_IXYZ" (or else just "1" to indicate the empty Pauli + string) optionally prefixed by "+" or "-". Measurements are included + by appending " xor rec[k]" for each measurement index k. Indexing uses + the python convention where non-negative indices index from the start + and negative indices index from the end. start: The input into the flow at the start of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). end: The output from the flow at the end of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). measurements: Defaults to None (empty). The indices of measurements to include in the flow. This should be a collection of integers and/or stim.GateTarget instances. Indexing uses the python convention where @@ -1936,6 +1949,16 @@ def has_flow( Examples: >>> import stim + >>> m = stim.Circuit('M 0') + >>> m.has_flow('Z -> Z') + True + >>> m.has_flow('X -> X') + False + >>> m.has_flow('Z -> I') + False + >>> m.has_flow('Z -> I xor rec[-1]') + True + >>> stim.Circuit(''' ... RY 0 ... ''').has_flow( diff --git a/doc/stim.pyi b/doc/stim.pyi index 291954718..28323d48e 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -1313,9 +1313,10 @@ class Circuit: """ def has_flow( self, + shorthand: Optional[str] = None, *, - start: Optional[stim.PauliString] = None, - end: Optional[stim.PauliString] = None, + start: Union[None, str, stim.PauliString] = None, + end: Union[None, str, stim.PauliString] = None, measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None, unsigned: bool = False, ) -> bool: @@ -1328,15 +1329,27 @@ class Circuit: the CNOT flows implemented by the circuit involve these measurements. A flow like P -> Q means that the circuit transforms P into Q. - A flow like IDENTITY -> P means that the circuit prepares P. - A flow like P -> IDENTITY means that the circuit measures P. - A flow like IDENTITY -> IDENTITY means that the circuit contains a detector. - - Args: + A flow like 1 -> P means that the circuit prepares P. + A flow like P -> 1 means that the circuit measures P. + A flow like 1 -> 1 means that the circuit contains a detector. + + Args: + shorthand: Specifies the flow as a short string like "IX -> -YZ xor rec[1]". + The text must contain "->" to separate the input pauli string from the + output pauli string. Each pauli string should be a sequence of + characters from "_IXYZ" (or else just "1" to indicate the empty Pauli + string) optionally prefixed by "+" or "-". Measurements are included + by appending " xor rec[k]" for each measurement index k. Indexing uses + the python convention where non-negative indices index from the start + and negative indices index from the end. start: The input into the flow at the start of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). end: The output from the flow at the end of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). measurements: Defaults to None (empty). The indices of measurements to include in the flow. This should be a collection of integers and/or stim.GateTarget instances. Indexing uses the python convention where @@ -1360,6 +1373,16 @@ class Circuit: Examples: >>> import stim + >>> m = stim.Circuit('M 0') + >>> m.has_flow('Z -> Z') + True + >>> m.has_flow('X -> X') + False + >>> m.has_flow('Z -> I') + False + >>> m.has_flow('Z -> I xor rec[-1]') + True + >>> stim.Circuit(''' ... RY 0 ... ''').has_flow( diff --git a/glue/javascript/pauli_string.js.cc b/glue/javascript/pauli_string.js.cc index f700eafe4..bc9bfdd20 100644 --- a/glue/javascript/pauli_string.js.cc +++ b/glue/javascript/pauli_string.js.cc @@ -14,7 +14,7 @@ ExposedPauliString::ExposedPauliString(const emscripten::val &arg) : pauli_strin if (arg.isNumber()) { pauli_string = PauliString(js_val_to_uint32_t(arg)); } else if (arg.isString()) { - pauli_string = PauliString::from_str(arg.as().data()); + pauli_string = PauliString::from_str(arg.as()); } else { throw std::invalid_argument("Expected an int or a string. Got " + t); } diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index 291954718..28323d48e 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -1313,9 +1313,10 @@ class Circuit: """ def has_flow( self, + shorthand: Optional[str] = None, *, - start: Optional[stim.PauliString] = None, - end: Optional[stim.PauliString] = None, + start: Union[None, str, stim.PauliString] = None, + end: Union[None, str, stim.PauliString] = None, measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None, unsigned: bool = False, ) -> bool: @@ -1328,15 +1329,27 @@ class Circuit: the CNOT flows implemented by the circuit involve these measurements. A flow like P -> Q means that the circuit transforms P into Q. - A flow like IDENTITY -> P means that the circuit prepares P. - A flow like P -> IDENTITY means that the circuit measures P. - A flow like IDENTITY -> IDENTITY means that the circuit contains a detector. - - Args: + A flow like 1 -> P means that the circuit prepares P. + A flow like P -> 1 means that the circuit measures P. + A flow like 1 -> 1 means that the circuit contains a detector. + + Args: + shorthand: Specifies the flow as a short string like "IX -> -YZ xor rec[1]". + The text must contain "->" to separate the input pauli string from the + output pauli string. Each pauli string should be a sequence of + characters from "_IXYZ" (or else just "1" to indicate the empty Pauli + string) optionally prefixed by "+" or "-". Measurements are included + by appending " xor rec[k]" for each measurement index k. Indexing uses + the python convention where non-negative indices index from the start + and negative indices index from the end. start: The input into the flow at the start of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). end: The output from the flow at the end of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). measurements: Defaults to None (empty). The indices of measurements to include in the flow. This should be a collection of integers and/or stim.GateTarget instances. Indexing uses the python convention where @@ -1360,6 +1373,16 @@ class Circuit: Examples: >>> import stim + >>> m = stim.Circuit('M 0') + >>> m.has_flow('Z -> Z') + True + >>> m.has_flow('X -> X') + False + >>> m.has_flow('Z -> I') + False + >>> m.has_flow('Z -> I xor rec[-1]') + True + >>> stim.Circuit(''' ... RY 0 ... ''').has_flow( diff --git a/src/stim/arg_parse.cc b/src/stim/arg_parse.cc index 3c7bbbe45..24e943da6 100644 --- a/src/stim/arg_parse.cc +++ b/src/stim/arg_parse.cc @@ -225,20 +225,20 @@ bool stim::find_bool_argument(const char *name, int argc, const char **argv) { throw std::invalid_argument(msg.str()); } -bool parse_int64(const char *data, int64_t *out) { - char c = *data; - if (c == 0) { +bool stim::parse_int64(std::string_view data, int64_t *out) { + if (data.empty()) { return false; } bool negate = false; - if (c == '-') { + if (data.starts_with("-")) { negate = true; - data++; - c = *data; + data = data.substr(1); + } else if (data.starts_with("+")) { + data = data.substr(1); } uint64_t accumulator = 0; - while (c) { + for (char c : data) { if (!(c >= '0' && c <= '9')) { return false; } @@ -248,8 +248,6 @@ bool parse_int64(const char *data, int64_t *out) { return false; // Overflow. } accumulator = next; - data++; - c = *data; } if (negate && accumulator == (uint64_t)INT64_MAX + uint64_t{1}) { @@ -423,7 +421,7 @@ uint64_t stim::parse_exact_uint64_t_from_string(const std::string &text) { if (end == c + text.size()) { // strtoull silently accepts spaces and negative signs and overflowing // values. The only guaranteed way I've found to ensure it actually - // worked is to recreate the string and check that it's the sam.e + // worked is to recreate the string and check that it's the same. std::stringstream ss; ss << v; if (ss.str() == text) { diff --git a/src/stim/arg_parse.h b/src/stim/arg_parse.h index c9d6c2484..68baf9d4d 100644 --- a/src/stim/arg_parse.h +++ b/src/stim/arg_parse.h @@ -264,6 +264,7 @@ std::vector split(char splitter, const std::string &text); double parse_exact_double_from_string(const std::string &text); uint64_t parse_exact_uint64_t_from_string(const std::string &text); +bool parse_int64(std::string_view data, int64_t *out); } // namespace stim diff --git a/src/stim/arg_parse.test.cc b/src/stim/arg_parse.test.cc index d4f685460..5d8a1b593 100644 --- a/src/stim/arg_parse.test.cc +++ b/src/stim/arg_parse.test.cc @@ -346,3 +346,37 @@ TEST(arg_parse, parse_exact_uint64_t_from_string) { ASSERT_EQ(parse_exact_uint64_t_from_string("2"), 2); ASSERT_EQ(parse_exact_uint64_t_from_string("18446744073709551615"), UINT64_MAX); } + +TEST(arg_parse, parse_int64) { + int64_t x = 0; + + ASSERT_TRUE(parse_int64("+0", &x)); + ASSERT_EQ(x, 0); + ASSERT_TRUE(parse_int64("-0", &x)); + ASSERT_EQ(x, 0); + ASSERT_TRUE(parse_int64("0", &x)); + ASSERT_EQ(x, 0); + ASSERT_TRUE(parse_int64("1", &x)); + ASSERT_EQ(x, 1); + ASSERT_TRUE(parse_int64("-1", &x)); + ASSERT_EQ(x, -1); + ASSERT_FALSE(parse_int64("i", &x)); + ASSERT_FALSE(parse_int64("1i", &x)); + ASSERT_FALSE(parse_int64("i1", &x)); + ASSERT_FALSE(parse_int64("1e2", &x)); + ASSERT_FALSE(parse_int64("12i1", &x)); + ASSERT_FALSE(parse_int64("12 ", &x)); + ASSERT_FALSE(parse_int64(" 12", &x)); + + ASSERT_TRUE(parse_int64("0123", &x)); + ASSERT_EQ(x, 123); + ASSERT_TRUE(parse_int64("-0123", &x)); + ASSERT_EQ(x, -123); + + ASSERT_FALSE(parse_int64("-9223372036854775809", &x)); + ASSERT_TRUE(parse_int64("-9223372036854775808", &x)); + ASSERT_EQ(x, INT64_MIN); + ASSERT_FALSE(parse_int64("9223372036854775808", &x)); + ASSERT_TRUE(parse_int64("9223372036854775807", &x)); + ASSERT_EQ(x, INT64_MAX); +} \ No newline at end of file diff --git a/src/stim/circuit/circuit.pybind.cc b/src/stim/circuit/circuit.pybind.cc index 4382ed526..47d3d0181 100644 --- a/src/stim/circuit/circuit.pybind.cc +++ b/src/stim/circuit/circuit.pybind.cc @@ -219,6 +219,90 @@ uint64_t obj_to_abs_detector_id(const pybind11::handle &obj, bool fail) { throw std::invalid_argument(ss.str()); } +PyPauliString arg_to_pauli_string(const pybind11::object &arg) { + if (arg.is_none()) { + return PyPauliString(PauliString(0)); + } else if (pybind11::isinstance(arg)) { + return pybind11::cast(arg); + } else if (pybind11::isinstance(arg)) { + return PyPauliString::from_text(pybind11::cast(arg).c_str()); + } else { + throw std::invalid_argument( + "Don't know how to get a stim.PauliString from " + pybind11::cast(pybind11::repr(arg))); + } +} + +void append_measurements_from_args( + uint64_t num_circuit_measurements, + const pybind11::object &arg_measurements, + std::vector &out_measurements) { + if (arg_measurements.is_none()) { + return; + } + for (const pybind11::handle &e : arg_measurements) { + if (pybind11::isinstance(e)) { + auto d = pybind11::cast(e); + if (d.is_measurement_record_target()) { + out_measurements.push_back(d); + continue; + } + } else { + try { + int64_t s = pybind11::cast(e); + if (s >= 0 && s < (int64_t)num_circuit_measurements) { + s -= num_circuit_measurements; + } + if (s < 0 && -s <= (int64_t)num_circuit_measurements) { + out_measurements.push_back(GateTarget::rec(s)); + continue; + } + } catch (const pybind11::cast_error &) { + } + } + throw std::invalid_argument( + "Each measurement must be an integer in `range(-circuit.num_measurements, " + "circuit.num_measurements)`, or a `stim.GateTarget`."); + } +} + +StabilizerFlow args_to_flow( + uint64_t num_circuit_measurements, + const pybind11::object &shorthand, + const pybind11::object &start, + const pybind11::object &end, + const pybind11::object &measurements) { + StabilizerFlow flow{ + .input = PauliString{0}, + .output = PauliString{0}, + .measurement_outputs = {}, + }; + if (!shorthand.is_none() && !start.is_none()) { + throw std::invalid_argument("Can't specify both `shorthand` and `start`."); + } + if (!shorthand.is_none() && !end.is_none()) { + throw std::invalid_argument("Can't specify both `shorthand` and `end`."); + } + + if (!shorthand.is_none()) { + flow = StabilizerFlow::from_str( + pybind11::cast(shorthand).c_str(), num_circuit_measurements); + } else { + PyPauliString in = arg_to_pauli_string(start); + PyPauliString out = arg_to_pauli_string(end); + if (in.imag != out.imag) { + throw std::invalid_argument( + "The requested flow '" + in.str() + " -> " + out.str() + + "' is anti-Hermitian (unbalanced imaginary signs). Stabilizer flows are always Hermitian."); + } + flow.input = std::move(in.value); + flow.output = std::move(out.value); + } + + append_measurements_from_args(num_circuit_measurements, measurements, flow.measurement_outputs); + + return flow; +} + std::set obj_to_abs_detector_id_set( const pybind11::object &obj, const std::function &get_num_detectors) { std::set filter; @@ -2221,48 +2305,13 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_ bool { - auto num_measurements = self.count_measurements(); - PauliString raw_start(0); - PauliString raw_end(0); - std::vector raw_measurements; - if (!start.is_none()) { - raw_start = pybind11::cast(start).value; - } - if (!end.is_none()) { - raw_end = pybind11::cast(end).value; - } - if (!measurements.is_none()) { - for (const pybind11::handle &e : measurements) { - if (pybind11::isinstance(e)) { - auto d = pybind11::cast(e); - if (d.is_measurement_record_target()) { - raw_measurements.push_back(d); - continue; - } - } else { - try { - int64_t s = pybind11::cast(e); - if (s >= 0 && s < (int64_t)num_measurements) { - s -= num_measurements; - } - if (s < 0 && -s <= (int64_t)num_measurements) { - raw_measurements.push_back(GateTarget::rec(s)); - continue; - } - } catch (const pybind11::cast_error &) { - } - } - throw std::invalid_argument( - "Each measurement must be an integer in `range(-circuit.num_measurements, " - "circuit.num_measurements)`, or a `stim.GateTarget`."); - } - } - StabilizerFlow flow{ - .input = raw_start, .output = raw_end, .measurement_outputs = raw_measurements}; + StabilizerFlow flow = + args_to_flow(self.count_measurements(), shorthand, start, end, measurements); if (unsigned_only) { return check_if_circuit_has_unsigned_stabilizer_flows(self, &flow)[0]; } else { @@ -2270,13 +2319,14 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_(256, rng, self, &flow)[0]; } }, + pybind11::arg("shorthand") = pybind11::none(), pybind11::kw_only(), pybind11::arg("start") = pybind11::none(), pybind11::arg("end") = pybind11::none(), pybind11::arg("measurements") = pybind11::none(), pybind11::arg("unsigned") = false, clean_doc_string(R"DOC( - @signature def has_flow(self, *, start: Optional[stim.PauliString] = None, end: Optional[stim.PauliString] = None, measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None, unsigned: bool = False) -> bool: + @signature def has_flow(self, shorthand: Optional[str] = None, *, start: Union[None, str, stim.PauliString] = None, end: Union[None, str, stim.PauliString] = None, measurements: Optional[Iterable[Union[int, stim.GateTarget]]] = None, unsigned: bool = False) -> bool: Determines if the circuit has a stabilizer flow or not. A circuit has a stabilizer flow P -> Q if it maps the instantaneous stabilizer @@ -2286,15 +2336,27 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_ Q means that the circuit transforms P into Q. - A flow like IDENTITY -> P means that the circuit prepares P. - A flow like P -> IDENTITY means that the circuit measures P. - A flow like IDENTITY -> IDENTITY means that the circuit contains a detector. + A flow like 1 -> P means that the circuit prepares P. + A flow like P -> 1 means that the circuit measures P. + A flow like 1 -> 1 means that the circuit contains a detector. Args: + shorthand: Specifies the flow as a short string like "IX -> -YZ xor rec[1]". + The text must contain "->" to separate the input pauli string from the + output pauli string. Each pauli string should be a sequence of + characters from "_IXYZ" (or else just "1" to indicate the empty Pauli + string) optionally prefixed by "+" or "-". Measurements are included + by appending " xor rec[k]" for each measurement index k. Indexing uses + the python convention where non-negative indices index from the start + and negative indices index from the end. start: The input into the flow at the start of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). end: The output from the flow at the end of the circuit. Defaults to None - (the identity Pauli string). + (the identity Pauli string). When specified, this should be a + `stim.PauliString`, or a `str` (which will be parsed using + `stim.PauliString.__init__`). measurements: Defaults to None (empty). The indices of measurements to include in the flow. This should be a collection of integers and/or stim.GateTarget instances. Indexing uses the python convention where @@ -2318,6 +2380,16 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_>> import stim + >>> m = stim.Circuit('M 0') + >>> m.has_flow('Z -> Z') + True + >>> m.has_flow('X -> X') + False + >>> m.has_flow('Z -> I') + False + >>> m.has_flow('Z -> I xor rec[-1]') + True + >>> stim.Circuit(''' ... RY 0 ... ''').has_flow( diff --git a/src/stim/circuit/circuit_pybind_test.py b/src/stim/circuit/circuit_pybind_test.py index a4cc62ed1..31fe6510f 100644 --- a/src/stim/circuit/circuit_pybind_test.py +++ b/src/stim/circuit/circuit_pybind_test.py @@ -1663,3 +1663,28 @@ def test_has_flow_lattice_surgery_without_feedback(): assert not c.has_flow(start=stim.PauliString("X_"), end=stim.PauliString("-YX")) assert not c.has_flow(start=stim.PauliString("X_"), end=stim.PauliString("XX"), unsigned=True) assert c.has_flow(start=stim.PauliString("X_"), end=stim.PauliString("-YX"), unsigned=True, measurements=[1]) + + +def test_has_flow_shorthands(): + c = stim.Circuit(""" + MZ 99 + MXX 1 99 + MZZ 0 99 + MX 99 + """) + + assert c.has_flow("X_ -> XX xor rec[1] xor rec[3]") + assert c.has_flow("Z_ -> Z_") + assert c.has_flow("_X -> _X") + assert c.has_flow("_Z -> ZZ", measurements=[0, 2]) + assert c.has_flow("_Z -> ZZ xor rec[0]", measurements=[2]) + + assert c.has_flow(start="X_", end="XX", measurements=[1, 3]) + assert not c.has_flow("Z_ -> -Z_") + assert not c.has_flow("-Z_ -> Z_") + assert not c.has_flow("Z_ -> X_") + assert c.has_flow("iX_ -> iXX xor rec[1] xor rec[3]") + assert not c.has_flow("-iX_ -> iXX xor rec[1] xor rec[3]") + assert c.has_flow("-iX_ -> -iXX xor rec[1] xor rec[3]") + with pytest.raises(ValueError): + c.has_flow("iX_ -> XX") diff --git a/src/stim/circuit/stabilizer_flow.h b/src/stim/circuit/stabilizer_flow.h index 66560da61..d3d53027b 100644 --- a/src/stim/circuit/stabilizer_flow.h +++ b/src/stim/circuit/stabilizer_flow.h @@ -31,7 +31,7 @@ struct StabilizerFlow { stim::PauliString output; std::vector measurement_outputs; - static StabilizerFlow from_str(const char *c); + static StabilizerFlow from_str(const char *text, uint64_t num_measurements_for_non_neg_recs = 0); bool operator==(const StabilizerFlow &other) const; bool operator!=(const StabilizerFlow &other) const; std::string str() const; diff --git a/src/stim/circuit/stabilizer_flow.inl b/src/stim/circuit/stabilizer_flow.inl index 8e7c627c3..6183f16d7 100644 --- a/src/stim/circuit/stabilizer_flow.inl +++ b/src/stim/circuit/stabilizer_flow.inl @@ -68,51 +68,110 @@ std::vector sample_if_circuit_has_stabilizer_flows( return result; } +inline bool parse_rec_allowing_non_negative(std::string_view rec, size_t num_measurements_for_non_neg, GateTarget *out) { + if (rec.size() < 6 || rec[0] != 'r' || rec[1] != 'e' || rec[2] != 'c' || rec[3] != '[' || rec.back() != ']') { + throw std::invalid_argument(""); // Caught and given a message below. + } + int64_t i = 0; + if (!parse_int64(rec.substr(4, rec.size() - 5), &i)) { + return false; + } + + if (i >= INT32_MIN && i < 0) { + *out = stim::GateTarget::rec((int32_t)i); + return true; + } + if (i >= 0 && (size_t)i < num_measurements_for_non_neg) { + *out = stim::GateTarget::rec((int32_t)i - (int32_t)num_measurements_for_non_neg); + return true; + } + return false; +} + template -StabilizerFlow StabilizerFlow::from_str(const char *text) { +PauliString parse_non_empty_pauli_string_allowing_i(std::string_view text, bool *imag_out) { + *imag_out = false; + if (text == "+1" || text == "1") { + return PauliString(0); + } + if (text == "-1") { + PauliString r(0); + r.sign = true; + return r; + } + if (text.empty()) { + throw std::invalid_argument("Got an ambiguously blank pauli string. Use '1' for the empty Pauli string."); + } + + bool negate = false; + if (text.starts_with('i')) { + *imag_out = true; + text = text.substr(1); + } else if (text.starts_with("-i")) { + negate = true; + *imag_out = true; + text = text.substr(2); + } else if (text.starts_with("+i")) { + *imag_out = true; + text = text.substr(2); + } + PauliString result = PauliString::from_str(text); + if (negate) { + result.sign ^= 1; + } + return result; +} + +template +StabilizerFlow StabilizerFlow::from_str(const char *text, uint64_t num_measurements_for_non_neg_recs) { try { auto parts = split('>', text); if (parts.size() != 2 || parts[0].empty() || parts[0].back() != '-') { - throw std::invalid_argument(""); + throw std::invalid_argument(""); // Caught and given a message below. } parts[0].pop_back(); while (!parts[0].empty() && parts[0].back() == ' ') { parts[0].pop_back(); } - PauliString input = parts[0] == "1" ? PauliString(0) - : parts[0] == "-1" ? PauliString::from_str("-") - : PauliString::from_str(parts[0].c_str()); + bool imag_inp = false; + bool imag_out = false; + PauliString inp = parse_non_empty_pauli_string_allowing_i(parts[0], &imag_inp); parts = split(' ', parts[1]); size_t k = 0; while (k < parts.size() && parts[k].empty()) { k += 1; } - PauliString output(0); + if (k >= parts.size()) { + throw std::invalid_argument(""); // Caught and given a message below. + } + PauliString out(0); std::vector measurements; - if (!parts[k].empty() && parts[k][0] != 'r') { - output = PauliString::from_str(parts[k].c_str()); + out = parse_non_empty_pauli_string_allowing_i(parts[k], &imag_out); } else { - auto t = stim::GateTarget::from_target_str(parts[k].c_str()); - if (!t.is_measurement_record_target()) { - throw std::invalid_argument(""); + GateTarget t; + if (!parse_rec_allowing_non_negative(parts[k], num_measurements_for_non_neg_recs, &t)) { + throw std::invalid_argument(""); // Caught and given a message below. } measurements.push_back(t); } k++; while (k < parts.size()) { if (parts[k] != "xor" || k + 1 == parts.size()) { - throw std::invalid_argument(""); + throw std::invalid_argument(""); // Caught and given a message below. } - auto t = stim::GateTarget::from_target_str(parts[k + 1].c_str()); - if (!t.is_measurement_record_target()) { - throw std::invalid_argument(""); + GateTarget rec; + if (!parse_rec_allowing_non_negative(parts[k + 1], num_measurements_for_non_neg_recs, &rec)) { + throw std::invalid_argument(""); // Caught and given a message below. } - measurements.push_back(t); + measurements.push_back(rec); k += 2; } - return StabilizerFlow{input, output, measurements}; + if (imag_inp != imag_out) { + throw std::invalid_argument("Anti-hermitian flows aren't allowed."); + } + return StabilizerFlow{inp, out, measurements}; } catch (const std::invalid_argument &ex) { throw std::invalid_argument("Invalid stabilizer flow text: '" + std::string(text) + "'."); } diff --git a/src/stim/circuit/stabilizer_flow.test.cc b/src/stim/circuit/stabilizer_flow.test.cc index 887249abb..e11511e84 100644 --- a/src/stim/circuit/stabilizer_flow.test.cc +++ b/src/stim/circuit/stabilizer_flow.test.cc @@ -22,6 +22,101 @@ using namespace stim; +TEST_EACH_WORD_SIZE_W(stabilizer_flow, from_str, { + ASSERT_THROW({ StabilizerFlow::from_str(""); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X>X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X-X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X > X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X - X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("->X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X->"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("rec[0] -> X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X -> rec[ -1]"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X -> X rec[-1]"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X -> X xor"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X -> rec[-1] xor X"); }, std::invalid_argument); + ASSERT_THROW({ StabilizerFlow::from_str("X -> rec[55]"); }, std::invalid_argument); + + ASSERT_EQ( + StabilizerFlow::from_str("1 -> 1"), + (StabilizerFlow{ + .input = PauliString::from_str(""), + .output = PauliString::from_str(""), + .measurement_outputs = {}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("i -> -i"), + (StabilizerFlow{ + .input = PauliString::from_str(""), + .output = PauliString::from_str("-"), + .measurement_outputs = {}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("iX -> -iY"), + (StabilizerFlow{ + .input = PauliString::from_str("X"), + .output = PauliString::from_str("-Y"), + .measurement_outputs = {}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("X->-Y"), + (StabilizerFlow{ + .input = PauliString::from_str("X"), + .output = PauliString::from_str("-Y"), + .measurement_outputs = {}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("X -> -Y"), + (StabilizerFlow{ + .input = PauliString::from_str("X"), + .output = PauliString::from_str("-Y"), + .measurement_outputs = {}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("-X -> Y"), + (StabilizerFlow{ + .input = PauliString::from_str("-X"), + .output = PauliString::from_str("Y"), + .measurement_outputs = {}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("XYZ -> -Z_Z"), + (StabilizerFlow{ + .input = PauliString::from_str("XYZ"), + .output = PauliString::from_str("-Z_Z"), + .measurement_outputs = {}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("XYZ -> Z_Y xor rec[-1]"), + (StabilizerFlow{ + .input = PauliString::from_str("XYZ"), + .output = PauliString::from_str("Z_Y"), + .measurement_outputs = {GateTarget::rec(-1)}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("XYZ -> rec[-1]"), + (StabilizerFlow{ + .input = PauliString::from_str("XYZ"), + .output = PauliString::from_str(""), + .measurement_outputs = {GateTarget::rec(-1)}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("XYZ -> Z_Y xor rec[-1] xor rec[-3]"), + (StabilizerFlow{ + .input = PauliString::from_str("XYZ"), + .output = PauliString::from_str("Z_Y"), + .measurement_outputs = {GateTarget::rec(-1), GateTarget::rec(-3)}, + })); + ASSERT_EQ( + StabilizerFlow::from_str("XYZ -> ZIY xor rec[55] xor rec[-3]", 100), + (StabilizerFlow{ + .input = PauliString::from_str("XYZ"), + .output = PauliString::from_str("Z_Y"), + .measurement_outputs = {GateTarget::rec(-45), GateTarget::rec(-3)}, + })); +}); + TEST_EACH_WORD_SIZE_W(stabilizer_flow, sample_if_circuit_has_stabilizer_flows, { auto rng = INDEPENDENT_TEST_RNG(); auto results = sample_if_circuit_has_stabilizer_flows( diff --git a/src/stim/stabilizers/pauli_string.h b/src/stim/stabilizers/pauli_string.h index 42afb0d33..b27009f09 100644 --- a/src/stim/stabilizers/pauli_string.h +++ b/src/stim/stabilizers/pauli_string.h @@ -85,7 +85,7 @@ struct PauliString { /// Factory method for creating a PauliString whose Pauli entries are returned by a function. static PauliString from_func(bool sign, size_t num_qubits, const std::function &func); /// Factory method for creating a PauliString by parsing a string (e.g. "-XIIYZ"). - static PauliString from_str(const char *text); + static PauliString from_str(std::string_view text); /// Factory method for creating a PauliString with uniformly random sign and Pauli entries. static PauliString random(size_t num_qubits, std::mt19937_64 &rng); diff --git a/src/stim/stabilizers/pauli_string.inl b/src/stim/stabilizers/pauli_string.inl index f3c5a1e0c..112bc28f3 100644 --- a/src/stim/stabilizers/pauli_string.inl +++ b/src/stim/stabilizers/pauli_string.inl @@ -37,7 +37,7 @@ PauliString::PauliString(size_t num_qubits) : num_qubits(num_qubits), sign(fa template PauliString::PauliString(const std::string &text) : num_qubits(0), sign(false), xs(0), zs(0) { - *this = std::move(PauliString::from_str(text.c_str())); + *this = std::move(PauliString::from_str(text)); } template @@ -104,12 +104,13 @@ PauliString PauliString::from_func(bool sign, size_t num_qubits, const std } template -PauliString PauliString::from_str(const char *text) { - auto sign = text[0] == '-'; - if (text[0] == '+' || text[0] == '-') { - text++; +PauliString PauliString::from_str(std::string_view text) { + bool is_negated = text.starts_with('-'); + bool is_prefixed = text.starts_with('+'); + if (is_prefixed || is_negated) { + text = text.substr(1); } - return PauliString::from_func(sign, strlen(text), [&](size_t i) { + return PauliString::from_func(is_negated, text.size(), [&](size_t i) { return text[i]; }); } From 97a9cd6c89daa18c20960bf4dfa929d7141088aa Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Mon, 27 Nov 2023 20:09:02 -0800 Subject: [PATCH 2/2] Add `sinter.FusionBlossomCompiledDecoder` (#673) --- .../src/sinter/_decoding_fusion_blossom.py | 46 +++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/glue/sample/src/sinter/_decoding_fusion_blossom.py b/glue/sample/src/sinter/_decoding_fusion_blossom.py index bcf45cf09..966b69a4f 100644 --- a/glue/sample/src/sinter/_decoding_fusion_blossom.py +++ b/glue/sample/src/sinter/_decoding_fusion_blossom.py @@ -1,19 +1,58 @@ import math import pathlib -from typing import Callable, List, TYPE_CHECKING -from typing import Tuple +from typing import Callable, List, TYPE_CHECKING, Tuple import numpy as np import stim -from sinter._decoding_decoder_class import Decoder +from sinter._decoding_decoder_class import Decoder, CompiledDecoder if TYPE_CHECKING: import fusion_blossom +class FusionBlossomCompiledDecoder(CompiledDecoder): + def __init__(self, solver: 'fusion_blossom.SolverSerial', fault_masks: 'np.ndarray', num_dets: int, num_obs: int): + self.solver = solver + self.fault_masks = fault_masks + self.num_dets = num_dets + self.num_obs = num_obs + + def decode_shots_bit_packed( + self, + *, + bit_packed_detection_event_data: 'np.ndarray', + ) -> 'np.ndarray': + num_shots = bit_packed_detection_event_data.shape[0] + predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8) + import fusion_blossom + for shot in range(num_shots): + dets_sparse = np.flatnonzero(np.unpackbits(bit_packed_detection_event_data[shot], count=self.num_dets, bitorder='little')) + syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse) + self.solver.solve(syndrome) + prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])) + predictions[shot] = np.packbits(prediction, bitorder='little') + self.solver.clear() + return predictions + + class FusionBlossomDecoder(Decoder): """Use fusion blossom to predict observables from detection events.""" + + def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder: + try: + import fusion_blossom + except ImportError as ex: + raise ImportError( + "The decoder 'fusion_blossom' isn't installed\n" + "To fix this, install the python package 'fusion_blossom' into your environment.\n" + "For example, if you are using pip, run `pip install fusion_blossom`.\n" + ) from ex + + solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(dem) + return FusionBlossomCompiledDecoder(solver, fault_masks, dem.num_detectors, dem.num_observables) + + def decode_via_files(self, *, num_shots: int, @@ -36,7 +75,6 @@ def decode_via_files(self, error_model = stim.DetectorErrorModel.from_file(dem_path) solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(error_model) num_det_bytes = math.ceil(num_dets / 8) - with open(dets_b8_in_path, 'rb') as dets_in_f: with open(obs_predictions_b8_out_path, 'wb') as obs_out_f: for _ in range(num_shots):