Skip to content

Commit

Permalink
Add variable edge weight for UF (#428)
Browse files Browse the repository at this point in the history
* variable edge weight

* improve reweighting
  • Loading branch information
quantumjim authored May 3, 2024
1 parent d61cb32 commit c580e06
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 37 deletions.
74 changes: 49 additions & 25 deletions src/qiskit_qec/decoders/hdrg_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
"""Hard decision renormalization group decoders."""

from abc import ABC
from copy import copy, deepcopy
from copy import copy
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple

from rustworkx import PyGraph, connected_components, distance_matrix
import numpy as np

from qiskit_qec.decoders.decoding_graph import DecodingGraph
from qiskit_qec.utils import DecodingGraphEdge
Expand Down Expand Up @@ -291,7 +292,11 @@ class UnionFindDecoder(ClusteringDecoder):
by the peeling decoder for compatible codes or by the standard HDRG
method in general.
See arXiv:1709.06218v3 for more details.
To avoid using the peeling decoder, and instead use the standard
method for clustering decoders to get corrections, set `use_peeling=False`.
Growth unit is 0.5 by default, but can be changed with `growth_unit`.
To use half the minimum boundarye edge weight for each clustering round,
set `growth_unit=None`.
"""

def __init__(
Expand All @@ -300,14 +305,17 @@ def __init__(
decoding_graph: DecodingGraph = None,
use_peeling=True,
use_is_cluster_neutral=False,
growth_unit=0.5,
) -> None:
super().__init__(code, decoding_graph=deepcopy(decoding_graph))
super().__init__(code, decoding_graph=decoding_graph)
self.graph = self.decoding_graph.graph
self.clusters: Dict[int, UnionFindDecoderCluster] = {}
self.odd_cluster_roots: List[int] = []
self.use_peeling = use_peeling
self.use_is_cluster_neutral = use_is_cluster_neutral
self._clusters4peeling = []
self.growth_unit = growth_unit
self._growth_unit = None

def process(self, string: str, predecoder=None):
"""
Expand All @@ -324,7 +332,6 @@ def process(self, string: str, predecoder=None):
"""

if self.use_peeling:
self.graph = deepcopy(self.decoding_graph.graph)
highlighted_nodes = self.code.string2nodes(string, all_logicals=True)
if predecoder:
highlighted_nodes = predecoder(highlighted_nodes)
Expand Down Expand Up @@ -377,6 +384,11 @@ def cluster(self, nodes: List):
the given node as keys and an integer specifying their cluster as the corresponding
value.
"""
if self.growth_unit:
self._growth_unit = self.growth_unit
else:
self._growth_unit = 0

node_indices = [self.decoding_graph.node_index(node) for node in nodes]
for node_index in self.graph.node_indexes():
self.graph[node_index].properties["syndrome"] = node_index in node_indices
Expand Down Expand Up @@ -439,7 +451,7 @@ def _create_new_cluster(self, node_index):
self.odd_cluster_roots.insert(0, node_index)
boundary_edges = []
for edge_index, neighbour, data in self.neighbouring_edges(node_index):
boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, data))
boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, copy(data)))
self.clusters[node_index] = UnionFindDecoderCluster(
boundary=boundary_edges,
fully_grown_edges=set(),
Expand All @@ -452,7 +464,7 @@ def _create_new_cluster(self, node_index):

def _grow_and_merge_clusters(self) -> Set[int]:
fusion_edge_list = self._grow_clusters()
return self._merge_clusters(fusion_edge_list)
self._merge_clusters(fusion_edge_list)

def _grow_clusters(self) -> List[FusionEntry]:
"""
Expand All @@ -463,10 +475,18 @@ def _grow_clusters(self) -> List[FusionEntry]:
clusters that will be merged in the next step.
"""
fusion_edge_list: List[FusionEntry] = []

if not self.growth_unit:
min_weight = np.inf
for root in self.odd_cluster_roots:
cluster = self.clusters[root]
for edge in cluster.boundary:
min_weight = max(min(min_weight, edge.data.weight), 1e-6)
self._growth_unit = min_weight / 2
for root in self.odd_cluster_roots:
cluster = self.clusters[root]
for edge in cluster.boundary:
edge.data.properties["growth"] += 0.5
edge.data.properties["growth"] += self._growth_unit
if (
edge.data.properties["growth"] >= edge.data.weight
and not edge.data.properties["fully_grown"]
Expand Down Expand Up @@ -508,9 +528,8 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]):
Args:
fusion_edge_list (List[FusionEntry]): List of edges that connect two
clusters that was computed in _grow_clusters().
Returns:
new_neutral_cluster_roots (List[int]): List of roots of newly neutral clusters
"""

new_neutral_clusters = []
for entry in fusion_edge_list:
root_u, root_v = self.find(entry.u), self.find(entry.v)
Expand Down Expand Up @@ -546,22 +565,27 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]):

# see if the cluster is neutral and update odd_cluster_roots accordingly
fully_neutral = False
for nodes in [
[self.graph[node] for node in cluster.atypical_nodes],
[
self.graph[node]
for node in cluster.atypical_nodes
| (set(list(cluster.boundary_nodes)[:1]) if cluster.boundary_nodes else set())
],
]:
if self.use_is_cluster_neutral:
fully_neutral = self.code.is_cluster_neutral(nodes)
else:
neutral, extras, num = self.code.check_nodes(nodes)
for node in extras:
neutral = neutral and (not node.is_boundary)
neutral = neutral and num <= len(cluster.edge_support)
fully_neutral = fully_neutral or neutral
if self._growth_unit: # assume non-neutral while growing along 0-weight edges
for nodes in [
[self.graph[node] for node in cluster.atypical_nodes],
[
self.graph[node]
for node in cluster.atypical_nodes
| (
set(list(cluster.boundary_nodes)[:1])
if cluster.boundary_nodes
else set()
)
],
]:
if self.use_is_cluster_neutral:
fully_neutral = self.code.is_cluster_neutral(nodes)
else:
neutral, extras, num = self.code.check_nodes(nodes)
for node in extras:
neutral = neutral and (not node.is_boundary)
neutral = neutral and num <= len(cluster.edge_support)
fully_neutral = fully_neutral or neutral
if fully_neutral:
if new_root in self.odd_cluster_roots:
self.odd_cluster_roots.remove(new_root)
Expand Down
22 changes: 10 additions & 12 deletions src/qiskit_qec/decoders/pymatching_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,17 @@ def matched_edges(
or a list of binaries indicating which node is highlighted.
Returns: list of DecodingGraphEdge-s included in the matching
"""
if isinstance(syndrome[0], DecodingGraphNode):
syndrome = self.nodes_to_detections(syndrome)
edge_dets = list(self.graph.edge_list())
edges = self.graph.edges()
matched_det_pairs = self.matcher.decode_to_edges_array(syndrome)
det_pairs = []
for pair in matched_det_pairs:
if pair[1] == -1:
pair[-1] = pair[-1] + len(self.graph.nodes())
pair.sort()
det_pairs.append(tuple(pair))
mached_edges = [edges[edge_dets.index(det_pair)] for det_pair in det_pairs]
return mached_edges
if syndrome:
if isinstance(syndrome[0], DecodingGraphNode):
syndrome = self.nodes_to_detections(syndrome)
matched_det_pairs = self.matcher.decode_to_edges_array(syndrome)
for pair in matched_det_pairs:
if pair[1] == -1:
pair[-1] = pair[-1] + len(self.graph.nodes())
pair.sort()
det_pairs.append(tuple(pair))
return det_pairs

def nodes_to_detections(self, syndrome_nodes: List[DecodingGraphNode]) -> List[int]:
"""Converts nodes to detector indices to be used by pymatching.Matching.decode"""
Expand Down

0 comments on commit c580e06

Please sign in to comment.