From c9cf30de237986f76ab30b19ba735067c845b7e7 Mon Sep 17 00:00:00 2001 From: James Wootton Date: Fri, 19 Jan 2024 19:57:01 +0100 Subject: [PATCH] add predecoders (#410) * add predecoders * speed up decoding for linear arcs --------- Co-authored-by: grace-harper <119029214+grace-harper@users.noreply.github.com> --- src/qiskit_qec/circuits/repetition_code.py | 73 ++++++- src/qiskit_qec/decoders/__init__.py | 2 +- src/qiskit_qec/decoders/decoding_graph.py | 97 ++++++++- src/qiskit_qec/decoders/hdrg_decoders.py | 197 ++---------------- test/code_circuits/test_rep_codes.py | 5 +- test/{union_find => decoding}/__init__.py | 4 +- test/decoding/test_cleaner.py | 85 ++++++++ .../test_union_find.py | 0 test/union_find/test_clayg.py | 177 ---------------- 9 files changed, 276 insertions(+), 364 deletions(-) rename test/{union_find => decoding}/__init__.py (87%) create mode 100644 test/decoding/test_cleaner.py rename test/{union_find => decoding}/test_union_find.py (100%) delete mode 100644 test/union_find/test_clayg.py diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 4ae904d2..f507b1da 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -631,6 +631,11 @@ def _get_cycles(self): """ self.link_graph = self._get_link_graph() + self.degree = {} + for n, q in enumerate(self.link_graph.nodes()): + self.degree[q] = self.link_graph.degree(n) + degrees = list(self.degree.values()) + self._linear = degrees.count(1) == 2 and degrees.count(2) == len(degrees) - 2 lg_edges = set(self.link_graph.edge_list()) lg_nodes = self.link_graph.nodes() ng = nx.Graph() @@ -1357,8 +1362,11 @@ def is_cluster_neutral(self, atypical_nodes: dict): Args: atypical_nodes: dictionary in the form of the return value of string2nodes """ - neutral, logicals, _ = self.check_nodes(atypical_nodes) - return neutral and not logicals + if self._linear: + return not bool(len(atypical_nodes) % 2) + else: + neutral, logicals, _ = self.check_nodes(atypical_nodes) + return neutral and not logicals def transpile(self, backend, echo=("X", "X"), echo_num=(2, 0)): """ @@ -1649,3 +1657,64 @@ def get_error_coords( return error_coords, sample_coords else: return error_coords + + def clean_code(self, string): + """ + Given an output string of the code, obvious code qubit errors are identified and their effects + are removed. + + Args: + string (str): Output string of the code. + + Returns: + string (str): Modifed output string of the code. + """ + + # get the parities for the rounds and turn them into lists of integers + # (also turn them the right way around) + parities = [] + for rstring in string.split(" ")[1:]: + parities.append([int(p) for p in rstring][::-1]) + parities = parities[::-1] + + # calculate the final parities from the final readout and add them on + final = string.split(" ")[0] + final_parities = [0] * self.num_qubits[1] + for c0, a, c1 in self.links: + final_parities[-self.link_index[a] - 1] = ( + int(final[-self.code_index[c0] - 1]) + int(final[-self.code_index[c1] - 1]) + ) % 2 + parities.append(final_parities[::-1]) + + flips = {c: 0 for c in self.code_index} + for rparities in parities: + # see how many links around each code qubit detect a flip + link_count = {c: 0 for c in self.code_index} + for c0, a, c1 in self.links: + # we'll need to determine whether the as yet uncorrected parity + # checks from this round should be flipped, based on results + # from previous rounds + flip = (flips[c0] + flips[c1]) % 2 + b = self.link_index[a] + for c in [c0, c1]: + link_count[c] += (rparities[b] + flip) % 2 + # if it's all of them, assume a flip + for c in link_count: + if link_count[c] == self.degree[c]: + flips[c] = (flips[c] + 1) % 2 + # modify the parities to remove the effect + for c0, a, c1 in self.links: + flip = (flips[c0] + flips[c1]) % 2 + b = self.link_index[a] + rparities[b] = (rparities[b] + flip) % 2 + # turn the results back into a string + new_string = "" + for rparities in parities[:-1][::-1]: + new_string += " " + "".join([str(p) for p in rparities][::-1]) + final_string = [int(p) for p in string.split(" ", maxsplit=1)[0]] + for c, flip in flips.items(): + b = self.code_index[c] + final_string[-b - 1] = (final_string[-b - 1] + flip) % 2 + final_string = "".join([str(p) for p in final_string]) + + return final_string + new_string diff --git a/src/qiskit_qec/decoders/__init__.py b/src/qiskit_qec/decoders/__init__.py index ee3a9e0d..71262f42 100644 --- a/src/qiskit_qec/decoders/__init__.py +++ b/src/qiskit_qec/decoders/__init__.py @@ -34,4 +34,4 @@ from .circuit_matching_decoder import CircuitModelMatchingDecoder from .repetition_decoder import RepetitionDecoder from .three_bit_decoder import ThreeBitDecoder -from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder, ClAYGDecoder +from .hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index 28ac6767..68c2293f 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -66,6 +66,47 @@ def __init__(self, code, brute=False, graph=None): if node.is_boundary: self._logical_nodes.append(node) + self.update_attributes() + + def update_attributes(self): + """ + Calculates properties of the graph used by `node_index` and `edge_in_graph`. + If `graph` is updated this method should called to update these properties. + """ + self._edge_set = set(self.graph.edge_list()) + self._node_index = {} + for n, node in enumerate(self.graph.nodes()): + clean_node = copy.deepcopy(node) + clean_node.properties = {} + self._node_index[clean_node] = n + + def node_index(self, node): + """ + Given a node of `graph`, returns the corrsponding index. + + Args: + node (DecodingGraphNode): Node of the graph. + + Returns: + n (int): Index corresponding to the node within the graph. + """ + clean_node = copy.deepcopy(node) + clean_node.properties = {} + return self._node_index[clean_node] + + def edge_in_graph(self, edge): + """ + Given a pair of node indices for `graph`, determines whether + the edge exists within the graph. + + Args: + edge (tuple): Pair of node indices for the graph. + + Returns: + in_graph (bool): Whether the edge is within the graph. + """ + return edge in self._edge_set + def _make_syndrome_graph(self): if not self.brute and hasattr(self.code, "_make_syndrome_graph"): self.graph, self.hyperedges = self.code._make_syndrome_graph() @@ -170,7 +211,7 @@ def get_error_probs( error_nodes = set(self.code.string2nodes(string, logical=logical)) for node0 in error_nodes: - n0 = self.graph.nodes().index(node0) + n0 = self.node_index(node0) av_v[n0] += counts[string] for n1 in neighbours[n0]: node1 = self.graph[n1] @@ -341,8 +382,8 @@ def weight_fn(edge): source = E[source_index] target = E[target_index] if target != source: - ns = self.graph.nodes().index(source) - nt = self.graph.nodes().index(target) + ns = self.node_index(source) + nt = self.node_index(target) distance = distance_matrix[ns][nt] if np.isfinite(distance): qubits = list(set(source.qubits).intersection(target.qubits)) @@ -350,6 +391,56 @@ def weight_fn(edge): E.add_edge(source_index, target_index, DecodingGraphEdge(qubits, distance)) return E + def clean_measurements(self, nodes: List): + """ + Removes pairs of nodes that obviously correspond to measurement errors + from a list of nodes. + + Args: + nodes: A list of nodes. + Returns: + nodes: The input list of nodes, with pairs removed if they obviously + correspond to a measurement error. + + """ + + # order the nodes by where and when + node_pos = {} + for node in nodes: + if not node.is_boundary: + if node.index not in node_pos: + node_pos[node.index] = {} + node_pos[node.index][node.time] = self.node_index(node) + # find pairs corresponding to time-like edges + all_pairs = set() + for node_times in node_pos.values(): + ts = list(node_times.keys()) + ts.sort() + for j in range(len(ts) - 1): + if ts[j + 1] - ts[j] <= 2: + n0 = node_times[ts[j]] + n1 = node_times[ts[j + 1]] + if self.edge_in_graph((n0, n1)) or self.edge_in_graph((n1, n0)): + all_pairs.add((n0, n1)) + # filter out those that share nodes + all_nodes = set() + common_nodes = set() + for pair in all_pairs: + for n in pair: + if n in all_nodes: + common_nodes.add(n) + all_nodes.add(n) + paired_ns = set() + for pair in all_pairs: + if pair[0] not in common_nodes: + if pair[1] not in common_nodes: + for n in pair: + paired_ns.add(n) + # return the nodes that were not paired + ns = set(self.node_index(node) for node in nodes) + unpaired_ns = ns.difference(paired_ns) + return [self.graph.nodes()[n] for n in unpaired_ns] + class CSSDecodingGraph: """ diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index 3f1c2bbe..4d035f54 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -185,8 +185,8 @@ def cluster(self, nodes): # get indices for nodes and boundary nodes dg = self.decoding_graph.graph - ns = set(dg.nodes().index(node) for node in nodes) - bns = set(dg.nodes().index(node) for node in self._get_boundary_nodes()) + ns = set(self.decoding_graph.node_index(node) for node in nodes) + bns = set(self.decoding_graph.node_index(node) for node in self._get_boundary_nodes()) dist_max = 0 final_clusters = {} @@ -212,12 +212,16 @@ def cluster(self, nodes): return final_clusters - def process(self, string): + def process(self, string, predecoder=None): """ Process an output string and return corrected final outcomes. Args: string (str): Output string of the code. + predecoder (callable): Function that takes in and returns + a list of nodes. Used to do preprocessing on the nodes + corresponding to the input string. + Returns: corrected_logicals (list): A list of integers that are 0 or 1. These are the corrected values of the final transversal @@ -227,6 +231,10 @@ def process(self, string): # turn string into nodes and cluster nodes = self.code.string2nodes(string, all_logicals=True) + # apply predecoder if one is given + if predecoder: + nodes = predecoder(nodes) + # then cluster clusters = self.cluster(nodes) return self.get_corrections(string, clusters) @@ -310,12 +318,15 @@ def __init__(self, code, decoding_graph: DecodingGraph = None, use_peeling=True) self.use_peeling = use_peeling self._clusters4peeling = [] - def process(self, string: str): + def process(self, string: str, predecoder=None): """ Process an output string and return corrected final outcomes. Args: string (str): Output string of the code. + predecoder (callable): Function that takes in and returns + a list of nodes. Used to do preprocessing on the nodes + corresponding to the input string. Returns: corrected_z_logicals (list): A list of integers that are 0 or 1. These are the corrected values of the final transversal @@ -326,6 +337,8 @@ def process(self, string: str): if self.use_peeling: self.graph = deepcopy(self.decoding_graph.graph) highlighted_nodes = self.code.string2nodes(string, all_logicals=True) + if predecoder: + highlighted_nodes = predecoder(highlighted_nodes) # call cluster to do the clustering, but actually use the peeling form self.cluster(highlighted_nodes) @@ -354,6 +367,8 @@ def process(self, string: str): else: # turn string into nodes and cluster nodes = self.code.string2nodes(string, all_logicals=True) + if predecoder: + nodes = predecoder(nodes) clusters = self.cluster(nodes) return self.get_corrections(string, clusters) @@ -370,7 +385,7 @@ def cluster(self, nodes: List): the given node as keys and an integer specifying their cluster as the corresponding value. """ - node_indices = [self.graph.nodes().index(node) for node in nodes] + node_indices = [self.decoding_graph.node_index(node) for node in nodes] for node_index, _ in enumerate(self.graph.nodes()): self.graph[node_index].properties["syndrome"] = node_index in node_indices self.graph[node_index].properties["root"] = node_index @@ -637,175 +652,3 @@ def neighbouring_edges(self, node_index) -> List[Tuple[int, int, DecodingGraphEd self.graph.incident_edge_index_map(node_index) ).items() ] - - -class ClAYGDecoder(UnionFindDecoder): - """ - Decoder that is very similar to the Union Find decoder, but instead of adding clusters all at once, - adds them separated by syndrome round with a growth and merge phase in between. - Then it just proceeds like the Union Find decoder. - - FIXME: Use the Union Find infrastructure and just change the self.cluster() method. Problem is that - the peeling decoder needs a modified version the graph with the syndrome nodes marked, which is done - in the process method. For now it is mostly its separate thing, but merging them shouldn't be - too big of a hassle. - Merge method should also be modified, as boundary clusters are not marked as odd clusters. - """ - - def __init__(self, code, decoding_graph: DecodingGraph = None) -> None: - super().__init__(code, decoding_graph) - self.graph = deepcopy(self.decoding_graph.graph) - self.r = 1 - self._clusters4peeling = [] - - def process(self, string: str): - """ - Process an output string and return corrected final outcomes. - Args: - string (str): Output string of the code. - Returns: - corrected_z_logicals (list): A list of integers that are 0 or 1. - These are the corrected values of the final transversal - measurement, corresponding to the logical operators of - self.z_logicals. - """ - - nodes_at_time_zero = [] - for index, node in zip( - self.decoding_graph.graph.node_indices(), self.decoding_graph.graph.nodes() - ): - if node.time == 0 or node.is_boundary: - nodes_at_time_zero.append(index) - self.graph = self.decoding_graph.graph.subgraph(nodes_at_time_zero) - for index, node in zip(self.graph.node_indices(), self.graph.nodes()): - node.properties["root"] = index - for edge in self.graph.edges(): - edge.properties["growth"] = 0 - edge.properties["fully_grown"] = False - - string = "".join([str(c) for c in string[::-1]]) - output = [int(bit) for bit in list(string.split(" ", maxsplit=self.code.d)[0])][::-1] - highlighted_nodes = self.code.string2nodes(string, all_logicals=True) - if not highlighted_nodes: - return output # There's nothing for us to do here - - self.cluster(highlighted_nodes) - clusters = self._clusters4peeling - - flattened_highlighted_nodes: List[DecodingGraphNode] = [] - for highlighted_node in highlighted_nodes: - highlighted_node.time = 0 - flattened_highlighted_nodes.append(self.graph.nodes().index(highlighted_node)) - - for cluster_nodes, cluster_atypical_nodes in clusters: - if not cluster_nodes: - continue - erasure_graph = deepcopy(self.graph) - for node in cluster_nodes: - erasure_graph[node].properties["syndrome"] = node in cluster_atypical_nodes - erasure = erasure_graph.subgraph(cluster_nodes) - qubits_to_be_corrected = self.peeling(erasure) - for idx in qubits_to_be_corrected: - output[idx] = (output[idx] + 1) % 2 - - return output - - def cluster(self, nodes): - """ - Args: - nodes (List): List of non-typical nodes in the syndrome graph, - of the type produced by `string2nodes`. - - Returns: - clusters (dict): Ddictionary with the indices of - the given node as keys and an integer specifying their cluster as the corresponding - value. - """ - self.clusters: Dict[int, UnionFindDecoderCluster] = {} - self.odd_cluster_roots = [] - - times: List[List[DecodingGraphNode]] = [[] for _ in range(self.code.T + 1)] - boundaries = [] - for node in deepcopy(nodes): - if node.is_boundary: - boundaries.append(node) - else: - times[node.time].append(node) - node.time = 0 - # FIXME: I am not sure when the optimal time to add the boundaries is. Maybe the middle? - # for node in boundaries: - times.insert(len(times) // 2, boundaries) - - neutral_clusters = [] - for time in times: - if not time: - continue - for node in time: - self._add_node(node) - neutral_clusters += self._collect_neutral_clusters() - for _ in range(self.r): - self._grow_and_merge_clusters() - neutral_clusters += self._collect_neutral_clusters() - - while self.odd_cluster_roots: - self._grow_and_merge_clusters() - - neutral_clusters += self._collect_neutral_clusters() - - # compile info into standard clusters dict - clusters = {} - for c, cluster in enumerate(neutral_clusters): - # determine which nodes exactly are in the neutral cluster - neutral_nodes = list(cluster.atypical_nodes | cluster.boundary_nodes) - # put them in the required dict - for n in neutral_nodes: - clusters[n] = c - - neutral_cluster_nodes: List[List[int]] = [] - for cluster in neutral_clusters: - neutral_cluster_nodes.append( - (list(cluster.nodes), list(cluster.atypical_nodes | cluster.boundary_nodes)) - ) - - self._clusters4peeling = neutral_cluster_nodes - - return neutral_cluster_nodes - - def _add_node(self, node): - node_index = self.graph.nodes().index(node) - root = self.find(node_index) - cluster = self.clusters.get(root) - if cluster and not node.is_boundary: - # Add the node to the cluster or remove it if it's already present - if node_index in cluster.atypical_nodes: - cluster.atypical_nodes.remove(node_index) - else: - cluster.atypical_nodes.add(node_index) - else: - self.graph[node_index].properties["root"] = node_index - self._create_new_cluster(node_index) - - def _collect_neutral_clusters(self): - neutral_clusters = [] - for root, cluster in self.clusters.copy().items(): - if self.code.is_cluster_neutral( - [ - self.graph[node] - for node in cluster.atypical_nodes - | (set([list(cluster.boundary_nodes)[0]]) if cluster.boundary_nodes else set()) - ] - ): - if root in self.odd_cluster_roots: - self.odd_cluster_roots.remove(root) - cluster = self.clusters.pop(root) - if cluster.atypical_nodes: - neutral_clusters.append(cluster) - for edge in cluster.fully_grown_edges: - self.graph.edges()[edge].properties["fully_grown"] = False - for edge in cluster.boundary: - self.graph.edges()[edge.index].properties["growth"] = 0 - for node in cluster.nodes: - if self.graph[node].is_boundary: - self._create_new_cluster(node) - self.graph[node].properties["root"] = node - return neutral_clusters diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index 7a76430c..8afb6fba 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -513,15 +513,18 @@ def clustering_decoder_test( N = 1000 # first an RCC - codes = [RepetitionCode(d, 1)] + codes = [] + codes.append(RepetitionCode(d, 1)) # then a linear ARC links = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 1)] codes.append(ArcCircuit(links, 0, logical="1")) + self.assertTrue(codes[-1]._linear, "Linear ARC not recognised as such") # then make a bunch of non-linear ARCs links_cross = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 2)] links_cross.append((2 * (d - 2), 2 * (d - 2) + 1, 2 * (int(d / 2)))) links_cross.append(((2 * (int(d / 2))), 2 * (d - 1), 2 * (d - 1) + 1)) codes.append(ArcCircuit(links_cross, 0)) + self.assertTrue(not codes[-1]._linear, "Non-inear ARC not recognised as such") # ladder (works for even d) half_d = int(d / 2) links_ladder = [] diff --git a/test/union_find/__init__.py b/test/decoding/__init__.py similarity index 87% rename from test/union_find/__init__.py rename to test/decoding/__init__.py index d7988275..428fe2e5 100644 --- a/test/union_find/__init__.py +++ b/test/decoding/__init__.py @@ -1,6 +1,6 @@ # This code is part of Qiskit. # -# (C) Copyright IBM 2017, 2019. +# (C) Copyright IBM 2019. # # This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory @@ -9,5 +9,3 @@ # Any modifications or derivative works of this code must retain this # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. - -"""Union Find Decoder Tests.""" diff --git a/test/decoding/test_cleaner.py b/test/decoding/test_cleaner.py new file mode 100644 index 00000000..716c7b9a --- /dev/null +++ b/test/decoding/test_cleaner.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name + +"""Run codes and decoders.""" + +import unittest + +from qiskit_qec.circuits import ArcCircuit +from qiskit_qec.decoders import DecodingGraph + + +class TestCleaner(unittest.TestCase): + """Test predecoders""" + + def test_measurement_cleaner(self): + """Test measurement cleaner""" + links = [(0, 1, 2), (2, 3, 4), (0, 5, 8), (2, 6, 10), (4, 7, 12), (8, 9, 10), (10, 11, 12)] + code = ArcCircuit(links, 3) + decoding_graph = DecodingGraph(code) + + # test that isolated measurement errors are removed + string = "010000 0000000 0000000 0100100" + cleaned_nodes = decoding_graph.clean_measurements(code.string2nodes(string)) + self.assertTrue( + len(cleaned_nodes) == 3, + "Wrong number of cleaned nodes for isolated measurement errors.", + ) + + # test that neighbouring ones aren't + string = "000000 0000100 0000000 0000100" + cleaned_nodes = decoding_graph.clean_measurements(code.string2nodes(string)) + self.assertTrue( + len(cleaned_nodes) == 4, + "Wrong number of cleaned nodes for neighbouring measurement errors.", + ) + + def test_code_cleaner(self): + """Test code cleaner""" + links = [(0, 1, 2), (2, 3, 4), (0, 5, 8), (2, 6, 10), (4, 7, 12), (8, 9, 10), (10, 11, 12)] + code = ArcCircuit(links, 2) + + error_string = "Single error handled incorrectly by code cleaner" + # test that single errors are corrected + for j in range(6): + string = "1" * j + "0" + "1" * (6 - j - 1) + " 0000000 0000000" + self.assertTrue(code.clean_code(string) == "111111 0000000 0000000", error_string) + test_strings = [ + "000010 0001011 0001011", + "000010 0001011 0000000", + "000001 0000101 0000101", + ] + for string in test_strings: + self.assertTrue(code.clean_code(string) == "000000 0000000 0000000", error_string) + self.assertTrue( + code.clean_code("000010 0001011 0001000") == "000000 0000000 0001000", error_string + ) + error_string = "code cleaer acts non-trivially on ambiguous syndrome" + # test syndromes that it shouldn't tackle + self.assertTrue( + code.clean_code("000000 0000000 1000001") == "000000 0000000 1000001", error_string + ) + self.assertTrue( + code.clean_code("000000 0000000 1000001") == "000000 0000000 1000001", error_string + ) + self.assertTrue( + code.clean_code("000000 0000011 0000011") == "000000 0000011 0000011", error_string + ) + # test a case where it causes a logical error + self.assertTrue( + code.clean_code("101010 0001011 0000000") == "000000 0000000 0000000", + "Code cleaner acts incorrectly on complex syndrome", + ) diff --git a/test/union_find/test_union_find.py b/test/decoding/test_union_find.py similarity index 100% rename from test/union_find/test_union_find.py rename to test/decoding/test_union_find.py diff --git a/test/union_find/test_clayg.py b/test/union_find/test_clayg.py deleted file mode 100644 index d6a2ed0c..00000000 --- a/test/union_find/test_clayg.py +++ /dev/null @@ -1,177 +0,0 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2023. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -"""Tests for template.""" - -import math -import random -import unittest - -from qiskit_qec.circuits import RepetitionCodeCircuit -from qiskit_qec.decoders import ClAYGDecoder -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel - - -def flip_with_probability(p, val): - """ - Flips parity of val with probability p. - """ - if random.random() <= p: - val = (val + 1) % 2 - return val - - -def noisy_surface_code_outcome(d, p): - """ - Generates outcome for surface code with phenomenological noise built in. - """ - string = "" - qubits = [0 for _ in range(d**2)] - for _ in range(d): - for qubit in qubits: - qubit = flip_with_probability(p, qubit) - # Top ancillas - for i in [2 * i for i in range((d - 1) // 2)]: - ancilla = (qubits[i] + qubits[i + 1]) % 2 - ancilla = flip_with_probability(p, ancilla) - string += str(ancilla) - for row in range(d - 1): - offset = (row + 1) % 2 - for topleft in [offset + row * d + 2 * i for i in range((d - 1) // 2)]: - ancilla = ( - qubits[topleft] - + qubits[topleft + 1] - + qubits[topleft + d] - + qubits[topleft + d + 1] - ) % 2 - ancilla = flip_with_probability(p, ancilla) - string += str(ancilla) - for i in [d * (d - 1) + 1 + 2 * i for i in range((d - 1) // 2)]: - ancilla = (qubits[i] + qubits[i + 1]) % 2 - ancilla = flip_with_probability(p, ancilla) - string += str(ancilla) - string += " " - for qubit in qubits: - qubit = flip_with_probability(p, qubit) - string += str(qubit) - return string - - -class ClAYGDecoderTest(unittest.TestCase): - """Tests will be here.""" - - def setUp(self) -> None: - # Bit-flip circuit noise model - p = 0.05 - noise_model = PauliNoiseModel() - noise_model.add_operation("cx", {"ix": 1, "xi": 1, "xx": 1}) - noise_model.add_operation("id", {"x": 1}) - noise_model.add_operation("reset", {"x": 1}) - noise_model.add_operation("measure", {"x": 1}) - noise_model.add_operation("x", {"x": 1, "y": 1, "z": 1}) - noise_model.set_error_probability("cx", p) - noise_model.set_error_probability("x", p) - noise_model.set_error_probability("id", p) - noise_model.set_error_probability("reset", p) - noise_model.set_error_probability("measure", p) - self.noise_model = noise_model - - self.fault_enumeration_method = "stabilizer" - - return super().setUp() - - def test_surface_code_d3(self): - """ - Test the ClAYG decoder on a surface code with d=3 and T=3 - with faults inserted by FaultEnumerator by checking if the syndromes - have even parity (if it's a valid code state) and if the logical value measured - is the one encoded by the circuit. - - This test won't complete atm, as the ClAYG decoder isn't able to decode some errors produced - by this error model, because of the placement of the boundary nodes and clusters neutralizing - when they shouldn't. - """ - return - - def test_repetition_code_d5(self): - """ - Test the ClAYG decoder on a repetition code with d=5 and T=5 - with faults inserted by FaultEnumerator by checking if the syndromes - have even parity (if it's a valid code state) and if the logical value measured - is the one encoded by the circuit. - - This test won't complete atm, as the ClAYG decoder isn't able to decode some errors produced - by this error model, because of the placement of the boundary nodes and clusters neutralizing - when they shouldn't. - """ - return - - def test_error_rates(self): - """ - Test the error rates using some repetition codes. - """ - d = 8 - p = 0.01 - samples = 1000 - - testcases = [] - testcases = [ - "".join([random.choices(["0", "1"], [1 - p, p])[0] for _ in range(d)]) - for _ in range(samples) - ] - codes = self.construct_codes(d) - - # now run them all and check it works - for code in codes: - decoder = ClAYGDecoder(code) - z_logicals = code.css_z_logical[0] - - logical_errors = 0 - min_flips_for_logical = code.d - for sample in range(samples): - # generate random string - string = "" - for _ in range(code.T): - string += "0" * (d - 1) + " " - string += testcases[sample] - # get and check corrected_z_logicals - outcome = decoder.process(string) - # pylint: disable=consider-using-generator - logical_outcome = sum([outcome[int(z_logical / 2)] for z_logical in z_logicals]) % 2 - if not logical_outcome == 0: - logical_errors += 1 - min_flips_for_logical = min(min_flips_for_logical, string.count("1")) - - # check that error rates are at least d/3 - self.assertTrue( - logical_errors / samples - < (math.factorial(d)) / (math.factorial(int(d / 2)) ** 2) * p**4, - "Logical error rate shouldn't exceed d!/((d/2)!^2)*p^(d/2).", - ) - self.assertTrue( - min_flips_for_logical >= d / 2, - "Minimum amount of errors that also causes logical errors shouldn't be lower than d/2.", - ) - - def construct_codes(self, d): - """ - Construct codes for the logical error rate test. - """ - codes = [] - # TODO: Add more codes - codes.append(RepetitionCodeCircuit(d=d, T=1)) - return codes - - -if __name__ == "__main__": - unittest.main()