diff --git a/src/qiskit_qec/circuits/arctools.h b/src/qiskit_qec/circuits/arctools.h index 9835abad..2b6d32fb 100644 --- a/src/qiskit_qec/circuits/arctools.h +++ b/src/qiskit_qec/circuits/arctools.h @@ -7,19 +7,19 @@ #include std::vector check_nodes( - std::vector> nodes, bool ignore_extra_logicals, bool minimal, + std::vector> nodes, bool ignore_extras, bool minimal, std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, - std::vector z_logicals + std::map extras ); bool is_cluster_neutral( - std::vector> nodes, bool ignore_extra_logicals, bool minimal, + std::vector> nodes, bool ignore_extras, bool minimal, std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, - std::vector z_logicals, + std::map extras, bool linear ); diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py index d6bb28b3..0842e4e9 100644 --- a/src/qiskit_qec/circuits/code_circuit.py +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -53,7 +53,7 @@ def string2nodes(self, string, **kwargs): pass @abstractmethod - def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): + def check_nodes(self, nodes, ignore_extras=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -61,7 +61,7 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_logical (bool): If `True`, undeeded logical nodes are + ignore_extras (bool): If `True`, undeeded logical and boundary nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. @@ -76,10 +76,10 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): pass @abstractmethod - def is_cluster_neutral(self, atypical_nodes): + def is_cluster_neutral(self, nodes): """ Determines whether or not the cluster is neutral, meaning that one or more - errors could have caused the set of atypical nodes (syndrome changes) passed + errors could have caused the set of nodes (syndrome changes) passed to the method. Default version here assumes that it is as simple as an an even/odd assessment @@ -87,6 +87,6 @@ def is_cluster_neutral(self, atypical_nodes): more complex codes. It also should be used with care, by only supplying sets of nodes for which the even/odd assessment is valid. Args: - atypical_nodes (dictionary in the form of the return value of string2nodes) + nodes (dictionary in the form of the return value of string2nodes) """ - return not bool(len(atypical_nodes) % 2) + return not bool(len(nodes) % 2) diff --git a/src/qiskit_qec/circuits/css_code.py b/src/qiskit_qec/circuits/css_code.py index 67a8f285..5f0699d6 100644 --- a/src/qiskit_qec/circuits/css_code.py +++ b/src/qiskit_qec/circuits/css_code.py @@ -279,10 +279,10 @@ def string2raw_logicals(self, string): log_outs = string2logical_meas(string, self.logicals, self.circuit["0"].clbits) return log_outs - def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): + def check_nodes(self, nodes, ignore_extras=False, minimal=False): raise NotImplementedError - def is_cluster_neutral(self, atypical_nodes): + def is_cluster_neutral(self, nodes): raise NotImplementedError def stim_detectors(self): diff --git a/src/qiskit_qec/circuits/intern/arctools.cpp b/src/qiskit_qec/circuits/intern/arctools.cpp index af7c08b6..b9372f0a 100644 --- a/src/qiskit_qec/circuits/intern/arctools.cpp +++ b/src/qiskit_qec/circuits/intern/arctools.cpp @@ -2,30 +2,34 @@ #include bool is_cluster_neutral( - std::vector> nodes, bool ignore_extra_logicals, bool minimal, - std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, std::vector z_logicals, + std::vector> nodes, bool ignore_extras, bool minimal, + std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, std::map extras, bool linear ) { if (linear) { return nodes.size()%2==0; } else { - std::vector output = check_nodes(nodes, ignore_extra_logicals, minimal, cycle_dict, link_graph, link_neighbors, z_logicals); - return (output[0]==1) and (output.size()==2); + std::vector output = check_nodes(nodes, ignore_extras, minimal, cycle_dict, link_graph, link_neighbors, extras); + bool no_boundary = true; + for (int j = 2; j < output.size(); j++) { + no_boundary = no_boundary and (extras[output[j]]<2); + }; + return (output[0]==1) and no_boundary; } - }; std::vector check_nodes( - std::vector> nodes, bool ignore_extra_logicals, bool minimal, - std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, std::vector z_logicals + std::vector> nodes, bool ignore_extras, bool minimal, + std::map, std::set> cycle_dict, std::vector> link_graph, std::map> link_neighbors, std::map extras ) { - // output[0] is neutral (as int), output[1] is num_errors, rest is list of given logicals + // output[0] is neutral (as int), output[1] is num_errors, rest is list of given extras std::vector output; - // we convert to flat nodes, which are a std::tuple with (q0, q1, boundary) + // we convert to flat nodes, which are a std::tuple with (q0, q1, boundary/logical) + // here boundary/logical is 0 if the node is neither boundary or logical, 1 for logical, 2 for bounary and 3 for both // if we have an even number of corresponding nodes, they cancel - std::map, int> node_counts; + std::map, int> node_counts; for (auto & node : nodes) { node_counts[std::make_tuple(std::get<0>(node), std::get<1>(node), std::get<3>(node))] = 0; } @@ -39,12 +43,12 @@ std::vector check_nodes( flat_nodes.push_back(node_count.first); } } - // see what logicals and bulk nodes are given - std::set given_logicals; + // see what extras and bulk nodes are given + std::set given_extras; std::set> bulk_nodes; for (auto & node : flat_nodes) { - if (std::get<2>(node)) { - given_logicals.insert(std::get<0>(node)); + if (std::get<2>(node) > 0) { + given_extras.insert(std::get<0>(node)); } else { bulk_nodes.insert(node); } @@ -52,12 +56,12 @@ std::vector check_nodes( if (bulk_nodes.size()==0){ // without bulk nodes, neutral only if no logical nodes are given (unless this is ignored) - int neutral = (ignore_extra_logicals || given_logicals.size()==0); + int neutral = (ignore_extras || given_extras.size()==0); int num_errors = 0; // compile the output output.push_back(neutral); output.push_back(num_errors); - // no flipped logicals need to be added + // no flipped extras need to be added } else { std::map parities; // check how many times the bulk nodes turn up in each cycle @@ -79,7 +83,7 @@ std::vector check_nodes( output.push_back(0); // number of errors not counted output.push_back(-1); - // no flipped logicals need to be added + // no flipped extras need to be added } else { // now we must bicolor the qubits of the link graph, such that node edges connect unlike edges @@ -163,53 +167,53 @@ std::vector check_nodes( if (not minimal){ cs.push_back((min_color+1)%2); } - // determine which flipped logicals correspond to which colour - std::vector> color_logicals = {{}, {}}; - for (auto & q: z_logicals){ - if (color.find(q) == color.end()){ - color[q] = (conv_color+1)%2; + // determine which flipped extras correspond to which colour + std::vector> color_extras = {{}, {}}; + for (auto & qe: extras){ + if (color.find(qe.first) == color.end()){ + color[qe.first] = (conv_color+1)%2; } - color_logicals[color[q]].insert(q); + color_extras[color[qe.first]].insert(qe.first); } - // see if we can find a color for which we have no extra logicals - // and see what additional logicals are required - std::set flipped_logicals; - std::set flipped_ng_logicals; - std::vector extra_logicals; + // see if we can find a color for which we have no extra extras + // and see what additional extras are required + std::set flipped_extras; + std::set flipped_ng_extras; + std::vector extra_extras; bool done = false; int j = 0; while (not done){ - flipped_logicals = {}; - flipped_ng_logicals = {}; - // see which logicals for this colour have been flipped - for (auto & q: color_logicals[cs[j]]){ - flipped_logicals.insert(q); + flipped_extras = {}; + flipped_ng_extras = {}; + // see which extras for this colour have been flipped + for (auto & q: color_extras[cs[j]]){ + flipped_extras.insert(q); // and which of those were not given - if (given_logicals.find(q) == given_logicals.end()) { - flipped_ng_logicals.insert(q); + if (given_extras.find(q) == given_extras.end()) { + flipped_ng_extras.insert(q); } } - // see which extra logicals are given - extra_logicals = {}; - if (not ignore_extra_logicals) { - for (auto & q: given_logicals){ - if ((flipped_logicals.find(q) == flipped_logicals.end())) { - extra_logicals.push_back(q); + // see which extra extras are given + extra_extras = {}; + if (not ignore_extras) { + for (auto & q: given_extras){ + if ((flipped_extras.find(q) == flipped_extras.end())) { + extra_extras.push_back(q); } } } - // if we have no extra logicals, or we've run out of colours, we move on + // if we have no extra extras, or we've run out of colours, we move on // otherwise we try the next colour - done = (extra_logicals.size()==0) || (j+1)==cs.size(); + done = (extra_extras.size()==0) || (j+1)==cs.size(); if (not done){ j++; } } // construct output - output.push_back(extra_logicals.size()==0); // neutral + output.push_back(extra_extras.size()==0); // neutral output.push_back(num_nodes[cs[j]]); // num_errors - for (auto & q: flipped_ng_logicals){ + for (auto & q: flipped_ng_extras){ output.push_back(q); } diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index a11cc366..e80931da 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -295,7 +295,7 @@ def string2nodes(self, string, **kwargs): separated_string = _separate_string(string) nodes = [] - # logical nodes + # logical/boundary nodes boundary = separated_string[0] # [, ] for bqec_index, belement in enumerate(boundary[::-1]): if all_logicals or belement != logical: @@ -326,7 +326,7 @@ def string2raw_logicals(self, string): """ return [string.split(" ", maxsplit=1)[0][-1]] - def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): + def check_nodes(self, nodes, ignore_extras=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -334,7 +334,7 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_logical (bool): If `True`, undeeded boundary nodes are + ignore_extras (bool): If `True`, undeeded boundary nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. @@ -397,7 +397,7 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): # if unneeded logical zs are given, cluster is not neutral # (unless this is ignored) - if (not ignore_extra_logical) and given_logicals.difference(flipped_logicals): + if (not ignore_extras) and given_logicals.difference(flipped_logicals): neutral = False # otherwise, report only needed logicals that aren't given else: @@ -418,15 +418,15 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): return neutral, flipped_logical_nodes, num_errors - def is_cluster_neutral(self, atypical_nodes): + def is_cluster_neutral(self, nodes): """ Determines whether or not the cluster is neutral, meaning that one or more - errors could have caused the set of atypical nodes (syndrome changes) passed + errors could have caused the set of nodes (syndrome changes) passed to the method. Args: - atypical_nodes (dictionary in the form of the return value of string2nodes) + nodes (list of nodes) """ - return not bool(len(atypical_nodes) % 2) + return not bool(len(nodes) % 2) def partition_outcomes( self, round_schedule: str, outcome: List[int] @@ -785,13 +785,17 @@ def _preparation(self): self.circuit[basis].x(self.code_qubit) self._basis_change(basis) - # use degree 1 code qubits for logical z readouts + # use degree 1 code qubits for logical z readouts (and boundary) graph = self._get_coupling_graph() + self._leaves = False z_logicals = [] + self.boundary = [] for n, node in enumerate(graph.nodes()): if graph.degree(n) == 1: z_logicals.append(node) - # if there are none, just use the first + self.boundary.append(node) + self._leaves = True + # if there are none, just use the first (not boundary) if not z_logicals: z_logicals = [min(self.code_index.keys())] self.z_logicals = z_logicals @@ -1070,6 +1074,7 @@ def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]: tau = 0 node = DecodingGraphNode( is_logical=is_logical, + is_boundary=(is_logical and self._leaves), time=syn_round if not is_logical else None, qubits=code_qubits, index=elem_num, @@ -1131,7 +1136,20 @@ def _links2cpp(self): link_neighbors[node].append(nodes[j]) return link_graph, link_neighbors - def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): + def _extras2cpp(self): + """ + Returns logical and boundary nodes as tuples. First value is the qubit, + second is 1 for logical only, 2 for boundary only and 3 for both. + """ + extras = {} + for q in self.z_logicals: + extras[q] = 1 + 2 * (q in self.boundary) + for q in self.boundary: + if q not in self.z_logicals: + extras[q] = 2 + return extras + + def check_nodes(self, nodes, ignore_extras=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -1139,7 +1157,7 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_logical (bool): If `True`, undeeded boundary nodes are + ignore_extras (bool): If `True`, undeeded boundary and logical nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. @@ -1156,36 +1174,43 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): cpp_output = _c_check_nodes( nodes, - ignore_extra_logical, + ignore_extras, minimal, self.cycle_dict, self._cpp_link_graph, self._cpp_link_neighbors, - self.z_logicals, + self._extras2cpp(), ) neutral = bool(cpp_output[0]) num_errors = cpp_output[1] - flipped_logical_nodes = [] - for flipped_logical in cpp_output[2::]: + flipped_extra_nodes = [] + for flipped_extra in cpp_output[2::]: + is_logical = flipped_extra in self.z_logicals + is_boundary = flipped_extra in self.boundary + if is_logical: + index = self.z_logicals.index(flipped_extra) + else: + index = self.boundary.index(flipped_extra) node = DecodingGraphNode( - is_logical=True, - qubits=[flipped_logical], - index=self.z_logicals.index(flipped_logical), + is_logical=is_logical, + is_boundary=is_boundary, + qubits=[flipped_extra], + index=index, ) - flipped_logical_nodes.append(node) + flipped_extra_nodes.append(node) - return neutral, flipped_logical_nodes, num_errors + return neutral, flipped_extra_nodes, num_errors - def is_cluster_neutral(self, atypical_nodes: dict): + def is_cluster_neutral(self, nodes: dict): """ Determines whether or not the cluster is neutral, meaning that one or more - errors could have caused the set of atypical nodes (syndrome changes) passed + errors could have caused the set of nodes (syndrome changes) passed to the method. Args: - atypical_nodes: dictionary in the form of the return value of string2nodes + nodes: dictionary in the form of the return value of string2nodes """ - nodes = _nodes2cpp(atypical_nodes) + nodes = _nodes2cpp(nodes) return _c_is_cluster_neutral( nodes, False, @@ -1193,7 +1218,7 @@ def is_cluster_neutral(self, atypical_nodes: dict): self.cycle_dict, self._cpp_link_graph, self._cpp_link_neighbors, - self.z_logicals, + self._extras2cpp(), self._linear, ) diff --git a/src/qiskit_qec/circuits/stim_code_circuit.py b/src/qiskit_qec/circuits/stim_code_circuit.py index 3325e5d8..8dade69c 100644 --- a/src/qiskit_qec/circuits/stim_code_circuit.py +++ b/src/qiskit_qec/circuits/stim_code_circuit.py @@ -587,5 +587,5 @@ def _make_syndrome_graph(self): graph, hyperedges = detector_error_model_to_rx_graph(e, detectors=self.detectors) return graph, hyperedges - def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): + def check_nodes(self, nodes, ignore_extras=False, minimal=False): raise NotImplementedError diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 6a7f2ba1..a779186d 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -395,6 +395,7 @@ def string2nodes(self, string, **kwargs): if all_logicals or belement != logical: node = DecodingGraphNode( is_logical=True, + is_boundary=True, qubits=self._logicals[self.basis][-bqec_index - 1], index=1 - bqec_index, ) @@ -414,7 +415,7 @@ def string2nodes(self, string, **kwargs): nodes.append(node) return nodes - def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): + def check_nodes(self, nodes, ignore_extras=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -422,7 +423,7 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): would be required to make the cluster. Args: nodes (list): List of nodes, of the type produced by `string2nodes`. - ignore_extra_logical (bool): If `True`, undeeded logical nodes are + ignore_extras (bool): If `True`, undeeded logical nodes are ignored. minimal (bool): Whether output should only reflect the minimal error case. @@ -445,7 +446,7 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): coords = self._xplaq_coords if (len(bulk_nodes) % 2) == 0: - if (len(logical_nodes) % 2) == 0 or ignore_extra_logical: + if (len(logical_nodes) % 2) == 0 or ignore_extras: neutral = True flipped_logicals = set() # estimate num_errors from size @@ -481,7 +482,7 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): # if unneeded logical zs are given, cluster is not neutral # (unless this is ignored) - if (not ignore_extra_logical) and given_logicals.difference(flipped_logicals): + if (not ignore_extras) and given_logicals.difference(flipped_logicals): neutral = False # otherwise, report only needed logicals that aren't given else: @@ -492,18 +493,21 @@ def check_nodes(self, nodes, ignore_extra_logical=False, minimal=False): flipped_logical_nodes = [] for elem in flipped_logicals: node = DecodingGraphNode( - is_logical=True, qubits=self._logicals[self.basis][elem], index=elem + is_logical=True, + is_boundary=True, + qubits=self._logicals[self.basis][elem], + index=elem, ) flipped_logical_nodes.append(node) return neutral, flipped_logical_nodes, num_errors - def is_cluster_neutral(self, atypical_nodes): + def is_cluster_neutral(self, nodes): """ Determines whether or not the cluster is neutral, meaning that one or more - errors could have caused the set of atypical nodes (syndrome changes) passed + errors could have caused the set of nodes (syndrome changes) passed to the method. Args: - atypical_nodes (dictionary in the form of the return value of string2nodes) + nodes (dictionary in the form of the return value of string2nodes) """ - return not bool(len(atypical_nodes) % 2) + return not bool(len(nodes) % 2) diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 4fa96b68..d6ed8e7d 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -96,16 +96,21 @@ def get_corrections(self, string, clusters): class BravyiHaahDecoder(ClusteringDecoder): """Decoder based on finding connected components within the decoding graph.""" + def __init__( + self, + code_circuit, + decoding_graph: DecodingGraph = None, + ): + super().__init__(code_circuit, decoding_graph) + self._distance = distance_matrix(self.decoding_graph.graph) + def _cluster(self, ns, dist_max): """ Finds connected components in the given nodes, for nodes connected by at most the given distance in the given decoding graph. """ - # calculate distance for the graph dg = self.decoding_graph.graph - distance = distance_matrix(dg) - # create empty `DecodingGraph` cluster_graph = DecodingGraph(None) cg = cluster_graph.graph @@ -120,7 +125,7 @@ def _cluster(self, ns, dist_max): for n0 in ns: for n1 in ns: if n0 < n1: - dist = distance[n0, n1] + dist = self._distance[n0, n1] if dist <= dist_max: cg.add_edge(d2c[n0], d2c[n1], {"distance": dist}) # find the connected components of cg @@ -134,9 +139,7 @@ def _cluster(self, ns, dist_max): # check the neutrality of each connected component con_nodes = [cg[n] for n in con_comp] - neutral, logicals, num_errors = self.code.check_nodes( - con_nodes, ignore_extra_logical=True - ) + neutral, logicals, num_errors = self.code.check_nodes(con_nodes, ignore_extras=True) # it's fully neutral if no extra logicals are needed # and if the error num is less than the max dist @@ -266,6 +269,7 @@ class UnionFindDecoderCluster: boundary_nodes: Set[int] nodes: Set[int] fully_grown_edges: Set[int] + edge_support: Set[Tuple[int]] size: int @@ -287,17 +291,22 @@ class UnionFindDecoder(ClusteringDecoder): by the peeling decoder for compatible codes or by the standard HDRG method in general. - TODO: Add weights to edges of graph according to Huang et al (see. arXiv:2004.04693, section III) - See arXiv:1709.06218v3 for more details. """ - def __init__(self, code, decoding_graph: DecodingGraph = None, use_peeling=True) -> None: - super().__init__(code, decoding_graph=decoding_graph) - self.graph = deepcopy(self.decoding_graph.graph) + def __init__( + self, + code, + decoding_graph: DecodingGraph = None, + use_peeling=True, + use_is_cluster_neutral=False, + ) -> None: + super().__init__(code, decoding_graph=deepcopy(decoding_graph)) + self.graph = self.decoding_graph.graph self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots: List[int] = [] self.use_peeling = use_peeling + self.use_is_cluster_neutral = use_is_cluster_neutral self._clusters4peeling = [] def process(self, string: str, predecoder=None): @@ -369,7 +378,7 @@ def cluster(self, nodes: List): value. """ node_indices = [self.decoding_graph.node_index(node) for node in nodes] - for node_index, _ in enumerate(self.graph.nodes()): + for node_index in self.graph.node_indexes(): self.graph[node_index].properties["syndrome"] = node_index in node_indices self.graph[node_index].properties["root"] = node_index @@ -434,6 +443,7 @@ def _create_new_cluster(self, node_index): self.clusters[node_index] = UnionFindDecoderCluster( boundary=boundary_edges, fully_grown_edges=set(), + edge_support=set(), atypical_nodes=set([node_index]) if not node.is_logical else set([]), boundary_nodes=set([node_index]) if node.is_logical else set([]), nodes=set([node_index]), @@ -476,6 +486,7 @@ def _grow_clusters(self) -> List[FusionEntry]: self.clusters[edge.neighbour_vertex] = UnionFindDecoderCluster( boundary=boundary_edges, fully_grown_edges=set(), + edge_support=set(), atypical_nodes=set(), boundary_nodes=set([edge.neighbour_vertex]) if self.graph[edge.neighbour_vertex].is_logical @@ -517,6 +528,9 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]): entry.connecting_edge.data.properties["growth"] = 0 entry.connecting_edge.data.properties["fully_grown"] = True cluster.fully_grown_edges.add(entry.connecting_edge.index) + cluster.edge_support.add( + tuple(self.graph.get_edge_data_by_index(entry.connecting_edge.index).qubits) + ) # Merge boundaries cluster.boundary += other_cluster.boundary @@ -527,18 +541,28 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]): cluster.atypical_nodes |= other_cluster.atypical_nodes cluster.boundary_nodes |= other_cluster.boundary_nodes cluster.fully_grown_edges |= other_cluster.fully_grown_edges + cluster.edge_support |= other_cluster.edge_support cluster.size += other_cluster.size - # update odd_cluster_roots - if self.code.is_cluster_neutral( - [self.graph[node] for node in cluster.atypical_nodes] - ) or self.code.is_cluster_neutral( + # see if the cluster is neutral and update odd_cluster_roots accordingly + fully_neutral = False + for nodes in [ + [self.graph[node] for node in cluster.atypical_nodes], [ self.graph[node] for node in cluster.atypical_nodes | (set(list(cluster.boundary_nodes)[:1]) if cluster.boundary_nodes else set()) - ] - ): + ], + ]: + if self.use_is_cluster_neutral: + fully_neutral = self.code.is_cluster_neutral(nodes) + else: + neutral, extras, num = self.code.check_nodes(nodes) + for node in extras: + neutral = neutral and (not node.is_boundary) + neutral = neutral and num <= len(cluster.edge_support) + fully_neutral = fully_neutral or neutral + if fully_neutral: if new_root in self.odd_cluster_roots: self.odd_cluster_roots.remove(new_root) new_neutral_clusters.append(new_root) diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index 3e619590..9571b4b2 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -179,7 +179,9 @@ def _nodes2cpp(nodes): """ Convert a list of nodes to the form required by C++ functions. """ - # nodes are a tuple with (q0, q1,t, boundary) + # nodes are a tuple with (q0, q1,t, extra) + # extra is 0 if neither logical nor boundary + # 1 for logical, 2 for boundary, 3 for both # if there is no q1 or t, -1 is used cnodes = [] for node in nodes: @@ -190,6 +192,6 @@ def _nodes2cpp(nodes): cnode.append(-1) else: cnode.append(node.time) - cnode.append(node.is_logical) + cnode.append(1 * node.is_logical + 2 * node.is_boundary) cnodes.append(tuple(cnode)) return cnodes diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index fc3d4790..a5d3b576 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -242,6 +242,11 @@ def single_error_test( ts = [node.time for node in nodes if not node.is_logical] if ts: minimal = minimal and (max(ts) - min(ts)) <= 1 + # check that it corresponds to more than one node (or none) + self.assertTrue( + len(nodes) != 1, + "Error: Single error creates only one node", + ) # check that it doesn't extend beyond the neigbourhood of a code qubit flat_nodes = code.flatten_nodes(nodes) link_qubits = set(node.properties["link qubit"] for node in flat_nodes) diff --git a/test/code_circuits/test_surface_codes.py b/test/code_circuits/test_surface_codes.py index c8cc8b36..22ba1417 100644 --- a/test/code_circuits/test_surface_codes.py +++ b/test/code_circuits/test_surface_codes.py @@ -68,6 +68,7 @@ def test_string2nodes(self): [ DecodingGraphNode( is_logical=True, + is_boundary=True, qubits=[0, 3, 6], index=0, ), @@ -80,6 +81,7 @@ def test_string2nodes(self): [ DecodingGraphNode( is_logical=True, + is_boundary=True, qubits=[2, 5, 8], index=1, ), @@ -117,11 +119,11 @@ def test_string2nodes(self): ), ], [ - DecodingGraphNode(is_logical=True, qubits=[0, 1, 2], index=0), + DecodingGraphNode(is_logical=True, is_boundary=True, qubits=[0, 1, 2], index=0), DecodingGraphNode(time=1, qubits=[0, 3], index=0), ], [ - DecodingGraphNode(is_logical=True, qubits=[8, 7, 6], index=1), + DecodingGraphNode(is_logical=True, is_boundary=True, qubits=[8, 7, 6], index=1), DecodingGraphNode(time=1, qubits=[5, 8], index=3), ], ] @@ -133,7 +135,15 @@ def test_string2nodes(self): generated_nodes = code.string2nodes(string) self.assertTrue( generated_nodes == nodes, - "Incorrect nodes for basis = " + basis + " for string = " + string + ".", + "Nodes for basis = " + + basis + + " and string = " + + string + + " are\ + \n" + + str(generated_nodes) + + " not\n" + + str(nodes), ) def test_check_nodes(self): @@ -149,26 +159,26 @@ def test_check_nodes(self): valid = valid and code.check_nodes(nodes) == (True, [], 0) # on one side nodes = [ - DecodingGraphNode(qubits=[0, 1, 2], is_logical=True, index=0), + DecodingGraphNode(qubits=[0, 1, 2], is_logical=True, is_boundary=True, index=0), DecodingGraphNode(time=3, qubits=[0, 3], index=0), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) nodes = [DecodingGraphNode(time=3, qubits=[0, 3], index=0)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[0, 1, 2], is_logical=True, index=0)], + [DecodingGraphNode(qubits=[0, 1, 2], is_logical=True, is_boundary=True, index=0)], 1.0, ) # and the other nodes = [ - DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, index=1), + DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, is_boundary=True, index=1), DecodingGraphNode(time=3, qubits=[5, 8], index=3), ] valid = valid and code.check_nodes(nodes) == (True, [], 1.0) nodes = [DecodingGraphNode(time=3, qubits=[5, 8], index=3)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, index=1)], + [DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, is_boundary=True, index=1)], 1.0, ) # and in the middle @@ -180,7 +190,7 @@ def test_check_nodes(self): nodes = [DecodingGraphNode(time=3, qubits=[3, 6, 4, 7], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, index=1)], + [DecodingGraphNode(qubits=[8, 7, 6], is_logical=True, is_boundary=True, index=1)], 1.0, ) @@ -194,7 +204,7 @@ def test_check_nodes(self): nodes = [DecodingGraphNode(time=3, qubits=[4, 5, 7, 8], index=2)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[2, 5, 8], is_logical=True, index=1)], + [DecodingGraphNode(qubits=[2, 5, 8], is_logical=True, is_boundary=True, index=1)], 1.0, ) @@ -208,33 +218,43 @@ def test_check_nodes(self): nodes = [DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7)] valid = valid and code.check_nodes(nodes) == ( True, - [DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1)], + [ + DecodingGraphNode( + qubits=[24, 23, 22, 21, 20], is_logical=True, is_boundary=True, index=1 + ) + ], 2.0, ) # wrong logical nodes = [ DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), - DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1), + DecodingGraphNode( + qubits=[24, 23, 22, 21, 20], is_logical=True, is_boundary=True, index=1 + ), ] valid = valid and code.check_nodes(nodes) == ( False, - [DecodingGraphNode(qubits=[0, 1, 2, 3, 4], is_logical=True, index=0)], + [DecodingGraphNode(qubits=[0, 1, 2, 3, 4], is_logical=True, is_boundary=True, index=0)], 2, ) # extra logical nodes = [ DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), - DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1), + DecodingGraphNode( + qubits=[24, 23, 22, 21, 20], is_logical=True, is_boundary=True, index=1 + ), ] valid = valid and code.check_nodes(nodes) == (False, [], 0) # ignoring extra nodes = [ DecodingGraphNode(time=3, qubits=[7, 12, 8, 13], index=4), DecodingGraphNode(time=3, qubits=[11, 16, 12, 17], index=7), - DecodingGraphNode(qubits=[24, 23, 22, 21, 20], is_logical=True, index=1), + DecodingGraphNode( + qubits=[24, 23, 22, 21, 20], is_logical=True, is_boundary=True, index=1 + ), ] - valid = valid and code.check_nodes(nodes, ignore_extra_logical=True) == (True, [], 1) + valid = valid and code.check_nodes(nodes, ignore_extras=True) == (True, [], 1) self.assertTrue(valid, "A set of nodes did not give the expected outcome for check_nodes.") diff --git a/test/decoding/test_union_find.py b/test/decoding/test_union_find.py index 08f2335a..53016b95 100644 --- a/test/decoding/test_union_find.py +++ b/test/decoding/test_union_find.py @@ -16,6 +16,7 @@ from unittest import TestCase from qiskit_qec.decoders import UnionFindDecoder from qiskit_qec.circuits import SurfaceCodeCircuit +from qiskit_qec.circuits.repetition_code import ArcCircuit class UnionFindDecoderTest(TestCase): @@ -28,14 +29,42 @@ def test_surface_code_d3(self): """ for logical in ["0", "1"]: code = SurfaceCodeCircuit(d=3, T=1) - decoder = UnionFindDecoder(code) - for j in range(code.n): - string = logical * (j) + str((1 + int(logical)) % 2) + logical * (code.n - j - 1) - string += " 0000 0000" - corrected_outcome = decoder.process(string) - self.assertTrue( - corrected_outcome[0] == int(logical), "Correction for surface code failed." - ) + for use_is_cluster_neutral in [True, False]: + decoder = UnionFindDecoder(code, use_is_cluster_neutral=use_is_cluster_neutral) + for j in range(code.n): + string = ( + logical * (j) + str((1 + int(logical)) % 2) + logical * (code.n - j - 1) + ) + string += " 0000 0000" + corrected_outcome = decoder.process(string) + self.assertTrue( + corrected_outcome[0] == int(logical), "Correction for surface code failed." + ) + + def test_hourglass_ARC(self): + """ + Tests that clustering is done correctly on an ARC with (possibly misleading) loops + """ + links = [ + (0, 1, 2), + (0, 3, 4), + (2, 5, 6), + (4, 7, 8), + (6, 9, 8), + (8, 11, 10), + (10, 13, 12), + (12, 15, 14), + (12, 17, 16), + (14, 19, 18), + (16, 21, 20), + (18, 22, 20), + ] + code = ArcCircuit(links, 0) + decoder = UnionFindDecoder(code) + cluster = decoder.cluster(code.string2nodes("00001110000")) + self.assertTrue( + len(set(cluster.values())) == 1, "Clustering doesn't create single cluster." + ) if __name__ == "__main__":