From fb8d9fa6ab70d6d5363d195e9633833bb4ae1237 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Mon, 19 Feb 2024 16:24:58 +0100 Subject: [PATCH] Cpp check nodes (#420) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add logical node property * remove unused and untested functions * rename boundary nodes as logical nodes (if associated with a logical) * fix issues * linting * add graph conversion * rename old file * Create pymatching_decoder.py (#415) Creates a matching object which can be used to decode a counts string via the method process. * linting * remove old matching code and tests * remove old matching code and tests * remove old imports * fix fault id issue * standardize stim code circuit * get rid of measured_logicals * remove measuremed_logicals * remove css decoding graph * remove css decoding graph * add graph conversions * make matching work for non-stim * add matching tests * convert to edge graph only when needed * black * update doc string * correct logical nodes * 2 qubit paulis for cx gates * extend beyond cx * allow default check nodes to be used * rename pymatching decoder * add cpp check_nodes * lint * fix building * correct path * add is_cluster_neutral * fix issues * fix is_cluster_neutral * remove legacy methods * add detail to error messages * fix extra logicals * fix minimals * add detail to message" " * thing * initialize frustration * initialize frustration * initialize frustration * limit loop * limit loop * limit loop * limit loop * limit loop * fun stuff * Revert "limit loop" This reverts commit 8036e3c9ee7f9b0d9411599e35a48e70d1172b19. * revert to 8036e3c * undo reformatting * add heavy-hex decoder test * add heavy-hex decoder test * add heavy-hex decoder test * add heavy-hex decoder test * add heavy-hex decoder test * linting --------- Co-authored-by: Bence Hetényi <55080156+hetenyib@users.noreply.github.com> --- CMakeLists.txt | 4 + setup.py | 7 +- src/qiskit_qec/CMakeLists.txt | 3 +- src/qiskit_qec/circuits/CMakeLists.txt | 28 + src/qiskit_qec/circuits/arctools.h | 26 + .../circuits/bindings/circuits_bindings.cpp | 12 + src/qiskit_qec/circuits/code_circuit.py | 12 +- src/qiskit_qec/circuits/css_code.py | 12 +- src/qiskit_qec/circuits/extensions.py | 42 ++ src/qiskit_qec/circuits/intern/arctools.cpp | 221 ++++++ src/qiskit_qec/circuits/repetition_code.py | 341 +++------- src/qiskit_qec/circuits/stim_code_circuit.py | 8 +- src/qiskit_qec/circuits/surface_code.py | 49 +- src/qiskit_qec/decoders/__init__.py | 7 +- src/qiskit_qec/decoders/base_matcher.py | 27 - .../decoders/circuit_matching_decoder.py | 639 ------------------ src/qiskit_qec/decoders/decoding_graph.py | 364 ++++------ src/qiskit_qec/decoders/hdrg_decoders.py | 97 ++- src/qiskit_qec/decoders/hhc_decoder.py | 203 ------ src/qiskit_qec/decoders/pymatching_decoder.py | 113 ++++ src/qiskit_qec/decoders/pymatching_matcher.py | 65 -- src/qiskit_qec/decoders/repetition_decoder.py | 54 -- src/qiskit_qec/decoders/rustworkx_matcher.py | 186 ----- src/qiskit_qec/decoders/temp_graph_util.py | 42 -- src/qiskit_qec/decoders/three_bit_decoder.py | 58 -- .../utils/decoding_graph_attributes.py | 74 +- src/qiskit_qec/utils/stim_tools.py | 11 +- test/code_circuits/test_rep_codes.py | 311 ++++++++- test/code_circuits/test_surface_codes.py | 36 +- test/heavy_hex_codes/__init__.py | 11 - test/heavy_hex_codes/test_heavy_hex_code.py | 172 ----- .../heavy_hex_codes/test_heavy_hex_decoder.py | 322 --------- test/matching/__init__.py | 11 - test/matching/test_circuitmatcher.py | 355 ---------- test/matching/test_matching.py | 111 +++ test/matching/test_pymatchingmatcher.py | 55 -- test/matching/test_repetitionmatcher.py | 111 --- test/matching/test_retworkxmatcher.py | 81 --- 38 files changed, 1207 insertions(+), 3074 deletions(-) create mode 100644 src/qiskit_qec/circuits/CMakeLists.txt create mode 100644 src/qiskit_qec/circuits/arctools.h create mode 100644 src/qiskit_qec/circuits/bindings/circuits_bindings.cpp create mode 100644 src/qiskit_qec/circuits/extensions.py create mode 100644 src/qiskit_qec/circuits/intern/arctools.cpp delete mode 100644 src/qiskit_qec/decoders/base_matcher.py delete mode 100644 src/qiskit_qec/decoders/circuit_matching_decoder.py delete mode 100644 src/qiskit_qec/decoders/hhc_decoder.py create mode 100644 src/qiskit_qec/decoders/pymatching_decoder.py delete mode 100644 src/qiskit_qec/decoders/pymatching_matcher.py delete mode 100644 src/qiskit_qec/decoders/repetition_decoder.py delete mode 100644 src/qiskit_qec/decoders/rustworkx_matcher.py delete mode 100644 src/qiskit_qec/decoders/three_bit_decoder.py delete mode 100644 test/heavy_hex_codes/__init__.py delete mode 100644 test/heavy_hex_codes/test_heavy_hex_code.py delete mode 100644 test/heavy_hex_codes/test_heavy_hex_decoder.py delete mode 100644 test/matching/__init__.py delete mode 100644 test/matching/test_circuitmatcher.py create mode 100644 test/matching/test_matching.py delete mode 100644 test/matching/test_pymatchingmatcher.py delete mode 100644 test/matching/test_repetitionmatcher.py delete mode 100644 test/matching/test_retworkxmatcher.py diff --git a/CMakeLists.txt b/CMakeLists.txt index bdee1c7a..b0d295a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,5 +60,9 @@ set(ANALYSIS_BINDINGS_SRC "src/qiskit_qec/analysis/bindings/analysis_bindings.cp pybind11_add_module(_c_analysis ${ANALYSIS_BINDINGS_SRC}) target_link_libraries(_c_analysis PRIVATE libanalysis) +set(CIRCUITS_BINDINGS_SRC "src/qiskit_qec/circuits/bindings/circuits_bindings.cpp") +pybind11_add_module(_c_circuits ${CIRCUITS_BINDINGS_SRC}) +target_link_libraries(_c_circuits PRIVATE libcircuits) + diff --git a/setup.py b/setup.py index 9014389a..8e8944ef 100644 --- a/setup.py +++ b/setup.py @@ -177,13 +177,18 @@ def build_extension(self, ext: CMakeExtension) -> None: python_requires=">=3.8", include_package_data=True, install_requires=(REQUIREMENTS,), - ext_modules=[CMakeExtension("qiskit_qec.analysis._c_analysis")], + ext_modules=[ + CMakeExtension("qiskit_qec.circuits._c_circuits"), + CMakeExtension("qiskit_qec.analysis._c_analysis"), + ], packages=find_packages( where="src", exclude=[ "test*", "src/qiskit_qec/analysis/bindings*", "src/qiskit_qec/analysis/intern*", + "src/qiskit_qec/circuits/bindings*", + "src/qiskit_qec/circuits/intern*", "src/qiskit_qec/codes/codebase/data*", ], ), diff --git a/src/qiskit_qec/CMakeLists.txt b/src/qiskit_qec/CMakeLists.txt index f8db0eca..566d8589 100644 --- a/src/qiskit_qec/CMakeLists.txt +++ b/src/qiskit_qec/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(analysis) \ No newline at end of file +add_subdirectory(analysis) +add_subdirectory(circuits) \ No newline at end of file diff --git a/src/qiskit_qec/circuits/CMakeLists.txt b/src/qiskit_qec/circuits/CMakeLists.txt new file mode 100644 index 00000000..3981da88 --- /dev/null +++ b/src/qiskit_qec/circuits/CMakeLists.txt @@ -0,0 +1,28 @@ +# Code circuit library +cmake_minimum_required(VERSION 3.12) +project(Circuits) + +set(CIRCUITS_SRC + intern/arctools.cpp + + arctools.h +) + +add_library(libcircuits + STATIC + ${CIRCUITS_SRC} +) + +target_include_directories(libcircuits + PUBLIC + "${CMAKE_CURRENT_SOURCE_DIR}" +) + +if (NOT (MSVC)) + target_compile_options(libcircuits PRIVATE -fno-strict-aliasing -fPIC ${ARCH_OPT}) +else () + target_compile_options(libcircuits PRIVATE -fPIC ${ARCH_OPT}) +endif () + +# Set prefix to "" since lib is already in project name +set_target_properties(libcircuits PROPERTIES PREFIX "") diff --git a/src/qiskit_qec/circuits/arctools.h b/src/qiskit_qec/circuits/arctools.h new file mode 100644 index 00000000..9835abad --- /dev/null +++ b/src/qiskit_qec/circuits/arctools.h @@ -0,0 +1,26 @@ +#ifndef __ArcTools__ +#define __ArcTools__ + +#include +#include +#include +#include + +std::vector check_nodes( + std::vector> nodes, bool ignore_extra_logicals, bool minimal, + std::map, std::set> cycle_dict, + std::vector> link_graph, + std::map> link_neighbors, + std::vector z_logicals + ); + +bool is_cluster_neutral( + std::vector> nodes, bool ignore_extra_logicals, bool minimal, + std::map, std::set> cycle_dict, + std::vector> link_graph, + std::map> link_neighbors, + std::vector z_logicals, + bool linear + ); + +#endif \ No newline at end of file diff --git a/src/qiskit_qec/circuits/bindings/circuits_bindings.cpp b/src/qiskit_qec/circuits/bindings/circuits_bindings.cpp new file mode 100644 index 00000000..1cdf4eca --- /dev/null +++ b/src/qiskit_qec/circuits/bindings/circuits_bindings.cpp @@ -0,0 +1,12 @@ +#include "arctools.h" +#include +#include + +namespace py = pybind11; + +PYBIND11_MODULE(_c_circuits, module) +{ + module.doc() = "qiskit-qec code circuit extensions"; + module.def("_c_check_nodes", &check_nodes, "check_nodes in C++"); + module.def("_c_is_cluster_neutral", &is_cluster_neutral, "is_cluster_neutral in C++"); +} \ No newline at end of file diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py index 714b27fc..d6bb28b3 100644 --- a/src/qiskit_qec/circuits/code_circuit.py +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -53,15 +53,7 @@ def string2nodes(self, string, **kwargs): pass @abstractmethod - def measured_logicals(self): - """ - Returns a list of logical operators, each expressed as a list of qubits for which - the parity of the final readouts corresponds to the raw logical readout. - """ - pass - - @abstractmethod - def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): + def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -69,7 +61,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are + ignore_extra_logical (bool): If `True`, undeeded logical nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. diff --git a/src/qiskit_qec/circuits/css_code.py b/src/qiskit_qec/circuits/css_code.py index 3f94ba75..67a8f285 100644 --- a/src/qiskit_qec/circuits/css_code.py +++ b/src/qiskit_qec/circuits/css_code.py @@ -183,16 +183,6 @@ def _get_code_properties(self): self.z_stabilizers = self.code.z_stabilizers self.logical_x = self.code.logical_x self.logical_z = self.code.logical_z - # for the unionfind decoder - self.css_x_logical = self.logical_x - self.css_z_logical = self.logical_z - - def measured_logicals(self): - if self.basis == "x": - measured_logicals = self.logical_x - else: - measured_logicals = self.logical_z - return measured_logicals def _prepare_initial_state(self, qc, qregs, state): if state[0] == "1": @@ -289,7 +279,7 @@ def string2raw_logicals(self, string): log_outs = string2logical_meas(string, self.logicals, self.circuit["0"].clbits) return log_outs - def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): + def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): raise NotImplementedError def is_cluster_neutral(self, atypical_nodes): diff --git a/src/qiskit_qec/circuits/extensions.py b/src/qiskit_qec/circuits/extensions.py new file mode 100644 index 00000000..b8c2cdc8 --- /dev/null +++ b/src/qiskit_qec/circuits/extensions.py @@ -0,0 +1,42 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2021. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=unused-import + +"""Code circuit extensions""" + +import logging # for logging! + +logger = logging.getLogger(__name__) + +# Load extensions if available and set appriate indicator flags + +try: + from qiskit_qec.analysis._c_circuits import _c_check_nodes + + C_CHECK_NODES = True +except ImportError as import_error: + logger.exception( # pylint: disable=logging-fstring-interpolation + f"from qiskit_qec.analysis._c_circuits import _c_check_nodes \ + failed, raising {import_error}" + ) + C_CHECK_NODES = False + +try: + from qiskit_qec.analysis._c_circuits import _c_is_cluster_neutral + + C_IS_CLUSTER_NEUTRAL = True +except ImportError as import_error: + logger.exception( # pylint: disable=logging-fstring-interpolation + f"from qiskit_qec.analysis._c_circuits import _c_is_cluster_neutral \ + failed, raising {import_error}" + ) + C_IS_CLUSTER_NEUTRAL = False diff --git a/src/qiskit_qec/circuits/intern/arctools.cpp b/src/qiskit_qec/circuits/intern/arctools.cpp new file mode 100644 index 00000000..af7c08b6 --- /dev/null +++ b/src/qiskit_qec/circuits/intern/arctools.cpp @@ -0,0 +1,221 @@ +#include "arctools.h" +#include + +bool is_cluster_neutral( + std::vector> nodes, bool ignore_extra_logicals, bool minimal, + std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, std::vector z_logicals, + bool linear + ) { + if (linear) { + return nodes.size()%2==0; + } else { + std::vector output = check_nodes(nodes, ignore_extra_logicals, minimal, cycle_dict, link_graph, link_neighbors, z_logicals); + return (output[0]==1) and (output.size()==2); + } + +}; + +std::vector check_nodes( + std::vector> nodes, bool ignore_extra_logicals, bool minimal, + std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, std::vector z_logicals + ) { + + // output[0] is neutral (as int), output[1] is num_errors, rest is list of given logicals + std::vector output; + + // we convert to flat nodes, which are a std::tuple with (q0, q1, boundary) + // if we have an even number of corresponding nodes, they cancel + std::map, int> node_counts; + for (auto & node : nodes) { + node_counts[std::make_tuple(std::get<0>(node), std::get<1>(node), std::get<3>(node))] = 0; + } + for (auto & node : nodes) { + node_counts[std::make_tuple(std::get<0>(node), std::get<1>(node), std::get<3>(node))] ++; + } + // make a std::vector of the net flat nodes + std::vector> flat_nodes; + for (auto & node_count : node_counts) { + if (node_count.second % 2 == 1) { + flat_nodes.push_back(node_count.first); + } + } + // see what logicals and bulk nodes are given + std::set given_logicals; + std::set> bulk_nodes; + for (auto & node : flat_nodes) { + if (std::get<2>(node)) { + given_logicals.insert(std::get<0>(node)); + } else { + bulk_nodes.insert(node); + } + } + + if (bulk_nodes.size()==0){ + // without bulk nodes, neutral only if no logical nodes are given (unless this is ignored) + int neutral = (ignore_extra_logicals || given_logicals.size()==0); + int num_errors = 0; + // compile the output + output.push_back(neutral); + output.push_back(num_errors); + // no flipped logicals need to be added + } else { + std::map parities; + // check how many times the bulk nodes turn up in each cycle + for (int c = 0; c < cycle_dict.size(); c++){ + parities[c] = 0; + } + for (auto & node: bulk_nodes) { + for (auto & c: cycle_dict[(std::make_tuple(std::get<0>(node), std::get<1>(node)))]){ + parities[c]++; + } + } + // we get frustration if any of these is odd + bool frust = false; + for (auto & parity: parities){ + frust = frust || (parity.second % 2 == 1); + } + if (frust) { + // if it's frustrated, it's not neutral + output.push_back(0); + // number of errors not counted + output.push_back(-1); + // no flipped logicals need to be added + } else { + // now we must bicolor the qubits of the link graph, such that node edges connect unlike edges + + // first make a list of the qubits that definitely need to be covered + // (those in the bulk nodes) and see how often each comes up in the nodes + std::set node_qubits; + std::map nq_nums; + for (auto & node: bulk_nodes){ + std::vector qs = { + std::get<0>(node), + std::get<1>(node) + }; + for (auto & q: qs) { + if (node_qubits.insert(q).second){ + nq_nums[q] = 1; + } else { + nq_nums[q] += 1; + } + } + } + // find the most mentioned qubit + int root; + int max_num = 0; + for (auto & nq_num: nq_nums){ + if (nq_num.second > max_num){ + root = nq_num.first; + max_num = nq_num.second; + } + } + // start colouring with the most mentioned qubit + std::map color; + color[root] = 0; + std::vector newly_colored = {root}; + std::set colored = {root}; + // stop once all node qubits are coloured and one color has stopped growing + bool converged = false; + node_qubits.erase(root); + std::map num_nodes = { + {0, 1}, + {1, 0} + }; + std::map last_num_nodes = num_nodes; + while (not converged){ + // for each newly coloured qubit + std::vector very_newly_colored; + for (auto & n: newly_colored){ + // loop through all the neighbours + for (auto & nn: link_neighbors[n]){ + // if they haven't yet been coloured + if (colored.find(nn) == colored.end()){ + // if this pair don't correspond to a bulk node, the new one is the same colour + if ((bulk_nodes.find(std::make_tuple(n,nn,false)) == bulk_nodes.end()) and (bulk_nodes.find(std::make_tuple(nn,n,false)) == bulk_nodes.end())){ + color[nn] = color[n]; + // otherwise, it's the opposite color + } else { + color[nn] = not color[n]; + } + very_newly_colored.push_back(nn); + colored.insert(nn); + num_nodes[color[nn]]++; + node_qubits.erase(nn); + } + } + } + converged = (node_qubits.size() == 0) and ((num_nodes[0] == last_num_nodes[0]) or (num_nodes[1] == last_num_nodes[1])); + newly_colored = very_newly_colored; + if (not converged){ + last_num_nodes = num_nodes; + } + } + // see which colour has converged + bool conv_color = (num_nodes[1] == last_num_nodes[1]); + // calculate the number of nodes for the other + num_nodes[not conv_color] = link_neighbors.size() - num_nodes[conv_color]; + // see which colour has the fewer qubits + int min_color = (num_nodes[1] <= num_nodes[0]); + // list the colours with the max error one first + // (unless we do min only) + std::vector cs; + cs.push_back(min_color); + if (not minimal){ + cs.push_back((min_color+1)%2); + } + // determine which flipped logicals correspond to which colour + std::vector> color_logicals = {{}, {}}; + for (auto & q: z_logicals){ + if (color.find(q) == color.end()){ + color[q] = (conv_color+1)%2; + } + color_logicals[color[q]].insert(q); + } + // see if we can find a color for which we have no extra logicals + // and see what additional logicals are required + std::set flipped_logicals; + std::set flipped_ng_logicals; + std::vector extra_logicals; + bool done = false; + int j = 0; + while (not done){ + flipped_logicals = {}; + flipped_ng_logicals = {}; + // see which logicals for this colour have been flipped + for (auto & q: color_logicals[cs[j]]){ + flipped_logicals.insert(q); + // and which of those were not given + if (given_logicals.find(q) == given_logicals.end()) { + flipped_ng_logicals.insert(q); + } + } + // see which extra logicals are given + extra_logicals = {}; + if (not ignore_extra_logicals) { + for (auto & q: given_logicals){ + if ((flipped_logicals.find(q) == flipped_logicals.end())) { + extra_logicals.push_back(q); + } + } + } + // if we have no extra logicals, or we've run out of colours, we move on + // otherwise we try the next colour + done = (extra_logicals.size()==0) || (j+1)==cs.size(); + if (not done){ + j++; + } + } + + // construct output + output.push_back(extra_logicals.size()==0); // neutral + output.push_back(num_nodes[cs[j]]); // num_errors + for (auto & q: flipped_ng_logicals){ + output.push_back(q); + } + + } + + } + + return output; +}; \ No newline at end of file diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index f507b1da..a11cc366 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -31,6 +31,8 @@ from qiskit_qec.circuits.code_circuit import CodeCircuit from qiskit_qec.utils import DecodingGraphEdge, DecodingGraphNode +from qiskit_qec.utils.decoding_graph_attributes import _nodes2cpp +from qiskit_qec.circuits._c_circuits import _c_check_nodes, _c_is_cluster_neutral def _separate_string(string): @@ -107,33 +109,9 @@ def __init__( self.syndrome_measurement(final=True) self.readout() - gauge_ops = [[j, j + 1] for j in range(self.d - 1)] - measured_logical = [[0]] - flip_logical = list(range(self.d)) - boundary = [[0], [self.d - 1]] - - if xbasis: - self.css_x_gauge_ops = gauge_ops - self.css_x_stabilizer_ops = gauge_ops - self.css_x_logical = measured_logical - self.css_x_boundary = boundary - self.css_z_gauge_ops = [] - self.css_z_stabilizer_ops = [] - self.css_z_logical = flip_logical - self.css_z_boundary = [] - self.basis = "x" - else: - self.css_x_gauge_ops = [] - self.css_x_stabilizer_ops = [] - self.css_x_logical = flip_logical - self.css_x_boundary = [] - self.css_z_gauge_ops = gauge_ops - self.css_z_stabilizer_ops = gauge_ops - self.css_z_logical = measured_logical - self.css_z_boundary = boundary - self.basis = "z" - self.round_schedule = self.basis - self.blocks = T + self.gauge_ops = [[j, j + 1] for j in range(self.d - 1)] + self.measured_logical = [[0], [self.d - 1]] + self.basis = "x" self.resets = resets self.delay = delay @@ -243,9 +221,6 @@ def readout(self): self.circuit[log].add_register(self.code_bit) self.circuit[log].measure(self.code_qubit, self.code_bit) - def measured_logicals(self): - return [[0]] - def _process_string(self, string): # logical readout taken from measured_log = string[0] + " " + string[self.d - 1] @@ -316,20 +291,18 @@ def string2nodes(self, string, **kwargs): logical = "0" string = self._process_string(string) - # [ , , ,...] + # [ , , ,...] separated_string = _separate_string(string) nodes = [] - # boundary nodes + # logical nodes boundary = separated_string[0] # [, ] for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: - i = [0, -1][bqec_index] - if self.basis == "z": - bqubits = [self.css_x_logical[i]] - else: - bqubits = [self.css_z_logical[i]] - bnode = DecodingGraphNode(is_boundary=True, qubits=bqubits, index=bqec_index) + bqubits = self.measured_logical[bqec_index] + bnode = DecodingGraphNode( + is_logical=True, is_boundary=True, qubits=bqubits, index=bqec_index + ) nodes.append(bnode) # bulk nodes @@ -338,10 +311,7 @@ def string2nodes(self, string, **kwargs): elements = separated_string[syn_type][syn_round] for qec_index, element in enumerate(elements[::-1]): if element == "1": - if self.basis == "z": - qubits = self.css_z_gauge_ops[qec_index] - else: - qubits = self.css_x_gauge_ops[qec_index] + qubits = self.gauge_ops[qec_index] node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index) nodes.append(node) return nodes @@ -354,9 +324,9 @@ def string2raw_logicals(self, string): Returns: list: Raw values for logical operators that correspond to nodes. """ - return string.split(" ", maxsplit=1)[0][-1] + return [string.split(" ", maxsplit=1)[0][-1]] - def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): + def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -364,7 +334,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are + ignore_extra_logical (bool): If `True`, undeeded boundary nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. @@ -380,14 +350,14 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): # see which qubits for logical zs are given and collect bulk nodes given_logicals = [] for node in nodes: - if node.is_boundary: + if node.is_logical: given_logicals += node.qubits given_logicals = set(given_logicals) # bicolour code qubits according to the domain walls walls = [] for node in nodes: - if not node.is_boundary: + if not node.is_logical: walls.append(node.qubits[1]) walls.sort() c = 0 @@ -427,7 +397,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): # if unneeded logical zs are given, cluster is not neutral # (unless this is ignored) - if (not ignore_extra_boundary) and given_logicals.difference(flipped_logicals): + if (not ignore_extra_logical) and given_logicals.difference(flipped_logicals): neutral = False # otherwise, report only needed logicals that aren't given else: @@ -437,11 +407,10 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): flipped_logical_nodes = [] for flipped_logical in flipped_logicals: qubits = [flipped_logical] - if self.basis == "z": - elem = self.css_z_boundary.index(qubits) - else: - elem = self.css_x_boundary.index(qubits) - node = DecodingGraphNode(is_boundary=True, qubits=qubits, index=elem) + elem = self.measured_logical.index(qubits) + node = DecodingGraphNode( + is_logical=True, is_boundary=True, qubits=qubits, index=elem + ) flipped_logical_nodes.append(node) if neutral and not flipped_logical_nodes: @@ -609,6 +578,8 @@ def __init__( self._syndrome_measurement(final=True) self._readout() + self._cpp_link_graph, self._cpp_link_neighbors = self._links2cpp() + def _get_link_graph(self, max_dist=1): graph = rx.PyGraph() for link in self.links: @@ -825,20 +796,6 @@ def _preparation(self): z_logicals = [min(self.code_index.keys())] self.z_logicals = z_logicals - # set css attributes for decoder - gauge_ops = [[link[0], link[2]] for link in self.links] - measured_logical = [[self.z_logicals[0]]] - flip_logical = list(range(self.d)) - boundary = [[logical] for logical in self.z_logicals] - self.css_x_gauge_ops = [] - self.css_x_stabilizer_ops = [] - self.css_x_logical = flip_logical - self.css_x_boundary = [] - self.css_z_gauge_ops = gauge_ops - self.css_z_stabilizer_ops = gauge_ops - self.css_z_logical = measured_logical - self.css_z_boundary = boundary - def _get_202(self, t): """ Returns the position within a 202 sequence for the current measurement round: @@ -973,9 +930,6 @@ def _readout(self): qc.add_register(self.code_bit) qc.measure(self.code_qubit, self.code_bit) - def measured_logicals(self): - return [[self.z_logicals[0]]] - def _process_string(self, string): # logical readout taken from assigned qubits measured_log = "" @@ -1101,8 +1055,8 @@ def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]: if (syn_type == 0 and (all_logicals or element != self.logical)) or ( syn_type != 0 and element == "1" ): - is_boundary = syn_type == 0 - if is_boundary: + is_logical = syn_type == 0 + if is_logical: elem_num = syn_round syn_round = 0 code_qubits = [self.z_logicals[elem_num]] @@ -1115,8 +1069,8 @@ def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]: if not tau: tau = 0 node = DecodingGraphNode( - is_boundary=is_boundary, - time=syn_round if not is_boundary else None, + is_logical=is_logical, + time=syn_round if not is_logical else None, qubits=code_qubits, index=elem_num, ) @@ -1152,14 +1106,32 @@ def flatten_nodes(nodes: List[DecodingGraphNode]): nodes_per_link[link_qubit] = 1 flat_nodes = [] for node in nodes: - if nodes_per_link[node.properties["link qubit"]] % 2: + if node.is_logical or node.is_boundary: + flat_nodes.append(node) + elif nodes_per_link[node.properties["link qubit"]] % 2: flat_node = deepcopy(node) flat_node.time = None if flat_node not in flat_nodes: flat_nodes.append(flat_node) + return flat_nodes - def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): + def _links2cpp(self): + """ + Convert data about the link graph to the form required by C++ functions. + """ + nodes = self.link_graph.nodes() + link_graph = [] + for edge in self.link_graph.edge_list(): + link_graph.append((nodes[edge[0]], nodes[edge[1]])) + link_neighbors = {} + for n, node in enumerate(self.link_graph.nodes()): + link_neighbors[node] = [] + for j in self.link_graph.neighbors(n): + link_neighbors[node].append(nodes[j]) + return link_graph, link_neighbors + + def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -1167,7 +1139,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are + ignore_extra_logical (bool): If `True`, undeeded boundary nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. @@ -1180,177 +1152,28 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): num_errors (int): Minimum number of errors required to create nodes. """ - nodes = self.flatten_nodes(nodes) - - # see which qubits for logical zs are given and collect bulk nodes - given_logicals = [] - bulk_nodes = [] - for node in nodes: - if node.is_boundary: - given_logicals += node.qubits - else: - bulk_nodes.append(node) - given_logicals = set(given_logicals) - - # see whether the bulk nodes are neutral - if bulk_nodes: + nodes = _nodes2cpp(nodes) - # check whether there are frustrated plaquettes - parities = {} - for node in bulk_nodes: - for c in self.cycle_dict[tuple(node.qubits)]: - if c in parities: - parities[c] += 1 - else: - parities[c] = 1 - for c, parity in parities.items(): - parities[c] = parity % 2 - frust = any(parities.values()) - - # if frust==True, no colouring is possible, so no need to do it - if frust: - # neutral if not frustrated - neutral = not frust - # empty because ignored - flipped_logical_nodes = [] - # None because not counted - num_errors = None - else: - # bicolor the nodes of the link graph, such that node edges connect unlike edges - link_qubits = set(node.properties["link qubit"] for node in nodes) - link_graph = self.link_graph - # all the qubits around cycles of the node edges have to be covered - ns_to_do = set() - for edge in [tuple(node.qubits) for node in bulk_nodes]: - for c in self.cycle_dict[edge]: - ns_to_do = ns_to_do.union(self.cycles[c]) - # if this gives us qubits to start with, we start with one - if ns_to_do: - n = ns_to_do.pop() - else: - # otherwise we commit to covering them all - ns_to_do = set(range(len(link_graph.nodes()))) - n = ns_to_do.pop() - node_color = {n: 0} - recently_colored = node_color.copy() - base_neutral = True - # count the number of nodes for each colour throughout - num_nodes = [1, 0] - last_num = [None, None] - fully_converged = False - last_converged = False - while base_neutral and not fully_converged: - # go through all nodes coloured in the last pass - newly_colored = {} - for n, c in recently_colored.items(): - # look at all the code qubits that are neighbours - incident_es = link_graph.incident_edges(n) - for e in incident_es: - edge = link_graph.edges()[e] - n0, n1 = link_graph.edge_list()[e] - if n0 == n: - nn = n1 - else: - nn = n0 - # see if the edge corresponds to one of the given nodes - dc = edge["link qubit"] in link_qubits - # if the neighbour is not yet coloured, colour it - # different color if edge is given node, same otherwise - if nn not in node_color: - new_c = (c + dc) % 2 - if nn not in newly_colored: - newly_colored[nn] = new_c - num_nodes[new_c] += 1 - # if it is coloured, check the colour is correct - else: - base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2) - for nn, c in newly_colored.items(): - node_color[nn] = c - if nn in ns_to_do: - ns_to_do.remove(nn) - recently_colored = newly_colored.copy() - # process is converged once one colour has stoppped growing - # once ns_to_do is empty - converged = (not ns_to_do) and ( - (num_nodes[0] == last_num[0] != 0) or (num_nodes[1] == last_num[1] != 0) - ) - fully_converged = converged and last_converged - if not fully_converged: - last_num = num_nodes.copy() - last_converged = converged - # see how many qubits are in the converged colour, and determine the min colour - for c in range(2): - if num_nodes[c] == last_num[c]: - conv_color = c - if num_nodes[conv_color] <= self.d / 2: - min_color = conv_color - else: - min_color = (conv_color + 1) % 2 - # calculate the number of nodes for the other - num_nodes[(conv_color + 1) % 2] = link_graph.num_nodes() - num_nodes[conv_color] - # get the set of min nodes - min_ns = set() - for n, c in node_color.items(): - if c == min_color: - min_ns.add(n) - - # see which qubits for logical zs are needed - flipped_logicals_all = [[], []] - if base_neutral: - for qubit in self.z_logicals: - n = link_graph.nodes().index(qubit) - dc = not n in min_ns - flipped_logicals_all[(min_color + dc) % 2].append(qubit) - for j in range(2): - flipped_logicals_all[j] = set(flipped_logicals_all[j]) - - # list the colours with the max error one first - # (unless we do min only) - cs = [] - if not minimal: - cs.append((min_color + 1) % 2) - cs.append(min_color) - - # see what happens for both colours - # if neutral for maximal, it's neutral - # otherwise, it is whatever it is for the minimal - for c in cs: - neutral = base_neutral - num_errors = num_nodes[c] - flipped_logicals = flipped_logicals_all[c] - - # if unneeded logical zs are given, cluster is not neutral - # (unless this is ignored) - if (not ignore_extra_boundary) and given_logicals.difference(flipped_logicals): - neutral = False - flipped_logicals = set() - # otherwise, report only needed logicals that aren't given - else: - flipped_logicals = flipped_logicals.difference(given_logicals) - - flipped_logical_nodes = [] - for flipped_logical in flipped_logicals: - node = DecodingGraphNode( - is_boundary=True, - qubits=[flipped_logical], - index=self.z_logicals.index(flipped_logical), - ) - flipped_logical_nodes.append(node) - - if neutral and not flipped_logical_nodes: - break - - else: - # without bulk nodes, neutral only if no boundary nodes are given - neutral = not bool(given_logicals) - # and no flipped logicals - flipped_logical_nodes = [] - num_errors = 0 + cpp_output = _c_check_nodes( + nodes, + ignore_extra_logical, + minimal, + self.cycle_dict, + self._cpp_link_graph, + self._cpp_link_neighbors, + self.z_logicals, + ) - # if unneeded logical zs are given, cluster is not neutral - # (unless this is ignored) - if (not ignore_extra_boundary) and given_logicals: - neutral = False + neutral = bool(cpp_output[0]) + num_errors = cpp_output[1] + flipped_logical_nodes = [] + for flipped_logical in cpp_output[2::]: + node = DecodingGraphNode( + is_logical=True, + qubits=[flipped_logical], + index=self.z_logicals.index(flipped_logical), + ) + flipped_logical_nodes.append(node) return neutral, flipped_logical_nodes, num_errors @@ -1362,11 +1185,17 @@ def is_cluster_neutral(self, atypical_nodes: dict): Args: atypical_nodes: dictionary in the form of the return value of string2nodes """ - if self._linear: - return not bool(len(atypical_nodes) % 2) - else: - neutral, logicals, _ = self.check_nodes(atypical_nodes) - return neutral and not logicals + nodes = _nodes2cpp(atypical_nodes) + return _c_is_cluster_neutral( + nodes, + False, + False, + self.cycle_dict, + self._cpp_link_graph, + self._cpp_link_neighbors, + self.z_logicals, + self._linear, + ) def transpile(self, backend, echo=("X", "X"), echo_num=(2, 0)): """ @@ -1452,7 +1281,7 @@ def _make_syndrome_graph(self): ) nodes: List[DecodingGraphNode] = [] for node in self.string2nodes(string, all_logicals=True): - if not node.is_boundary: + if not node.is_logical: for t in range(self.T + 1): new_node = deepcopy(node) new_node.time = t @@ -1471,7 +1300,7 @@ def _make_syndrome_graph(self): dt = abs((node1.time or 0) - (node0.time or 0)) adj = set(node0.qubits).intersection(set(node1.qubits)) if adj: - if (node0.is_boundary ^ node1.is_boundary) or dt <= 1: + if (node0.is_logical ^ node1.is_logical) or dt <= 1: edges.append((n0, n1)) elif not self.resets: if node0.qubits == node1.qubits and dt == 2: @@ -1485,7 +1314,7 @@ def _make_syndrome_graph(self): source = nodes[n0] target = nodes[n1] qubits = [] - if not (source.is_boundary and target.is_boundary): + if not (source.is_logical and target.is_logical): qubits = list(set(source.qubits).intersection(target.qubits)) if source.time != target.time and len(qubits) > 1: qubits = [] @@ -1553,9 +1382,9 @@ def get_error_coords( else: error_probs = {} for n0, n1 in graph.edge_list(): - if nodes[n0].is_boundary: + if nodes[n0].is_logical: edge = (n1, n1) - elif nodes[n1].is_boundary: + elif nodes[n1].is_logical: edge = (n0, n0) else: edge = (n0, n1) diff --git a/src/qiskit_qec/circuits/stim_code_circuit.py b/src/qiskit_qec/circuits/stim_code_circuit.py index 888fc9d3..3325e5d8 100644 --- a/src/qiskit_qec/circuits/stim_code_circuit.py +++ b/src/qiskit_qec/circuits/stim_code_circuit.py @@ -191,7 +191,8 @@ def _helper(stim_circuit: StimCircuit, reps: int): self.decomp_stim_circuit = self.decompose_stim_circuit(self.stim_circuit) _helper(self.decomp_stim_circuit, 1) - self.circuit = self.qc + self.circuit = {"": self.qc} + self.base = "" # if a set of measurement comparisons is deterministically 1 in the absence of errors, # the set of syndromes is compared to that @@ -586,8 +587,5 @@ def _make_syndrome_graph(self): graph, hyperedges = detector_error_model_to_rx_graph(e, detectors=self.detectors) return graph, hyperedges - def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): - raise NotImplementedError - - def is_cluster_neutral(self, atypical_nodes): + def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): raise NotImplementedError diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 4775c800..6a7f2ba1 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -64,17 +64,15 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True): self._logicals["z"].append(list(range(self.d))) self._logicals["z"].append([self.d**2 - 1 - j for j in range(self.d)]) - # set info needed for css codes - self.css_x_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.xplaqs] - self.css_x_stabilizer_ops = self.css_x_gauge_ops - self.css_x_logical = [self._logicals["x"][0]] - self.css_x_boundary = [self._logicals["x"][0] + self._logicals["x"][1]] - self.css_z_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.zplaqs] - self.css_z_stabilizer_ops = self.css_z_gauge_ops - self.css_z_logical = [self._logicals["z"][0]] - self.css_z_boundary = [self._logicals["z"][0] + self._logicals["z"][1]] - self.round_schedule = self.basis - self.blocks = T + # set gauge and stabilizer info + self.x_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.xplaqs] + self.x_stabilizer_ops = self.x_gauge_ops + self.x_logical = [self._logicals["x"][0]] + self.x_boundary = [self._logicals["x"][0] + self._logicals["x"][1]] + self.z_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.zplaqs] + self.z_stabilizer_ops = self.z_gauge_ops + self.z_logical = [self._logicals["z"][0]] + self.z_boundary = [self._logicals["z"][0] + self._logicals["z"][1]] # quantum registers self._num_xy = int((d**2 - 1) / 2) @@ -316,13 +314,6 @@ def _string2changes(self, string): return syndrome_changes - def measured_logicals(self): - if self.basis == "x": - measured_logicals = self.css_x_logical - else: - measured_logicals = self.css_z_logical - return measured_logicals - def string2raw_logicals(self, string): """ Extracts raw logicals from output string. @@ -403,7 +394,7 @@ def string2nodes(self, string, **kwargs): for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: node = DecodingGraphNode( - is_boundary=True, + is_logical=True, qubits=self._logicals[self.basis][-bqec_index - 1], index=1 - bqec_index, ) @@ -416,14 +407,14 @@ def string2nodes(self, string, **kwargs): for qec_index, element in enumerate(elements[::-1]): if element == "1": if self.basis == "x": - qubits = self.css_x_stabilizer_ops[qec_index] + qubits = self.x_stabilizer_ops[qec_index] else: - qubits = self.css_z_stabilizer_ops[qec_index] + qubits = self.z_stabilizer_ops[qec_index] node = DecodingGraphNode(time=syn_round, qubits=qubits, index=qec_index) nodes.append(node) return nodes - def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): + def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -431,7 +422,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are + ignore_extra_logical (bool): If `True`, undeeded logical nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. @@ -444,9 +435,9 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): num_errors (int): Minimum number of errors required to create nodes. """ - bulk_nodes = [node for node in nodes if not node.is_boundary] - boundary_nodes = [node for node in nodes if node.is_boundary] - given_logicals = set(node.index for node in boundary_nodes) + bulk_nodes = [node for node in nodes if not node.is_logical] + logical_nodes = [node for node in nodes if node.is_logical] + given_logicals = set(node.index for node in logical_nodes) if self.basis == "z": coords = self._zplaq_coords @@ -454,7 +445,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): coords = self._xplaq_coords if (len(bulk_nodes) % 2) == 0: - if (len(boundary_nodes) % 2) == 0 or ignore_extra_boundary: + if (len(logical_nodes) % 2) == 0 or ignore_extra_logical: neutral = True flipped_logicals = set() # estimate num_errors from size @@ -490,7 +481,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): # if unneeded logical zs are given, cluster is not neutral # (unless this is ignored) - if (not ignore_extra_boundary) and given_logicals.difference(flipped_logicals): + if (not ignore_extra_logical) and given_logicals.difference(flipped_logicals): neutral = False # otherwise, report only needed logicals that aren't given else: @@ -501,7 +492,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): flipped_logical_nodes = [] for elem in flipped_logicals: node = DecodingGraphNode( - is_boundary=True, qubits=self._logicals[self.basis][elem], index=elem + is_logical=True, qubits=self._logicals[self.basis][elem], index=elem ) flipped_logical_nodes.append(node) diff --git a/src/qiskit_qec/decoders/__init__.py b/src/qiskit_qec/decoders/__init__.py index 71262f42..254d3ccb 100644 --- a/src/qiskit_qec/decoders/__init__.py +++ b/src/qiskit_qec/decoders/__init__.py @@ -24,14 +24,9 @@ :toctree: ../stubs/ DecodingGraph - CircuitModelMatchingDecoder - RepetitionDecoder - ThreeBitDecoder UnionFindDecoder """ from .decoding_graph import DecodingGraph -from .circuit_matching_decoder import CircuitModelMatchingDecoder -from .repetition_decoder import RepetitionDecoder -from .three_bit_decoder import ThreeBitDecoder +from .pymatching_decoder import PyMatchingDecoder from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder diff --git a/src/qiskit_qec/decoders/base_matcher.py b/src/qiskit_qec/decoders/base_matcher.py deleted file mode 100644 index 75bc8811..00000000 --- a/src/qiskit_qec/decoders/base_matcher.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Base matching object.""" -from abc import ABC, abstractmethod -from typing import List, Tuple, Dict, Set -import rustworkx as rx - - -class BaseMatcher(ABC): - """Matching subroutine base class.""" - - def __init__(self): - """Create the base matcher.""" - pass - - @abstractmethod - def preprocess(self, graph: rx.PyGraph): - """Do any preprocessing using the graph data.""" - raise NotImplementedError("Not implemented.") - - @abstractmethod - def find_errors( - self, - graph: rx.PyGraph, - idxmap: Dict[Tuple[int, List[int]], int], - highlighted: List[Tuple[int, Tuple[int]]], - ) -> Tuple[Set[int], Set[Tuple[int, Tuple[int]]]]: - """Process a set of highlighted vertices and return error locations.""" - raise NotImplementedError("Not implemented.") diff --git a/src/qiskit_qec/decoders/circuit_matching_decoder.py b/src/qiskit_qec/decoders/circuit_matching_decoder.py deleted file mode 100644 index f3b812a0..00000000 --- a/src/qiskit_qec/decoders/circuit_matching_decoder.py +++ /dev/null @@ -1,639 +0,0 @@ -"""Abstract object for matching decoders for CSS codes and circuit noise.""" - -import logging -from abc import ABC, abstractmethod -from copy import copy -from math import log -from typing import Dict, List, Tuple -from sympy import Poly, Symbol, symbols - -import rustworkx as rx -from qiskit import QuantumCircuit -from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.decoders.decoding_graph import CSSDecodingGraph, DecodingGraph -from qiskit_qec.utils import DecodingGraphEdge -from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher -from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher -from qiskit_qec.decoders.temp_code_util import temp_gauge_products, temp_syndrome -from qiskit_qec.exceptions import QiskitQECError -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel - - -class CircuitModelMatchingDecoder(ABC): - """Matching decoder for circuit noise.""" - - METHOD_RETWORKX: str = "rustworkx" - METHOD_PYMATCHING: str = "pymatching" - AVAILABLE_METHODS = {METHOD_RETWORKX, METHOD_PYMATCHING} - - def __init__( - self, - n: int, - css_x_gauge_ops: List[Tuple[int]], - css_x_stabilizer_ops: List[Tuple[int]], - css_x_boundary: List[int], - css_z_gauge_ops: List[Tuple[int]], - css_z_stabilizer_ops: List[Tuple[int]], - css_z_boundary: List[int], - circuit: QuantumCircuit, - model: PauliNoiseModel, - basis: str, - round_schedule: str, - blocks: int, - method: str, - uniform: bool, - decoding_graph: DecodingGraph = None, - annotate: bool = False, - ): - """Create a matching decoder. - - Specialized to (subsystem) CSS codes encoding one logical qubit, - and quantum circuits that prepare and measure in the Z or X basis - and do repeated Pauli measurements. - - n : block size of quantum code - css_x_gauge_ops : list of supports of X gauge operators - css_x_stabilizer_ops : list of supports of X stabilizers - css_x_boundary : list of qubits along the X-type boundary - css_z_gauge_ops : list of supports of Z gauge operators - css_z_stabilizer_ops : list of supports of Z stabilizers - css_x_boundary : list of qubits along the Z-type boundary - circuit : entire quantum circuit to build the decoding graph - model : noise for operations in the circuit - basis : initializaton and measurement basis ("x" or "z") - round_schedule : gauge measurements in each block - blocks : number of measurement blocks - method : matching implementation - uniform : use same edge weight everywhere? - annotate : for rustworkx method, compute self.matcher.annotated_graph - """ - self.n = n - self.css_x_gauge_ops = css_x_gauge_ops - self.css_x_stabilizer_ops = css_x_stabilizer_ops - self.css_x_boundary = css_x_boundary - self.css_z_gauge_ops = css_z_gauge_ops - self.css_z_stabilizer_ops = css_z_stabilizer_ops - self.css_z_boundary = css_z_boundary - self.model = model - self.blocks = blocks - - if self.blocks < 1: - raise QiskitQECError("expected positive integer for blocks") - self.round_schedule = round_schedule - if set(self.round_schedule) > set("xyz"): - raise QiskitQECError("expected round schedule of 'x', 'y', 'z' chars") - self.basis = basis - if not self.basis in ("x", "z"): - raise QiskitQECError("expected basis to be 'x' or 'z'") - - self.uniform = uniform - - if method not in self.AVAILABLE_METHODS: - raise QiskitQECError("fmethod {method} is not supported.") - self.method = method - self.matcher = None - if self.method == self.METHOD_PYMATCHING: - self.matcher = PyMatchingMatcher() - else: - self.matcher = RustworkxMatcher(annotate) - - self.z_gauge_products = temp_gauge_products(self.css_z_stabilizer_ops, self.css_z_gauge_ops) - self.x_gauge_products = temp_gauge_products(self.css_x_stabilizer_ops, self.css_x_gauge_ops) - - if decoding_graph: - ( - self.idxmap, - self.node_layers, - self.graph, - self.layer_types, - ) = self._process_graph(decoding_graph.graph, blocks, round_schedule, basis) - else: - dg = CSSDecodingGraph( - css_x_gauge_ops, - css_x_stabilizer_ops, - css_x_boundary, - css_z_gauge_ops, - css_z_stabilizer_ops, - css_z_boundary, - blocks, - round_schedule, - basis, - ) - - ( - self.idxmap, - self.node_layers, - self.graph, - self.layer_types, - ) = (dg.idxmap, dg.node_layers, dg.graph, dg.layer_types) - - logging.debug("layer_types = %s", self.layer_types) - - self.ridxmap = {v: k for k, v in self.idxmap.items()} - - self.circuit = circuit - self.event_map = {} - self.parameters = model.get_operations() - self.edge_weight_polynomials = {} - self.symbols = None - if not self.uniform: - fe = FaultEnumerator(circuit, order=1, method="propagator", model=self.model) - self.event_map = self._enumerate_events( - self.css_x_gauge_ops, - self.css_x_stabilizer_ops, - self.css_x_boundary, - self.x_gauge_products, - self.css_z_gauge_ops, - self.css_z_stabilizer_ops, - self.css_z_boundary, - self.z_gauge_products, - self.blocks, - self.round_schedule, - self.basis, - self.layer_types, - fe, - ) - logging.debug("event_map = %s", self.event_map) - self.symbols, self.edge_weight_polynomials = self._edge_weight_polynomials( - self.model, self.event_map - ) - logging.debug("symbols = %s", self.symbols) - logging.debug("edge_weight_polynomials = %s", self.edge_weight_polynomials) - self.graph = self._revise_decoding_graph( - self.idxmap, self.graph, self.edge_weight_polynomials - ) - - @staticmethod - def _process_graph( - graph: DecodingGraph, blocks: int, round_schedule: str, basis: str - ) -> Tuple[Dict[Tuple[int, List[int]], int], List[List[int]], DecodingGraph, List[str]]: - """Process a decoding graph to add required attributes.""" - - # symmetrize hook errors - for j, edge in enumerate(graph.edges()): - n0, n1 = graph.edge_list()[j] - source = graph.nodes()[n0] - target = graph.nodes()[n1] - if source.time != target.time: - if source.is_boundary == target.is_boundary == False: - new_source = copy(source) - new_source.time = target.time - nn0 = graph.nodes().index(new_source) - new_target = copy(target) - new_target.time = source.time - nn1 = graph.nodes().index(new_target) - graph.add_edge(nn0, nn1, edge) - - edges_to_remove = [] - for j, edge in enumerate(graph.edges()): - n0, n1 = graph.edge_list()[j] - source = graph.nodes()[n0] - target = graph.nodes()[n1] - - # add the required attributes - # highlighted', 'measurement_error','qubit_id' and 'error_probability' - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = int(source.time != target.time) - - # make it so times of boundary/boundary nodes agree - if source.is_boundary and not target.is_boundary: - if source.time != target.time: - new_source = copy(source) - new_source.time = target.time - n = graph.add_node(new_source) - edge.properties["measurement_error"] = 0 - edges_to_remove.append((n0, n1)) - graph.add_edge(n, n1, edge) - - # remove old boundary/boundary nodes - for n0, n1 in edges_to_remove: - graph.remove_edge(n0, n1) - - for n0, source in enumerate(graph.nodes()): - for n1, target in enumerate(graph.nodes()): - # add weightless nodes connecting different boundaries - if source.time == target.time: - if source.is_boundary and target.is_boundary: - if source.qubits != target.qubits: - edge = DecodingGraphEdge( - weight=0, - qubits=list(set(source.qubits).intersection((set(target.qubits)))), - ) - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = 0 - if (n0, n1) not in graph.edge_list(): - graph.add_edge(n0, n1, edge) - - # connect one of the boundaries at different times - if target.time == (source.time or 0) + 1: - if source.qubits == target.qubits == [0]: - edge = DecodingGraphEdge(weight=0, qubits=[]) - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = 0 - if (n0, n1) not in graph.edge_list(): - graph.add_edge(n0, n1, edge) - - # symmetrize edges - for j, edge in enumerate(graph.edges()): - n0, n1 = graph.edge_list()[j] - if (n1, n0) not in graph.edge_list(): - graph.add_edge(n1, n0, edge) - - idxmap = {} - for n, node in enumerate(graph.nodes()): - idxmap[node.time, tuple(node.qubits)] = n - - node_layers = [] - for node in graph.nodes(): - time = node.time or 0 - if len(node_layers) < time + 1: - node_layers += [[]] * (time + 1 - len(node_layers)) - node_layers[time].append(node.qubits) - - # create a list of decoding graph layer types - # the entries are 'g' for gauge and 's' for stabilizer - layer_types = [] - last_step = basis - for _ in range(blocks): - for step in round_schedule: - if basis == "z" and step == "z" and last_step == "z": - layer_types.append("g") - elif basis == "z" and step == "z" and last_step == "x": - layer_types.append("s") - elif basis == "x" and step == "x" and last_step == "x": - layer_types.append("g") - elif basis == "x" and step == "x" and last_step == "z": - layer_types.append("s") - last_step = step - if last_step == basis: - layer_types.append("g") - else: - layer_types.append("s") - - return idxmap, node_layers, graph, layer_types - - @staticmethod - def _revise_decoding_graph( - idxmap: Dict[Tuple[int, List[int]], int], - graph: rx.PyGraph, - edge_weight_polynomials: Dict[Tuple[int, Tuple[int]], Dict[Tuple[int, Tuple[int]], Poly]], - ) -> rx.PyGraph: - """Add edge weight polynomials to the decoding graph g and prune it. - - Update attribute "weight_poly" on decoding graph edges contained in - edge_weight_polynomials. Remove all other edges from the decoding graph - that have non-zero weight. - """ - for s1, sub in edge_weight_polynomials.items(): - for s2, wpoly in sub.items(): - if s1 not in idxmap: - raise QiskitQECError(f"vertex {s1} not in decoding graph") - if s2 not in idxmap: - raise QiskitQECError(f"vertex {s2} not in decoding graph") - if not graph.has_edge(idxmap[s1], idxmap[s2]): - # TODO: new edges may be needed for hooks, but raise exception for now - raise QiskitQECError("edge {s1} - {s2} not in decoding graph") - data = graph.get_edge_data(idxmap[s1], idxmap[s2]) - data.properties["weight_poly"] = wpoly - remove_list = [] - for source, target in graph.edge_list(): - edge_data = graph.get_edge_data(source, target) - if "weight_poly" not in edge_data.properties and edge_data.weight != 0: - # Remove the edge - remove_list.append((source, target)) - logging.info("remove edge (%d, %d)", source, target) - graph.remove_edges_from(remove_list) - return graph - - def update_edge_weights(self, model: PauliNoiseModel): - """Evaluate the numerical edge weights and update graph data. - - For each edge in the decoding graph that has a "weight_poly" - property, evaluate the polynomial at the given model parameters - and set the corresponding "weight" property. Once this is done, - recompute the shortest paths between pairs of vertices - in the decoding graph. - - model is a PauliNoiseModel whose error probabilities have been - previously assigned. The probabilities are then assigned to - the variables in self.symbols. - - Updates properties of matcher. - - Args: - model: moise model - """ - parameter_values = [model.get_error_probability(name) for name in self.parameters] - if not self.uniform: - if len(parameter_values) != len(self.parameters): - raise QiskitQECError("wrong number of error rate parameters") - symbol_list = [self.symbols[s] for s in self.parameters] - assignment = dict(zip(symbol_list, parameter_values)) - logging.info("update_edge_weights %s", str(assignment)) - # P(chain) = \prod_i (1-p_i)^{1-l(i)}*p_i^{l(i)} - # \propto \prod_i ((1-p_i)/p_i)^{l(i)} - # -log P(chain) \propto \sum_i -log[((1-p_i)/p_i)^{l(i)}] - # p_i is the probability that edge i carries an error - # l(i) is 1 if the link belongs to the chain and 0 otherwise - for source, target in self.graph.edge_list(): - edge_data = self.graph.get_edge_data(source, target).properties - if "weight_poly" in edge_data: - logging.info( - "update_edge_weights (%d, %d) %s", - source, - target, - str(edge_data["weight_poly"]), - ) - restriction = {x: assignment[x] for x in edge_data["weight_poly"].gens} - p = edge_data["weight_poly"].eval(restriction).evalf() - # if approximate edge flip probability is large, saturate at 1/2 - p = min(p, 0.5) - edge_data["weight"] = log((1 - p) / p) - self.matcher.preprocess(self.graph) - - def _enumerate_events( - self, - css_x_gauge_ops: List[Tuple[int]], - css_x_stabilizer_ops: List[Tuple[int]], - css_x_boundary: List[int], - x_gauge_products: List[int], - css_z_gauge_ops: List[Tuple[int]], - css_z_stabilizer_ops: List[Tuple[int]], - css_z_boundary: List[int], - z_gauge_products: List[int], - blocks: int, - round_schedule: str, - basis: str, - layer_types: List[str], - fault_enumerator: FaultEnumerator, - ) -> Dict[Tuple[int, Tuple[int]], Dict[Tuple[int, Tuple[int]], Dict[List[str], int]]]: - """Enumerate fault events in the input circuit. - - Use the code definition to identify highlighted edges - in a decoding graph and return a dict containing the events - that highlight each edge. - - The basis input value 'x' or 'z' informs whether to - look at the Z error or X error syndrome, respectively. - - fault_enumerator is a FaultEnumerator object. - - Return a dict containing the total number of events of each - type: event_map[v0][v1][name][pauli] contains the number of - events where a gate "name" fails with error "pauli" and - the edge between v0 and v1 is highlighted. - - Args: - css_x_gauge_ops: x gauge ops - css_x_stabilizer_ops: x stabilizer ops - css_x_boundary: x boundary - x_gauge_products: x gauge products - css_z_gauge_ops: z gauge ops - css_z_stabilizer_ops: z stabilizer ops - css_z_boundary: z boundary - z_gauge_products: z gauge products - blocks: blocks - round_schedule: - basis: basis - layer_types: layer types - fault_enumerator: fault enumerator - - Returns: - Events map - """ - event_map = {} - for event in fault_enumerator.generate(): - # Unpack the event data - # Select the first element since order = 1 - ctr = event[0] # event counter - comb = event[1][0] # combination of faulty operations - pauli = event[2][0] # Pauli error string - outcome = event[3] # result of simulation - logging.debug("event %d %s %s %s", ctr, comb, pauli, outcome) - - ( - x_gauge_outcomes, - z_gauge_outcomes, - final_outcomes, - ) = self._partition_outcomes(blocks, round_schedule, outcome) - - # Compute the highlighted vertices - # Note that this only depends on the stabilizers at each - # time and does not require an explicit decoding graph - gauge_outcomes, highlighted = self._highlighted_vertices( - css_x_gauge_ops, - css_x_stabilizer_ops, - css_x_boundary, - x_gauge_products, - css_z_gauge_ops, - css_z_stabilizer_ops, - css_z_boundary, - z_gauge_products, - basis, - layer_types, - x_gauge_outcomes, - z_gauge_outcomes, - final_outcomes, - ) - logging.debug("gauge_outcomes %s", gauge_outcomes) - logging.debug("highlighted %s", highlighted) - # Examine the highlighted vertices to find the edge of the - # decoding graph that corresponds with this fault event - if len(highlighted) > 2: - raise QiskitQECError("too many highlighted vertices for a " + "single fault event") - if len(highlighted) == 1: # _highlighted_vertices highlights the boundary - raise QiskitQECError("only one highlighted vertex for a " + "single fault event") - if len(highlighted) == 2: - v0 = highlighted[0] - v1 = highlighted[1] - if basis == "z": - boundary = css_z_boundary - elif basis == "x": - boundary = css_x_boundary - # Is the special boundary vertex highlighted? - if v1 == (0, tuple(boundary[0])): - # Replace it with an adjacent vertex - for b in boundary: - assert len(b) == 1 # Assume each b has one element - isect = set(b).intersection(set(v0[1])) - if len(isect) > 0: - v1 = (v0[0], tuple(b)) - break - submap1 = event_map.setdefault(v0, {}) - submap2 = submap1.setdefault(v1, {}) - submap3 = submap2.setdefault(comb, {}) - eventcount = submap3.setdefault(pauli, 0) - submap3[pauli] = eventcount + 1 - return event_map - - @abstractmethod - def _partition_outcomes( - self, blocks: int, round_schedule: str, outcome: List[int] - ) -> Tuple[List[List[int]], List[List[int]], List[int]]: - """Process the raw outcome and return results. - - blocks = number of repetition of round_schedule - round_schedule = string of z and x characters - outcome = list of 0, 1 outcomes - - Return lists x_gauge_outcomes, z_gauge_outcomes, final_outcomes. - """ - raise NotImplementedError("Not implemented.") - - def _edge_weight_polynomials( - self, - model: PauliNoiseModel, - event_map: Dict[Tuple[int, Tuple[int]], Dict[Tuple[int, Tuple[int]], Dict[List[str], int]]], - ) -> Tuple[ - Dict[str, Symbol], - Dict[Tuple[int, Tuple[int]], Dict[Tuple[int, Tuple[int]], Poly]], - ]: - """Compute edge weight polynomials given the error events. - - event_map is the output of _enumerate_events - """ - symbs = {n: symbols(n) for n in model.get_operations()} - edge_weight_expressions = {} - for n1, submap1 in event_map.items(): - for n2 in submap1.keys(): - # check the names in the event map and warn if symbs - # does not contain one of the names - for name in event_map[n1][n2].keys(): - if name not in symbs: - logging.warning("%s in event_map but not in model", name) - # construct a linear approximation to the edge probability - # using the weights from the noise model - expr = 0 - for name in self.model.get_operations(): - if name in event_map[n1][n2]: - for pauli, count in event_map[n1][n2][name].items(): - expr += count * model.get_pauli_weight(name, pauli) * symbs[name] - map1 = edge_weight_expressions.setdefault(n1, {}) - map1[n2] = Poly(expr) - return symbs, edge_weight_expressions - - def process(self, outcomes: List[int]) -> List[int]: - """Process a set of outcomes and return corrected final outcomes. - - Be sure to have called update_edge_weights for the - noise parameters. - - The result is a list of code.n integers that are 0 or 1. - These are the corrected values of the final transversal - measurement in the basis given by self.basis. - """ - logging.debug("process: outcomes = %s", outcomes) - - x_gauge_outcomes, z_gauge_outcomes, final_outcomes = self._partition_outcomes( - self.blocks, self.round_schedule, outcomes - ) - - gauge_outcomes, highlighted = self._highlighted_vertices( - self.css_x_gauge_ops, - self.css_x_stabilizer_ops, - self.css_x_boundary, - self.x_gauge_products, - self.css_z_gauge_ops, - self.css_z_stabilizer_ops, - self.css_z_boundary, - self.z_gauge_products, - self.basis, - self.layer_types, - x_gauge_outcomes, - z_gauge_outcomes, - final_outcomes, - ) - logging.info("process: gauge_outcomes = %s", gauge_outcomes) - logging.info("process: final_outcomes = %s", final_outcomes) - logging.info("process: highlighted = %s", highlighted) - - qubit_errors, _ = self.matcher.find_errors(self.graph, self.idxmap, highlighted) - - corrected_outcomes = copy(final_outcomes) - for i in qubit_errors: - if i != -1: - corrected_outcomes[i] = (corrected_outcomes[i] + 1) % 2 - logging.info("process: corrected_outcomes = %s", corrected_outcomes) - if self.basis == "z": - test = temp_syndrome(corrected_outcomes, self.css_z_stabilizer_ops) - elif self.basis == "x": - test = temp_syndrome(corrected_outcomes, self.css_x_stabilizer_ops) - logging.debug("process: test syndrome = %s", test) - if sum(test) != 0: - raise QiskitQECError("decoder failure: syndrome should be trivial!") - return corrected_outcomes - - @staticmethod - def _highlighted_vertices( - css_x_gauge_ops: List[Tuple[int]], - css_x_stabilizer_ops: List[Tuple[int]], - css_x_boundary: List[int], - x_gauge_products: List[int], - css_z_gauge_ops: List[Tuple[int]], - css_z_stabilizer_ops: List[Tuple[int]], - css_z_boundary: List[int], - z_gauge_products: List[int], - basis: str, - layer_types: List[str], - x_gauge_outcomes: List[List[int]], - z_gauge_outcomes: List[List[int]], - final_outcomes: List[int], - ) -> Tuple[List[List[int]], List[Tuple[int, Tuple[int]]]]: - """Identify highlighted vertices in the decoding graph for an outcome. - - Gauge operator measurement outcomes are lists of integers 0, 1. - """ - if basis == "z": - gauge_outcomes = z_gauge_outcomes - gauges = css_z_gauge_ops - stabilizers = css_z_stabilizer_ops - boundary = css_z_boundary - gauge_products = z_gauge_products - elif basis == "x": - gauge_outcomes = x_gauge_outcomes - gauges = css_x_gauge_ops - stabilizers = css_x_stabilizer_ops - boundary = css_x_boundary - gauge_products = x_gauge_products - final_gauges = [] - for supp in gauges: - parity = 0 - for i in supp: - if i != -1: # supp can contain -1 if no qubit at that site - parity += final_outcomes[i] - final_gauges.append(parity % 2) - gauge_outcomes.append(final_gauges) - - highlighted = [] - # Now iterate over the layers and look at appropriate - # syndrome differences - for i, ltype in enumerate(layer_types): - if ltype == "g": - # compare current and past gauge measurements - # if a bit differs, the vertex (i, g) is highlighted - for j, gauge_op in enumerate(gauges): - if (i == 0 and gauge_outcomes[i][j] == 1) or ( - i > 0 and gauge_outcomes[i][j] != gauge_outcomes[i - 1][j] - ): - highlighted.append((i, tuple(gauge_op))) - elif ltype == "s": - # compare current and past stabilizer measurements - # if a bit differs, the vertex (i, s) is highlighted - for j, stab_op in enumerate(stabilizers): - outcome = 0 - prior_outcome = 0 - for k in gauge_products[j]: - outcome += gauge_outcomes[i][k] - if i > 0: - prior_outcome += gauge_outcomes[i - 1][k] - outcome %= 2 - prior_outcome %= 2 - if outcome != prior_outcome: - highlighted.append((i, tuple(stab_op))) - logging.debug("|highlighted| = %d", len(highlighted)) - # If the total number of highlighted vertices is odd, - # add a single special highlighted vertex at the boundary - if len(highlighted) % 2 == 1: - highlighted.append((0, tuple(boundary[0]))) - logging.debug("highlighted = %s", highlighted) - return gauge_outcomes, highlighted diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index 68c2293f..69ef9599 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -18,7 +18,6 @@ Graph used as the basis of decoders. """ import itertools -import logging import copy from typing import List, Tuple, Union @@ -44,7 +43,7 @@ class DecodingGraph: METHOD_NAIVE: str = "naive" AVAILABLE_METHODS = {METHOD_SPITZ, METHOD_NAIVE} - def __init__(self, code, brute=False, graph=None): + def __init__(self, code, brute=False, graph=None, hyperedges=None): """ Args: code (CodeCircuit): The QEC code circuit object for which this decoding @@ -58,13 +57,14 @@ def __init__(self, code, brute=False, graph=None): if graph: self.graph = graph + self.hyperedges = hyperedges else: self._make_syndrome_graph() - self._logical_nodes = [] + self.logical_nodes = [] for node in self.graph.nodes(): - if node.is_boundary: - self._logical_nodes.append(node) + if node.is_logical: + self.logical_nodes.append(node) self.update_attributes() @@ -142,15 +142,15 @@ def _make_syndrome_graph(self): n0 = graph.nodes().index(source) n1 = graph.nodes().index(target) qubits = [] - if not (source.is_boundary and target.is_boundary): + if not (source.is_logical and target.is_logical): qubits = list(set(source.qubits).intersection(target.qubits)) if not qubits: continue if ( source.time != target.time and len(qubits) > 1 - and not source.is_boundary - and not target.is_boundary + and not source.is_logical + and not target.is_logical ): qubits = [] edge = DecodingGraphEdge(qubits, 1) @@ -232,9 +232,9 @@ def get_error_probs( boundary = [] error_probs = {} for n0, n1 in self.graph.edge_list(): - if self.graph[n0].is_boundary: + if self.graph[n0].is_logical: boundary.append(n1) - elif self.graph[n1].is_boundary: + elif self.graph[n1].is_logical: boundary.append(n0) else: if (1 - 2 * av_xor[n0, n1]) != 0: @@ -289,9 +289,9 @@ def get_error_probs( else: ratio = np.nan p = ratio / (1 + ratio) - if self.graph[n0].is_boundary and not self.graph[n1].is_boundary: + if self.graph[n0].is_logical and not self.graph[n1].is_logical: edge = (n1, n1) - elif not self.graph[n0].is_boundary and self.graph[n1].is_boundary: + elif not self.graph[n0].is_logical and self.graph[n1].is_logical: edge = (n0, n0) else: edge = (n0, n1) @@ -363,7 +363,7 @@ def make_error_graph(self, data: Union[str, List], all_logicals=True): nodes = self.code.string2nodes(data, all_logicals=all_logicals) else: if all_logicals: - nodes = list(set(data).union(set(self._logical_nodes))) + nodes = list(set(data).union(set(self.logical_nodes))) else: nodes = data for node in nodes: @@ -441,242 +441,100 @@ def clean_measurements(self, nodes: List): unpaired_ns = ns.difference(paired_ns) return [self.graph.nodes()[n] for n in unpaired_ns] - -class CSSDecodingGraph: - """ - Class to construct the decoding graph required for the CircuitModelMatchingDecoder - for a generic CSS code. - """ - - def __init__( - self, - css_x_gauge_ops: List[Tuple[int]], - css_x_stabilizer_ops: List[Tuple[int]], - css_x_boundary: List[Tuple[int]], - css_z_gauge_ops: List[Tuple[int]], - css_z_stabilizer_ops: List[Tuple[int]], - css_z_boundary: List[Tuple[int]], - blocks: int, - round_schedule: str, - basis: str, - ): - self.css_x_gauge_ops = css_x_gauge_ops - self.css_x_stabilizer_ops = css_x_stabilizer_ops - self.css_x_boundary = css_x_boundary - self.css_z_gauge_ops = css_z_gauge_ops - self.css_z_stabilizer_ops = css_z_stabilizer_ops - self.css_z_boundary = css_z_boundary - self.blocks = blocks - self.round_schedule = round_schedule - self.basis = basis - - self.layer_types = self._layer_types(self.blocks, self.round_schedule, self.basis) - - self._decoding_graph() - - @staticmethod - def _layer_types(blocks: int, round_schedule: str, basis: str) -> List[str]: - """Return a list of decoding graph layer types. - - The entries are 'g' for gauge and 's' for stabilizer. + def get_edge_graph(self): """ - layer_types = [] - last_step = basis - for _ in range(blocks): - for step in round_schedule: - if basis == "z" and step == "z" and last_step == "z": - layer_types.append("g") - elif basis == "z" and step == "z" and last_step == "x": - layer_types.append("s") - elif basis == "x" and step == "x" and last_step == "x": - layer_types.append("g") - elif basis == "x" and step == "x" and last_step == "z": - layer_types.append("s") - last_step = step - if last_step == basis: - layer_types.append("g") - else: - layer_types.append("s") - return layer_types - - def _decoding_graph(self): - """Construct the decoding graph for the given basis. + Returns a copy of the graph that uses edges to store information + about the effects of errors on logical operators. This is done + via the `'fault_ids'` of the edges. No logical nodes are present + in such a graph. - This method sets edge weights all to 1 and is based on - computing intersections of operator supports. + Returns: + edge_graph (rx.PyGraph): The edge graph. + """ - Returns a tuple (idxmap, node_layers, G) - where idxmap is a dict - mapping tuples (t, qubit_set) to integer vertex indices in the - decoding graph G. The list node_layers contains lists of nodes - for each time step. + nodes = self.graph.nodes() + # get a list of boundary nodes + bns = [] + for n, node in enumerate(nodes): + if node.is_logical: + bns.append(n) + # find pairs of bulk edges that have overlap with a boundary + bedge = {} + # and their edges connecting to the boundary, that we'll discard + spares = set() + for edge, (n0, n1) in zip(self.graph.edges(), self.graph.edge_list()): + if not nodes[n0].is_logical and not nodes[n1].is_logical: + for n2 in bns: + adj = set(edge.qubits).intersection(set(nodes[n2].qubits)) + if adj: + if (n0, n1) not in bedge: + bedge[n0, n1] = {nodes[n2].index} + else: + bedge[n0, n1].add(nodes[n2].index) + for n in (n0, n1): + spares.add((n, n2)) + spares.add((n2, n)) + # find bulk-boundary pairs not covered by the above + for (n0, n1) in self.graph.edge_list(): + n2 = None + for n in (n0, n1): + if nodes[n].is_logical: + n2 = n + if n2 is not None: + if (n0, n1) not in spares: + adj = set(nodes[n2].qubits) + for n in (n0, n1): + adj = adj.intersection(set(nodes[n].qubits)) + if (n0, n1) not in bedge: + bedge[n0, n1] = {nodes[n2].index} + else: + bedge[n0, n1].add(nodes[n2].index) + # make a new graph with fault_ids on boundary edges, and ignoring the spare edges + edge_graph = rx.PyGraph(multigraph=False) + for node in nodes: + edge_graph.add_node(copy.copy(node)) + for edge, (n0, n1) in zip(self.graph.edges(), self.graph.edge_list()): + if (n0, n1) in bedge: + edge.fault_ids = bedge[n0, n1] + edge_graph.add_edge(n0, n1, edge) + elif (n0, n1) not in spares and (n1, n0) not in spares: + edge.fault_ids = set() + edge_graph.add_edge(n0, n1, edge) + # turn logical nodes into boundary nodes + for node in edge_graph.nodes(): + if node.is_logical: + node.is_boundary = True + node.is_logical = False + return edge_graph + + def get_node_graph(self): """ - graph = rx.PyGraph(multigraph=False) - gauges = [] - stabilizers = [] - boundary = [] - if self.basis == "z": - gauges = self.css_z_gauge_ops - stabilizers = self.css_z_stabilizer_ops - boundary = self.css_z_boundary - elif self.basis == "x": - gauges = self.css_x_gauge_ops - stabilizers = self.css_x_stabilizer_ops - boundary = self.css_x_boundary - - # Construct the decoding graph - idx = 0 # vertex index counter - idxmap = {} # map from vertex data (t, qubits) to vertex index - node_layers = [] - for time, layer in enumerate(self.layer_types): - # Add vertices at time t - node_layer = [] - if layer == "g": - all_z = gauges - elif layer == "s": - all_z = stabilizers - for index, supp in enumerate(all_z): - node = DecodingGraphNode(time=time, qubits=supp, index=index) - node.properties["highlighted"] = True - graph.add_node(node) - logging.debug("node %d t=%d %s", idx, time, supp) - idxmap[(time, tuple(supp))] = idx - node_layer.append(idx) - idx += 1 - for index, supp in enumerate(boundary): - # Add optional is_boundary property for pymatching - node = DecodingGraphNode(is_boundary=True, qubits=supp, index=index) - node.properties["highlighted"] = False - graph.add_node(node) - logging.debug("boundary %d t=%d %s", idx, time, supp) - idxmap[(time, tuple(supp))] = idx - node_layer.append(idx) - idx += 1 - node_layers.append(node_layer) - if layer == "g": - all_z = gauges + boundary - elif layer == "s": - all_z = stabilizers + boundary - # Add space-like edges at time t - # The qubit sets of any pair of vertices at time - # t can intersect on multiple qubits. - # If they intersect, we add an edge and label it by - # one of the common qubits. This makes an assumption - # that the intersection operator is equivalent to a single - # qubit operator modulo the gauge group. - # Space-like edges do not correspond to syndrome errors, so the - # syndrome property is an empty list. - for i, op_g in enumerate(all_z): - for j in range(i + 1, len(all_z)): - op_h = all_z[j] - com = list(set(op_g).intersection(set(op_h))) - if -1 in com: - com.remove(-1) - if len(com) > 0: - # Include properties for use with pymatching: - # qubit_id is an integer or set of integers - # weight is a floating point number - # error_probability is a floating point number - edge = DecodingGraphEdge(qubits=[com[0]], weight=1) - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = 0 - graph.add_edge( - idxmap[(time, tuple(op_g))], idxmap[(time, tuple(op_h))], edge - ) - logging.debug("spacelike t=%d (%s, %s)", time, op_g, op_h) - logging.debug( - " qubits %s", - [com[0]], - ) + Returns a copy of the graph that uses logical nodes to store information + about the effects of errors on logical operators. No non-trivial `'fault_ids'` + are present in such a graph. - # Add boundary space-like edges - for i in range(len(boundary) - 1): - bound_g = boundary[i] - bound_h = boundary[i + 1] - # Include properties for use with pymatching: - # qubit_id is an integer or set of integers - # weight is a floating point number - # error_probability is a floating point number - edge = DecodingGraphEdge(qubits=[], weight=0) - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = 0 - graph.add_edge(idxmap[(time, tuple(bound_g))], idxmap[(time, tuple(bound_h))], edge) - logging.debug("spacelike boundary t=%d (%s, %s)", time, bound_g, bound_h) - - # Add (space)time-like edges from t to t-1 - # By construction, the qubit sets of pairs of vertices at graph and T - # at times t-1 and t respectively - # either (a) contain each other (graph subset T or T subset graph) and - # |graph|,|T|>1, - # (b) intersect on one or more qubits, or (c) are disjoint. - # In case (a), we add an edge that corresponds to a syndrome bit - # error at time t-1. - # In case (b), we add an edge that corresponds to a spacetime hook - # error, i.e. a syndrome bit error at time t-1 - # together with an error on one of the common qubits. Again - # this makes an assumption that all such errors are equivalent. - # In case (c), we do not add an edge. - # Important: some space-like hooks are not accounted for. - # They can have longer paths between non-intersecting operators. - # We will account for these in _revise_decoding_graph if needed. - if time > 0: - current_sets = gauges - prior_sets = gauges - if self.layer_types[time] == "s": - current_sets = stabilizers - if self.layer_types[time - 1] == "s": - prior_sets = stabilizers - for op_g in current_sets: - for op_h in prior_sets: - com = list(set(op_g).intersection(set(op_h))) - if -1 in com: - com.remove(-1) - if len(com) > 0: # not Case (c) - # Include properties for use with pymatching: - # qubit_id is an integer or set of integers - # weight is a floating point number - # error_probability is a floating point number - # Case (a) - if set(com) == set(op_h) or set(com) == set(op_g): - edge = DecodingGraphEdge(qubits=[], weight=1) - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = 1 - graph.add_edge( - idxmap[(time - 1, tuple(op_h))], - idxmap[(time, tuple(op_g))], - edge, - ) - logging.debug("timelike t=%d (%s, %s)", time, op_g, op_h) - else: # Case (b) - edge = DecodingGraphEdge(qubits=[com[0]], weight=1) - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = 1 - graph.add_edge( - idxmap[(time - 1, tuple(op_h))], - idxmap[(time, tuple(op_g))], - edge, - ) - logging.debug("spacetime hook t=%d (%s, %s)", time, op_g, op_h) - logging.debug(" qubits %s", [com[0]]) - # Add a single time-like edge between boundary vertices at - # time t-1 and t - edge = DecodingGraphEdge(qubits=[], weight=0) - edge.properties["highlighted"] = False - edge.properties["measurement_error"] = 0 - graph.add_edge( - idxmap[(time - 1, tuple(boundary[0]))], idxmap[(time, tuple(boundary[0]))], edge - ) - logging.debug("boundarylink t=%d", time) - - self.idxmap = idxmap - self.node_layers = node_layers - self.graph = graph + Returns: + node_graph (rx.PyGraph): The node graph. + """ + node_graph = self.graph.copy() + for edge, (n0, n1) in zip(self.graph.edges(), self.graph.edge_list()): + if edge.fault_ids: + # is the edge has fault ids, make corresponding logical nodes + # and connect them to these edges + for index in edge.fault_ids: + node2 = DecodingGraphNode(is_logical=True, index=index) + n2 = node_graph.add_node(node2) + node_graph.add_edge(n0, n2, copy.copy(edge)) + node_graph.add_edge(n1, n2, copy.copy(edge)) + for edge in self.graph.edges(): + edge.fault_ids = set() + return node_graph def make_syndrome_graph_from_aer(code, shots=1): """ - Generates a graph and list of hyperedges for a given code by inserting single qubit - Paulis into the base circuit for that code. Also supplied information regarding which + Generates a graph and list of hyperedges for a given code by inserting Pauli errors + around the gates of the base circuit for that code. Also supplied information regarding which edges where generated by which Pauli insertions. Args: @@ -708,7 +566,7 @@ def make_syndrome_graph_from_aer(code, shots=1): for j in range(depth): gate = qc.data[j][0].name qubits = qc.data[j][1] - if gate not in ["measure", "reset", "barrier"]: + if gate not in ["measure", "reset", "barrier"] and len(qubits) != 2: for error in ["x", "y", "z"]: for qubit in qubits: temp_qc = copy.deepcopy(blank_qc) @@ -720,6 +578,28 @@ def make_syndrome_graph_from_aer(code, shots=1): getattr(temp_qc, error)(qubit) temp_qc.data += qc.data[j : depth + 1] error_circuit[temp_qc_name] = temp_qc + elif len(qubits) == 2: + qregs = [] + for qubit in qubits: + for qreg in qc.qregs: + if qubit in qreg: + qregs.append(qreg) + break + for pauli_0 in ["id", "x", "y", "z"]: + for pauli_1 in ["id", "x", "y", "z"]: + if not pauli_0 == pauli_1 == "id": + temp_qc = copy.deepcopy(blank_qc) + temp_qc_name = ( + j, + (qc.qregs.index(qregs[0]), qc.qregs.index(qregs[1])), + (qregs[0].index(qubits[0]), qregs[1].index(qubits[1])), + pauli_0 + "," + pauli_1, + ) + temp_qc.data = qc.data[0:j] + getattr(temp_qc, pauli_0)(qubits[0]) + getattr(temp_qc, pauli_1)(qubits[1]) + temp_qc.data += qc.data[j : depth + 1] + error_circuit[temp_qc_name] = temp_qc elif gate == "measure": pre_error = "x" for post_error in ["id", "x"]: @@ -761,7 +641,7 @@ def make_syndrome_graph_from_aer(code, shots=1): if target != source or (len(nodes) == 1): n0 = graph.nodes().index(source) n1 = graph.nodes().index(target) - if not (source.is_boundary and target.is_boundary): + if not (source.is_logical and target.is_logical): qubits = list(set(source.qubits).intersection(target.qubits)) if source.time != target.time and len(qubits) > 1: qubits = [] diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 4d035f54..4fa96b68 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -23,9 +23,8 @@ from rustworkx import PyGraph, connected_components, distance_matrix -from qiskit_qec.circuits.repetition_code import ArcCircuit from qiskit_qec.decoders.decoding_graph import DecodingGraph -from qiskit_qec.utils import DecodingGraphEdge, DecodingGraphNode +from qiskit_qec.utils import DecodingGraphEdge class ClusteringDecoder(ABC): @@ -40,8 +39,6 @@ def __init__( ): self.code = code_circuit - self.measured_logicals = self.code.measured_logicals() - if hasattr(self.code, "code_index"): self.code_index = self.code.code_index else: @@ -64,41 +61,36 @@ def get_corrections(self, string, clusters): Returns: corrected_logicals (list): A list of integers that are 0 or 1. These are the corrected values of the final transversal - measurement, corresponding to the logical operators of - self.measured_logicals. + measurement, in the same form as given by the code's `string2raw_logicals`. """ # get the list of bulk nodes for each cluster cluster_nodes = {c: [] for c in clusters.values()} for n, c in clusters.items(): node = self.decoding_graph.graph[n] - if not node.is_boundary: + if not node.is_logical: cluster_nodes[c].append(node) # get the list of required logicals for each cluster cluster_logicals = {} for c, nodes in cluster_nodes.items(): _, logical_nodes, _ = self.code.check_nodes(nodes, minimal=True) - z_logicals = [node.qubits[0] for node in logical_nodes] - cluster_logicals[c] = z_logicals + log_indexes = [node.index for node in logical_nodes] + cluster_logicals[c] = log_indexes # get the net effect on each logical - net_z_logicals = {z_logical[0]: 0 for z_logical in self.measured_logicals} - for c, z_logicals in cluster_logicals.items(): - for z_logical in self.measured_logicals: - if z_logical[0] in z_logicals: - net_z_logicals[z_logical[0]] += 1 - for z_logical, num in net_z_logicals.items(): - net_z_logicals[z_logical] = num % 2 + net_logicals = {node.index: 0 for node in self.decoding_graph.logical_nodes} + for c, log_indexes in cluster_logicals.items(): + for log_index in log_indexes: + net_logicals[log_index] += 1 + for log_index, num in net_logicals.items(): + net_logicals[log_index] = num % 2 - corrected_z_logicals = [] - string = string.split(" ")[0] - for z_logical in self.measured_logicals: - raw_logical = int(string[-1 - self.code_index[z_logical[0]]]) - corrected_logical = (raw_logical + net_z_logicals[z_logical[0]]) % 2 - corrected_z_logicals.append(corrected_logical) + corrected_logicals = self.code.string2raw_logicals(string) + for log_index, log_value in enumerate(corrected_logicals): + corrected_logicals[log_index] = (net_logicals[log_index] + int(log_value)) % 2 - return corrected_z_logicals + return corrected_logicals class BravyiHaahDecoder(ClusteringDecoder): @@ -143,7 +135,7 @@ def _cluster(self, ns, dist_max): # check the neutrality of each connected component con_nodes = [cg[n] for n in con_comp] neutral, logicals, num_errors = self.code.check_nodes( - con_nodes, ignore_extra_boundary=True + con_nodes, ignore_extra_logical=True ) # it's fully neutral if no extra logicals are needed @@ -163,15 +155,6 @@ def _cluster(self, ns, dist_max): return clusters, con_comp_dict - def _get_boundary_nodes(self): - boundary_nodes = [] - for element, z_logical in enumerate(self.measured_logicals): - node = DecodingGraphNode(is_boundary=True, qubits=z_logical, index=element) - if isinstance(self.code, ArcCircuit): - node.properties["link qubit"] = None - boundary_nodes.append(node) - return boundary_nodes - def cluster(self, nodes): """ @@ -183,10 +166,10 @@ def cluster(self, nodes): value. """ - # get indices for nodes and boundary nodes + # get indices for nodes and logical nodes dg = self.decoding_graph.graph - ns = set(self.decoding_graph.node_index(node) for node in nodes) - bns = set(self.decoding_graph.node_index(node) for node in self._get_boundary_nodes()) + ns = set(dg.nodes().index(node) for node in nodes) + lns = set(dg.nodes().index(node) for node in self.decoding_graph.logical_nodes) dist_max = 0 final_clusters = {} @@ -194,8 +177,8 @@ def cluster(self, nodes): clusterss = [] while ns and dist_max <= self.code.d: dist_max += 1 - # add boundary nodes to unpaired nodes - ns = set(ns).union(bns) + # add logical nodes to unpaired nodes + ns = set(ns).union(lns) # cluster nodes and contract decoding graph given the current distance clusters, con_comp = self._cluster(ns, dist_max) @@ -205,7 +188,7 @@ def cluster(self, nodes): if c is not None: final_clusters[n] = c else: - if not dg[n].is_boundary: + if not dg[n].is_logical: ns.append(n) con_comps.append(con_comp) clusterss.append(clusters) @@ -225,8 +208,7 @@ def process(self, string, predecoder=None): Returns: corrected_logicals (list): A list of integers that are 0 or 1. These are the corrected values of the final transversal - measurement, corresponding to the logical operators of - self.measured_logicals. + measurement, in the same form as given by the code's `string2raw_logicals`. """ # turn string into nodes and cluster @@ -302,8 +284,8 @@ class UnionFindDecoder(ClusteringDecoder): """ Decoder based on growing clusters around syndrome errors to "convert" them into erasure errors, which can be corrected easily, - by the peeling decoder in case of the surface code, or by checking for - interference with the boundary in case of an abritrary ARC. + by the peeling decoder for compatible codes or by the standard HDRG + method in general. TODO: Add weights to edges of graph according to Huang et al (see. arXiv:2004.04693, section III) @@ -328,10 +310,8 @@ def process(self, string: str, predecoder=None): a list of nodes. Used to do preprocessing on the nodes corresponding to the input string. Returns: - corrected_z_logicals (list): A list of integers that are 0 or 1. - These are the corrected values of the final transversal - measurement, corresponding to the logical operators of - self.z_logicals. + corrected_logicals (list): A list of integers that are 0 or 1. + These are the corrected values of the final logical measurement. """ if self.use_peeling: @@ -345,7 +325,10 @@ def process(self, string: str, predecoder=None): clusters = self._clusters4peeling # determine the net logical z - net_z_logicals = {tuple(z_logical): 0 for z_logical in self.measured_logicals} + measured_logicals = {} + for node in self.decoding_graph.logical_nodes: + measured_logicals[node.index] = node.qubits + net_z_logicals = {tuple(z_logical): 0 for z_logical in measured_logicals.values()} for cluster_nodes, _ in clusters: erasure = self.graph.subgraph(cluster_nodes) flipped_qubits = self.peeling(erasure) @@ -359,7 +342,7 @@ def process(self, string: str, predecoder=None): # apply this to the raw readout corrected_z_logicals = [] raw_logicals = self.code.string2raw_logicals(string) - for j, z_logical in enumerate(self.measured_logicals): + for j, z_logical in measured_logicals.items(): raw_logical = int(raw_logicals[j]) corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2 corrected_z_logicals.append(corrected_logical) @@ -399,8 +382,10 @@ def cluster(self, nodes: List): for node_index in node_indices: self._create_new_cluster(node_index) - while self.odd_cluster_roots: + j = 0 + while self.odd_cluster_roots and j < 2 * self.code.d * (self.code.T + 1): self._grow_and_merge_clusters() + j += 1 # compile info into standard clusters dict clusters = {} @@ -441,7 +426,7 @@ def find(self, u: int) -> int: def _create_new_cluster(self, node_index): node = self.graph[node_index] - if not node.is_boundary: + if not node.is_logical: self.odd_cluster_roots.insert(0, node_index) boundary_edges = [] for edge_index, neighbour, data in self.neighbouring_edges(node_index): @@ -449,8 +434,8 @@ def _create_new_cluster(self, node_index): self.clusters[node_index] = UnionFindDecoderCluster( boundary=boundary_edges, fully_grown_edges=set(), - atypical_nodes=set([node_index]) if not node.is_boundary else set([]), - boundary_nodes=set([node_index]) if node.is_boundary else set([]), + atypical_nodes=set([node_index]) if not node.is_logical else set([]), + boundary_nodes=set([node_index]) if node.is_logical else set([]), nodes=set([node_index]), size=1, ) @@ -493,7 +478,7 @@ def _grow_clusters(self) -> List[FusionEntry]: fully_grown_edges=set(), atypical_nodes=set(), boundary_nodes=set([edge.neighbour_vertex]) - if self.graph[edge.neighbour_vertex].is_boundary + if self.graph[edge.neighbour_vertex].is_logical else set([]), nodes=set([edge.neighbour_vertex]), size=1, @@ -577,8 +562,6 @@ def peeling(self, erasure: PyGraph) -> List[int]: going backwards through the edges of the tree computing the error based on the syndrome. Based on arXiv:1703.01517. - TODO: Extract to a separate decoder. - Args: erasure (PyGraph): subgraph of the syndrome graph that represents the erasure. @@ -590,7 +573,7 @@ def peeling(self, erasure: PyGraph) -> List[int]: # Construct spanning forest # Pick starting vertex for vertex in erasure.node_indices(): - if erasure[vertex].is_boundary and erasure[vertex].properties["syndrome"]: + if erasure[vertex].is_logical and erasure[vertex].properties["syndrome"]: tree.vertices[vertex] = [] break diff --git a/src/qiskit_qec/decoders/hhc_decoder.py b/src/qiskit_qec/decoders/hhc_decoder.py deleted file mode 100644 index 00ba4789..00000000 --- a/src/qiskit_qec/decoders/hhc_decoder.py +++ /dev/null @@ -1,203 +0,0 @@ -"""Object to construct decoders and decode the HHC.""" - -from typing import List, Tuple -import logging - -from qiskit import QuantumCircuit - -from qiskit_qec.decoders.decoding_graph import DecodingGraph -from qiskit_qec.decoders.circuit_matching_decoder import CircuitModelMatchingDecoder -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel -from qiskit_qec.decoders.temp_code_util import temp_syndrome - - -class HHCDecoder(CircuitModelMatchingDecoder): - """Decoder for the heavy-hexagon compass code.""" - - def __init__( - self, - n: int, - css_x_gauge_ops: List[Tuple[int]], - css_x_stabilizer_ops: List[Tuple[int]], - css_x_boundary: List[int], - css_z_gauge_ops: List[Tuple[int]], - css_z_stabilizer_ops: List[Tuple[int]], - css_z_boundary: List[int], - circuit: QuantumCircuit, - model: PauliNoiseModel, - basis: str, - round_schedule: str, - blocks: int, - method: str, - uniform: bool, - decoding_graph: DecodingGraph = None, - annotate: bool = False, - ): - """Create a decoder object.""" - # Sum the total number of bits per round - self.bits_per_round = 0 - self.round_schedule = round_schedule - if not isinstance(self.round_schedule, str): - raise Exception("expected round_schedule to be a string") - if set(self.round_schedule) > set("xz"): - raise Exception("expected round schedule of 'x', 'z' chars") - for step in self.round_schedule: - if step == "z": - # need to include 2 flags per z-gauge outcome - self.bits_per_round += 3 * len(css_z_gauge_ops) - elif step == "x": - self.bits_per_round += len(css_x_gauge_ops) - super().__init__( - n, - css_x_gauge_ops, - css_x_stabilizer_ops, - css_x_boundary, - css_z_gauge_ops, - css_z_stabilizer_ops, - css_z_boundary, - circuit, - model, - basis, - round_schedule, - blocks, - method, - uniform, - decoding_graph, - annotate, - ) - - def _partition_outcomes(self, blocks: int, round_schedule: str, outcome: List[int]): - """Process the raw outcome and return results. - - blocks = number of blocks - round_schedule = string of z and x characters - outcome = list of 0, 1 outcomes - - Return lists x_gauge_outcomes, z_gauge_outcomes, final_outcomes. - """ - # partition the outcome list by outcome type - x_gauge_outcomes = [] - z_gauge_outcomes = [] - left_flag_outcomes = [] - right_flag_outcomes = [] - final_outcomes = [] - for r in range(blocks): - bits_into_round = 0 - for rs in round_schedule: - if rs == "z": - z_gauge_outcomes.append( - outcome[ - r * self.bits_per_round - + bits_into_round : r * self.bits_per_round - + bits_into_round - + len(self.css_z_gauge_ops) - ] - ) - bits_into_round += len(self.css_z_gauge_ops) - left_flag_outcomes.append( - outcome[ - r * self.bits_per_round - + bits_into_round : r * self.bits_per_round - + bits_into_round - + len(self.css_z_gauge_ops) - ] - ) - bits_into_round += len(self.css_z_gauge_ops) - right_flag_outcomes.append( - outcome[ - r * self.bits_per_round - + bits_into_round : r * self.bits_per_round - + bits_into_round - + len(self.css_z_gauge_ops) - ] - ) - bits_into_round += len(self.css_z_gauge_ops) - if rs == "x": - x_gauge_outcomes.append( - outcome[ - r * self.bits_per_round - + bits_into_round : r * self.bits_per_round - + bits_into_round - + len(self.css_x_gauge_ops) - ] - ) - bits_into_round += len(self.css_x_gauge_ops) - final_outcomes = outcome[-self.n :] - # Process the flags - logging.debug("left_flag_outcomes = %s", left_flag_outcomes) - logging.debug("right_flag_outcomes = %s", right_flag_outcomes) - logging.debug("x_gauge_outcomes (before deflag) = %s", x_gauge_outcomes) - x_gauge_outcomes, final_outcomes = self._process_flags( - self.blocks, - self.round_schedule, - x_gauge_outcomes, - left_flag_outcomes, - right_flag_outcomes, - final_outcomes, - ) - logging.debug("x_gauge_outcomes (after deflag) = %s", x_gauge_outcomes) - return x_gauge_outcomes, z_gauge_outcomes, final_outcomes - - def _process_flags( - self, - blocks: int, - round_schedule: str, - x_gauge_outcomes: List[Tuple[int]], - left_flag_outcomes: List[Tuple[int]], - right_flag_outcomes: List[Tuple[int]], - final_outcomes: List[int], - ): - """Process the flag data for a set of outcomes. - - The outcomes are 0-1 lists. - - For each Z gauge measurement, we look at the left and right flag - outcomes. If only the left flag is raised, we apply (in software) - a Z to the qubit in the upper left corner of the plaquette, after - that round of Z gauge measurements. If only the right flag is - raised, we instead apply Z to the qubit in the lower right corner. - These errors are propagate and change the X gauge outcomes - and the final outcomes if measured in the X basis. - - We return a new list of X gauge outcomes and final outcomes. - """ - frame = self.n * [0] # Z error frame - zidx = 0 # index z measurement cycles - xidx = 0 # index x measurement cycles - for _ in range(blocks): - for rs in round_schedule: - if rs == "z": - # Examine the left/right flags and update frame - for j, zg in enumerate(self.css_z_gauge_ops): - # Only consider weight 4 operators - if len(zg) == 4: - if ( - left_flag_outcomes[zidx][j] == 1 - and right_flag_outcomes[zidx][j] == 0 - ): - # upper left qubit - qubit = zg[0] - frame[qubit] ^= 1 - logging.debug("frame, cycle %d -> qubit %d", zidx, qubit) - if ( - left_flag_outcomes[zidx][j] == 0 - and right_flag_outcomes[zidx][j] == 1 - ): - # lower right qubit - qubit = zg[3] - frame[qubit] ^= 1 - logging.debug("frame, cycle %d -> qubit %d", zidx, qubit) - zidx += 1 - if rs == "x": - # Update the X gauge syndromes - syn = temp_syndrome(frame, self.css_x_gauge_ops) - logging.debug("frame syndrome, cycle %d -> %s", xidx, syn) - block = x_gauge_outcomes[xidx] - block = list(u ^ v for u, v in zip(block, syn)) - x_gauge_outcomes[xidx] = block - logging.debug("x gauge update, cycle %d -> %s", xidx, block) - xidx += 1 - # Update the final outcomes if X basis measurement - if self.basis == "x": - final_outcomes = list(u ^ v for u, v in zip(final_outcomes, frame)) - return x_gauge_outcomes, final_outcomes diff --git a/src/qiskit_qec/decoders/pymatching_decoder.py b/src/qiskit_qec/decoders/pymatching_decoder.py new file mode 100644 index 00000000..1e1f24ab --- /dev/null +++ b/src/qiskit_qec/decoders/pymatching_decoder.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name, disable=no-name-in-module, disable=no-member + +"""PyMatching""" +from typing import List, Union +from pymatching import Matching +from qiskit_qec.decoders.decoding_graph import DecodingGraph +from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge + + +class PyMatchingDecoder: + """ + Matching decoder using PyMatching. + """ + + def __init__( + self, + code_circuit, + decoding_graph: DecodingGraph = None, + ): + """Setting up the matching object""" + self.code = code_circuit + if decoding_graph: + self.decoding_graph = decoding_graph + else: + self.decoding_graph = DecodingGraph(self.code) + if self.decoding_graph.logical_nodes: + self.graph = self.decoding_graph.get_edge_graph() + else: + self.graph = self.decoding_graph.graph + self.matcher = self._matching() + self.indexer = None + super().__init__() + + def _matching(self) -> Matching: + return Matching(self.graph) + + def logical_flips(self, syndrome: Union[List[DecodingGraphNode], List[int]]) -> List[int]: + """ + Args: + syndrome: Either a list of DecodingGraphNode objects returnes by string2nodes, + or a list of binaries indicating which node is highlighted, e.g., + the output of a stim detector sampler + Returns: list of binaries indicating which logical is flipped + """ + syndrome_as_nodes = True + for elem in syndrome: + syndrome_as_nodes = syndrome_as_nodes and isinstance(elem, DecodingGraphNode) + if syndrome_as_nodes: + syndrome = self.nodes_to_detections(syndrome) + return self.matcher.decode(syndrome) + + def process(self, string: str) -> List[int]: + """ + Converts qiskit counts string into a list of flipped logicals + Args: counts string + Returns: list of corrected logicals (0 or 1) + """ + nodes = self.code.string2nodes(string) + raw_logicals = self.code.string2raw_logicals(string) + + logical_flips = self.logical_flips(nodes) + + corrected_logicals = [ + (int(raw) + flip) % 2 for raw, flip in zip(raw_logicals, logical_flips) + ] + + return corrected_logicals + + def matched_edges( + self, syndrome: Union[List[DecodingGraphNode], List[int]] + ) -> List[DecodingGraphEdge]: + """ + Args: + syndrome: Either a list of DecodingGraphNode objects returnes by string2nodes, + or a list of binaries indicating which node is highlighted. + Returns: list of DecodingGraphEdge-s included in the matching + """ + if isinstance(syndrome[0], DecodingGraphNode): + syndrome = self.nodes_to_detections(syndrome) + edge_dets = list(self.graph.edge_list()) + edges = self.graph.edges() + matched_det_pairs = self.matcher.decode_to_edges_array(syndrome) + det_pairs = [] + for pair in matched_det_pairs: + if pair[1] == -1: + pair[-1] = pair[-1] + len(self.graph.nodes()) + pair.sort() + det_pairs.append(tuple(pair)) + mached_edges = [edges[edge_dets.index(det_pair)] for det_pair in det_pairs] + return mached_edges + + def nodes_to_detections(self, syndrome_nodes: List[DecodingGraphNode]) -> List[int]: + """Converts nodes to detector indices to be used by pymatching.Matching.decode""" + graph_nodes = self.graph.nodes() + detections = [0] * len(graph_nodes) + for i, node in enumerate(graph_nodes): + if node in syndrome_nodes: + detections[i] = 1 + return detections diff --git a/src/qiskit_qec/decoders/pymatching_matcher.py b/src/qiskit_qec/decoders/pymatching_matcher.py deleted file mode 100644 index 197b0e65..00000000 --- a/src/qiskit_qec/decoders/pymatching_matcher.py +++ /dev/null @@ -1,65 +0,0 @@ -"""PyMatching matching object.""" - -from typing import List, Tuple, Dict, Set -import logging - -import rustworkx as rx -from pymatching import Matching - -from qiskit_qec.exceptions import QiskitQECError -from qiskit_qec.utils.indexer import Indexer -from qiskit_qec.decoders.base_matcher import BaseMatcher -from qiskit_qec.decoders.temp_graph_util import ret2net - - -class PyMatchingMatcher(BaseMatcher): - """Matching subroutines using PyMatching. - - The input rustworkx graph is expected to have the following properties: - edge["weight"] : real edge weight - edge["qubits"] : list of qubit ids associated to edge - vertex["is_boundary"] : bool, true if boundary node - """ - - def __init__(self): - """Create the matcher.""" - self.pymatching = None - self.indexer = None - super().__init__() - - def preprocess(self, graph: rx.PyGraph): - """Create the pymatching object. - Add qubit_id properties to the graph. - """ - self.indexer = Indexer() - nxgraph = ret2net(graph) - for edge in nxgraph.edges(data=True): - if edge[2]["qubits"]: - qset = set() - for q in edge[2]["qubits"]: - qset.add(self.indexer[q]) - edge[2]["qubit_id"] = tuple(qset)[0] if len(qset) == 1 else tuple(qset) - else: - edge[2]["qubit_id"] = -1 - self.pymatching = Matching(nxgraph) - - def find_errors( - self, - graph: rx.PyGraph, - idxmap: Dict[Tuple[int, List[int]], int], - highlighted: List[Tuple[int, Tuple[int]]], - ) -> Tuple[Set[int], Set[Tuple[int, Tuple[int]]]]: - """Process a set of highlighted vertices and return error locations.""" - syndrome = [0] * len(idxmap) - for vertex in highlighted: - syndrome[idxmap[vertex]] = 1 - try: - correction = self.pymatching.decode(syndrome) - except AttributeError as attrib_error: - raise QiskitQECError("Did you call preprocess?") from attrib_error - qubit_errors = [] - for i, corr in enumerate(correction): - if corr == 1: - qubit_errors.append(self.indexer.rlookup(i)) - logging.info("qubit_errors = %s", qubit_errors) - return set(qubit_errors), set() diff --git a/src/qiskit_qec/decoders/repetition_decoder.py b/src/qiskit_qec/decoders/repetition_decoder.py deleted file mode 100644 index 708b3e21..00000000 --- a/src/qiskit_qec/decoders/repetition_decoder.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Matching Decoder for Repetition Codes.""" -from typing import Tuple, List - -from qiskit_qec.decoders.circuit_matching_decoder import CircuitModelMatchingDecoder -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel -from qiskit_qec.decoders.decoding_graph import DecodingGraph - - -class RepetitionDecoder(CircuitModelMatchingDecoder): - """Instance of CircuitModelMatchingDecoder for use with - circuits from RepetitionCodeCircuit. - - Args: - code_circuit: The QEC code circuit object for which this decoder - will be used. - model: Noise model used to generate syndrome graph. - uniform: Whether to use uniform weights for the syndrome graph. - logical: Logical value for the circuit to be used. - """ - - def __init__( - self, - code_circuit, - model: PauliNoiseModel, - method: str, - uniform: bool, - logical: str, - ): - """Constructor.""" - self.code_circuit = code_circuit - dg = DecodingGraph(code_circuit) - super().__init__( - code_circuit.n, - code_circuit.css_x_gauge_ops, - code_circuit.css_x_stabilizer_ops, - code_circuit.css_x_boundary, - code_circuit.css_z_gauge_ops, - code_circuit.css_z_stabilizer_ops, - code_circuit.css_z_boundary, - code_circuit.circuit[logical], - model, - code_circuit.basis, - code_circuit.round_schedule, - code_circuit.blocks, - method, - uniform, - dg, - ) - - def _partition_outcomes( - self, blocks: int, round_schedule: str, outcome: List[int] - ) -> Tuple[List[List[int]], List[List[int]], List[int]]: - """Extract measurement outcomes.""" - return self.code_circuit.partition_outcomes(round_schedule, outcome) diff --git a/src/qiskit_qec/decoders/rustworkx_matcher.py b/src/qiskit_qec/decoders/rustworkx_matcher.py deleted file mode 100644 index 55fd65d4..00000000 --- a/src/qiskit_qec/decoders/rustworkx_matcher.py +++ /dev/null @@ -1,186 +0,0 @@ -"""rustworkx matching object.""" - -import logging -from copy import deepcopy -from typing import Dict, List, Set, Tuple - -import rustworkx as rx -from qiskit_qec.decoders.base_matcher import BaseMatcher -from qiskit_qec.utils import DecodingGraphEdge - - -class RustworkxMatcher(BaseMatcher): - """Matching subroutines using rustworkx. - - The input rustworkx graph is expected to have decoding_graph.Node as the type of the node payload - and decoding_graph.Edge as the type of the edge payload. - - Additionally the edges are expected to have the following properties: - - edge.properties["measurement_error"] (bool): Whether or not the error - corresponds to a measurement error. - - The annotated graph will also have "highlighted" properties on edges and vertices. - """ - - def __init__(self, annotate: bool = False): - """Create the matcher.""" - self.length = {} - self.path = {} - self.annotate = annotate - self.annotated_graph = None - super().__init__() - - def preprocess(self, graph: rx.PyGraph): - """Compute shortest paths between vertex pairs in decoding graph. - - Updates sets self.length and self.path. - """ - - # edge_cost_fn = lambda edge: edge["weight"] - def edge_cost_fn(edge: DecodingGraphEdge): - return edge.weight - - length = rx.all_pairs_dijkstra_path_lengths(graph, edge_cost_fn) - self.length = {s: dict(length[s]) for s in length} - path = rx.all_pairs_dijkstra_shortest_paths(graph, edge_cost_fn) - self.path = {s: {t: list(path[s][t]) for t in path[s]} for s in path} - - def find_errors( - self, - graph: rx.PyGraph, - idxmap: Dict[Tuple[int, List[int]], int], - highlighted: List[Tuple[int, Tuple[int]]], - ) -> Tuple[Set[int], Set[Tuple[int, Tuple[int]]]]: - """Process a set of highlighted vertices and return error locations. - - Be sure to have called recompute_paths if needed. - """ - matching = self._compute_matching(idxmap, highlighted) - logging.info("process: matching = %s", matching) - qubit_errors, measurement_errors = self._compute_error_correction( - graph, idxmap, matching, highlighted - ) - logging.info("process: qubit_errors = %s", qubit_errors) - logging.debug("process: measurement_errors = %s", measurement_errors) - return qubit_errors, measurement_errors - - def _compute_matching( - self, - idxmap: Dict[Tuple[int, List[int]], int], - highlighted: List[Tuple[int, Tuple[int]]], - ) -> Set[Tuple[int, int]]: - """Compute a min. weight perfect matching of highlighted vertices. - - highlighted is a list of highlighted vertices given as tuples - (t, qubit_set). - Return the matching. - """ - gm = rx.PyGraph(multigraph=False) # matching graph - idx = 0 # vertex index in matching graph - midxmap = {} # map from (t, qubit_tuple) to vertex index - for v in highlighted: - gm.add_node({"dvertex": v}) - midxmap[v] = idx - idx += 1 - for i, high_i in enumerate(highlighted): - for j in range(i + 1, len(highlighted)): - vi = midxmap[high_i] - vj = midxmap[highlighted[j]] - vip = idxmap[high_i] - vjp = idxmap[highlighted[j]] - gm.add_edge(vi, vj, {"weight": -self.length[vip][vjp]}) - - def weight_fn(edge): - return int(edge["weight"]) - - matching = rx.max_weight_matching(gm, max_cardinality=True, weight_fn=weight_fn) - return matching - - @staticmethod - def _error_chain_from_vertex_path( - graph: rx.PyGraph, vertex_path: List[int] - ) -> Tuple[Set[int], Set[Tuple[int, Tuple[int]]]]: - """Return a chain of qubit and measurement errors from a vertex path. - - Examine the edges along the path to extract the error chain. - Store error chains as sets and merge using symmetric difference. - The vertex_path is a list of rustworkx node indices. - """ - qubit_errors = set([]) - measurement_errors = set([]) - logging.debug("_error_chain_from_vertex_path %s", vertex_path) - for i in range(len(vertex_path) - 1): - v0 = vertex_path[i] - v1 = vertex_path[i + 1] - if graph.get_edge_data(v0, v1).properties["measurement_error"] == 1: - measurement_errors ^= set( - [(graph.nodes()[v0].time, tuple(graph.nodes()[v0].qubits))] - ) - qubit_errors ^= set(graph.get_edge_data(v0, v1).qubits) - logging.debug( - "_error_chain_for_vertex_path q = %s, m = %s", - qubit_errors, - measurement_errors, - ) - return qubit_errors, measurement_errors - - def _compute_error_correction( - self, - graph: rx.PyGraph, - idxmap: Dict[Tuple[int, List[int]], int], - matching: Set[Tuple[int, int]], - highlighted: List[Tuple[int, Tuple[int]]], - ) -> Tuple[Set[int], Set[Tuple[int, Tuple[int]]]]: - """Compute the qubit and measurement corrections. - - graph : the decoding graph - idxmap : maps (t, qubit_idx) to vertex index - matching : perfect matching computed by _compute_matching - highlighted : list of highlighted vertices - - Returns a tuple of sets, (qubit_errors, measurement_errors) where - qubit_errors contains the indices of qubits with errors and - measurement_errors contains tuples (t, qubit_set) indicating - failed measurements. - """ - used_paths = [] - qubit_errors = set([]) - measurement_errors = set([]) - for p in matching: - v0 = idxmap[highlighted[p[0]]] - v1 = idxmap[highlighted[p[1]]] - # Use the shortest paths between the matched vertices to - # identify all of the qubits in the error chains - path = self.path[v0][v1] - q, m = self._error_chain_from_vertex_path(graph, path) - # Add the error chains modulo two to get the total correction - # (uses set symmetric difference) - qubit_errors ^= q - measurement_errors ^= m - used_paths.append(path) - if self.annotate: - self.annotated_graph = self._make_annotated_graph(graph, used_paths) - return qubit_errors, measurement_errors - - @staticmethod - def _make_annotated_graph(gin: rx.PyGraph, paths: List[List[int]]) -> rx.PyGraph: - """Highlight the vertex paths and return annotated graph. - - gin : decoding graph - paths : list of vertex paths, each given as a list of - vertex indices in the decoding graph. - """ - graph = deepcopy(gin) - for path in paths: - # Highlight the endpoints of the path - for i in [0, -1]: - graph.nodes()[path[i]].properties["highlighted"] = True - # Highlight the edges along the path - for i in range(len(path) - 1): - try: - idx = list(graph.edge_list()).index((path[i], path[i + 1])) - except ValueError: - idx = list(graph.edge_list()).index((path[i + 1], path[i])) - edge = graph.edges()[idx] - edge.properties["highlighted"] = True - return graph diff --git a/src/qiskit_qec/decoders/temp_graph_util.py b/src/qiskit_qec/decoders/temp_graph_util.py index e2fb199f..95c970dd 100644 --- a/src/qiskit_qec/decoders/temp_graph_util.py +++ b/src/qiskit_qec/decoders/temp_graph_util.py @@ -1,11 +1,8 @@ """Temporary module with methods for graphs.""" import json -import os import networkx as nx import rustworkx as rx -from qiskit_qec.utils.decoding_graph_attributes import DecodingGraphEdge, DecodingGraphNode - def ret2net(graph: rx.PyGraph): """Convert rustworkx graph to equivalent networkx graph.""" @@ -36,42 +33,3 @@ def write_graph_to_json(graph: rx.PyGraph, filename: str): from_ret = ret2net(graph) json.dump(nx.node_link_data(from_ret), fp, indent=4, default=str) fp.close() - - -def get_cached_decoding_graph(path): - """ - Returns graph cached in file at path "file" using cache_graph method. - """ - if os.path.isfile(path) and not os.stat(path) == 0: - with open(path, "r+", encoding="utf-8") as file: - json_data = json.loads(file.read()) - net_graph = nx.node_link_graph(json_data) - ret_graph = rx.networkx_converter(net_graph, keep_attributes=True) - for node_index, node in zip(ret_graph.node_indices(), ret_graph.nodes()): - del node["__networkx_node__"] - qubits = node.pop("qubits") - time = node.pop("time") - index = node.pop("index") - is_boundary = node.pop("is_boundary") - properties = node.copy() - node = DecodingGraphNode(is_boundary=is_boundary, time=time, index=index, qubits=qubits) - node.properties = properties - ret_graph[node_index] = node - for edge_index, edge in zip(ret_graph.edge_indices(), ret_graph.edges()): - weight = edge.pop("weight") - qubits = edge.pop("qubits") - properties = edge.copy() - edge = DecodingGraphEdge(weight=weight, qubits=qubits) - edge.properties = properties - ret_graph.update_edge_by_index(edge_index, edge) - return ret_graph - return None - - -def cache_decoding_graph(graph, path): - """ - Cache rustworkx PyGraph to file at path. - """ - net_graph = ret2net(graph) - with open(path, "w+", encoding="utf-8") as file: - json.dump(nx.node_link_data(net_graph), file) diff --git a/src/qiskit_qec/decoders/three_bit_decoder.py b/src/qiskit_qec/decoders/three_bit_decoder.py deleted file mode 100644 index 24667dc3..00000000 --- a/src/qiskit_qec/decoders/three_bit_decoder.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Three bit decoder.""" -from typing import Tuple, List - -from qiskit import QuantumCircuit - -from qiskit_qec.decoders.circuit_matching_decoder import CircuitModelMatchingDecoder -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel - - -class ThreeBitDecoder(CircuitModelMatchingDecoder): - """Simple 3-bit code matching decoder.""" - - def __init__( - self, - n: int, - css_x_gauge_ops: List[Tuple[int]], - css_x_stabilizer_ops: List[Tuple[int]], - css_x_boundary: List[Tuple[int]], - css_z_gauge_ops: List[Tuple[int]], - css_z_stabilizer_ops: List[Tuple[int]], - css_z_boundary: List[Tuple[int]], - circuit: QuantumCircuit, - model: PauliNoiseModel, - basis: str, - round_schedule: str, - blocks: int, - method: str, - uniform: bool, - ): - """Constructor.""" - self.bits_per_round = 2 - super().__init__( - n, - css_x_gauge_ops, - css_x_stabilizer_ops, - css_x_boundary, - css_z_gauge_ops, - css_z_stabilizer_ops, - css_z_boundary, - circuit, - model, - basis, - round_schedule, - blocks, - method, - uniform, - ) - - def _partition_outcomes( - self, blocks: int, round_schedule: str, outcome: List[int] - ) -> Tuple[List[List[int]], List[List[int]], List[int]]: - """Extract measurement outcomes.""" - assert blocks == 2 - assert round_schedule == "z" - x_gauge_outcomes = [] - z_gauge_outcomes = [outcome[0:2], outcome[2:4]] - final_outcomes = outcome[4:7] - return x_gauge_outcomes, z_gauge_outcomes, final_outcomes diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index 0bbe5e52..3e619590 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -18,7 +18,7 @@ Graph used as the basis of decoders. """ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Union, Any, Dict, List, Set, Optional from qiskit_qec.exceptions import QiskitQECError @@ -29,6 +29,7 @@ class DecodingGraphNode: Attributes: - is_boundary (bool): whether or not the node is a boundary node. + - is_logical (bool): whether or not the node is a logical node. - time (int): what syndrome node the node corrsponds to. Doesn't need to be set if it's a boundary node. - qubits (List[int]): List of indices which are stabilized by @@ -38,14 +39,22 @@ class DecodingGraphNode: Are not considered when comparing nodes. """ - def __init__(self, index: int, qubits: List[int] = None, is_boundary=False, time=None) -> None: - if not is_boundary and time is None: + def __init__( + self, + index: int, + qubits: List[int] = None, + is_boundary=False, + is_logical=False, + time=None, + ) -> None: + if not is_boundary and not is_logical and time is None: raise QiskitQECError( - "DecodingGraph node must either have a time or be a boundary node." + "DecodingGraph node must either have a time or be a boundary or logical node." ) self.is_boundary: bool = is_boundary - self.time: Optional[int] = time if not is_boundary else None + self.is_logical: bool = is_logical + self.time: Optional[int] = None if (is_boundary or is_logical) else time self.qubits: List[int] = qubits if qubits else [] self.index: int = index self.properties: Dict[str, Any] = {} @@ -56,14 +65,21 @@ def __getitem__(self, key): elif key in self.properties: return self.properties[key] else: - raise QiskitQECError( + return QiskitQECError( "'" + str(key) + "'" + " is not an an attribute or property of the node." ) - def get(self, key, _): - """A dummy docstring.""" + def get(self, key, default=None): + """Return value for given key.""" # pylint: disable=unnecessary-dunder-call - return self.__getitem__(key) + output = self.__getitem__(key) + if isinstance(output, QiskitQECError): + if default: + return default + else: + raise output + else: + return output def __setitem__(self, key, value): if key in self.__dict__: @@ -79,8 +95,9 @@ def __eq__(self, rhs): self.index == rhs.index and set(self.qubits) == set(rhs.qubits) and self.is_boundary == rhs.is_boundary + and self.is_logical == rhs.is_logical ) - if not self.is_boundary: + if not (self.is_boundary or self.is_logical): result = result and self.time == rhs.time return result @@ -103,12 +120,14 @@ class DecodingGraphEdge: Attributes: - qubits (List[int]): List of indices of code qubits that correspond to this edge. - weight (float): Weight of the edge. + - fault_ids fault_ids: Union[Set[int],List[int]]: In the style of pymatching. - properties (Dict[str, Any]): Decoder/code specific attributes. Are not considered when comparing edges. """ qubits: List[int] weight: float + fault_ids: Union[Set[int], List[int]] = field(default_factory=set) properties: Dict[str, Any] = field(default_factory=dict) def __getitem__(self, key): @@ -117,14 +136,21 @@ def __getitem__(self, key): elif key in self.properties: return self.properties[key] else: - raise QiskitQECError( + return QiskitQECError( "'" + str(key) + "'" + " is not an an attribute or property of the edge." ) - def get(self, key, _): - """A dummy docstring.""" + def get(self, key, default=None): + """Return value for given key.""" # pylint: disable=unnecessary-dunder-call - return self.__getitem__(key) + value = self.__getitem__(key) + if isinstance(value, QiskitQECError): + if default is not None: + return default + else: + raise value + else: + return value def __setitem__(self, key, value): if key in self.__dict__: @@ -147,3 +173,23 @@ def __iter__(self): def __repr__(self): return str(dict(self)) + + +def _nodes2cpp(nodes): + """ + Convert a list of nodes to the form required by C++ functions. + """ + # nodes are a tuple with (q0, q1,t, boundary) + # if there is no q1 or t, -1 is used + cnodes = [] + for node in nodes: + cnode = [] + cnode += node.qubits + cnode += [-1] * (2 - len(node.qubits)) + if node.time is None: + cnode.append(-1) + else: + cnode.append(node.time) + cnode.append(node.is_logical) + cnodes.append(tuple(cnode)) + return cnodes diff --git a/src/qiskit_qec/utils/stim_tools.py b/src/qiskit_qec/utils/stim_tools.py index 9bbb732e..5ae18cb9 100644 --- a/src/qiskit_qec/utils/stim_tools.py +++ b/src/qiskit_qec/utils/stim_tools.py @@ -389,8 +389,8 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: if g.has_edge(dets[0], dets[1]): edge_ind = list(g.edge_list()).index((dets[0], dets[1])) edge_data = g.edges()[edge_ind].properties + old_frame_changes = g.edges()[edge_ind].fault_ids old_p = edge_data["error_probability"] - old_frame_changes = edge_data["fault_ids"] # If frame changes differ, the code has distance 2; just keep whichever was first. if set(old_frame_changes) == set(frame_changes): p = p * (1 - old_p) + old_p * (1 - p) @@ -406,7 +406,8 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: edge = DecodingGraphEdge( qubits=qubits, weight=loga((1 - p) / p), - properties={"fault_ids": set(frame_changes), "error_probability": p}, + fault_ids=set(frame_changes), + properties={"error_probability": p}, ) g.add_edge(dets[0], dets[1], edge) hyperedge[dets[0], dets[1]] = edge @@ -532,10 +533,14 @@ def string2rawlogicals_with_detectors( if all_logicals or str(logical_out) != logical: node = DecodingGraphNode( + is_logical=True, is_boundary=True, - qubits=[], index=index, ) + if "qubits" in logical_op: + node.qubits = logical_op["qubits"] + else: + node.qubits = [] nodes.append(node) return nodes diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 8afb6fba..fc3d4790 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -138,11 +138,7 @@ def test_string2nodes_2(self): (5, 0), "00001", [ - DecodingGraphNode( - is_boundary=True, - qubits=[0], - index=0, - ), + DecodingGraphNode(is_logical=True, is_boundary=True, qubits=[0], index=0), DecodingGraphNode( time=0, qubits=[0, 1], @@ -243,7 +239,7 @@ def single_error_test( string = "".join([str(c) for c in output[::-1]]) nodes = code.string2nodes(string) # check that it doesn't extend over more than two rounds - ts = [node.time for node in nodes if not node.is_boundary] + ts = [node.time for node in nodes if not node.is_logical] if ts: minimal = minimal and (max(ts) - min(ts)) <= 1 # check that it doesn't extend beyond the neigbourhood of a code qubit @@ -254,20 +250,25 @@ def single_error_test( minimal, "Error: Single error creates too many nodes", ) + neutral, flipped_logicals, num = code.check_nodes(nodes) # check that the nodes are neutral - neutral, flipped_logicals, _ = code.check_nodes(nodes) self.assertTrue( neutral and flipped_logicals == [], - "Error: Single error nodes are not neutral: " + string, + "Error: Single error nodes are not neutral for string " + + string + + " which yields " + + str(code.check_nodes(nodes)) + + " for nodes " + + str(nodes), + ) + # and caused by at most a single error + self.assertTrue( + num <= 1, + "Error: Nodes seem to be caused by more than one error for " + + string + + " which yields " + + str(code.check_nodes(nodes)), ) - # and that the given flipped logical makes sense - for node in nodes: - if not node.is_boundary: - for logical in flipped_logicals: - self.assertTrue( - logical in node.qubits, - "Error: Single error appears to flip logical is not part of nodes.", - ) def test_graph_construction(self): """Test single errors for a range of layouts""" @@ -360,7 +361,7 @@ def test_single_error_202s(self): nodes = [ node for node in code.string2nodes(string) - if "conjugate" not in node.properties and not node.is_boundary + if "conjugate" not in node.properties and not node.is_logical ] # require at most two (or three for the trivalent vertex or neighbouring aux) self.assertTrue( @@ -502,9 +503,7 @@ def test_empty_decoding_graph(): """Test initializtion of decoding graphs with None""" DecodingGraph(None) - def clustering_decoder_test( - self, Decoder - ): # NOT run directly by unittest; called by test_graph_constructions + def clustering_decoder_test(self, Decoder): # NOT run directly by unittest """Test decoding of ARCs and RCCs with clustering decoders""" # parameters for test @@ -545,7 +544,6 @@ def clustering_decoder_test( decoder = Decoder(code, decoding_graph=decoding_graph, use_peeling=False) else: decoder = Decoder(code, decoding_graph=decoding_graph) - errors = {z_logical[0]: 0 for z_logical in decoder.measured_logicals} min_error_num = code.d min_error_string = "" for _ in range(N): @@ -555,14 +553,14 @@ def clustering_decoder_test( string = string + " " + "0" * (d - 1) # get and check corrected_z_logicals corrected_z_logicals = decoder.process(string) - for j, z_logical in enumerate(decoder.measured_logicals): - error = corrected_z_logicals[j] != 1 - if error: - error_num = string.split(" ", maxsplit=1)[0].count("0") - if error_num < min_error_num: - min_error_num = error_num - min_error_string = string - errors[z_logical[0]] += error + for node in decoder.decoding_graph.logical_nodes: + if node.index < len(corrected_z_logicals): + error = corrected_z_logicals[node.index] != 1 + if error: + error_num = string.split(" ", maxsplit=1)[0].count("0") + if error_num < min_error_num: + min_error_num = error_num + min_error_string = string # check that min num errors to cause logical errors >d/3 self.assertTrue( min_error_num > d / 3, @@ -573,16 +571,271 @@ def clustering_decoder_test( + str(c) + " with " + min_error_string + + "." + + " Corresponding clusters are " + + str(decoder.cluster(code.string2nodes(string, all_logicals=True))) + ".", ) + def heavy_hex_test(self, Decoder): # NOT run directly by unittest + """Test decoding of heavy hex ARC""" + links = [ + (0, 1, 2), + (2, 3, 4), + (4, 5, 6), + (6, 7, 8), + (8, 9, 10), + (10, 11, 12), + (0, 14, 18), + (4, 15, 22), + (8, 16, 26), + (12, 17, 30), + (18, 19, 20), + (20, 21, 22), + (22, 23, 24), + (24, 25, 26), + (26, 27, 28), + (28, 29, 30), + (30, 31, 32), + (20, 33, 39), + (24, 34, 43), + (28, 35, 47), + (32, 36, 51), + (37, 38, 39), + (39, 40, 41), + (41, 42, 43), + (43, 44, 45), + (45, 46, 47), + (47, 48, 49), + (49, 50, 51), + (37, 52, 56), + (41, 53, 60), + (45, 54, 64), + (49, 55, 68), + (56, 57, 58), + (58, 59, 60), + (60, 61, 62), + (62, 63, 64), + (64, 65, 66), + (66, 67, 68), + (68, 69, 70), + (58, 71, 77), + (62, 72, 81), + (66, 73, 85), + (70, 74, 89), + (75, 76, 77), + (77, 78, 79), + (79, 80, 81), + (81, 82, 83), + (83, 84, 85), + (85, 86, 87), + (87, 88, 89), + (75, 90, 94), + (79, 91, 98), + (83, 92, 102), + (87, 93, 106), + (94, 95, 96), + (96, 97, 98), + (98, 99, 100), + (100, 101, 102), + (102, 103, 104), + (104, 105, 106), + (106, 107, 108), + (96, 109, 114), + (100, 110, 118), + (104, 111, 122), + (108, 112, 126), + (114, 115, 116), + (116, 117, 118), + (118, 119, 120), + (120, 121, 122), + (122, 123, 124), + (124, 125, 126), + ] + schedule = [ + [ + (0, 14), + (2, 3), + (4, 15), + (6, 7), + (8, 16), + (10, 11), + (12, 17), + (18, 19), + (22, 23), + (26, 27), + (30, 31), + (20, 33), + (24, 34), + (28, 35), + (32, 36), + (39, 40), + (43, 44), + (47, 48), + (37, 52), + (41, 53), + (45, 54), + (49, 55), + (56, 57), + (60, 61), + (64, 65), + (68, 69), + (58, 71), + (62, 72), + (66, 73), + (70, 74), + (77, 78), + (81, 82), + (85, 86), + (75, 90), + (79, 91), + (83, 92), + (87, 93), + (94, 95), + (98, 99), + (102, 103), + (106, 107), + (96, 109), + (100, 110), + (104, 111), + (108, 112), + (114, 115), + (118, 119), + (122, 123), + ], + [ + (0, 1), + (4, 5), + (8, 9), + (18, 14), + (22, 15), + (26, 16), + (30, 17), + (20, 21), + (24, 25), + (28, 29), + (39, 33), + (43, 34), + (47, 35), + (51, 36), + (37, 38), + (41, 42), + (45, 46), + (49, 50), + (56, 52), + (60, 53), + (64, 54), + (68, 55), + (58, 59), + (62, 63), + (66, 67), + (77, 71), + (81, 72), + (85, 73), + (89, 74), + (75, 76), + (79, 80), + (83, 84), + (87, 88), + (94, 90), + (98, 91), + (102, 92), + (106, 93), + (96, 97), + (100, 101), + (104, 105), + (114, 109), + (118, 110), + (122, 111), + (126, 112), + (116, 117), + (120, 121), + (124, 125), + ], + [ + (2, 1), + (4, 3), + (6, 5), + (8, 7), + (10, 9), + (12, 11), + (22, 21), + (26, 25), + (30, 29), + (20, 19), + (24, 23), + (28, 27), + (32, 31), + (39, 38), + (43, 42), + (47, 46), + (51, 50), + (41, 40), + (45, 44), + (49, 48), + (60, 59), + (64, 63), + (68, 67), + (58, 57), + (62, 61), + (66, 65), + (70, 69), + (77, 76), + (81, 80), + (85, 84), + (89, 88), + (79, 78), + (83, 82), + (87, 86), + (98, 97), + (102, 101), + (106, 105), + (96, 95), + (100, 99), + (104, 103), + (108, 107), + (118, 117), + (122, 121), + (126, 125), + (116, 115), + (120, 119), + (124, 123), + ], + ] + code = ArcCircuit( + links, 10, schedule=schedule, run_202=False, basis="zx", logical="0", resets=True + ) + if Decoder is UnionFindDecoder: + decoder = Decoder(code, use_peeling=False) + else: + decoder = Decoder(code) + string = ( + "110100001100010110010011110110100011111100000101100101 " + + "01111010000111111110111110111111000010001000111111101110111111101010001 " + + "11110100001000110001011110110111100000111111011011011100011001000110111 " + + "11110100010000010110010100110110000011010110101000010101011100001000111 " + + "11110101010001001010001110001111000011011100111001011100001010001000111 " + + "01010000110001100011010110001110010011000000111010011000000100011010011 " + + "11001000110001110011010011101101010101000110101010010000000000011111011 " + + "11000000100011111001010101101011010101000100111110010000001010001101011 " + + "11000000001000100000010000001011001101000100100111010000110001101111100 " + + "11000001000000001001010001111100010000100111011110000000011000010000101 " + + "01000010000000000000110000001100011000000100000010000010000100010000000" + ) + self.assertTrue( + decoder.process(string)[0] == 0, + "Incorrect decoding for example string with heavy-hex ARC.", + ) + def test_bravyi_haah(self): """Test decoding of ARCs and RCCs with Bravyi Haah""" self.clustering_decoder_test(BravyiHaahDecoder) + self.heavy_hex_test(BravyiHaahDecoder) def test_union_find(self): """Test decoding of ARCs and RCCs with Union Find""" self.clustering_decoder_test(UnionFindDecoder) + self.heavy_hex_test(UnionFindDecoder) if __name__ == "__main__": diff --git a/test/code_circuits/test_surface_codes.py b/test/code_circuits/test_surface_codes.py index 43cf1259..c8cc8b36 100644 --- a/test/code_circuits/test_surface_codes.py +++ b/test/code_circuits/test_surface_codes.py @@ -67,7 +67,7 @@ def test_string2nodes(self): ], [ DecodingGraphNode( - is_boundary=True, + is_logical=True, qubits=[0, 3, 6], index=0, ), @@ -79,7 +79,7 @@ def test_string2nodes(self): ], [ DecodingGraphNode( - is_boundary=True, + is_logical=True, qubits=[2, 5, 8], index=1, ), @@ -117,11 +117,11 @@ def test_string2nodes(self): ), ], [ - DecodingGraphNode(is_boundary=True, qubits=[0, 1, 2], index=0), + DecodingGraphNode(is_logical=True, qubits=[0, 1, 2], index=0), DecodingGraphNode(time=1, qubits=[0, 3], index=0), ], [ - DecodingGraphNode(is_boundary=True, qubits=[8, 7, 6], index=1), + DecodingGraphNode(is_logical=True, qubits=[8, 7, 6], index=1), DecodingGraphNode(time=1, qubits=[5, 8], index=3), ], ] @@ -149,26 +149,26 @@ def test_check_nodes(self): valid = valid and code.check_nodes(nodes) == (True, [], 0) # on one side nodes = [ - DecodingGraphNode(qubits=[0, 1, 2], is_boundary=True, index=0), + DecodingGraphNode(qubits=[0, 1, 2], is_logical=True, index=0), DecodingGraphNode(time=3, qubits=[0, 3], index=0), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) nodes = [DecodingGraphNode(time=3, qubits=[0, 3], index=0)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(time=0, qubits=[0, 1, 2], is_boundary=True, index=0)], + [DecodingGraphNode(qubits=[0, 1, 2], is_logical=True, index=0)], 1.0, ) # and the other nodes = [ - DecodingGraphNode(time=0, qubits=[8, 7, 6], is_boundary=True, index=1), + DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, index=1), DecodingGraphNode(time=3, qubits=[5, 8], index=3), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) nodes = [DecodingGraphNode(time=3, qubits=[5, 8], index=3)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(time=0, qubits=[8, 7, 6], is_boundary=True, index=1)], + [DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, index=1)], 1.0, ) # and in the middle @@ -180,7 +180,7 @@ def test_check_nodes(self): nodes = [DecodingGraphNode(time=3, qubits=[3, 6, 4, 7], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[8, 7, 6], is_boundary=True, index=1)], + [DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, index=1)], 1.0, ) @@ -194,7 +194,7 @@ def test_check_nodes(self): nodes = [DecodingGraphNode(time=3, qubits=[4, 5, 7, 8], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[2, 5, 8], is_boundary=True, index=1)], + [DecodingGraphNode(qubits=[2, 5, 8], is_logical=True, index=1)], 1.0, ) @@ -208,33 +208,33 @@ def test_check_nodes(self): nodes = [DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1)], + [DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1)], 2.0, ) - # wrong boundary + # wrong logical nodes = [ DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), - DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1), ] valid = valid and code.check_nodes(nodes) == ( False, - [DecodingGraphNode(qubits=[0, 1, 2, 3, 4], is_boundary=True, index=0)], + [DecodingGraphNode(qubits=[0, 1, 2, 3, 4], is_logical=True, index=0)], 2, ) - # extra boundary + # extra logical nodes = [ DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), - DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1), ] valid = valid and code.check_nodes(nodes) == (False, [], 0) # ignoring extra nodes = [ DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), - DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_boundary=True, index=1), + DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1), ] - valid = valid and code.check_nodes(nodes, ignore_extra_boundary=True) == (True, [], 1) + valid = valid and code.check_nodes(nodes, ignore_extra_logical=True) == (True, [], 1) self.assertTrue(valid, "A set of nodes did not give the expected outcome for check_nodes.") diff --git a/test/heavy_hex_codes/__init__.py b/test/heavy_hex_codes/__init__.py deleted file mode 100644 index 428fe2e5..00000000 --- a/test/heavy_hex_codes/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2019. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. diff --git a/test/heavy_hex_codes/test_heavy_hex_code.py b/test/heavy_hex_codes/test_heavy_hex_code.py deleted file mode 100644 index e53afe5a..00000000 --- a/test/heavy_hex_codes/test_heavy_hex_code.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Test heavy-hexagon code family definition.""" -import unittest -from qiskit_qec.codes.hhc import HHC - - -class TestHHC(unittest.TestCase): - """Test heavy-hexagon code family.""" - - def test_even_distance(self): - """Only odd distance implemented.""" - with self.assertRaises(Exception): - HHC(4) - - def test_d3(self): - """Check d=3 output.""" - c = HHC(3) - self.assertEqual(c.n, 9) - self.assertEqual(c.k, 1) - self.assertEqual(c.d, 3) - self.assertEqual(c.logical_x, [[0, 1, 2, 3, 4, 5, 6, 7, 8]]) - self.assertEqual(c.logical_z, [[0, 1, 2, 3, 4, 5, 6, 7, 8]]) - self.assertEqual(c.x_boundary, [[0], [1], [2], [6], [7], [8]]) - self.assertEqual(c.z_boundary, [[0], [3], [6], [2], [5], [8]]) - self.assertEqual(c.x_gauges, [[0, 3], [1, 4], [2, 5], [3, 6], [4, 7], [5, 8]]) - self.assertEqual(c.z_gauges, [[0, 1], [1, 2, 4, 5], [3, 4, 6, 7], [7, 8]]) - self.assertEqual(c.x_stabilizers, [[0, 1, 3, 4], [2, 5], [3, 6], [4, 5, 7, 8]]) - self.assertEqual(c.z_stabilizers, [[0, 1, 3, 4, 6, 7], [1, 2, 4, 5, 7, 8]]) - - def test_d5(self): - """Check d=5 output.""" - c = HHC(5) - self.assertEqual(c.n, 25) - self.assertEqual(c.k, 1) - self.assertEqual(c.d, 5) - self.assertEqual( - c.logical_x, - [ - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - ] - ], - ) - self.assertEqual( - c.logical_z, - [ - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - ] - ], - ) - self.assertEqual(c.x_boundary, [[0], [1], [2], [3], [4], [20], [21], [22], [23], [24]]) - self.assertEqual(c.z_boundary, [[0], [5], [10], [15], [20], [4], [9], [14], [19], [24]]) - self.assertEqual( - c.x_gauges, - [ - [0, 5], - [1, 6], - [2, 7], - [3, 8], - [4, 9], - [5, 10], - [6, 11], - [7, 12], - [8, 13], - [9, 14], - [10, 15], - [11, 16], - [12, 17], - [13, 18], - [14, 19], - [15, 20], - [16, 21], - [17, 22], - [18, 23], - [19, 24], - ], - ) - self.assertEqual( - c.z_gauges, - [ - [0, 1], - [2, 3], - [1, 2, 6, 7], - [3, 4, 8, 9], - [5, 6, 10, 11], - [7, 8, 12, 13], - [11, 12, 16, 17], - [13, 14, 18, 19], - [15, 16, 20, 21], - [17, 18, 22, 23], - [21, 22], - [23, 24], - ], - ) - self.assertEqual( - c.x_stabilizers, - [ - [0, 1, 5, 6], - [2, 3, 7, 8], - [4, 9], - [5, 10], - [6, 7, 11, 12], - [8, 9, 13, 14], - [10, 11, 15, 16], - [12, 13, 17, 18], - [14, 19], - [15, 20], - [16, 17, 21, 22], - [18, 19, 23, 24], - ], - ) - self.assertEqual( - c.z_stabilizers, - [ - [0, 1, 5, 6, 10, 11, 15, 16, 20, 21], - [1, 2, 6, 7, 11, 12, 16, 17, 21, 22], - [2, 3, 7, 8, 12, 13, 17, 18, 22, 23], - [3, 4, 8, 9, 13, 14, 18, 19, 23, 24], - ], - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/heavy_hex_codes/test_heavy_hex_decoder.py b/test/heavy_hex_codes/test_heavy_hex_decoder.py deleted file mode 100644 index a02bebf8..00000000 --- a/test/heavy_hex_codes/test_heavy_hex_decoder.py +++ /dev/null @@ -1,322 +0,0 @@ -"""Test a heavy-hexagon code decoder""" -import unittest - -from qiskit_aer import Aer - -from qiskit_qec.codes.hhc import HHC -from qiskit_qec.circuits.hhc_circuit import HHCCircuit -from qiskit_qec.decoders.hhc_decoder import HHCDecoder -from qiskit_qec.decoders.temp_code_util import temp_syndrome -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel -from qiskit_qec.analysis.faultenumerator import FaultEnumerator - - -def make_model(p): - """Make a Pauli model for depolarizing noise.""" - pnm = PauliNoiseModel() - pnm.add_operation( - "cx", - { - "ix": 1, - "iy": 1, - "iz": 1, - "xi": 1, - "xx": 1, - "xy": 1, - "xz": 1, - "yi": 1, - "yx": 1, - "yy": 1, - "yz": 1, - "zi": 1, - "zx": 1, - "zy": 1, - "zz": 1, - }, - ) - pnm.add_operation("id", {"x": 1, "y": 1, "z": 1}) - pnm.add_operation("reset", {"x": 1}) - pnm.add_operation("measure", {"x": 1}) - pnm.add_operation("h", {"x": 1, "y": 1, "z": 1}) - pnm.add_operation("x", {"x": 1, "y": 1, "z": 1}) - pnm.add_operation("y", {"x": 1, "y": 1, "z": 1}) - pnm.add_operation("z", {"x": 1, "y": 1, "z": 1}) - pnm.add_operation("idm", {"x": 1, "y": 1, "z": 1}) - pnm.set_error_probability("cx", p) - pnm.set_error_probability("id", p) - pnm.set_error_probability("reset", p) - pnm.set_error_probability("measure", p) - pnm.set_error_probability("h", p) - pnm.set_error_probability("x", p) - pnm.set_error_probability("y", p) - pnm.set_error_probability("z", p) - pnm.set_error_probability("idm", p) - return pnm - - -class TestHHCDecoder(unittest.TestCase): - """Tests for a heavy-hexagon code decoder.""" - - def setUp(self) -> None: - """Work we can do once.""" - self.model = make_model(0.0001) - - def correct_all_1(self, code, circ, dec, model, good=0, xbasis=False, method="propagator"): - """Test if we can correct all single-location faults.""" - dec.update_edge_weights(model) - fe = FaultEnumerator(circ, method=method, model=model) - failures = 0 - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - if xbasis: - fail = temp_syndrome(corrected_outcomes, [range(code.n)]) - else: - fail = temp_syndrome(corrected_outcomes, [range(code.n)]) - if fail[0] != good: - failures += 1 - print(good, fail, faultpath, corrected_outcomes) - self.assertEqual(failures, 0) - - def no_faults_success(self, code, circ, dec, model, good=0, xbasis=False): - """Test for correct behavior without faults.""" - shots = 10 - seed = 100 - backend = Aer.get_backend("aer_simulator") - options = {"method": "stabilizer", "shots": shots, "seed_simulator": seed} - result = backend.run(circ, **options).result() - counts = result.get_counts(circ) - dec.update_edge_weights(model) - failures = 0 - for outcome, _ in counts.items(): - reversed_outcome = list(map(int, outcome[::-1])) - corrected_outcomes = dec.process(reversed_outcome) - if xbasis: - fail = temp_syndrome(corrected_outcomes, [range(code.n)]) - else: - fail = temp_syndrome(corrected_outcomes, [range(code.n)]) - if fail[0] != good: - print(good, fail, reversed_outcome, corrected_outcomes) - failures += 1 - self.assertEqual(failures, 0) - - def test_d3_2(self): - """Check 3, zx, z, pymatching.""" - blocks = 3 - round_schedule = "zx" - basis = "z" - logical_paulis = "ii" - c = HHC(3) - gen = HHCCircuit( - c, - barriers=True, - idles=True, - distinct_measurement_idle=True, - init_error=True, - group_meas=False, - xprs=False, - blocks=blocks, - round_schedule=round_schedule, - basis=basis, - initial_state="+", - logical_paulis=logical_paulis, - num_initialize=1, - idle_before_measure=False, - ) - circ = gen.syndrome_measurement() - dec = HHCDecoder( - n=c.n, - css_x_gauge_ops=c.x_gauges, - css_x_stabilizer_ops=c.x_stabilizers, - css_x_boundary=c.x_boundary, - css_z_gauge_ops=c.z_gauges, - css_z_stabilizer_ops=c.z_stabilizers, - css_z_boundary=c.z_boundary, - circuit=circ, - model=self.model, - basis=basis, - round_schedule=round_schedule, - blocks=blocks, - method="pymatching", - uniform=False, - ) - self.no_faults_success(c, circ, dec, self.model) - self.correct_all_1(c, circ, dec, self.model) - - def test_d3_3(self): - """Check 3, zx, z, rustworkx.""" - blocks = 3 - round_schedule = "zx" - basis = "z" - logical_paulis = "ii" - c = HHC(3) - gen = HHCCircuit( - c, - barriers=True, - idles=True, - distinct_measurement_idle=True, - init_error=True, - group_meas=False, - xprs=False, - blocks=blocks, - round_schedule=round_schedule, - basis=basis, - initial_state="+", - logical_paulis=logical_paulis, - num_initialize=1, - idle_before_measure=False, - ) - circ = gen.syndrome_measurement() - dec = HHCDecoder( - n=c.n, - css_x_gauge_ops=c.x_gauges, - css_x_stabilizer_ops=c.x_stabilizers, - css_x_boundary=c.x_boundary, - css_z_gauge_ops=c.z_gauges, - css_z_stabilizer_ops=c.z_stabilizers, - css_z_boundary=c.z_boundary, - circuit=circ, - model=self.model, - basis=basis, - round_schedule=round_schedule, - blocks=blocks, - method="rustworkx", - uniform=False, - ) - # self.no_faults_success(c, circ, dec, self.model) - self.correct_all_1(c, circ, dec, self.model) - - def test_d3_5(self): - """Check 1, zxzxzx, z, pymatching, logical=xyzxyz.""" - blocks = 1 - round_schedule = "zxzxzx" - basis = "z" - logical_paulis = "xyzxyz" - c = HHC(3) - gen = HHCCircuit( - c, - barriers=True, - idles=True, - distinct_measurement_idle=True, - init_error=True, - group_meas=False, - xprs=False, - blocks=blocks, - round_schedule=round_schedule, - basis=basis, - initial_state="+", - logical_paulis=logical_paulis, - num_initialize=1, - idle_before_measure=False, - ) - circ = gen.syndrome_measurement() - dec = HHCDecoder( - n=c.n, - css_x_gauge_ops=c.x_gauges, - css_x_stabilizer_ops=c.x_stabilizers, - css_x_boundary=c.x_boundary, - css_z_gauge_ops=c.z_gauges, - css_z_stabilizer_ops=c.z_stabilizers, - css_z_boundary=c.z_boundary, - circuit=circ, - model=self.model, - basis=basis, - round_schedule=round_schedule, - blocks=blocks, - method="pymatching", - uniform=False, - ) - self.no_faults_success(c, circ, dec, self.model) - self.correct_all_1(c, circ, dec, self.model) - - def test_d3_7(self): - """Check 3, zx, z, pymatching, -1 eigenstate.""" - blocks = 3 - round_schedule = "zx" - basis = "z" - logical_paulis = "ii" - c = HHC(3) - gen = HHCCircuit( - c, - barriers=True, - idles=True, - distinct_measurement_idle=True, - init_error=True, - group_meas=False, - xprs=False, - blocks=blocks, - round_schedule=round_schedule, - basis=basis, - initial_state="-", - logical_paulis=logical_paulis, - num_initialize=1, - idle_before_measure=False, - ) - circ = gen.syndrome_measurement() - dec = HHCDecoder( - n=c.n, - css_x_gauge_ops=c.x_gauges, - css_x_stabilizer_ops=c.x_stabilizers, - css_x_boundary=c.x_boundary, - css_z_gauge_ops=c.z_gauges, - css_z_stabilizer_ops=c.z_stabilizers, - css_z_boundary=c.z_boundary, - circuit=circ, - model=self.model, - basis=basis, - round_schedule=round_schedule, - blocks=blocks, - method="pymatching", - uniform=False, - ) - self.no_faults_success(c, circ, dec, self.model, 1) - # The propagator method does not treat Paulis in circ as - # errors, so the observed outcomes are not flipped (0 in arguments) - self.correct_all_1(c, circ, dec, self.model, 0, False, "propagator") - - def test_d3_10(self): - """Check 3, zx, x, pymatching.""" - blocks = 3 - round_schedule = "zx" - basis = "x" - logical_paulis = "ii" - c = HHC(3) - gen = HHCCircuit( - c, - barriers=True, - idles=True, - distinct_measurement_idle=True, - init_error=True, - group_meas=False, - xprs=False, - blocks=blocks, - round_schedule=round_schedule, - basis=basis, - initial_state="+", - logical_paulis=logical_paulis, - num_initialize=1, - idle_before_measure=False, - ) - circ = gen.syndrome_measurement() - dec = HHCDecoder( - n=c.n, - css_x_gauge_ops=c.x_gauges, - css_x_stabilizer_ops=c.x_stabilizers, - css_x_boundary=c.x_boundary, - css_z_gauge_ops=c.z_gauges, - css_z_stabilizer_ops=c.z_stabilizers, - css_z_boundary=c.z_boundary, - circuit=circ, - model=self.model, - basis=basis, - round_schedule=round_schedule, - blocks=blocks, - method="pymatching", - uniform=True, - ) - self.no_faults_success(c, circ, dec, self.model, 0, True) - self.correct_all_1(c, circ, dec, self.model, 0, True) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/matching/__init__.py b/test/matching/__init__.py deleted file mode 100644 index 428fe2e5..00000000 --- a/test/matching/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2019. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. diff --git a/test/matching/test_circuitmatcher.py b/test/matching/test_circuitmatcher.py deleted file mode 100644 index d15c7a8e..00000000 --- a/test/matching/test_circuitmatcher.py +++ /dev/null @@ -1,355 +0,0 @@ -"""Tests for the subsystem CSS circuit-level matching decoder.""" -import unittest - -from qiskit import QuantumCircuit -from qiskit_aer import Aer - -from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.decoders.circuit_matching_decoder import temp_syndrome -from qiskit_qec.decoders.three_bit_decoder import ThreeBitDecoder -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel - - -class TestCircuitMatcher(unittest.TestCase): - """Tests for the three bit decoder example.""" - - def setUp(self) -> None: - # Bit-flip circuit noise model - p = 0.05 - pnm = PauliNoiseModel() - pnm.add_operation("cx", {"ix": 1, "xi": 1, "xx": 1}) - pnm.add_operation("id", {"x": 1}) - pnm.add_operation("reset", {"x": 1}) - pnm.add_operation("measure", {"x": 1}) - pnm.add_operation("x", {"x": 1, "y": 1, "z": 1}) - pnm.set_error_probability("cx", p) - pnm.set_error_probability("x", p) - pnm.set_error_probability("id", p) - pnm.set_error_probability("reset", p) - pnm.set_error_probability("measure", p) - - # 3-bit repetition code - x_stabilizers = [] - z_stabilizers = [[0, 1], [1, 2]] - x_logical = [[0, 1, 2]] - z_logical = [[0]] - x_boundary = [] - z_boundary = [[0], [2]] - - # Construct the 3-bit syndrome measurement circuit - qc = QuantumCircuit(4, 7) - qc.reset(0) - qc.reset(1) - qc.reset(2) - for i in range(2): - qc.reset(3) - qc.cx(0, 3) - qc.cx(1, 3) - qc.measure(3, 0 + 2 * i) - qc.reset(3) - qc.cx(1, 3) - qc.cx(2, 3) - qc.measure(3, 1 + 2 * i) - qc.measure(0, 4) - qc.measure(1, 5) - qc.measure(2, 6) - - self.pnm = pnm - self.qc = qc - self.x_stabilizers = x_stabilizers - self.x_boundary = x_boundary - self.z_stabilizers = z_stabilizers - self.z_boundary = z_boundary - self.z_logical = z_logical - self.x_logical = x_logical - - def test_no_errors(self): - """Test the case with no errors using rustworkx.""" - shots = 100 - seed = 100 - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "rustworkx", - False, - ) - backend = Aer.get_backend("aer_simulator") - options = {"method": "stabilizer", "shots": shots, "seed_simulator": seed} - result = backend.run(self.qc, **options).result() - counts = result.get_counts(self.qc) - dec.update_edge_weights(self.pnm) - failures = 0 - for outcome, _ in counts.items(): - reversed_outcome = list(map(int, outcome[::-1])) - corrected_outcomes = dec.process(reversed_outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - failures += fail[0] - self.assertEqual(failures, 0) - - def test_no_errors_pymatching(self): - """Test the case with no errors using pymatching.""" - shots = 100 - seed = 100 - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "pymatching", - False, - ) - backend = Aer.get_backend("aer_simulator") - options = {"method": "stabilizer", "shots": shots, "seed_simulator": seed} - result = backend.run(self.qc, **options).result() - counts = result.get_counts(self.qc) - dec.update_edge_weights(self.pnm) - failures = 0 - for outcome, _ in counts.items(): - reversed_outcome = list(map(int, outcome[::-1])) - corrected_outcomes = dec.process(reversed_outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - failures += fail[0] - self.assertEqual(failures, 0) - - def test_correct_single_errors(self): - """Test the case with single faults using rustworkx.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "rustworkx", - False, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, method="stabilizer", model=self.pnm) - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - self.assertEqual(fail[0], 0) - - def test_correct_single_errors_uniform(self): - """Test the case with single faults using rustworkx.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "rustworkx", - True, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, method="stabilizer", model=self.pnm) - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - self.assertEqual(fail[0], 0) - - def test_correct_single_errors_pymatching(self): - """Test the case with single faults using pymatching.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "pymatching", - False, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, method="stabilizer", model=self.pnm) - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - self.assertEqual(fail[0], 0) - - def test_correct_single_errors_pymatching_uniform(self): - """Test the case with single faults using pymatching.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "pymatching", - True, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, method="stabilizer", model=self.pnm) - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - self.assertEqual(fail[0], 0) - - def test_error_pairs(self): - """Test the case with two faults using rustworkx.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "rustworkx", - False, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, order=2, method="stabilizer", model=self.pnm) - failures = 0 - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - failures += fail[0] - self.assertEqual(failures, 140) - - def test_error_pairs_uniform(self): - """Test the case with two faults using rustworkx.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "rustworkx", - True, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, order=2, method="stabilizer", model=self.pnm) - failures = 0 - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - failures += fail[0] - self.assertEqual(failures, 152) - - def test_error_pairs_propagator_pymatching(self): - """Test the case with two faults using error propagator and pymatching.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "pymatching", - False, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, order=2, method="propagator", model=self.pnm) - failures = 0 - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - failures += fail[0] - self.assertEqual(failures, 128) - - def test_error_pairs_propagator_pymatching_uniform(self): - """Test the case with two faults using error propagator and pymatching.""" - dec = ThreeBitDecoder( - 3, - self.x_stabilizers, - self.x_stabilizers, - self.x_boundary, - self.z_stabilizers, - self.z_stabilizers, - self.z_boundary, - self.qc, - self.pnm, - "z", - "z", - 2, - "pymatching", - True, - ) - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(self.qc, order=2, method="propagator", model=self.pnm) - failures = 0 - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - failures += fail[0] - # For pymatching v0.x, there are 168 failures, whereas for pymatching >v2.0.0 there are 156. - # The reason for the difference is that many of these test cases have degenerate solutions - # (Both versions of pymatching are giving valid minimum-weight perfect matching solutions, but - # the predictions they make are not always the same when there is more than one valid - # minimum-weight solution.) - self.assertTrue(failures in {156, 168}) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/matching/test_matching.py b/test/matching/test_matching.py new file mode 100644 index 00000000..d3671a56 --- /dev/null +++ b/test/matching/test_matching.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2019. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name + +"""Run codes and decoders.""" + +import itertools +import unittest +from random import choices + +from qiskit_qec.circuits import ArcCircuit, RepetitionCodeCircuit +from qiskit_qec.decoders import PyMatchingDecoder +from qiskit_qec.codes.hhc import HHC +from qiskit_qec.circuits.css_code import CSSCodeCircuit +from qiskit_qec.utils.stim_tools import get_counts_via_stim + + +class TestMatching(unittest.TestCase): + """Test the PyMatching decoder.""" + + def test_repetition_codes(self, code): + """ + Test on repetition codes. + """ + + d = 8 + p = 0.1 + N = 1000 + + codes = [] + codes.append(RepetitionCodeCircuit(8, 1)) + codes.append(ArcCircuit([(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 1)], 1)) + codes.append( + ArcCircuit([(2 * j, 2 * j + 1, (2 * (j + 1)) % (2 * d - 2)) for j in range(d - 1)], 1) + ) + for c, code in enumerate(codes): + matcher = PyMatchingDecoder(code) + min_error_num = code.d + min_error_string = "" + for _ in range(N): + string = "".join([choices(["1", "0"], [1 - p, p])[0] for _ in range(d)]) + for _ in range(code.T): + string = string + " " + "0" * (d - 1) + # get and check corrected_z_logicals + corrected_z_logicals = matcher.process(string) + for node in matcher.decoding_graph.logical_nodes: + if node.index < len(corrected_z_logicals): + error = corrected_z_logicals[node.index] != 1 + if error: + error_num = string.split(" ", maxsplit=1)[0].count("0") + if error_num < min_error_num: + min_error_num = error_num + min_error_string = string + # check that min num errors to cause logical errors >d/2 + self.assertTrue( + min_error_num >= d / 2, + str(min_error_num) + + " errors cause logical error despite d=" + + str(code.d) + + " for code " + + str(c) + + " with " + + min_error_string + + ".", + ) + print(c, min_error_num, min_error_string) + + def test_css_codes(self, code): + """ + Test on CSS codes. + """ + d = 5 + p = 0.01 + N = 1000 + + codes = [CSSCodeCircuit(HHC(d), T=1, basis="x", noise_model=(p, p))] + for c, code in enumerate(codes): + matcher = PyMatchingDecoder(code) + stim_counts = get_counts_via_stim(code.noisy_circuit["1"], N) + errors = 0 + for string in stim_counts: + corrected_z_logicals = matcher.process(string) + if corrected_z_logicals: + if corrected_z_logicals[0] != 1: + errors += 1 + self.assertTrue( + errors / N < p, + "Logical error rate of" + + str(errors / N) + + " despite p=" + + str(p) + + " for code " + + str(c) + + ".", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/matching/test_pymatchingmatcher.py b/test/matching/test_pymatchingmatcher.py deleted file mode 100644 index e9378e1e..00000000 --- a/test/matching/test_pymatchingmatcher.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Tests for the rustworkx matcher subroutines.""" -import unittest - -from typing import Dict, Tuple -import rustworkx as rx -from qiskit_qec.decoders.pymatching_matcher import PyMatchingMatcher -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge - - -class TestPyMatchingMatcher(unittest.TestCase): - """Tests for the pymatching matcher subroutines.""" - - @staticmethod - def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: - """Make a basic decoding graph. - - 4 -- 0 -- 1 -- 2 -- 3 -- (4) - """ - graph = rx.PyGraph(multigraph=False) - idxmap = {} - for i, q in enumerate([[0, 1], [1, 2], [2, 3], [3, 4]]): - node = DecodingGraphNode(time=0, qubits=q, index=i) - node.properties["highlighted"] = False - graph.add_node(node) - idxmap[(0, tuple(q))] = i - node = {"time": 0, "qubits": [], "highlighted": False, "is_boundary": True} - node = DecodingGraphNode(is_boundary=True, qubits=[], index=0) - node.properties["highlighted"] = False - graph.add_node(node) - idxmap[(0, tuple([]))] = 4 - for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = DecodingGraphEdge( - qubits=dat[0], - weight=1, - ) - edge.properties["measurement_error"] = False - edge.properties["highlighted"] = False - graph.add_edge(dat[1], dat[2], edge) - return graph, idxmap - - def setUp(self) -> None: - self.m = PyMatchingMatcher() - - def test_match(self): - """Test matching example.""" - graph, idxmap = self.make_test_graph() - self.m.preprocess(graph) - highlighted = [(0, (0, 1)), (0, (1, 2)), (0, (3, 4)), (0, ())] # must be even - qubit_errors, measurement_errors = self.m.find_errors(graph, idxmap, highlighted) - self.assertEqual(qubit_errors, set([1, 4])) - self.assertEqual(measurement_errors, set()) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/matching/test_repetitionmatcher.py b/test/matching/test_repetitionmatcher.py deleted file mode 100644 index 6de89203..00000000 --- a/test/matching/test_repetitionmatcher.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Tests for the subsystem CSS circuit-level matching decoder.""" -import unittest - -from qiskit_aer import Aer - -from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.decoders.circuit_matching_decoder import temp_syndrome -from qiskit_qec.decoders.repetition_decoder import RepetitionDecoder -from qiskit_qec.circuits.repetition_code import RepetitionCodeCircuit -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel - - -class TestRepetitionCircuitMatcher(unittest.TestCase): - """Tests for the three bit decoder example.""" - - def setUp(self) -> None: - # Bit-flip circuit noise model - p = 0.05 - pnm = PauliNoiseModel() - pnm.add_operation("cx", {"ix": 1, "xi": 1, "xx": 1}) - pnm.add_operation("id", {"x": 1}) - pnm.add_operation("reset", {"x": 1}) - pnm.add_operation("measure", {"x": 1}) - pnm.add_operation("x", {"x": 1, "y": 1, "z": 1}) - pnm.set_error_probability("cx", p) - pnm.set_error_probability("x", p) - pnm.set_error_probability("id", p) - pnm.set_error_probability("reset", p) - pnm.set_error_probability("measure", p) - self.pnm = pnm - - # 3-bit, 2 round repetition code - self.code_circuit = RepetitionCodeCircuit(3, 2) - self.z_logical = self.code_circuit.css_z_logical - - # 5-bit, 2 round repetition code - self.code_circuit_5 = RepetitionCodeCircuit(5, 2) - self.z_logical_5 = self.code_circuit.css_z_logical - - def test_no_errors(self, method="rustworkx"): - """Test the case with no errors using rustworkx.""" - - def gint(c): - """Casts to int if possible""" - if c.isnumeric(): - return int(c) - else: - return c - - shots = 100 - seed = 100 - for logical in ["0", "1"]: - dec = RepetitionDecoder(self.code_circuit, self.pnm, method, False, logical) - qc = self.code_circuit.circuit[logical] - backend = Aer.get_backend("aer_simulator") - options = {"method": "stabilizer", "shots": shots, "seed_simulator": seed} - result = backend.run(qc, **options).result() - counts = result.get_counts(qc) - dec.update_edge_weights(self.pnm) - failures = 0 - for outcome, _ in counts.items(): - reversed_outcome = list(map(gint, outcome[::-1])) - corrected_outcomes = dec.process(reversed_outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - failures += str(fail[0]) != logical - self.assertEqual(failures, 0) - - def test_no_errors_pymatching(self): - """Test the case with no errors using pymatching.""" - self.test_no_errors(method="pymatching") - - def test_correct_single_errors(self, method="rustworkx"): - """Test the case with single faults using rustworkx.""" - for logical in ["0", "1"]: - dec = RepetitionDecoder(self.code_circuit, self.pnm, method, False, logical) - qc = self.code_circuit.circuit[logical] - dec.update_edge_weights(self.pnm) - fe = FaultEnumerator(qc, method="stabilizer", model=self.pnm) - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical) - self.assertEqual(str(fail[0]), logical) - - def test_correct_single_errors_pymatching(self): - """Test the case with two faults using pymatching.""" - self.test_correct_single_errors(method="pymatching") - - def test_error_pairs(self, dec_method="rustworkx", fe_method="stabilizer"): - """Test the case with two faults on a d=5 code using rustworkx.""" - expected_failures = {"0": 0, "1": 0} - for logical in ["0", "1"]: - dec = RepetitionDecoder(self.code_circuit_5, self.pnm, dec_method, False, logical) - dec.update_edge_weights(self.pnm) - qc = self.code_circuit_5.circuit[logical] - fe = FaultEnumerator(qc, order=2, method=fe_method, model=self.pnm) - failures = 0 - for faultpath in fe.generate(): - outcome = faultpath[3] - corrected_outcomes = dec.process(outcome) - fail = temp_syndrome(corrected_outcomes, self.z_logical_5) - failures += str(fail[0]) != logical - self.assertEqual(failures, expected_failures[logical]) - - def test_error_pairs_pymatching(self): - """Test the case with two faults on a d=5 code using pymatching.""" - self.test_error_pairs(dec_method="pymatching", fe_method="stabilizer") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/matching/test_retworkxmatcher.py b/test/matching/test_retworkxmatcher.py deleted file mode 100644 index 0f9f16ca..00000000 --- a/test/matching/test_retworkxmatcher.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Tests for the rustworkx matcher subroutines.""" -import unittest - -from typing import Dict, Tuple -import rustworkx as rx -from qiskit_qec.decoders.rustworkx_matcher import RustworkxMatcher -from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge - - -class TestRustworkxMatcher(unittest.TestCase): - """Tests for the rustworkx matcher subroutines.""" - - @staticmethod - def make_test_graph() -> Tuple[rx.PyGraph, Dict[Tuple[int, Tuple[int]], int]]: - """Make a basic decoding graph. - - 4 -- 0 -- 1 -- 2 -- 3 -- (4) - """ - graph = rx.PyGraph(multigraph=False) - idxmap = {} - basic_config = [[0, 1], [1, 2], [2, 3], [3, 4]] - for i, q in enumerate(basic_config): - node = DecodingGraphNode(time=0, qubits=q, index=i) - node.properties["highlighted"] = False - graph.add_node(node) - idxmap[(0, tuple(q))] = i - node = DecodingGraphNode(time=0, qubits=[], index=len(basic_config) + 1) - node.properties["highlighted"] = False - graph.add_node(node) - idxmap[(0, tuple([]))] = 4 - for dat in [[[0], 0, 4], [[1], 0, 1], [[2], 1, 2], [[3], 2, 3], [[4], 3, 4]]: - edge = DecodingGraphEdge(qubits=dat[0], weight=1) - edge.properties["measurement_error"] = False - edge.properties["highlighted"] = False - graph.add_edge(dat[1], dat[2], edge) - return graph, idxmap - - def setUp(self) -> None: - self.rxm = RustworkxMatcher(annotate=True) - - def test_preprocess(self): - """Test preprocessing example.""" - graph, _ = self.make_test_graph() - self.rxm.preprocess(graph) - self.assertEqual(self.rxm.length[0][1], 1) - self.assertEqual(self.rxm.length[1][3], 2) - self.assertEqual(self.rxm.length[2][4], 2) - self.assertEqual(self.rxm.length[0][3], 2) - self.assertEqual(self.rxm.path[0][2], [0, 1, 2]) - self.assertEqual(self.rxm.path[3][0], [3, 4, 0]) - - def test_match(self): - """Test matching example.""" - graph, idxmap = self.make_test_graph() - self.rxm.preprocess(graph) - highlighted = [(0, (0, 1)), (0, (1, 2)), (0, (3, 4)), (0, ())] # must be even - qubit_errors, measurement_errors = self.rxm.find_errors(graph, idxmap, highlighted) - self.assertEqual(qubit_errors, set([1, 4])) - self.assertEqual(measurement_errors, set()) - - def test_annotate(self): - """Test the annotated graph.""" - graph, idxmap = self.make_test_graph() - self.rxm.preprocess(graph) - highlighted = [(0, (0, 1)), (0, (1, 2)), (0, (3, 4)), (0, ())] # must be even - self.rxm.find_errors(graph, idxmap, highlighted) - self.assertEqual(self.rxm.annotated_graph[0].properties["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[1].properties["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[2].properties["highlighted"], False) - self.assertEqual(self.rxm.annotated_graph[3].properties["highlighted"], True) - self.assertEqual(self.rxm.annotated_graph[4].properties["highlighted"], True) - eim = self.rxm.annotated_graph.edge_index_map() - self.assertEqual(eim[0][2].properties["highlighted"], False) - self.assertEqual(eim[1][2].properties["highlighted"], True) - self.assertEqual(eim[2][2].properties["highlighted"], False) - self.assertEqual(eim[3][2].properties["highlighted"], False) - self.assertEqual(eim[4][2].properties["highlighted"], True) - - -if __name__ == "__main__": - unittest.main()