diff --git a/file_lists/test_files b/file_lists/test_files index b08819aff..3b0f6ea4d 100644 --- a/file_lists/test_files +++ b/file_lists/test_files @@ -13,6 +13,7 @@ src/stim/cmd/command_gen.test.cc src/stim/cmd/command_m2d.test.cc src/stim/cmd/command_sample.test.cc src/stim/cmd/command_sample_dem.test.cc +src/stim/dem/dem_instruction.test.cc src/stim/dem/detector_error_model.test.cc src/stim/diagram/ascii_diagram.test.cc src/stim/diagram/base64.test.cc diff --git a/src/stim/circuit/circuit.pybind.cc b/src/stim/circuit/circuit.pybind.cc index bbca2c6cb..7864c0102 100644 --- a/src/stim/circuit/circuit.pybind.cc +++ b/src/stim/circuit/circuit.pybind.cc @@ -46,6 +46,89 @@ using namespace stim; using namespace stim_pybind; +std::set py_dem_filter_to_dem_target_set( + const Circuit &circuit, const CircuitStats &stats, const pybind11::object &included_targets_filter) { + std::set result; + auto add_all_dets = [&]() { + for (uint64_t k = 0; k < stats.num_detectors; k++) { + result.insert(DemTarget::relative_detector_id(k)); + } + }; + auto add_all_obs = [&]() { + for (uint64_t k = 0; k < stats.num_observables; k++) { + result.insert(DemTarget::observable_id(k)); + } + }; + + bool has_coords = false; + std::map> cached_coords; + auto get_coords_cached = [&]() -> const std::map> & { + std::set all_dets; + for (uint64_t k = 0; k < stats.num_detectors; k++) { + all_dets.insert(k); + } + if (!has_coords) { + cached_coords = circuit.get_detector_coordinates(all_dets); + has_coords = true; + } + return cached_coords; + }; + + if (included_targets_filter.is_none()) { + add_all_dets(); + add_all_obs(); + return result; + } + for (const auto &filter : included_targets_filter) { + bool fail = false; + if (pybind11::isinstance(filter)) { + result.insert(pybind11::cast(filter)); + } else if (pybind11::isinstance(filter)) { + std::string s = pybind11::cast(filter); + if (s == "D") { + add_all_dets(); + } else if (s == "L") { + add_all_obs(); + } else if (s.starts_with("D") || s.starts_with("L")) { + result.insert(DemTarget::from_text(s)); + } else { + fail = true; + } + } else { + std::vector prefix; + for (auto e : filter) { + if (pybind11::isinstance(e) || pybind11::isinstance(e)) { + prefix.push_back(pybind11::cast(e)); + } else { + fail = true; + break; + } + } + if (!fail) { + for (const auto &[target, coord] : get_coords_cached()) { + if (coord.size() >= prefix.size()) { + bool match = true; + for (size_t k = 0; k < prefix.size(); k++) { + match &= prefix[k] == coord[k]; + } + if (match) { + result.insert(DemTarget::relative_detector_id(target)); + } + } + } + } + } + if (fail) { + std::stringstream ss; + ss << "Don't know how to interpret '"; + ss << pybind11::cast(pybind11::repr(filter)); + ss << "' as a dem target filter."; + throw std::invalid_argument(ss.str()); + } + } + return result; +} + std::string circuit_repr(const Circuit &self) { if (self.operations.empty()) { return "stim.Circuit()"; @@ -2034,6 +2117,163 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_ std::map> { + auto stats = self.compute_stats(); + auto included_target_set = py_dem_filter_to_dem_target_set(self, stats, included_targets); + std::set included_tick_set; + + if (included_ticks.is_none()) { + for (uint64_t k = 0; k < stats.num_ticks; k++) { + included_tick_set.insert(k); + } + } else { + for (const auto &t : included_ticks) { + included_tick_set.insert(pybind11::cast(t)); + } + } + auto result = circuit_to_detecting_regions( + self, included_target_set, included_tick_set, ignore_anticommutation_errors); + std::map> exposed_result; + for (const auto &[k, v] : result) { + exposed_result.insert({ExposedDemTarget(k), std::move(v)}); + } + return exposed_result; + }, + pybind11::kw_only(), + pybind11::arg("targets") = pybind11::none(), + pybind11::arg("ticks") = pybind11::none(), + pybind11::arg("ignore_anticommutation_errors") = false, + clean_doc_string(R"DOC( + @signature def detecting_regions(self, *, targets: Optional[Iterable[stim.DemTarget | str | Iterable[float]]] = None, ticks: Optional[Iterable[int]] = None) -> Dict[stim.DemTarget, Dict[int, stim.PauliString]]: + Records where detectors and observables are sensitive to errors over time. + + The result of this method is a nested dictionary, mapping detectors/observables + and ticks to Pauli sensitivities for that detector/observable at that time. + + For example, if observable 2 has Z-type sensitivity on qubits 5 and 6 during + tick 3, then `result[stim.target_logical_observable_id(2)][3]` will be equal to + `stim.PauliString("Z5*Z6")`. + + If you want sensitivities from more places in the circuit, besides just at the + TICK instructions, you can work around this by making a version of the circuit + with more TICKs. + + Args: + targets: Defaults to everything (None). + + When specified, this should be an iterable of filters where items + matching any one filter are included. + + A variety of filters are supported: + stim.DemTarget: Includes the targeted detector or observable. + Iterable[float]: Coordinate prefix match. Includes detectors whose + coordinate data begins with the same floats. + "D": Includes all detectors. + "L": Includes all observables. + "D#" (e.g. "D5"): Includes the detector with the specified index. + "L#" (e.g. "L5"): Includes the observable with the specified index. + + ticks: Defaults to everything (None). + When specified, this should be a list of integers corresponding to + the tick indices to report sensitivities for. + + ignore_anticommutation_errors: Defaults to False. + When set to False, invalid detecting regions that anticommute with a + reset will cause the method to raise an exception. When set to True, + the offending component will simply be silently dropped. This can + result in broken detectors having apparently enormous detecting + regions. + + Returns: + Nested dictionaries keyed first by a `stim.DemTarget` identifying the + detector or observable, then by the index of the tick, leading to a + PauliString with that target's error sensitivity at that tick. + + Note you can use `stim.PauliString.pauli_indices` to quickly get to the + non-identity terms in the sensitivity. + + Examples: + >>> import stim + + >>> detecting_regions = stim.Circuit(''' + ... R 0 + ... TICK + ... H 0 + ... TICK + ... CX 0 1 + ... TICK + ... MX 0 1 + ... DETECTOR rec[-1] rec[-2] + ... ''').detecting_regions() + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D0 + tick 0 = +Z_ + tick 1 = +X_ + tick 2 = +XX + + >>> circuit = stim.Circuit.generated( + ... "surface_code:rotated_memory_x", + ... rounds=5, + ... distance=4, + ... ) + + >>> detecting_regions = circuit.detecting_regions( + ... targets=["L0", (2, 4), stim.DemTarget.relative_detector_id(5)], + ... ticks=range(5, 15), + ... ) + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D1 + tick 5 = +____________________X______________________ + tick 6 = +____________________Z______________________ + target D5 + tick 5 = +______X____________________________________ + tick 6 = +______Z____________________________________ + target D14 + tick 5 = +__________X_X______XXX_____________________ + tick 6 = +__________X_X______XZX_____________________ + tick 7 = +__________X_X______XZX_____________________ + tick 8 = +__________X_X______XXX_____________________ + tick 9 = +__________XXX_____XXX______________________ + tick 10 = +__________XXX_______X______________________ + tick 11 = +__________X_________X______________________ + tick 12 = +____________________X______________________ + tick 13 = +____________________Z______________________ + target D29 + tick 7 = +____________________Z______________________ + tick 8 = +____________________X______________________ + tick 9 = +____________________XX_____________________ + tick 10 = +___________________XXX_______X_____________ + tick 11 = +____________X______XXXX______X_____________ + tick 12 = +__________X_X______XXX_____________________ + tick 13 = +__________X_X______XZX_____________________ + tick 14 = +__________X_X______XZX_____________________ + target D44 + tick 14 = +____________________Z______________________ + target L0 + tick 5 = +_X________X________X________X______________ + tick 6 = +_X________X________X________X______________ + tick 7 = +_X________X________X________X______________ + tick 8 = +_X________X________X________X______________ + tick 9 = +_X________X_______XX________X______________ + tick 10 = +_X________X________X________X______________ + tick 11 = +_X________XX_______X________XX_____________ + tick 12 = +_X________X________X________X______________ + tick 13 = +_X________X________X________X______________ + tick 14 = +_X________X________X________X______________ + )DOC") + .data()); + c.def( "without_noise", &Circuit::without_noise, diff --git a/src/stim/circuit/circuit_pybind_test.py b/src/stim/circuit/circuit_pybind_test.py index 5cd9d280a..e6dbc3cec 100644 --- a/src/stim/circuit/circuit_pybind_test.py +++ b/src/stim/circuit/circuit_pybind_test.py @@ -1703,7 +1703,7 @@ def test_decomposed(): assert stim.Circuit(""" ISWAP 0 1 2 1 TICK - CPP X2*X1 !Z1*Z2 + MPP X1*Z2*Y3 """).decomposed() == stim.Circuit(""" H 0 CX 0 1 1 0 @@ -1714,13 +1714,43 @@ def test_decomposed(): H 1 S 1 2 TICK - H 1 2 - CX 2 1 - H 2 2 - CX 1 2 - H 2 - S 1 1 - H 2 - CX 2 1 - H 1 2 + H 1 3 + S 3 + H 3 + S 3 3 + CX 2 1 3 1 + M 1 + CX 2 1 3 1 + H 3 + S 3 + H 3 + S 3 3 + H 1 """) + + +def test_detecting_regions(): + assert stim.Circuit(''' + R 0 + TICK + H 0 + TICK + CX 0 1 + TICK + MX 0 1 + DETECTOR rec[-1] rec[-2] + ''').detecting_regions() == {stim.DemTarget.relative_detector_id(0): { + 0: stim.PauliString("Z_"), + 1: stim.PauliString("X_"), + 2: stim.PauliString("XX"), + }} + + +def test_detecting_region_filters(): + c = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3) + assert len(c.detecting_regions(targets=["D"])) == c.num_detectors + assert len(c.detecting_regions(targets=["L"])) == c.num_observables + assert len(c.detecting_regions()) == c.num_observables + c.num_detectors + assert len(c.detecting_regions(targets=["D0"])) == 1 + assert len(c.detecting_regions(targets=["D0", "L0"])) == 2 + assert len(c.detecting_regions(targets=[stim.target_relative_detector_id(0), "D0"])) == 1 diff --git a/src/stim/dem/dem_instruction.cc b/src/stim/dem/dem_instruction.cc index b240ff14c..6ddff59d7 100644 --- a/src/stim/dem/dem_instruction.cc +++ b/src/stim/dem/dem_instruction.cc @@ -2,6 +2,7 @@ #include +#include "stim/arg_parse.h" #include "stim/dem/detector_error_model.h" #include "stim/simulators/error_analyzer.h" #include "stim/str_util.h" @@ -11,14 +12,17 @@ using namespace stim; constexpr uint64_t OBSERVABLE_BIT = uint64_t{1} << 63; constexpr uint64_t SEPARATOR_SYGIL = UINT64_MAX; +constexpr uint64_t MAX_OBS = 0xFFFFFFFF; +constexpr uint64_t MAX_DET = (uint64_t{1} << 62) - 1; + DemTarget DemTarget::observable_id(uint64_t id) { - if (id > 0xFFFFFFFF) { + if (id > MAX_OBS) { throw std::invalid_argument("id > 0xFFFFFFFF"); } return {OBSERVABLE_BIT | id}; } DemTarget DemTarget::relative_detector_id(uint64_t id) { - if (id >= (uint64_t{1} << 62)) { + if (id > MAX_DET) { throw std::invalid_argument("Relative detector id too large."); } return {id}; @@ -75,6 +79,25 @@ void DemTarget::shift_if_detector_id(int64_t offset) { data = (uint64_t)((int64_t)data + offset); } } +DemTarget DemTarget::from_text(std::string_view text) { + if (!text.empty()) { + bool is_det = text[0] == 'D'; + bool is_obs = text[0] == 'L'; + if (is_det || is_obs) { + int64_t parsed = 0; + if (parse_int64(text.substr(1), &parsed)) { + if (parsed >= 0) { + if (is_det && parsed <= (int64_t)MAX_DET) { + return DemTarget::relative_detector_id(parsed); + } else if (is_obs && parsed <= (int64_t)MAX_OBS) { + return DemTarget::observable_id(parsed); + } + } + } + } + } + throw std::invalid_argument("Failed to parse as a stim.DemTarget: '" + std::string(text) + "'"); +} bool DemInstruction::operator<(const DemInstruction &other) const { if (type != other.type) { diff --git a/src/stim/dem/dem_instruction.h b/src/stim/dem/dem_instruction.h index cba775605..d75c4f349 100644 --- a/src/stim/dem/dem_instruction.h +++ b/src/stim/dem/dem_instruction.h @@ -37,6 +37,8 @@ struct DemTarget { bool operator!=(const DemTarget &other) const; bool operator<(const DemTarget &other) const; std::string str() const; + + static DemTarget from_text(std::string_view text); }; struct DetectorErrorModel; diff --git a/src/stim/dem/dem_instruction.test.cc b/src/stim/dem/dem_instruction.test.cc new file mode 100644 index 000000000..53865e465 --- /dev/null +++ b/src/stim/dem/dem_instruction.test.cc @@ -0,0 +1,31 @@ +#include "stim/dem/dem_instruction.h" + +#include "gtest/gtest.h" + +using namespace stim; + +TEST(dem_instruction, from_str) { + ASSERT_EQ(DemTarget::from_text("D5"), DemTarget::relative_detector_id(5)); + ASSERT_EQ(DemTarget::from_text("D0"), DemTarget::relative_detector_id(0)); + ASSERT_EQ(DemTarget::from_text("D4611686018427387903"), DemTarget::relative_detector_id(4611686018427387903)); + + ASSERT_EQ(DemTarget::from_text("L5"), DemTarget::observable_id(5)); + ASSERT_EQ(DemTarget::from_text("L0"), DemTarget::observable_id(0)); + ASSERT_EQ(DemTarget::from_text("L4294967295"), DemTarget::observable_id(4294967295)); + + ASSERT_THROW({ DemTarget::from_text("D4611686018427387904"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("L4294967296"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("L-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("L-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("D-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("Da"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("Da "); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text(" Da"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("X"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text(""); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("0"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("'"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text(" "); }, std::invalid_argument); +} diff --git a/src/stim/stabilizers/conversions.cc b/src/stim/stabilizers/conversions.cc index caf7eba86..44c18119f 100644 --- a/src/stim/stabilizers/conversions.cc +++ b/src/stim/stabilizers/conversions.cc @@ -1,5 +1,7 @@ #include "stim/stabilizers/conversions.h" +#include "stim/simulators/sparse_rev_frame_tracker.h" + using namespace stim; void stim::independent_to_disjoint_xyz_errors( @@ -137,3 +139,44 @@ double stim::independent_per_channel_probability_to_depolarize2_probability(doub q *= q; return 15.0 / 16.0 * (1.0 - q); } + +std::map> stim::circuit_to_detecting_regions( + const Circuit &circuit, + std::set included_targets, + std::set included_ticks, + bool ignore_anticommutation_errors) { + CircuitStats stats = circuit.compute_stats(); + uint64_t tick_index = stats.num_ticks; + SparseUnsignedRevFrameTracker tracker( + stats.num_qubits, stats.num_measurements, stats.num_detectors, !ignore_anticommutation_errors); + std::map> result; + circuit.for_each_operation_reverse([&](const CircuitInstruction &inst) { + if (inst.gate_type == GateType::TICK) { + tick_index -= 1; + if (included_ticks.contains(tick_index)) { + for (size_t q = 0; q < stats.num_qubits; q++) { + for (auto target : tracker.xs[q]) { + if (included_targets.contains(target)) { + auto &m = result[target]; + if (!m.contains(tick_index)) { + m.insert({tick_index, FlexPauliString(stats.num_qubits)}); + } + m.at(tick_index).value.xs[q] ^= 1; + } + } + for (auto target : tracker.zs[q]) { + if (included_targets.contains(target)) { + auto &m = result[target]; + if (!m.contains(tick_index)) { + m.insert({tick_index, FlexPauliString(stats.num_qubits)}); + } + m.at(tick_index).value.zs[q] ^= 1; + } + } + } + } + } + tracker.undo_gate(inst); + }); + return result; +} diff --git a/src/stim/stabilizers/conversions.h b/src/stim/stabilizers/conversions.h index a1fb9fc4d..4349693fd 100644 --- a/src/stim/stabilizers/conversions.h +++ b/src/stim/stabilizers/conversions.h @@ -18,6 +18,8 @@ #define _STIM_STABILIZERS_CONVERSIONS_H #include "stim/circuit/circuit.h" +#include "stim/dem/dem_instruction.h" +#include "stim/stabilizers/flex_pauli_string.h" #include "stim/stabilizers/tableau.h" namespace stim { @@ -179,6 +181,12 @@ double depolarize2_probability_to_independent_per_channel_probability(double p); double independent_per_channel_probability_to_depolarize1_probability(double p); double independent_per_channel_probability_to_depolarize2_probability(double p); +std::map> circuit_to_detecting_regions( + const Circuit &circuit, + std::set included_targets, + std::set included_ticks, + bool ignore_anticommutation_errors); + } // namespace stim #include "stim/stabilizers/conversions.inl"