From edf0d4ae5f92d30e229253c12f4ac28153d4c804 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 9 Apr 2024 00:46:26 -0700 Subject: [PATCH 1/2] Add `stim.DetectorErrorModel.to_simple_error_lists` --- doc/python_api_reference_vDev.md | 63 ++++++++++++ doc/stim.pyi | 55 +++++++++++ file_lists/source_files_no_main | 1 + file_lists/test_files | 1 + glue/python/src/stim/__init__.pyi | 55 +++++++++++ src/stim.h | 1 + src/stim/dem/detector_error_model.pybind.cc | 97 +++++++++++++++++++ .../dem/detector_error_model_pybind_test.py | 39 ++++++++ src/stim/util_top/dem_to_matrix.cc | 25 +++++ src/stim/util_top/dem_to_matrix.h | 43 ++++++++ src/stim/util_top/dem_to_matrix.test.cc | 58 +++++++++++ 11 files changed, 438 insertions(+) create mode 100644 src/stim/util_top/dem_to_matrix.cc create mode 100644 src/stim/util_top/dem_to_matrix.h create mode 100644 src/stim/util_top/dem_to_matrix.test.cc diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index 8439360b6..a9659701f 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -174,6 +174,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`stim.DetectorErrorModel.rounded`](#stim.DetectorErrorModel.rounded) - [`stim.DetectorErrorModel.shortest_graphlike_error`](#stim.DetectorErrorModel.shortest_graphlike_error) - [`stim.DetectorErrorModel.to_file`](#stim.DetectorErrorModel.to_file) + - [`stim.DetectorErrorModel.to_simple_error_lists`](#stim.DetectorErrorModel.to_simple_error_lists) - [`stim.ExplainedError`](#stim.ExplainedError) - [`stim.ExplainedError.__init__`](#stim.ExplainedError.__init__) - [`stim.ExplainedError.circuit_error_locations`](#stim.ExplainedError.circuit_error_locations) @@ -6581,6 +6582,68 @@ def to_file( """ ``` + +```python +# stim.DetectorErrorModel.to_simple_error_lists + +# (in class stim.DetectorErrorModel) +def to_simple_error_lists( + self, +) -> Tuple['np.ndarray[float]', List[Set[int]], List[Set[int]]]: + """Simplifies the model into lists of chances, detection sets, and obs flips sets. + + Note that this summary doesn't include information about loops, coordinates, or + suggested decompositions. Also note that this method will merge all errors with + identical symptoms, by Bernoulli summing their probabilities, and will sort the + errors. For example, the following two detector error models will give identical + results: + + # dem 1 + error(0.01) D1 D2 ^ D3 D4 + error(0.125) D1 + error(0.25) D1 + detector(2, 3, 5) D1 + + vs + + # dem 2 + error(0.3125) D1 + error(0.01) D1 D2 D3 D4 + + Returns: + A tuple of sequences (probabilities, dets, obs). + + `probabilities` will be a numpy array with dtype=np.float64 and + shape=(num_errors,). + + `dets` will be a list of detection event sets. Each detection + event set is a python set containing the indices of detectors + flipped by the error. + + `obs` will be a list of observable flip sets. Each observable + flip set is a python set containing the indices of detectors + flipped by the error. + + These sequences are paired up by index. They will have the same length, and + the combination of probabilities[k], dets[k], and obs[k] form a single + error. + + Examples: + >>> import stim + >>> dem = stim.DetectorErrorModel(''' + ... error(0.25) D0 D1 + ... error(0.125) D0 D2 L1 L5 + ... ''') + >>> probs, dets, obs = dem.to_simple_error_lists() + >>> probs + array([0.25 , 0.125]) + >>> dets + [{0, 1}, {0, 2}] + >>> obs + [set(), {1, 5}] + """ +``` + ```python # stim.ExplainedError diff --git a/doc/stim.pyi b/doc/stim.pyi index d08243a74..217cb7f72 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -5069,6 +5069,61 @@ class DetectorErrorModel: >>> contents 'error(0.25) D2 D3\n' """ + def to_simple_error_lists( + self, + ) -> Tuple['np.ndarray[float]', List[Set[int]], List[Set[int]]]: + """Simplifies the model into lists of chances, detection sets, and obs flips sets. + + Note that this summary doesn't include information about loops, coordinates, or + suggested decompositions. Also note that this method will merge all errors with + identical symptoms, by Bernoulli summing their probabilities, and will sort the + errors. For example, the following two detector error models will give identical + results: + + # dem 1 + error(0.01) D1 D2 ^ D3 D4 + error(0.125) D1 + error(0.25) D1 + detector(2, 3, 5) D1 + + vs + + # dem 2 + error(0.3125) D1 + error(0.01) D1 D2 D3 D4 + + Returns: + A tuple of sequences (probabilities, dets, obs). + + `probabilities` will be a numpy array with dtype=np.float64 and + shape=(num_errors,). + + `dets` will be a list of detection event sets. Each detection + event set is a python set containing the indices of detectors + flipped by the error. + + `obs` will be a list of observable flip sets. Each observable + flip set is a python set containing the indices of detectors + flipped by the error. + + These sequences are paired up by index. They will have the same length, and + the combination of probabilities[k], dets[k], and obs[k] form a single + error. + + Examples: + >>> import stim + >>> dem = stim.DetectorErrorModel(''' + ... error(0.25) D0 D1 + ... error(0.125) D0 D2 L1 L5 + ... ''') + >>> probs, dets, obs = dem.to_simple_error_lists() + >>> probs + array([0.25 , 0.125]) + >>> dets + [{0, 1}, {0, 2}] + >>> obs + [set(), {1, 5}] + """ class ExplainedError: """Describes the location of an error mechanism from a stim circuit. """ diff --git a/file_lists/source_files_no_main b/file_lists/source_files_no_main index 2c508d733..92531b88a 100644 --- a/file_lists/source_files_no_main +++ b/file_lists/source_files_no_main @@ -91,6 +91,7 @@ src/stim/util_bot/error_decomp.cc src/stim/util_top/circuit_inverse_unitary.cc src/stim/util_top/circuit_to_detecting_regions.cc src/stim/util_top/circuit_vs_amplitudes.cc +src/stim/util_top/dem_to_matrix.cc src/stim/util_top/export_crumble_url.cc src/stim/util_top/export_qasm.cc src/stim/util_top/export_quirk_url.cc diff --git a/file_lists/test_files b/file_lists/test_files index 5c986278b..ef8d28c8e 100644 --- a/file_lists/test_files +++ b/file_lists/test_files @@ -86,6 +86,7 @@ src/stim/util_top/circuit_inverse_unitary.test.cc src/stim/util_top/circuit_to_detecting_regions.test.cc src/stim/util_top/circuit_vs_amplitudes.test.cc src/stim/util_top/circuit_vs_tableau.test.cc +src/stim/util_top/dem_to_matrix.test.cc src/stim/util_top/export_crumble_url.test.cc src/stim/util_top/export_qasm.test.cc src/stim/util_top/export_quirk_url.test.cc diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index d08243a74..217cb7f72 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -5069,6 +5069,61 @@ class DetectorErrorModel: >>> contents 'error(0.25) D2 D3\n' """ + def to_simple_error_lists( + self, + ) -> Tuple['np.ndarray[float]', List[Set[int]], List[Set[int]]]: + """Simplifies the model into lists of chances, detection sets, and obs flips sets. + + Note that this summary doesn't include information about loops, coordinates, or + suggested decompositions. Also note that this method will merge all errors with + identical symptoms, by Bernoulli summing their probabilities, and will sort the + errors. For example, the following two detector error models will give identical + results: + + # dem 1 + error(0.01) D1 D2 ^ D3 D4 + error(0.125) D1 + error(0.25) D1 + detector(2, 3, 5) D1 + + vs + + # dem 2 + error(0.3125) D1 + error(0.01) D1 D2 D3 D4 + + Returns: + A tuple of sequences (probabilities, dets, obs). + + `probabilities` will be a numpy array with dtype=np.float64 and + shape=(num_errors,). + + `dets` will be a list of detection event sets. Each detection + event set is a python set containing the indices of detectors + flipped by the error. + + `obs` will be a list of observable flip sets. Each observable + flip set is a python set containing the indices of detectors + flipped by the error. + + These sequences are paired up by index. They will have the same length, and + the combination of probabilities[k], dets[k], and obs[k] form a single + error. + + Examples: + >>> import stim + >>> dem = stim.DetectorErrorModel(''' + ... error(0.25) D0 D1 + ... error(0.125) D0 D2 L1 L5 + ... ''') + >>> probs, dets, obs = dem.to_simple_error_lists() + >>> probs + array([0.25 , 0.125]) + >>> dets + [{0, 1}, {0, 2}] + >>> obs + [set(), {1, 5}] + """ class ExplainedError: """Describes the location of an error mechanism from a stim circuit. """ diff --git a/src/stim.h b/src/stim.h index 090c6d585..bc69d9793 100644 --- a/src/stim.h +++ b/src/stim.h @@ -112,6 +112,7 @@ #include "stim/util_top/circuit_to_detecting_regions.h" #include "stim/util_top/circuit_vs_amplitudes.h" #include "stim/util_top/circuit_vs_tableau.h" +#include "stim/util_top/dem_to_matrix.h" #include "stim/util_top/export_crumble_url.h" #include "stim/util_top/export_qasm.h" #include "stim/util_top/export_quirk_url.h" diff --git a/src/stim/dem/detector_error_model.pybind.cc b/src/stim/dem/detector_error_model.pybind.cc index 2eb642efe..e57718149 100644 --- a/src/stim/dem/detector_error_model.pybind.cc +++ b/src/stim/dem/detector_error_model.pybind.cc @@ -23,10 +23,13 @@ #include "stim/dem/detector_error_model_target.pybind.h" #include "stim/io/raii_file.h" #include "stim/py/base.pybind.h" +#include "stim/py/numpy.pybind.h" #include "stim/search/search.h" #include "stim/simulators/dem_sampler.h" +#include "stim/util_top/dem_to_matrix.h" using namespace stim; +using namespace stim_pybind; std::string stim_pybind::detector_error_model_repr(const DetectorErrorModel &self) { if (self.instructions.empty()) { @@ -1196,4 +1199,98 @@ void stim_pybind::pybind_detector_error_model_methods( ... print(diagram, file=f) )DOC") .data()); + + c.def( + "to_simple_error_lists", + [](const DetectorErrorModel &self) { + auto map = dem_to_map(self); + size_t num_errors = map.size(); + + pybind11::list dets_list; + pybind11::list obs_list; + double *probabilities = new double[num_errors]; + + size_t row = 0; + for (const auto &kv : map) { + probabilities[row] = kv.second; + pybind11::set dets; + pybind11::set obs; + for (DemTarget t : kv.first) { + if (t.is_relative_detector_id()) { + dets.add(t.val()); + } else { + obs.add(t.val()); + } + } + dets_list.append(dets); + obs_list.append(obs); + row++; + } + + auto prob_array = pybind11::array_t( + {(pybind11::ssize_t)num_errors}, + {(pybind11::ssize_t)sizeof(double)}, + probabilities, + pybind11::capsule(probabilities, [](void *f) { + delete[] reinterpret_cast(f); + }) + ); + + return pybind11::make_tuple(prob_array, dets_list, obs_list); + }, + clean_doc_string(R"DOC( + @signature def to_simple_error_lists(self) -> Tuple['np.ndarray[float]', List[Set[int]], List[Set[int]]]: + Simplifies the model into lists of chances, detection sets, and obs flips sets. + + Note that this summary doesn't include information about loops, coordinates, or + suggested decompositions. Also note that this method will merge all errors with + identical symptoms, by Bernoulli summing their probabilities, and will sort the + errors. For example, the following two detector error models will give identical + results: + + # dem 1 + error(0.01) D1 D2 ^ D3 D4 + error(0.125) D1 + error(0.25) D1 + detector(2, 3, 5) D1 + + vs + + # dem 2 + error(0.3125) D1 + error(0.01) D1 D2 D3 D4 + + Returns: + A tuple of sequences (probabilities, dets, obs). + + `probabilities` will be a numpy array with dtype=np.float64 and + shape=(num_errors,). + + `dets` will be a list of detection event sets. Each detection + event set is a python set containing the indices of detectors + flipped by the error. + + `obs` will be a list of observable flip sets. Each observable + flip set is a python set containing the indices of detectors + flipped by the error. + + These sequences are paired up by index. They will have the same length, and + the combination of probabilities[k], dets[k], and obs[k] form a single + error. + + Examples: + >>> import stim + >>> dem = stim.DetectorErrorModel(''' + ... error(0.25) D0 D1 + ... error(0.125) D0 D2 L1 L5 + ... ''') + >>> probs, dets, obs = dem.to_simple_error_lists() + >>> probs + array([0.25 , 0.125]) + >>> dets + [{0, 1}, {0, 2}] + >>> obs + [set(), {1, 5}] + )DOC") + .data()); } diff --git a/src/stim/dem/detector_error_model_pybind_test.py b/src/stim/dem/detector_error_model_pybind_test.py index cd479aa92..6c8892bd4 100644 --- a/src/stim/dem/detector_error_model_pybind_test.py +++ b/src/stim/dem/detector_error_model_pybind_test.py @@ -14,6 +14,7 @@ import pathlib import tempfile +import numpy as np import pytest import stim @@ -536,3 +537,41 @@ def test_shortest_graphlike_error_remnant(): assert len(d.shortest_graphlike_error(ignore_ungraphlike_errors=True)) == 8 assert len(c.shortest_graphlike_error()) == 8 assert len(d.shortest_graphlike_error()) == 8 + + +def test_to_simple_error_lists(): + dem = stim.DetectorErrorModel(""" + error(0.125) D0 D2 L1 L5 + error(0.25) D0 D1 + """) + probs, dets, obs = dem.to_simple_error_lists() + np.testing.assert_array_equal(probs, np.array([0.25, 0.125], dtype=np.float64)) + assert dets == [ + {0, 1}, + {0, 2}, + ] + assert obs == [ + set(), + {1, 5}, + ] + + +def test_to_simple_error_lists_simplify(): + dem = stim.DetectorErrorModel(""" + error(0.125) D0 D2 L1 L5 + error(0.25) D0 ^ D1 + error(0) D0 D1 + error(0) D2 + """) + probs, dets, obs = dem.to_simple_error_lists() + np.testing.assert_array_equal(probs, np.array([0.25, 0.125, 0], dtype=np.float64)) + assert dets == [ + {0, 1}, + {0, 2}, + {2}, + ] + assert obs == [ + set(), + {1, 5}, + set(), + ] diff --git a/src/stim/util_top/dem_to_matrix.cc b/src/stim/util_top/dem_to_matrix.cc new file mode 100644 index 000000000..f057ad593 --- /dev/null +++ b/src/stim/util_top/dem_to_matrix.cc @@ -0,0 +1,25 @@ +#include "stim/util_top/dem_to_matrix.h" + +using namespace stim; + +std::map, double> stim::dem_to_map(const DetectorErrorModel &dem) { + std::map, double> result; + SparseXorVec buf; + dem.iter_flatten_error_instructions([&](DemInstruction instruction) { + if (instruction.type != DemInstructionType::DEM_ERROR) { + return; + } + buf.sorted_items.clear(); + for (DemTarget t: instruction.target_data) { + if (t.is_observable_id() || t.is_relative_detector_id()) { + buf.sorted_items.push_back(t); + } + } + size_t kept = xor_sort(buf); + buf.sorted_items.resize(kept); + auto q = instruction.arg_data[0]; + auto &p = result[buf]; + p = p * (1 - q) + q * (1 - p); + }); + return result; +} diff --git a/src/stim/util_top/dem_to_matrix.h b/src/stim/util_top/dem_to_matrix.h new file mode 100644 index 000000000..581a5a135 --- /dev/null +++ b/src/stim/util_top/dem_to_matrix.h @@ -0,0 +1,43 @@ +#ifndef _STIM_UTIL_TOP_DEM_TO_MATRIX_H +#define _STIM_UTIL_TOP_DEM_TO_MATRIX_H + +#include "stim/mem/sparse_xor_vec.h" +#include "stim/dem/detector_error_model.h" + +namespace stim { + +std::map, double> dem_to_map(const DetectorErrorModel &dem); + +/// Sorts the given items, and cancels out duplicates. +/// If an item appears an even number of times in the span, it is removed. +/// If it appears an odd number of times, exactly one instance is kept. +/// +/// Args: +/// target: The span of items to xor-sort. +/// +/// Returns: +/// The number of kept items. Kept items are moved to the start of the range. +template +size_t xor_sort(std::span target) { + if (target.empty()) { + return 0; + } + std::sort(target.begin(), target.end()); + size_t kept = 0; + for (size_t k = 0; k < target.size(); k++) { + if (k + 1 < target.size() && target[k] == target[k + 1]) { + k += 1; + continue; + } + if (kept < k) { + target[kept] = std::move(target[k]); + } + kept++; + } + + return kept; +} + +} // namespace stim + +#endif diff --git a/src/stim/util_top/dem_to_matrix.test.cc b/src/stim/util_top/dem_to_matrix.test.cc new file mode 100644 index 000000000..c0f836795 --- /dev/null +++ b/src/stim/util_top/dem_to_matrix.test.cc @@ -0,0 +1,58 @@ +#include "stim/util_top/dem_to_matrix.h" + +#include "gtest/gtest.h" + +using namespace stim; + +TEST(dem_to_matrix, xor_sort) { + std::vector vals; + size_t v; + + vals = {}; + v = xor_sort(vals); + ASSERT_EQ(v, 0); + ASSERT_EQ(vals, (std::vector{})); + + vals = {5}; + v = xor_sort(vals); + ASSERT_EQ(v, 1); + ASSERT_EQ(vals, (std::vector{5})); + + vals = {5, 6}; + v = xor_sort(vals); + ASSERT_EQ(v, 2); + ASSERT_EQ(vals, (std::vector{5, 6})); + + vals = {6, 5}; + v = xor_sort(vals); + ASSERT_EQ(v, 2); + ASSERT_EQ(vals, (std::vector{5, 6})); + + vals = {5, 5}; + v = xor_sort(vals); + ASSERT_EQ(v, 0); + + vals = {5, 5, 5}; + v = xor_sort(vals); + ASSERT_EQ(v, 1); + while (vals.size() > v) { + vals.pop_back(); + } + ASSERT_EQ(vals, (std::vector{5})); + + vals = {5, 6, 5, 6, 5}; + v = xor_sort(vals); + ASSERT_EQ(v, 1); + while (vals.size() > v) { + vals.pop_back(); + } + ASSERT_EQ(vals, (std::vector{5})); + + vals = {5, 6, 5, 6, 5, 2, 3, 5}; + v = xor_sort(vals); + ASSERT_EQ(v, 2); + while (vals.size() > v) { + vals.pop_back(); + } + ASSERT_EQ(vals, (std::vector{2, 3})); +} From 05f6037fdcf6d75b76672005fd839910cbc353ea Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Tue, 9 Apr 2024 00:56:19 -0700 Subject: [PATCH 2/2] sort --- src/stim/stabilizers/pauli_string_ref.inl | 18 +++++++++--------- src/stim/util_top/dem_to_matrix.cc | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/stim/stabilizers/pauli_string_ref.inl b/src/stim/stabilizers/pauli_string_ref.inl index 5d8a95994..943b67633 100644 --- a/src/stim/stabilizers/pauli_string_ref.inl +++ b/src/stim/stabilizers/pauli_string_ref.inl @@ -883,9 +883,9 @@ template void PauliStringRef::do_ZCX(const CircuitInstruction &inst) { const auto &targets = inst.targets; assert((targets.size() & 1) == 0); - for (size_t k = 0; k < inst.targets.size(); k += 2) { - size_t k2 = reverse_order ? inst.targets.size() - 2 - k : k; - size_t q1 = inst.targets[k2].data, q2 = inst.targets[k2 + 1].data; + for (size_t k = 0; k < targets.size(); k += 2) { + size_t k2 = reverse_order ? targets.size() - 2 - k : k; + size_t q1 = targets[k2].data, q2 = targets[k2 + 1].data; do_single_cx(inst, q1, q2); } } @@ -895,9 +895,9 @@ template void PauliStringRef::do_ZCY(const CircuitInstruction &inst) { const auto &targets = inst.targets; assert((targets.size() & 1) == 0); - for (size_t k = 0; k < inst.targets.size(); k += 2) { - size_t k2 = reverse_order ? inst.targets.size() - 2 - k : k; - size_t q1 = inst.targets[k2].data, q2 = inst.targets[k2 + 1].data; + for (size_t k = 0; k < targets.size(); k += 2) { + size_t k2 = reverse_order ? targets.size() - 2 - k : k; + size_t q1 = targets[k2].data, q2 = targets[k2 + 1].data; do_single_cy(inst, q1, q2); } } @@ -916,9 +916,9 @@ template void PauliStringRef::do_SWAP(const CircuitInstruction &inst) { const auto &targets = inst.targets; assert((targets.size() & 1) == 0); - for (size_t k = 0; k < inst.targets.size(); k += 2) { - size_t k2 = reverse_order ? inst.targets.size() - 2 - k : k; - size_t q1 = inst.targets[k2].data, q2 = inst.targets[k2 + 1].data; + for (size_t k = 0; k < targets.size(); k += 2) { + size_t k2 = reverse_order ? targets.size() - 2 - k : k; + size_t q1 = targets[k2].data, q2 = targets[k2 + 1].data; zs[q1].swap_with(zs[q2]); xs[q1].swap_with(xs[q2]); } diff --git a/src/stim/util_top/dem_to_matrix.cc b/src/stim/util_top/dem_to_matrix.cc index f057ad593..6544968b6 100644 --- a/src/stim/util_top/dem_to_matrix.cc +++ b/src/stim/util_top/dem_to_matrix.cc @@ -15,7 +15,7 @@ std::map, double> stim::dem_to_map(const DetectorErrorMo buf.sorted_items.push_back(t); } } - size_t kept = xor_sort(buf); + size_t kept = xor_sort(buf.sorted_items); buf.sorted_items.resize(kept); auto q = instruction.arg_data[0]; auto &p = result[buf];