Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add variable edge weight for UF #428

Merged
merged 2 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions src/qiskit_qec/decoders/hdrg_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
"""Hard decision renormalization group decoders."""

from abc import ABC
from copy import copy, deepcopy
from copy import copy
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple

from rustworkx import PyGraph, connected_components, distance_matrix
import numpy as np

from qiskit_qec.decoders.decoding_graph import DecodingGraph
from qiskit_qec.utils import DecodingGraphEdge
Expand Down Expand Up @@ -291,7 +292,11 @@ class UnionFindDecoder(ClusteringDecoder):
by the peeling decoder for compatible codes or by the standard HDRG
method in general.

See arXiv:1709.06218v3 for more details.
grace-harper marked this conversation as resolved.
Show resolved Hide resolved
To avoid using the peeling decoder, and instead use the standard
method for clustering decoders to get corrections, set `use_peeling=False`.
Growth unit is 0.5 by default, but can be changed with `growth_unit`.
To use half the minimum boundarye edge weight for each clustering round,
set `growth_unit=None`.
"""

def __init__(
Expand All @@ -300,14 +305,17 @@ def __init__(
decoding_graph: DecodingGraph = None,
use_peeling=True,
use_is_cluster_neutral=False,
growth_unit=0.5,
) -> None:
super().__init__(code, decoding_graph=deepcopy(decoding_graph))
super().__init__(code, decoding_graph=decoding_graph)
self.graph = self.decoding_graph.graph
self.clusters: Dict[int, UnionFindDecoderCluster] = {}
self.odd_cluster_roots: List[int] = []
self.use_peeling = use_peeling
self.use_is_cluster_neutral = use_is_cluster_neutral
self._clusters4peeling = []
self.growth_unit = growth_unit
self._growth_unit = None

def process(self, string: str, predecoder=None):
"""
Expand All @@ -324,7 +332,6 @@ def process(self, string: str, predecoder=None):
"""

if self.use_peeling:
self.graph = deepcopy(self.decoding_graph.graph)
highlighted_nodes = self.code.string2nodes(string, all_logicals=True)
if predecoder:
highlighted_nodes = predecoder(highlighted_nodes)
Expand Down Expand Up @@ -377,6 +384,11 @@ def cluster(self, nodes: List):
the given node as keys and an integer specifying their cluster as the corresponding
value.
"""
if self.growth_unit:
self._growth_unit = self.growth_unit
else:
self._growth_unit = 0

node_indices = [self.decoding_graph.node_index(node) for node in nodes]
for node_index in self.graph.node_indexes():
self.graph[node_index].properties["syndrome"] = node_index in node_indices
Expand Down Expand Up @@ -439,7 +451,7 @@ def _create_new_cluster(self, node_index):
self.odd_cluster_roots.insert(0, node_index)
boundary_edges = []
for edge_index, neighbour, data in self.neighbouring_edges(node_index):
boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, data))
boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, copy(data)))
self.clusters[node_index] = UnionFindDecoderCluster(
boundary=boundary_edges,
fully_grown_edges=set(),
Expand All @@ -452,7 +464,7 @@ def _create_new_cluster(self, node_index):

def _grow_and_merge_clusters(self) -> Set[int]:
fusion_edge_list = self._grow_clusters()
return self._merge_clusters(fusion_edge_list)
self._merge_clusters(fusion_edge_list)

def _grow_clusters(self) -> List[FusionEntry]:
"""
Expand All @@ -463,10 +475,18 @@ def _grow_clusters(self) -> List[FusionEntry]:
clusters that will be merged in the next step.
"""
fusion_edge_list: List[FusionEntry] = []

if not self.growth_unit:
min_weight = np.inf
for root in self.odd_cluster_roots:
cluster = self.clusters[root]
for edge in cluster.boundary:
min_weight = max(min(min_weight, edge.data.weight), 1e-6)
self._growth_unit = min_weight / 2
for root in self.odd_cluster_roots:
cluster = self.clusters[root]
for edge in cluster.boundary:
edge.data.properties["growth"] += 0.5
edge.data.properties["growth"] += self._growth_unit
if (
edge.data.properties["growth"] >= edge.data.weight
and not edge.data.properties["fully_grown"]
Expand Down Expand Up @@ -508,9 +528,8 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]):
Args:
fusion_edge_list (List[FusionEntry]): List of edges that connect two
clusters that was computed in _grow_clusters().
Returns:
new_neutral_cluster_roots (List[int]): List of roots of newly neutral clusters
"""

new_neutral_clusters = []
for entry in fusion_edge_list:
root_u, root_v = self.find(entry.u), self.find(entry.v)
Expand Down Expand Up @@ -546,22 +565,27 @@ def _merge_clusters(self, fusion_edge_list: List[FusionEntry]):

# see if the cluster is neutral and update odd_cluster_roots accordingly
fully_neutral = False
for nodes in [
[self.graph[node] for node in cluster.atypical_nodes],
[
self.graph[node]
for node in cluster.atypical_nodes
| (set(list(cluster.boundary_nodes)[:1]) if cluster.boundary_nodes else set())
],
]:
if self.use_is_cluster_neutral:
fully_neutral = self.code.is_cluster_neutral(nodes)
else:
neutral, extras, num = self.code.check_nodes(nodes)
for node in extras:
neutral = neutral and (not node.is_boundary)
neutral = neutral and num <= len(cluster.edge_support)
fully_neutral = fully_neutral or neutral
if self._growth_unit: # assume non-neutral while growing along 0-weight edges
for nodes in [
[self.graph[node] for node in cluster.atypical_nodes],
[
self.graph[node]
for node in cluster.atypical_nodes
| (
set(list(cluster.boundary_nodes)[:1])
if cluster.boundary_nodes
else set()
)
],
]:
if self.use_is_cluster_neutral:
fully_neutral = self.code.is_cluster_neutral(nodes)
else:
neutral, extras, num = self.code.check_nodes(nodes)
for node in extras:
neutral = neutral and (not node.is_boundary)
neutral = neutral and num <= len(cluster.edge_support)
fully_neutral = fully_neutral or neutral
if fully_neutral:
if new_root in self.odd_cluster_roots:
self.odd_cluster_roots.remove(new_root)
Expand Down
22 changes: 10 additions & 12 deletions src/qiskit_qec/decoders/pymatching_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,17 @@ def matched_edges(
or a list of binaries indicating which node is highlighted.
Returns: list of DecodingGraphEdge-s included in the matching
"""
if isinstance(syndrome[0], DecodingGraphNode):
syndrome = self.nodes_to_detections(syndrome)
edge_dets = list(self.graph.edge_list())
edges = self.graph.edges()
matched_det_pairs = self.matcher.decode_to_edges_array(syndrome)
det_pairs = []
for pair in matched_det_pairs:
if pair[1] == -1:
pair[-1] = pair[-1] + len(self.graph.nodes())
pair.sort()
det_pairs.append(tuple(pair))
mached_edges = [edges[edge_dets.index(det_pair)] for det_pair in det_pairs]
return mached_edges
if syndrome:
if isinstance(syndrome[0], DecodingGraphNode):
syndrome = self.nodes_to_detections(syndrome)
matched_det_pairs = self.matcher.decode_to_edges_array(syndrome)
for pair in matched_det_pairs:
if pair[1] == -1:
pair[-1] = pair[-1] + len(self.graph.nodes())
pair.sort()
det_pairs.append(tuple(pair))
return det_pairs

def nodes_to_detections(self, syndrome_nodes: List[DecodingGraphNode]) -> List[int]:
"""Converts nodes to detector indices to be used by pymatching.Matching.decode"""
Expand Down
Loading