From 433c7b3a6acbe2ea4b2fff07415eef9fbae41f97 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Fri, 2 Jun 2023 16:06:00 +0200 Subject: [PATCH] 370 speedup (#371) * use sets rather than lists * speed up check_nodes * speed up check nodes * add lower bound errors for flattened nodes * undo flattened node count * reinstate flatten nodes * expand UF tests * use higher defaulty value for 202s * fix tests --- src/qiskit_qec/circuits/repetition_code.py | 135 +++++++++++++-------- src/qiskit_qec/decoders/decoding_graph.py | 4 +- src/qiskit_qec/decoders/hdrg_decoders.py | 67 +++++----- test/code_circuits/test_rep_codes.py | 49 ++++---- 4 files changed, 147 insertions(+), 108 deletions(-) diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index d9e4a018..04b64cb0 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -17,7 +17,7 @@ """Generates circuits based on repetition codes.""" from typing import List, Optional, Tuple -from copy import copy, deepcopy +from copy import deepcopy import numpy as np import rustworkx as rx @@ -349,32 +349,6 @@ def string2raw_logicals(self, string): """ return _separate_string(self._process_string(string))[0] - @staticmethod - def flatten_nodes(nodes: List[DecodingGraphNode]): - """ - Removes time information from a set of nodes, and consolidates those on - the same position at different times. - Args: - nodes (list): List of nodes, of the type produced by `string2nodes`, to be flattened. - Returns: - flat_nodes (list): List of flattened nodes. - """ - nodes_per_link = {} - for node in nodes: - link_qubit = node.properties["link qubit"] - if link_qubit in nodes_per_link: - nodes_per_link[link_qubit] += 1 - else: - nodes_per_link[link_qubit] = 1 - flat_nodes = [] - for node in nodes: - if nodes_per_link[node.properties["link qubit"]] % 2: - flat_node = copy(node) - flat_node.time = None - if flat_node not in flat_nodes: - flat_nodes.append(flat_node) - return flat_nodes - def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also @@ -541,7 +515,7 @@ def __init__( max_dist: int = 2, schedule: Optional[list] = None, run_202: bool = True, - rounds_per_202: int = 7, + rounds_per_202: int = 9, conditional_reset: bool = False, ): """ @@ -567,7 +541,8 @@ def __init__( run_202 (bool): Whether to run [[2,0,2]] sequences. This will be overwritten if T is not high enough (at least rounds_per_202xlen(links)). rounds_per_202 (int): Number of rounds that are part of the 202, including the typical link - measurements at the beginning and edge. At least 5 are required to detect conjugate errors. + measurements at the beginning and edge. At least 9 are required to get an event dedicated to + conjugate errors. conditional_reset: Whether to apply conditional resets (an x conditioned on the result of the previous measurement), rather than a reset gate. """ @@ -591,6 +566,7 @@ def __init__( self._scheduling() else: self.schedule = schedule + self._get_cycles() self._preparation() # determine the placement of [2,0,2] rounds @@ -627,7 +603,6 @@ def __init__( self._readout() def _get_link_graph(self, max_dist=1): - # FIXME: Migrate link graph to new Edge type graph = rx.PyGraph() for link in self.links: add_edge(graph, (link[0], link[2]), {"distance": 1, "link qubit": link[1]}) @@ -642,6 +617,30 @@ def _get_link_graph(self, max_dist=1): add_edge(graph, (node0, node1), {"distance": dist}) return graph + def _get_cycles(self): + """ + For each edge in the link graph (expressed in terms of the pair of qubits), the + set of qubits around adjacent cycles is found. + """ + + link_graph = self._get_link_graph() + lg_edges = set(link_graph.edge_list()) + lg_nodes = link_graph.nodes() + cycles = rx.cycle_basis(link_graph) + cycle_dict = {(lg_nodes[edge[0]], lg_nodes[edge[1]]): list(edge) for edge in lg_edges} + for cycle in cycles: + edges = [] + cl = len(cycle) + for j in range(cl): + for edge in [(cycle[j], cycle[(j + 1) % cl]), (cycle[(j + 1) % cl], cycle[j])]: + if edge in lg_edges: + edges.append((lg_nodes[edge[0]], lg_nodes[edge[1]])) + for edge in edges: + cycle_dict[edge] += cycle + for edge, ns in cycle_dict.items(): + cycle_dict[edge] = set(ns) + self.cycle_dict = cycle_dict + def _coloring(self): """ Creates a graph with a weight=1 edge for each link, and additional edges up to `max_weight`. @@ -1159,6 +1158,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): num_errors (int): Minimum number of errors required to create nodes. """ + nodes = self.flatten_nodes(nodes) + # see which qubits for logical zs are given and collect bulk nodes given_logicals = [] bulk_nodes = [] @@ -1171,16 +1172,30 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): # see whether the bulk nodes are neutral if bulk_nodes: - nodes = self.flatten_nodes(nodes) + # bicolor the nodes of the link graph, such that node edges connect unlike edges link_qubits = set(node.properties["link qubit"] for node in nodes) - node_color = {0: 0} - base_neutral = True link_graph = self._get_link_graph() - ns_to_do = set(n for n in range(1, len(link_graph.nodes()))) - while ns_to_do and base_neutral: - # go through all coloured nodes + # all the qubits around cycles of the node edges have to be covered + ns_to_do = set() + for edge in [tuple(node.qubits) for node in bulk_nodes]: + ns_to_do = ns_to_do.union(self.cycle_dict[edge]) + # start with one of these + if ns_to_do: + n = ns_to_do.pop() + else: + n = 0 + node_color = {n: 0} + recently_colored = node_color.copy() + base_neutral = True + # count the number of nodes for each colour throughout + num_nodes = [1, 0] + last_num = [None, None] + fully_converged = False + last_converged = False + while base_neutral and not fully_converged: + # go through all nodes coloured in the last pass newly_colored = {} - for n, c in node_color.items(): + for n, c in recently_colored.items(): # look at all the code qubits that are neighbours incident_es = link_graph.incident_edges(n) for e in incident_es: @@ -1195,30 +1210,52 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): # if the neighbour is not yet coloured, colour it # different color if edge is given node, same otherwise if nn not in node_color: - newly_colored[nn] = (c + dc) % 2 + new_c = (c + dc) % 2 + newly_colored[nn] = new_c + num_nodes[new_c] += 1 # if it is coloured, check the colour is correct else: base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2) for nn, c in newly_colored.items(): node_color[nn] = c - ns_to_do.remove(nn) + if nn in ns_to_do: + ns_to_do.remove(nn) + recently_colored = newly_colored.copy() + # process is converged once one colour has stoppped growing + # once ns_to_do is empty + converged = (not ns_to_do) and ( + (num_nodes[0] == last_num[0] != 0) or (num_nodes[1] == last_num[1] != 0) + ) + fully_converged = converged and last_converged + if not fully_converged: + last_num = num_nodes.copy() + last_converged = converged + # see how many qubits are in the converged colour, and determine the min colour + for c in range(2): + if num_nodes[c] == last_num[c]: + conv_color = c + if num_nodes[conv_color] <= self.d / 2: + min_color = conv_color + else: + min_color = (conv_color + 1) % 2 + # calculate the number of nodes for the other + num_nodes[(min_color + 1) % 2] = link_graph.num_nodes() - num_nodes[min_color] + # get the set of min nodes + min_ns = set() + for n, c in node_color.items(): + if c == min_color: + min_ns.add(n) # see which qubits for logical zs are needed flipped_logicals_all = [[], []] if base_neutral: - for inside_c in range(2): - for n, c in node_color.items(): - qubit = link_graph.nodes()[n] - if qubit in self.z_logicals and c == inside_c: - flipped_logicals_all[int(inside_c)].append(qubit) + for qubit in self.z_logicals: + n = link_graph.nodes().index(qubit) + dc = not n in min_ns + flipped_logicals_all[(min_color + dc) % 2].append(qubit) for j in range(2): flipped_logicals_all[j] = set(flipped_logicals_all[j]) - # count the number of nodes for each colour - num_nodes = [0, 0] - for n, c in node_color.items(): - num_nodes[c] += 1 - # list the colours with the max error one first # (unless we do min only) min_color = int(sum(node_color.values()) < len(node_color) / 2) diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index 084c817d..774cf361 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -161,7 +161,7 @@ def get_error_probs( for string in counts: # list of i for which v_i=1 - error_nodes = self.code.string2nodes(string, logical=logical) + error_nodes = set(self.code.string2nodes(string, logical=logical)) for node0 in error_nodes: n0 = self.graph.nodes().index(node0) @@ -222,7 +222,7 @@ def get_error_probs( for edge in self.graph.edge_list() } for string in counts: - error_nodes = self.code.string2nodes(string, logical=logical) + error_nodes = set(self.code.string2nodes(string, logical=logical)) for edge in self.graph.edge_list(): element = "" for j in range(2): diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 67773a40..60a788b7 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -304,15 +304,12 @@ class UnionFindDecoder(ClusteringDecoder): See arXiv:1709.06218v3 for more details. """ - def __init__( - self, - code, - decoding_graph: DecodingGraph = None, - ) -> None: + def __init__(self, code, decoding_graph: DecodingGraph = None, use_peeling=True) -> None: super().__init__(code, decoding_graph=decoding_graph) self.graph = deepcopy(self.decoding_graph.graph) self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots: List[int] = [] + self.use_peeling = use_peeling self._clusters4peeling = [] def process(self, string: str): @@ -327,34 +324,40 @@ def process(self, string: str): measurement, corresponding to the logical operators of self.z_logicals. """ - self.graph = deepcopy(self.decoding_graph.graph) - highlighted_nodes = self.code.string2nodes(string, all_logicals=True) - - # call cluster to do the clustering, but actually use the peeling form - self.cluster(highlighted_nodes) - clusters = self._clusters4peeling - # determine the net logical z - net_z_logicals = {tuple(z_logical): 0 for z_logical in self.measured_logicals} - for cluster_nodes, _ in clusters: - erasure = self.graph.subgraph(cluster_nodes) - flipped_qubits = self.peeling(erasure) - for qubit_to_be_corrected in flipped_qubits: - for z_logical in net_z_logicals: - if qubit_to_be_corrected in z_logical: - net_z_logicals[z_logical] += 1 - for z_logical, num in net_z_logicals.items(): - net_z_logicals[z_logical] = num % 2 - - # apply this to the raw readout - corrected_z_logicals = [] - raw_logicals = self.code.string2raw_logicals(string) - for j, z_logical in enumerate(self.measured_logicals): - raw_logical = int(raw_logicals[j]) - corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2 - corrected_z_logicals.append(corrected_logical) - - return corrected_z_logicals + if self.use_peeling: + self.graph = deepcopy(self.decoding_graph.graph) + highlighted_nodes = self.code.string2nodes(string, all_logicals=True) + + # call cluster to do the clustering, but actually use the peeling form + self.cluster(highlighted_nodes) + clusters = self._clusters4peeling + + # determine the net logical z + net_z_logicals = {tuple(z_logical): 0 for z_logical in self.measured_logicals} + for cluster_nodes, _ in clusters: + erasure = self.graph.subgraph(cluster_nodes) + flipped_qubits = self.peeling(erasure) + for qubit_to_be_corrected in flipped_qubits: + for z_logical in net_z_logicals: + if qubit_to_be_corrected in z_logical: + net_z_logicals[z_logical] += 1 + for z_logical, num in net_z_logicals.items(): + net_z_logicals[z_logical] = num % 2 + + # apply this to the raw readout + corrected_z_logicals = [] + raw_logicals = self.code.string2raw_logicals(string) + for j, z_logical in enumerate(self.measured_logicals): + raw_logical = int(raw_logicals[j]) + corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2 + corrected_z_logicals.append(corrected_logical) + return corrected_z_logicals + else: + # turn string into nodes and cluster + nodes = self.code.string2nodes(string, all_logicals=True) + clusters = self.cluster(nodes) + return self.get_corrections(string, clusters) def cluster(self, nodes): """ diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 6b72fc89..763fe70d 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -306,7 +306,7 @@ def test_202s(self): T = 15 # first, do they appear when needed for run_202 in [True, False]: - code = ArcCircuit(links, T=T, run_202=run_202) + code = ArcCircuit(links, T=T, run_202=run_202, rounds_per_202=5) running_202 = False for t in range(T): tau, _, _ = code._get_202(t) @@ -318,7 +318,7 @@ def test_202s(self): + "Error: [[2,0,2]] codes present when not required." * (not run_202), ) # second, do they yield non-trivial outputs yet trivial nodes - code = ArcCircuit(links, T=T, run_202=True, logical="1") + code = ArcCircuit(links, T=T, run_202=True, logical="1", rounds_per_202=5) backend = Aer.get_backend("aer_simulator") counts = backend.run(code.circuit[code.basis]).result().get_counts() self.assertTrue(len(counts) > 1, "No randomness in the results for [[2,0,2]] circuits.") @@ -331,7 +331,7 @@ def test_single_error_202s(self): """Test a range of single errors for a code with [[2,0,2]] codes.""" links = [(0, 1, 2), (2, 3, 4), (4, 5, 0), (2, 7, 6)] for T in [21, 25]: - code = ArcCircuit(links, T, run_202=True, barriers=True, logical="1") + code = ArcCircuit(links, T, run_202=True, barriers=True, logical="1", rounds_per_202=5) assert code.run_202 # insert errors on a selection of qubits during a selection of rounds qc = code.circuit[code.base] @@ -516,31 +516,30 @@ def clustering_decoder_test( links = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 1)] codes.append(ArcCircuit(links, 0)) # then make a bunch of non-linear ARCs - # TODO: make these work for union find too - if Decoder is not UnionFindDecoder: - # crossed line - links_cross = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 2)] - links_cross.append((2 * (d - 2), 2 * (d - 2) + 1, 2 * (int(d / 2)))) - links_cross.append(((2 * (int(d / 2))), 2 * (d - 1), 2 * (d - 1) + 1)) - # ladder (works for even d) - half_d = int(d / 2) - links_ladder = [] - for row in [0, 1]: - for j in range(half_d - 1): - delta = row * (2 * half_d - 1) - links_ladder.append((delta + 2 * j, delta + 2 * j + 1, delta + 2 * (j + 1))) - q = links_ladder[-1][2] + 1 - for j in range(half_d): - delta = 2 * half_d - 1 - links_ladder.append((2 * j, q, delta + 2 * j)) - q += 1 - # add them to the code list - for links in [links_ladder, links_cross]: - codes.append(ArcCircuit(links, 0)) + links_cross = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 2)] + links_cross.append((2 * (d - 2), 2 * (d - 2) + 1, 2 * (int(d / 2)))) + links_cross.append(((2 * (int(d / 2))), 2 * (d - 1), 2 * (d - 1) + 1)) + codes.append(ArcCircuit(links_cross, 0)) + # ladder (works for even d) + half_d = int(d / 2) + links_ladder = [] + for row in [0, 1]: + for j in range(half_d - 1): + delta = row * (2 * half_d - 1) + links_ladder.append((delta + 2 * j, delta + 2 * j + 1, delta + 2 * (j + 1))) + q = links_ladder[-1][2] + 1 + for j in range(half_d): + delta = 2 * half_d - 1 + links_ladder.append((2 * j, q, delta + 2 * j)) + q += 1 + codes.append(ArcCircuit(links_ladder, 0)) # now run them all and check it works for c, code in enumerate(codes): decoding_graph = DecodingGraph(code) - decoder = Decoder(code, decoding_graph=decoding_graph) + if c == 3 and Decoder is UnionFindDecoder: + decoder = Decoder(code, decoding_graph=decoding_graph, use_peeling=False) + else: + decoder = Decoder(code, decoding_graph=decoding_graph) errors = {z_logical[0]: 0 for z_logical in decoder.measured_logicals} min_error_num = code.d min_error_string = ""