Skip to content

Commit

Permalink
Merge branch 'main' of github.com:quantumlib/Stim into cpp2
Browse files Browse the repository at this point in the history
# Conflicts:
#	doc/python_api_reference_vDev.md
#	doc/stim.pyi
#	glue/python/src/stim/__init__.pyi
#	src/stim/circuit/circuit_pybind_test.py
  • Loading branch information
Strilanc committed Mar 12, 2024
2 parents 42adc06 + dc59667 commit cfa6ef3
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 12 deletions.
1 change: 1 addition & 0 deletions file_lists/test_files
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ src/stim/cmd/command_gen.test.cc
src/stim/cmd/command_m2d.test.cc
src/stim/cmd/command_sample.test.cc
src/stim/cmd/command_sample_dem.test.cc
src/stim/dem/dem_instruction.test.cc
src/stim/dem/detector_error_model.test.cc
src/stim/diagram/ascii_diagram.test.cc
src/stim/diagram/base64.test.cc
Expand Down
240 changes: 240 additions & 0 deletions src/stim/circuit/circuit.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,89 @@
using namespace stim;
using namespace stim_pybind;

std::set<DemTarget> py_dem_filter_to_dem_target_set(
const Circuit &circuit, const CircuitStats &stats, const pybind11::object &included_targets_filter) {
std::set<DemTarget> result;
auto add_all_dets = [&]() {
for (uint64_t k = 0; k < stats.num_detectors; k++) {
result.insert(DemTarget::relative_detector_id(k));
}
};
auto add_all_obs = [&]() {
for (uint64_t k = 0; k < stats.num_observables; k++) {
result.insert(DemTarget::observable_id(k));
}
};

bool has_coords = false;
std::map<uint64_t, std::vector<double>> cached_coords;
auto get_coords_cached = [&]() -> const std::map<uint64_t, std::vector<double>> & {
std::set<uint64_t> all_dets;
for (uint64_t k = 0; k < stats.num_detectors; k++) {
all_dets.insert(k);
}
if (!has_coords) {
cached_coords = circuit.get_detector_coordinates(all_dets);
has_coords = true;
}
return cached_coords;
};

if (included_targets_filter.is_none()) {
add_all_dets();
add_all_obs();
return result;
}
for (const auto &filter : included_targets_filter) {
bool fail = false;
if (pybind11::isinstance<ExposedDemTarget>(filter)) {
result.insert(pybind11::cast<ExposedDemTarget>(filter));
} else if (pybind11::isinstance<pybind11::str>(filter)) {
std::string s = pybind11::cast<std::string>(filter);
if (s == "D") {
add_all_dets();
} else if (s == "L") {
add_all_obs();
} else if (s.starts_with("D") || s.starts_with("L")) {
result.insert(DemTarget::from_text(s));
} else {
fail = true;
}
} else {
std::vector<double> prefix;
for (auto e : filter) {
if (pybind11::isinstance<pybind11::int_>(e) || pybind11::isinstance<pybind11::float_>(e)) {
prefix.push_back(pybind11::cast<double>(e));
} else {
fail = true;
break;
}
}
if (!fail) {
for (const auto &[target, coord] : get_coords_cached()) {
if (coord.size() >= prefix.size()) {
bool match = true;
for (size_t k = 0; k < prefix.size(); k++) {
match &= prefix[k] == coord[k];
}
if (match) {
result.insert(DemTarget::relative_detector_id(target));
}
}
}
}
}
if (fail) {
std::stringstream ss;
ss << "Don't know how to interpret '";
ss << pybind11::cast<std::string>(pybind11::repr(filter));
ss << "' as a dem target filter.";
throw std::invalid_argument(ss.str());
}
}
return result;
}

std::string circuit_repr(const Circuit &self) {
if (self.operations.empty()) {
return "stim.Circuit()";
Expand Down Expand Up @@ -2034,6 +2117,163 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_<Ci
)DOC")
.data());

c.def(
"detecting_regions",
[](const Circuit &self,
const pybind11::object &included_targets,
const pybind11::object &included_ticks,
bool ignore_anticommutation_errors) -> std::map<ExposedDemTarget, std::map<uint64_t, FlexPauliString>> {
auto stats = self.compute_stats();
auto included_target_set = py_dem_filter_to_dem_target_set(self, stats, included_targets);
std::set<uint64_t> included_tick_set;

if (included_ticks.is_none()) {
for (uint64_t k = 0; k < stats.num_ticks; k++) {
included_tick_set.insert(k);
}
} else {
for (const auto &t : included_ticks) {
included_tick_set.insert(pybind11::cast<uint64_t>(t));
}
}
auto result = circuit_to_detecting_regions(
self, included_target_set, included_tick_set, ignore_anticommutation_errors);
std::map<ExposedDemTarget, std::map<uint64_t, FlexPauliString>> exposed_result;
for (const auto &[k, v] : result) {
exposed_result.insert({ExposedDemTarget(k), std::move(v)});
}
return exposed_result;
},
pybind11::kw_only(),
pybind11::arg("targets") = pybind11::none(),
pybind11::arg("ticks") = pybind11::none(),
pybind11::arg("ignore_anticommutation_errors") = false,
clean_doc_string(R"DOC(
@signature def detecting_regions(self, *, targets: Optional[Iterable[stim.DemTarget | str | Iterable[float]]] = None, ticks: Optional[Iterable[int]] = None) -> Dict[stim.DemTarget, Dict[int, stim.PauliString]]:
Records where detectors and observables are sensitive to errors over time.
The result of this method is a nested dictionary, mapping detectors/observables
and ticks to Pauli sensitivities for that detector/observable at that time.
For example, if observable 2 has Z-type sensitivity on qubits 5 and 6 during
tick 3, then `result[stim.target_logical_observable_id(2)][3]` will be equal to
`stim.PauliString("Z5*Z6")`.
If you want sensitivities from more places in the circuit, besides just at the
TICK instructions, you can work around this by making a version of the circuit
with more TICKs.
Args:
targets: Defaults to everything (None).
When specified, this should be an iterable of filters where items
matching any one filter are included.
A variety of filters are supported:
stim.DemTarget: Includes the targeted detector or observable.
Iterable[float]: Coordinate prefix match. Includes detectors whose
coordinate data begins with the same floats.
"D": Includes all detectors.
"L": Includes all observables.
"D#" (e.g. "D5"): Includes the detector with the specified index.
"L#" (e.g. "L5"): Includes the observable with the specified index.
ticks: Defaults to everything (None).
When specified, this should be a list of integers corresponding to
the tick indices to report sensitivities for.
ignore_anticommutation_errors: Defaults to False.
When set to False, invalid detecting regions that anticommute with a
reset will cause the method to raise an exception. When set to True,
the offending component will simply be silently dropped. This can
result in broken detectors having apparently enormous detecting
regions.
Returns:
Nested dictionaries keyed first by a `stim.DemTarget` identifying the
detector or observable, then by the index of the tick, leading to a
PauliString with that target's error sensitivity at that tick.
Note you can use `stim.PauliString.pauli_indices` to quickly get to the
non-identity terms in the sensitivity.
Examples:
>>> import stim
>>> detecting_regions = stim.Circuit('''
... R 0
... TICK
... H 0
... TICK
... CX 0 1
... TICK
... MX 0 1
... DETECTOR rec[-1] rec[-2]
... ''').detecting_regions()
>>> for target, tick_regions in detecting_regions.items():
... print("target", target)
... for tick, sensitivity in tick_regions.items():
... print(" tick", tick, "=", sensitivity)
target D0
tick 0 = +Z_
tick 1 = +X_
tick 2 = +XX
>>> circuit = stim.Circuit.generated(
... "surface_code:rotated_memory_x",
... rounds=5,
... distance=4,
... )
>>> detecting_regions = circuit.detecting_regions(
... targets=["L0", (2, 4), stim.DemTarget.relative_detector_id(5)],
... ticks=range(5, 15),
... )
>>> for target, tick_regions in detecting_regions.items():
... print("target", target)
... for tick, sensitivity in tick_regions.items():
... print(" tick", tick, "=", sensitivity)
target D1
tick 5 = +____________________X______________________
tick 6 = +____________________Z______________________
target D5
tick 5 = +______X____________________________________
tick 6 = +______Z____________________________________
target D14
tick 5 = +__________X_X______XXX_____________________
tick 6 = +__________X_X______XZX_____________________
tick 7 = +__________X_X______XZX_____________________
tick 8 = +__________X_X______XXX_____________________
tick 9 = +__________XXX_____XXX______________________
tick 10 = +__________XXX_______X______________________
tick 11 = +__________X_________X______________________
tick 12 = +____________________X______________________
tick 13 = +____________________Z______________________
target D29
tick 7 = +____________________Z______________________
tick 8 = +____________________X______________________
tick 9 = +____________________XX_____________________
tick 10 = +___________________XXX_______X_____________
tick 11 = +____________X______XXXX______X_____________
tick 12 = +__________X_X______XXX_____________________
tick 13 = +__________X_X______XZX_____________________
tick 14 = +__________X_X______XZX_____________________
target D44
tick 14 = +____________________Z______________________
target L0
tick 5 = +_X________X________X________X______________
tick 6 = +_X________X________X________X______________
tick 7 = +_X________X________X________X______________
tick 8 = +_X________X________X________X______________
tick 9 = +_X________X_______XX________X______________
tick 10 = +_X________X________X________X______________
tick 11 = +_X________XX_______X________XX_____________
tick 12 = +_X________X________X________X______________
tick 13 = +_X________X________X________X______________
tick 14 = +_X________X________X________X______________
)DOC")
.data());

c.def(
"without_noise",
&Circuit::without_noise,
Expand Down
50 changes: 40 additions & 10 deletions src/stim/circuit/circuit_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,7 +1703,7 @@ def test_decomposed():
assert stim.Circuit("""
ISWAP 0 1 2 1
TICK
CPP X2*X1 !Z1*Z2
MPP X1*Z2*Y3
""").decomposed() == stim.Circuit("""
H 0
CX 0 1 1 0
Expand All @@ -1714,13 +1714,43 @@ def test_decomposed():
H 1
S 1 2
TICK
H 1 2
CX 2 1
H 2 2
CX 1 2
H 2
S 1 1
H 2
CX 2 1
H 1 2
H 1 3
S 3
H 3
S 3 3
CX 2 1 3 1
M 1
CX 2 1 3 1
H 3
S 3
H 3
S 3 3
H 1
""")


def test_detecting_regions():
assert stim.Circuit('''
R 0
TICK
H 0
TICK
CX 0 1
TICK
MX 0 1
DETECTOR rec[-1] rec[-2]
''').detecting_regions() == {stim.DemTarget.relative_detector_id(0): {
0: stim.PauliString("Z_"),
1: stim.PauliString("X_"),
2: stim.PauliString("XX"),
}}


def test_detecting_region_filters():
c = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3)
assert len(c.detecting_regions(targets=["D"])) == c.num_detectors
assert len(c.detecting_regions(targets=["L"])) == c.num_observables
assert len(c.detecting_regions()) == c.num_observables + c.num_detectors
assert len(c.detecting_regions(targets=["D0"])) == 1
assert len(c.detecting_regions(targets=["D0", "L0"])) == 2
assert len(c.detecting_regions(targets=[stim.target_relative_detector_id(0), "D0"])) == 1
27 changes: 25 additions & 2 deletions src/stim/dem/dem_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cmath>

#include "stim/arg_parse.h"
#include "stim/dem/detector_error_model.h"
#include "stim/simulators/error_analyzer.h"
#include "stim/str_util.h"
Expand All @@ -11,14 +12,17 @@ using namespace stim;
constexpr uint64_t OBSERVABLE_BIT = uint64_t{1} << 63;
constexpr uint64_t SEPARATOR_SYGIL = UINT64_MAX;

constexpr uint64_t MAX_OBS = 0xFFFFFFFF;
constexpr uint64_t MAX_DET = (uint64_t{1} << 62) - 1;

DemTarget DemTarget::observable_id(uint64_t id) {
if (id > 0xFFFFFFFF) {
if (id > MAX_OBS) {
throw std::invalid_argument("id > 0xFFFFFFFF");
}
return {OBSERVABLE_BIT | id};
}
DemTarget DemTarget::relative_detector_id(uint64_t id) {
if (id >= (uint64_t{1} << 62)) {
if (id > MAX_DET) {
throw std::invalid_argument("Relative detector id too large.");
}
return {id};
Expand Down Expand Up @@ -75,6 +79,25 @@ void DemTarget::shift_if_detector_id(int64_t offset) {
data = (uint64_t)((int64_t)data + offset);
}
}
DemTarget DemTarget::from_text(std::string_view text) {
if (!text.empty()) {
bool is_det = text[0] == 'D';
bool is_obs = text[0] == 'L';
if (is_det || is_obs) {
int64_t parsed = 0;
if (parse_int64(text.substr(1), &parsed)) {
if (parsed >= 0) {
if (is_det && parsed <= (int64_t)MAX_DET) {
return DemTarget::relative_detector_id(parsed);
} else if (is_obs && parsed <= (int64_t)MAX_OBS) {
return DemTarget::observable_id(parsed);
}
}
}
}
}
throw std::invalid_argument("Failed to parse as a stim.DemTarget: '" + std::string(text) + "'");
}

bool DemInstruction::operator<(const DemInstruction &other) const {
if (type != other.type) {
Expand Down
2 changes: 2 additions & 0 deletions src/stim/dem/dem_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct DemTarget {
bool operator!=(const DemTarget &other) const;
bool operator<(const DemTarget &other) const;
std::string str() const;

static DemTarget from_text(std::string_view text);
};

struct DetectorErrorModel;
Expand Down
Loading

0 comments on commit cfa6ef3

Please sign in to comment.