Skip to content

Commit

Permalink
370 speedup (#371)
Browse files Browse the repository at this point in the history
* use sets rather than lists

* speed up check_nodes

* speed up check nodes

* add lower bound errors for flattened nodes

* undo flattened node count

* reinstate flatten nodes

* expand UF tests

* use higher defaulty value for 202s

* fix tests
  • Loading branch information
quantumjim authored Jun 2, 2023
1 parent daf0999 commit 433c7b3
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 108 deletions.
135 changes: 86 additions & 49 deletions src/qiskit_qec/circuits/repetition_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Generates circuits based on repetition codes."""
from typing import List, Optional, Tuple

from copy import copy, deepcopy
from copy import deepcopy
import numpy as np
import rustworkx as rx

Expand Down Expand Up @@ -349,32 +349,6 @@ def string2raw_logicals(self, string):
"""
return _separate_string(self._process_string(string))[0]

@staticmethod
def flatten_nodes(nodes: List[DecodingGraphNode]):
"""
Removes time information from a set of nodes, and consolidates those on
the same position at different times.
Args:
nodes (list): List of nodes, of the type produced by `string2nodes`, to be flattened.
Returns:
flat_nodes (list): List of flattened nodes.
"""
nodes_per_link = {}
for node in nodes:
link_qubit = node.properties["link qubit"]
if link_qubit in nodes_per_link:
nodes_per_link[link_qubit] += 1
else:
nodes_per_link[link_qubit] = 1
flat_nodes = []
for node in nodes:
if nodes_per_link[node.properties["link qubit"]] % 2:
flat_node = copy(node)
flat_node.time = None
if flat_node not in flat_nodes:
flat_nodes.append(flat_node)
return flat_nodes

def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
"""
Determines whether a given set of nodes are neutral. If so, also
Expand Down Expand Up @@ -541,7 +515,7 @@ def __init__(
max_dist: int = 2,
schedule: Optional[list] = None,
run_202: bool = True,
rounds_per_202: int = 7,
rounds_per_202: int = 9,
conditional_reset: bool = False,
):
"""
Expand All @@ -567,7 +541,8 @@ def __init__(
run_202 (bool): Whether to run [[2,0,2]] sequences. This will be overwritten if T is not high
enough (at least rounds_per_202xlen(links)).
rounds_per_202 (int): Number of rounds that are part of the 202, including the typical link
measurements at the beginning and edge. At least 5 are required to detect conjugate errors.
measurements at the beginning and edge. At least 9 are required to get an event dedicated to
conjugate errors.
conditional_reset: Whether to apply conditional resets (an x conditioned on the result of the
previous measurement), rather than a reset gate.
"""
Expand All @@ -591,6 +566,7 @@ def __init__(
self._scheduling()
else:
self.schedule = schedule
self._get_cycles()
self._preparation()

# determine the placement of [2,0,2] rounds
Expand Down Expand Up @@ -627,7 +603,6 @@ def __init__(
self._readout()

def _get_link_graph(self, max_dist=1):
# FIXME: Migrate link graph to new Edge type
graph = rx.PyGraph()
for link in self.links:
add_edge(graph, (link[0], link[2]), {"distance": 1, "link qubit": link[1]})
Expand All @@ -642,6 +617,30 @@ def _get_link_graph(self, max_dist=1):
add_edge(graph, (node0, node1), {"distance": dist})
return graph

def _get_cycles(self):
"""
For each edge in the link graph (expressed in terms of the pair of qubits), the
set of qubits around adjacent cycles is found.
"""

link_graph = self._get_link_graph()
lg_edges = set(link_graph.edge_list())
lg_nodes = link_graph.nodes()
cycles = rx.cycle_basis(link_graph)
cycle_dict = {(lg_nodes[edge[0]], lg_nodes[edge[1]]): list(edge) for edge in lg_edges}
for cycle in cycles:
edges = []
cl = len(cycle)
for j in range(cl):
for edge in [(cycle[j], cycle[(j + 1) % cl]), (cycle[(j + 1) % cl], cycle[j])]:
if edge in lg_edges:
edges.append((lg_nodes[edge[0]], lg_nodes[edge[1]]))
for edge in edges:
cycle_dict[edge] += cycle
for edge, ns in cycle_dict.items():
cycle_dict[edge] = set(ns)
self.cycle_dict = cycle_dict

def _coloring(self):
"""
Creates a graph with a weight=1 edge for each link, and additional edges up to `max_weight`.
Expand Down Expand Up @@ -1159,6 +1158,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
num_errors (int): Minimum number of errors required to create nodes.
"""

nodes = self.flatten_nodes(nodes)

# see which qubits for logical zs are given and collect bulk nodes
given_logicals = []
bulk_nodes = []
Expand All @@ -1171,16 +1172,30 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):

# see whether the bulk nodes are neutral
if bulk_nodes:
nodes = self.flatten_nodes(nodes)
# bicolor the nodes of the link graph, such that node edges connect unlike edges
link_qubits = set(node.properties["link qubit"] for node in nodes)
node_color = {0: 0}
base_neutral = True
link_graph = self._get_link_graph()
ns_to_do = set(n for n in range(1, len(link_graph.nodes())))
while ns_to_do and base_neutral:
# go through all coloured nodes
# all the qubits around cycles of the node edges have to be covered
ns_to_do = set()
for edge in [tuple(node.qubits) for node in bulk_nodes]:
ns_to_do = ns_to_do.union(self.cycle_dict[edge])
# start with one of these
if ns_to_do:
n = ns_to_do.pop()
else:
n = 0
node_color = {n: 0}
recently_colored = node_color.copy()
base_neutral = True
# count the number of nodes for each colour throughout
num_nodes = [1, 0]
last_num = [None, None]
fully_converged = False
last_converged = False
while base_neutral and not fully_converged:
# go through all nodes coloured in the last pass
newly_colored = {}
for n, c in node_color.items():
for n, c in recently_colored.items():
# look at all the code qubits that are neighbours
incident_es = link_graph.incident_edges(n)
for e in incident_es:
Expand All @@ -1195,30 +1210,52 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
# if the neighbour is not yet coloured, colour it
# different color if edge is given node, same otherwise
if nn not in node_color:
newly_colored[nn] = (c + dc) % 2
new_c = (c + dc) % 2
newly_colored[nn] = new_c
num_nodes[new_c] += 1
# if it is coloured, check the colour is correct
else:
base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2)
for nn, c in newly_colored.items():
node_color[nn] = c
ns_to_do.remove(nn)
if nn in ns_to_do:
ns_to_do.remove(nn)
recently_colored = newly_colored.copy()
# process is converged once one colour has stoppped growing
# once ns_to_do is empty
converged = (not ns_to_do) and (
(num_nodes[0] == last_num[0] != 0) or (num_nodes[1] == last_num[1] != 0)
)
fully_converged = converged and last_converged
if not fully_converged:
last_num = num_nodes.copy()
last_converged = converged
# see how many qubits are in the converged colour, and determine the min colour
for c in range(2):
if num_nodes[c] == last_num[c]:
conv_color = c
if num_nodes[conv_color] <= self.d / 2:
min_color = conv_color
else:
min_color = (conv_color + 1) % 2
# calculate the number of nodes for the other
num_nodes[(min_color + 1) % 2] = link_graph.num_nodes() - num_nodes[min_color]
# get the set of min nodes
min_ns = set()
for n, c in node_color.items():
if c == min_color:
min_ns.add(n)

# see which qubits for logical zs are needed
flipped_logicals_all = [[], []]
if base_neutral:
for inside_c in range(2):
for n, c in node_color.items():
qubit = link_graph.nodes()[n]
if qubit in self.z_logicals and c == inside_c:
flipped_logicals_all[int(inside_c)].append(qubit)
for qubit in self.z_logicals:
n = link_graph.nodes().index(qubit)
dc = not n in min_ns
flipped_logicals_all[(min_color + dc) % 2].append(qubit)
for j in range(2):
flipped_logicals_all[j] = set(flipped_logicals_all[j])

# count the number of nodes for each colour
num_nodes = [0, 0]
for n, c in node_color.items():
num_nodes[c] += 1

# list the colours with the max error one first
# (unless we do min only)
min_color = int(sum(node_color.values()) < len(node_color) / 2)
Expand Down
4 changes: 2 additions & 2 deletions src/qiskit_qec/decoders/decoding_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_error_probs(

for string in counts:
# list of i for which v_i=1
error_nodes = self.code.string2nodes(string, logical=logical)
error_nodes = set(self.code.string2nodes(string, logical=logical))

for node0 in error_nodes:
n0 = self.graph.nodes().index(node0)
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_error_probs(
for edge in self.graph.edge_list()
}
for string in counts:
error_nodes = self.code.string2nodes(string, logical=logical)
error_nodes = set(self.code.string2nodes(string, logical=logical))
for edge in self.graph.edge_list():
element = ""
for j in range(2):
Expand Down
67 changes: 35 additions & 32 deletions src/qiskit_qec/decoders/hdrg_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,12 @@ class UnionFindDecoder(ClusteringDecoder):
See arXiv:1709.06218v3 for more details.
"""

def __init__(
self,
code,
decoding_graph: DecodingGraph = None,
) -> None:
def __init__(self, code, decoding_graph: DecodingGraph = None, use_peeling=True) -> None:
super().__init__(code, decoding_graph=decoding_graph)
self.graph = deepcopy(self.decoding_graph.graph)
self.clusters: Dict[int, UnionFindDecoderCluster] = {}
self.odd_cluster_roots: List[int] = []
self.use_peeling = use_peeling
self._clusters4peeling = []

def process(self, string: str):
Expand All @@ -327,34 +324,40 @@ def process(self, string: str):
measurement, corresponding to the logical operators of
self.z_logicals.
"""
self.graph = deepcopy(self.decoding_graph.graph)
highlighted_nodes = self.code.string2nodes(string, all_logicals=True)

# call cluster to do the clustering, but actually use the peeling form
self.cluster(highlighted_nodes)
clusters = self._clusters4peeling

# determine the net logical z
net_z_logicals = {tuple(z_logical): 0 for z_logical in self.measured_logicals}
for cluster_nodes, _ in clusters:
erasure = self.graph.subgraph(cluster_nodes)
flipped_qubits = self.peeling(erasure)
for qubit_to_be_corrected in flipped_qubits:
for z_logical in net_z_logicals:
if qubit_to_be_corrected in z_logical:
net_z_logicals[z_logical] += 1
for z_logical, num in net_z_logicals.items():
net_z_logicals[z_logical] = num % 2

# apply this to the raw readout
corrected_z_logicals = []
raw_logicals = self.code.string2raw_logicals(string)
for j, z_logical in enumerate(self.measured_logicals):
raw_logical = int(raw_logicals[j])
corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2
corrected_z_logicals.append(corrected_logical)

return corrected_z_logicals
if self.use_peeling:
self.graph = deepcopy(self.decoding_graph.graph)
highlighted_nodes = self.code.string2nodes(string, all_logicals=True)

# call cluster to do the clustering, but actually use the peeling form
self.cluster(highlighted_nodes)
clusters = self._clusters4peeling

# determine the net logical z
net_z_logicals = {tuple(z_logical): 0 for z_logical in self.measured_logicals}
for cluster_nodes, _ in clusters:
erasure = self.graph.subgraph(cluster_nodes)
flipped_qubits = self.peeling(erasure)
for qubit_to_be_corrected in flipped_qubits:
for z_logical in net_z_logicals:
if qubit_to_be_corrected in z_logical:
net_z_logicals[z_logical] += 1
for z_logical, num in net_z_logicals.items():
net_z_logicals[z_logical] = num % 2

# apply this to the raw readout
corrected_z_logicals = []
raw_logicals = self.code.string2raw_logicals(string)
for j, z_logical in enumerate(self.measured_logicals):
raw_logical = int(raw_logicals[j])
corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2
corrected_z_logicals.append(corrected_logical)
return corrected_z_logicals
else:
# turn string into nodes and cluster
nodes = self.code.string2nodes(string, all_logicals=True)
clusters = self.cluster(nodes)
return self.get_corrections(string, clusters)

def cluster(self, nodes):
"""
Expand Down
Loading

0 comments on commit 433c7b3

Please sign in to comment.