diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index d6ed8e7d..754492dc 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -17,11 +17,12 @@ """Hard decision renormalization group decoders.""" from abc import ABC -from copy import copy, deepcopy +from copy import copy from dataclasses import dataclass from typing import Dict, List, Set, Tuple from rustworkx import PyGraph, connected_components, distance_matrix +import numpy as np from qiskit_qec.decoders.decoding_graph import DecodingGraph from qiskit_qec.utils import DecodingGraphEdge @@ -291,7 +292,11 @@ class UnionFindDecoder(ClusteringDecoder): by the peeling decoder for compatible codes or by the standard HDRG method in general. - See arXiv:1709.06218v3 for more details. + To avoid using the peeling decoder, and instead use the standard + method for clustering decoders to get corrections, set `use_peeling=False`. + Growth unit is 0.5 by default, but can be changed with `growth_unit`. + To use half the minimum boundarye edge weight for each clustering round, + set `growth_unit=None`. """ def __init__( @@ -300,14 +305,17 @@ def __init__( decoding_graph: DecodingGraph = None, use_peeling=True, use_is_cluster_neutral=False, + growth_unit=0.5, ) -> None: - super().__init__(code, decoding_graph=deepcopy(decoding_graph)) + super().__init__(code, decoding_graph=decoding_graph) self.graph = self.decoding_graph.graph self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots: List[int] = [] self.use_peeling = use_peeling self.use_is_cluster_neutral = use_is_cluster_neutral self._clusters4peeling = [] + self.growth_unit = growth_unit + self._growth_unit = None def process(self, string: str, predecoder=None): """ @@ -324,7 +332,6 @@ def process(self, string: str, predecoder=None): """ if self.use_peeling: - self.graph = deepcopy(self.decoding_graph.graph) highlighted_nodes = self.code.string2nodes(string, all_logicals=True) if predecoder: highlighted_nodes = predecoder(highlighted_nodes) @@ -377,6 +384,11 @@ def cluster(self, nodes: List): the given node as keys and an integer specifying their cluster as the corresponding value. """ + if self.growth_unit: + self._growth_unit = self.growth_unit + else: + self._growth_unit = 0 + node_indices = [self.decoding_graph.node_index(node) for node in nodes] for node_index in self.graph.node_indexes(): self.graph[node_index].properties["syndrome"] = node_index in node_indices @@ -439,7 +451,7 @@ def _create_new_cluster(self, node_index): self.odd_cluster_roots.insert(0, node_index) boundary_edges = [] for edge_index, neighbour, data in self.neighbouring_edges(node_index): - boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, data)) + boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, copy(data))) self.clusters[node_index] = UnionFindDecoderCluster( boundary=boundary_edges, fully_grown_edges=set(), @@ -452,7 +464,7 @@ def _create_new_cluster(self, node_index): def _grow_and_merge_clusters(self) -> Set[int]: fusion_edge_list = self._grow_clusters() - return self._merge_clusters(fusion_edge_list) + self._merge_clusters(fusion_edge_list) def _grow_clusters(self) -> List[FusionEntry]: """ @@ -463,10 +475,18 @@ def _grow_clusters(self) -> List[FusionEntry]: clusters that will be merged in the next step. """ fusion_edge_list: List[FusionEntry] = [] + + if not self.growth_unit: + min_weight = np.inf + for root in self.odd_cluster_roots: + cluster = self.clusters[root] + for edge in cluster.boundary: + min_weight = max(min(min_weight, edge.data.weight), 1e-6) + self._growth_unit = min_weight / 2 for root in self.odd_cluster_roots: cluster = self.clusters[root] for edge in cluster.boundary: - edge.data.properties["growth"] += 0.5 + edge.data.properties["growth"] += self._growth_unit if ( edge.data.properties["growth"] >= edge.data.weight and not edge.data.properties["fully_grown"] @@ -508,9 +528,8 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]): Args: fusion_edge_list (List[FusionEntry]): List of edges that connect two clusters that was computed in _grow_clusters(). - Returns: - new_neutral_cluster_roots (List[int]): List of roots of newly neutral clusters """ + new_neutral_clusters = [] for entry in fusion_edge_list: root_u, root_v = self.find(entry.u), self.find(entry.v) @@ -546,22 +565,27 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]): # see if the cluster is neutral and update odd_cluster_roots accordingly fully_neutral = False - for nodes in [ - [self.graph[node] for node in cluster.atypical_nodes], - [ - self.graph[node] - for node in cluster.atypical_nodes - | (set(list(cluster.boundary_nodes)[:1]) if cluster.boundary_nodes else set()) - ], - ]: - if self.use_is_cluster_neutral: - fully_neutral = self.code.is_cluster_neutral(nodes) - else: - neutral, extras, num = self.code.check_nodes(nodes) - for node in extras: - neutral = neutral and (not node.is_boundary) - neutral = neutral and num <= len(cluster.edge_support) - fully_neutral = fully_neutral or neutral + if self._growth_unit: # assume non-neutral while growing along 0-weight edges + for nodes in [ + [self.graph[node] for node in cluster.atypical_nodes], + [ + self.graph[node] + for node in cluster.atypical_nodes + | ( + set(list(cluster.boundary_nodes)[:1]) + if cluster.boundary_nodes + else set() + ) + ], + ]: + if self.use_is_cluster_neutral: + fully_neutral = self.code.is_cluster_neutral(nodes) + else: + neutral, extras, num = self.code.check_nodes(nodes) + for node in extras: + neutral = neutral and (not node.is_boundary) + neutral = neutral and num <= len(cluster.edge_support) + fully_neutral = fully_neutral or neutral if fully_neutral: if new_root in self.odd_cluster_roots: self.odd_cluster_roots.remove(new_root) diff --git a/src/qiskit_qec/decoders/pymatching_decoder.py b/src/qiskit_qec/decoders/pymatching_decoder.py index 1e1f24ab..cf9bd685 100644 --- a/src/qiskit_qec/decoders/pymatching_decoder.py +++ b/src/qiskit_qec/decoders/pymatching_decoder.py @@ -89,19 +89,17 @@ def matched_edges( or a list of binaries indicating which node is highlighted. Returns: list of DecodingGraphEdge-s included in the matching """ - if isinstance(syndrome[0], DecodingGraphNode): - syndrome = self.nodes_to_detections(syndrome) - edge_dets = list(self.graph.edge_list()) - edges = self.graph.edges() - matched_det_pairs = self.matcher.decode_to_edges_array(syndrome) det_pairs = [] - for pair in matched_det_pairs: - if pair[1] == -1: - pair[-1] = pair[-1] + len(self.graph.nodes()) - pair.sort() - det_pairs.append(tuple(pair)) - mached_edges = [edges[edge_dets.index(det_pair)] for det_pair in det_pairs] - return mached_edges + if syndrome: + if isinstance(syndrome[0], DecodingGraphNode): + syndrome = self.nodes_to_detections(syndrome) + matched_det_pairs = self.matcher.decode_to_edges_array(syndrome) + for pair in matched_det_pairs: + if pair[1] == -1: + pair[-1] = pair[-1] + len(self.graph.nodes()) + pair.sort() + det_pairs.append(tuple(pair)) + return det_pairs def nodes_to_detections(self, syndrome_nodes: List[DecodingGraphNode]) -> List[int]: """Converts nodes to detector indices to be used by pymatching.Matching.decode"""