diff --git a/src/qiskit_qec/circuits/code_circuit.py b/src/qiskit_qec/circuits/code_circuit.py index d0b6f489..b0a991b8 100644 --- a/src/qiskit_qec/circuits/code_circuit.py +++ b/src/qiskit_qec/circuits/code_circuit.py @@ -55,7 +55,7 @@ def string2nodes(self, string, **kwargs): pass @abstractmethod - def check_nodes(self, nodes, ignore_extra_boundary=False): + def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -65,6 +65,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): nodes (list): List of nodes, of the type produced by `string2nodes`. ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are ignored. + minimal (bool): Whether output should only reflect the minimal error + case. Returns: neutral (bool): Whether the nodes independently correspond to a valid set of errors. diff --git a/src/qiskit_qec/circuits/repetition_code.py b/src/qiskit_qec/circuits/repetition_code.py index 9620e155..d9e4a018 100644 --- a/src/qiskit_qec/circuits/repetition_code.py +++ b/src/qiskit_qec/circuits/repetition_code.py @@ -370,12 +370,12 @@ def flatten_nodes(nodes: List[DecodingGraphNode]): for node in nodes: if nodes_per_link[node.properties["link qubit"]] % 2: flat_node = copy(node) - # FIXME: Seems unsafe. flat_node.time = None - flat_nodes.append(flat_node) + if flat_node not in flat_nodes: + flat_nodes.append(flat_node) return flat_nodes - def check_nodes(self, nodes, ignore_extra_boundary=False): + def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -385,6 +385,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): nodes (list): List of nodes, of the type produced by `string2nodes`. ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are ignored. + minimal (bool): Whether output should only reflect the minimal error + case. Returns: neutral (bool): Whether the nodes independently correspond to a valid set of errors. @@ -422,9 +424,17 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): # and majority error_c_max = str((int(error_c_min) + 1) % 2) - # calculate all required info for the max to see if that is fully neutral - # if not, calculate and output for the min case - for error_c in [error_c_max, error_c_min]: + # list the colours with the max error one first + # (unless we do min only) + error_cs = [] + if minimal: + error_cs.append(error_c_max) + error_cs.append(error_c_min) + + # see what happens for both colours + # if neutral for maximal, it's neutral + # otherwise, it is whatever it is for the minimal + for error_c in error_cs: num_errors = colors.count(error_c) # determine the corresponding flipped logicals @@ -1036,6 +1046,16 @@ def _process_string(self, string): return new_string + def string2raw_logicals(self, string): + """ + Extracts raw logicals from output string. + Args: + string (string): Results string from which to extract logicals + Returns: + list: Raw values for logical operators that correspond to nodes. + """ + return _separate_string(self._process_string(string))[0] + def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]: """ Convert output string from circuits into a set of nodes. @@ -1114,10 +1134,11 @@ def flatten_nodes(nodes: List[DecodingGraphNode]): if nodes_per_link[node.properties["link qubit"]] % 2: flat_node = deepcopy(node) flat_node.time = None - flat_nodes.append(flat_node) + if flat_node not in flat_nodes: + flat_nodes.append(flat_node) return flat_nodes - def check_nodes(self, nodes, ignore_extra_boundary=False): + def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -1127,6 +1148,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): nodes (list): List of nodes, of the type produced by `string2nodes`. ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are ignored. + minimal (bool): Whether output should only reflect the minimal error + case. Returns: neutral (bool): Whether the nodes independently correspond to a valid set of errors. @@ -1151,10 +1174,10 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): nodes = self.flatten_nodes(nodes) link_qubits = set(node.properties["link qubit"] for node in nodes) node_color = {0: 0} - neutral = True + base_neutral = True link_graph = self._get_link_graph() ns_to_do = set(n for n in range(1, len(link_graph.nodes()))) - while ns_to_do and neutral: + while ns_to_do and base_neutral: # go through all coloured nodes newly_colored = {} for n, c in node_color.items(): @@ -1175,14 +1198,14 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): newly_colored[nn] = (c + dc) % 2 # if it is coloured, check the colour is correct else: - neutral = neutral and (node_color[nn] == (c + dc) % 2) + base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2) for nn, c in newly_colored.items(): node_color[nn] = c ns_to_do.remove(nn) # see which qubits for logical zs are needed flipped_logicals_all = [[], []] - if neutral: + if base_neutral: for inside_c in range(2): for n, c in node_color.items(): qubit = link_graph.nodes()[n] @@ -1196,22 +1219,28 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): for n, c in node_color.items(): num_nodes[c] += 1 - if num_nodes[0] == num_nodes[1]: - min_cs = [0, 1] - else: - min_cs = [int(sum(node_color.values()) < len(node_color) / 2)] + # list the colours with the max error one first + # (unless we do min only) + min_color = int(sum(node_color.values()) < len(node_color) / 2) + cs = [] + if not minimal: + cs.append((min_color + 1) % 2) + cs.append(min_color) # see what happens for both colours - # once full neutrality us found, go for it! - for c in min_cs: - this_neutral = neutral + # if neutral for maximal, it's neutral + # otherwise, it is whatever it is for the minimal + for c in cs: + + neutral = base_neutral num_errors = num_nodes[c] flipped_logicals = flipped_logicals_all[c] # if unneeded logical zs are given, cluster is not neutral # (unless this is ignored) if (not ignore_extra_boundary) and given_logicals.difference(flipped_logicals): - this_neutral = False + neutral = False + flipped_logicals = set() # otherwise, report only needed logicals that aren't given else: flipped_logicals = flipped_logicals.difference(given_logicals) @@ -1225,8 +1254,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): ) flipped_logical_nodes.append(node) - if this_neutral and flipped_logical_nodes == []: - neutral = this_neutral + if neutral and flipped_logical_nodes == []: break else: @@ -1250,6 +1278,8 @@ def is_cluster_neutral(self, atypical_nodes): to the method. Args: atypical_nodes (dictionary in the form of the return value of string2nodes) + ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are + ignored. """ neutral, logicals, _ = self.check_nodes(atypical_nodes) return neutral and not logicals diff --git a/src/qiskit_qec/circuits/surface_code.py b/src/qiskit_qec/circuits/surface_code.py index 7770d1e9..dc800dcb 100644 --- a/src/qiskit_qec/circuits/surface_code.py +++ b/src/qiskit_qec/circuits/surface_code.py @@ -69,12 +69,12 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True): # set info needed for css codes self.css_x_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.xplaqs] self.css_x_stabilizer_ops = self.css_x_gauge_ops - self.css_x_logical = self._logicals["x"][0] - self.css_x_boundary = self._logicals["x"][0] + self._logicals["x"][1] + self.css_x_logical = [self._logicals["x"][0]] + self.css_x_boundary = [self._logicals["x"][0] + self._logicals["x"][1]] self.css_z_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.zplaqs] self.css_z_stabilizer_ops = self.css_z_gauge_ops - self.css_z_logical = self._logicals["z"][0] - self.css_z_boundary = self._logicals["z"][0] + self._logicals["z"][1] + self.css_z_logical = [self._logicals["z"][0]] + self.css_z_boundary = [self._logicals["z"][0] + self._logicals["z"][1]] self.round_schedule = self.basis self.blocks = T @@ -342,7 +342,7 @@ def string2raw_logicals(self, string): Z[0] = (Z[0] + int(final_readout[j * self.d])) % 2 # evaluated using right side Z[1] = (Z[1] + int(final_readout[(j + 1) * self.d - 1])) % 2 - return str(Z[0]) + " " + str(Z[1]) + return [str(Z[0]), str(Z[1])] def _process_string(self, string): # get logical readout @@ -353,7 +353,7 @@ def _process_string(self, string): # the space separated string of syndrome changes then gets a # double space separated logical value on the end - new_string = measured_Z + " " + syndrome_changes + new_string = " ".join(measured_Z) + " " + syndrome_changes return new_string @@ -416,7 +416,7 @@ def string2nodes(self, string, **kwargs): nodes.append(node) return nodes - def check_nodes(self, nodes, ignore_extra_boundary=False): + def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): """ Determines whether a given set of nodes are neutral. If so, also determines any additional logical readout qubits that would be @@ -426,6 +426,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False): nodes (list): List of nodes, of the type produced by `string2nodes`. ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are ignored. + minimal (bool): Whether output should only reflect the minimal error + case. Returns: neutral (bool): Whether the nodes independently correspond to a valid set of errors. diff --git a/src/qiskit_qec/decoders/hdrg_decoders.py b/src/qiskit_qec/decoders/hdrg_decoders.py index a3bc6a99..67773a40 100644 --- a/src/qiskit_qec/decoders/hdrg_decoders.py +++ b/src/qiskit_qec/decoders/hdrg_decoders.py @@ -19,6 +19,7 @@ from copy import copy, deepcopy from dataclasses import dataclass from typing import Dict, List, Set, Tuple +from abc import ABC from rustworkx import connected_components, distance_matrix, PyGraph from qiskit_qec.circuits.repetition_code import ArcCircuit @@ -26,7 +27,7 @@ from qiskit_qec.utils import DecodingGraphNode, DecodingGraphEdge -class ClusteringDecoder: +class ClusteringDecoder(ABC): """ Generic base class for clustering decoders. """ @@ -37,21 +38,6 @@ def __init__( decoding_graph: DecodingGraph = None, ): self.code = code_circuit - if decoding_graph: - self.decoding_graph = decoding_graph - else: - self.decoding_graph = DecodingGraph(self.code) - - -class BravyiHaahDecoder(ClusteringDecoder): - """Decoder based on finding connected components within the decoding graph.""" - - def __init__( - self, - code_circuit, - decoding_graph: DecodingGraph = None, - ): - super().__init__(code_circuit, decoding_graph) if hasattr(self.code, "_xbasis"): if self.code._xbasis: @@ -65,6 +51,63 @@ def __init__( else: self.code_index = {j: j for j in range(self.code.n)} + if decoding_graph: + self.decoding_graph = decoding_graph + else: + self.decoding_graph = DecodingGraph(self.code) + + def get_corrections(self, string, clusters): + """ + Turn a set of neutral clusters into corrections. + + Args: + string (str): Output string of the code + clusters (dict): Dictionary with the indices of the given node + as keys and an integer specifying their cluster as the corresponding + value. + Returns: + corrected_logicals (list): A list of integers that are 0 or 1. + These are the corrected values of the final transversal + measurement, corresponding to the logical operators of + self.measured_logicals. + """ + + # get the list of bulk nodes for each cluster + cluster_nodes = {c: [] for c in clusters.values()} + for n, c in clusters.items(): + node = self.decoding_graph.graph[n] + if not node.is_boundary: + cluster_nodes[c].append(node) + + # get the list of required logicals for each cluster + cluster_logicals = {} + for c, nodes in cluster_nodes.items(): + _, logical_nodes, _ = self.code.check_nodes(nodes, minimal=True) + z_logicals = [node.qubits[0] for node in logical_nodes] + cluster_logicals[c] = z_logicals + + # get the net effect on each logical + net_z_logicals = {z_logical[0]: 0 for z_logical in self.measured_logicals} + for c, z_logicals in cluster_logicals.items(): + for z_logical in self.measured_logicals: + if z_logical[0] in z_logicals: + net_z_logicals[z_logical[0]] += 1 + for z_logical, num in net_z_logicals.items(): + net_z_logicals[z_logical] = num % 2 + + corrected_z_logicals = [] + string = string.split(" ")[0] + for z_logical in self.measured_logicals: + raw_logical = int(string[-1 - self.code_index[z_logical[0]]]) + corrected_logical = (raw_logical + net_z_logicals[z_logical[0]]) % 2 + corrected_z_logicals.append(corrected_logical) + + return corrected_z_logicals + + +class BravyiHaahDecoder(ClusteringDecoder): + """Decoder based on finding connected components within the decoding graph.""" + def _cluster(self, ns, dist_max): """ Finds connected components in the given nodes, for nodes connected by at most the given distance @@ -183,44 +226,12 @@ def process(self, string): measurement, corresponding to the logical operators of self.measured_logicals. """ - code = self.code - decoding_graph = self.decoding_graph # turn string into nodes and cluster - nodes = code.string2nodes(string, all_logicals=True) + nodes = self.code.string2nodes(string, all_logicals=True) clusters = self.cluster(nodes) - # get the list of bulk nodes for each cluster - cluster_nodes = {c: [] for c in clusters.values()} - for n, c in clusters.items(): - node = decoding_graph.graph[n] - if not node.is_boundary: - cluster_nodes[c].append(node) - - # get the list of required logicals for each cluster - cluster_logicals = {} - for c, nodes in cluster_nodes.items(): - _, logical_nodes, _ = code.check_nodes(nodes) - z_logicals = [node.qubits[0] for node in logical_nodes] - cluster_logicals[c] = z_logicals - - # get the net effect on each logical - net_z_logicals = {z_logical[0]: 0 for z_logical in self.measured_logicals} - for c, z_logicals in cluster_logicals.items(): - for z_logical in self.measured_logicals: - if z_logical[0] in z_logicals: - net_z_logicals[z_logical[0]] += 1 - for z_logical, num in net_z_logicals.items(): - net_z_logicals[z_logical] = num % 2 - - corrected_z_logicals = [] - string = string.split(" ")[0] - for z_logical in self.measured_logicals: - raw_logical = int(string[-1 - self.code_index[z_logical[0]]]) - corrected_logical = (raw_logical + net_z_logicals[z_logical[0]]) % 2 - corrected_z_logicals.append(corrected_logical) - - return corrected_z_logicals + return self.get_corrections(string, clusters) @dataclass @@ -298,13 +309,11 @@ def __init__( code, decoding_graph: DecodingGraph = None, ) -> None: - super().__init__(code, decoding_graph) + super().__init__(code, decoding_graph=decoding_graph) self.graph = deepcopy(self.decoding_graph.graph) self.clusters: Dict[int, UnionFindDecoderCluster] = {} - # FIXME: Use a better datastructure - # It needs to support inserting at specific index, unique elements and - # sorted insert self.odd_cluster_roots: List[int] = [] + self._clusters4peeling = [] def process(self, string: str): """ @@ -319,42 +328,49 @@ def process(self, string: str): self.z_logicals. """ self.graph = deepcopy(self.decoding_graph.graph) - string = "".join([str(c) for c in string[::-1]]) - output = [int(bit) for bit in list(string.split(" ", maxsplit=self.code.d)[0])][::-1] highlighted_nodes = self.code.string2nodes(string, all_logicals=True) - if not highlighted_nodes: - return output # There's nothing for us to do here - clusters = self.cluster(highlighted_nodes) + # call cluster to do the clustering, but actually use the peeling form + self.cluster(highlighted_nodes) + clusters = self._clusters4peeling + + # determine the net logical z + net_z_logicals = {tuple(z_logical): 0 for z_logical in self.measured_logicals} for cluster_nodes, _ in clusters: erasure = self.graph.subgraph(cluster_nodes) - if isinstance(self.code, ArcCircuit): - # NOTE: it just corrects for final logical readout - for node in erasure.nodes(): - if node.is_boundary: - # FIXME: Find a general way to go from physical qubit - # index to code qubit index - qubit_to_be_corrected = int(node.qubits[0] / 2) - output[qubit_to_be_corrected] = (output[qubit_to_be_corrected] + 1) % 2 - continue - flipped_qubits = self.peeling(erasure) for qubit_to_be_corrected in flipped_qubits: - output[qubit_to_be_corrected] = (output[qubit_to_be_corrected] + 1) % 2 + for z_logical in net_z_logicals: + if qubit_to_be_corrected in z_logical: + net_z_logicals[z_logical] += 1 + for z_logical, num in net_z_logicals.items(): + net_z_logicals[z_logical] = num % 2 - return output + # apply this to the raw readout + corrected_z_logicals = [] + raw_logicals = self.code.string2raw_logicals(string) + for j, z_logical in enumerate(self.measured_logicals): + raw_logical = int(raw_logicals[j]) + corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2 + corrected_z_logicals.append(corrected_logical) + + return corrected_z_logicals - def cluster(self, nodes) -> List[List[int]]: + def cluster(self, nodes): """ Create clusters using the union-find algorithm. Args: nodes (List): List of non-typical nodes in the syndrome graph, of the type produced by `string2nodes`. + standard_form (Bool): Whether to use the standard form of + the clusters for clustering decoders, or the form used internally + by the class. Returns: - FIXME: Make this more expressive. - clusters (List[List[int]]): List of Lists of indices of nodes in clusters + clusters (dict): Dictionary with the indices of + the given node as keys and an integer specifying their cluster as the corresponding + value. """ node_indices = [self.graph.nodes().index(node) for node in nodes] for node_index, _ in enumerate(self.graph.nodes()): @@ -373,13 +389,24 @@ def cluster(self, nodes) -> List[List[int]]: while self.odd_cluster_roots: self._grow_and_merge_clusters() - clusters = [] + # compile info into standard clusters dict + clusters = {} + for c, cluster in self.clusters.items(): + # determine which nodes exactly are in the neutral cluster + neutral_nodes = list(cluster.atypical_nodes | cluster.boundary_nodes) + # put them in the required dict + for n in neutral_nodes: + clusters[n] = c + + # also compile into form required for peeling + self._clusters4peeling = [] for _, cluster in self.clusters.items(): if not cluster.atypical_nodes: continue - clusters.append( + self._clusters4peeling.append( (list(cluster.nodes), list(cluster.atypical_nodes | cluster.boundary_nodes)) ) + return clusters def find(self, u: int) -> int: @@ -627,6 +654,7 @@ def __init__(self, code, decoding_graph: DecodingGraph = None) -> None: super().__init__(code, decoding_graph) self.graph = deepcopy(self.decoding_graph.graph) self.r = 1 + self._clusters4peeling = [] def process(self, string: str): """ @@ -659,7 +687,8 @@ def process(self, string: str): if not highlighted_nodes: return output # There's nothing for us to do here - clusters = self.cluster(highlighted_nodes) + self.cluster(highlighted_nodes) + clusters = self._clusters4peeling flattened_highlighted_nodes: List[DecodingGraphNode] = [] for highlighted_node in highlighted_nodes: @@ -679,7 +708,17 @@ def process(self, string: str): return output - def cluster(self, nodes: List[DecodingGraphNode]): + def cluster(self, nodes): + """ + Args: + nodes (List): List of non-typical nodes in the syndrome graph, + of the type produced by `string2nodes`. + + Returns: + clusters (dict): Ddictionary with the indices of + the given node as keys and an integer specifying their cluster as the corresponding + value. + """ self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots = [] @@ -711,12 +750,23 @@ def cluster(self, nodes: List[DecodingGraphNode]): neutral_clusters += self._collect_neutral_clusters() + # compile info into standard clusters dict + clusters = {} + for c, cluster in enumerate(neutral_clusters): + # determine which nodes exactly are in the neutral cluster + neutral_nodes = list(cluster.atypical_nodes | cluster.boundary_nodes) + # put them in the required dict + for n in neutral_nodes: + clusters[n] = c + neutral_cluster_nodes: List[List[int]] = [] for cluster in neutral_clusters: neutral_cluster_nodes.append( (list(cluster.nodes), list(cluster.atypical_nodes | cluster.boundary_nodes)) ) + self._clusters4peeling = neutral_cluster_nodes + return neutral_cluster_nodes def _add_node(self, node): diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index c06f9944..7903f10f 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -70,6 +70,9 @@ def __iter__(self): for attr, value in self.__dict__.items(): yield attr, value + def __repr__(self): + return str(dict(self)) + @dataclass class DecodingGraphEdge: @@ -100,3 +103,6 @@ def __hash__(self) -> int: def __iter__(self): for attr, value in self.__dict__.items(): yield attr, value + + def __repr__(self): + return str(dict(self)) diff --git a/test/code_circuits/test_rep_codes.py b/test/code_circuits/test_rep_codes.py index f618d7ff..6b72fc89 100644 --- a/test/code_circuits/test_rep_codes.py +++ b/test/code_circuits/test_rep_codes.py @@ -29,7 +29,7 @@ from qiskit_qec.decoders.decoding_graph import DecodingGraph from qiskit_qec.utils import DecodingGraphNode from qiskit_qec.analysis.faultenumerator import FaultEnumerator -from qiskit_qec.decoders.hdrg_decoders import BravyiHaahDecoder +from qiskit_qec.decoders.hdrg_decoders import BravyiHaahDecoder, UnionFindDecoder def get_syndrome(code, noise_model, shots=1024): @@ -255,7 +255,8 @@ def single_error_test( # check that the nodes are neutral neutral, flipped_logicals, _ = code.check_nodes(nodes) self.assertTrue( - neutral and flipped_logicals == [], "Error: Single error nodes are not neutral" + neutral and flipped_logicals == [], + "Error: Single error nodes are not neutral: " + string, ) # and that the given flipped logical makes sense for node in nodes: @@ -499,47 +500,51 @@ def test_empty_decoding_graph(): """Test initializtion of decoding graphs with None""" DecodingGraph(None) - def test_clustering_decoder(self): - """Test decoding of ARCs and RCCs with ClusteringDecoder""" + def clustering_decoder_test( + self, Decoder + ): # NOT run directly by unittest; called by test_graph_constructions + """Test decoding of ARCs and RCCs with clustering decoders""" # parameters for test d = 8 p = 0.1 N = 1000 - codes = [] - # first make a bunch of ARCs - # crossed line - links_cross = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 2)] - links_cross.append((2 * (d - 2), 2 * (d - 2) + 1, 2 * (int(d / 2)))) - links_cross.append(((2 * (int(d / 2))), 2 * (d - 1), 2 * (d - 1) + 1)) - # ladder (works for even d) - half_d = int(d / 2) - links_ladder = [] - for row in [0, 1]: - for j in range(half_d - 1): - delta = row * (2 * half_d - 1) - links_ladder.append((delta + 2 * j, delta + 2 * j + 1, delta + 2 * (j + 1))) - q = links_ladder[-1][2] + 1 - for j in range(half_d): - delta = 2 * half_d - 1 - links_ladder.append((2 * j, q, delta + 2 * j)) - q += 1 - # line - links_line = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 1)] - # add them to the code list - for links in [links_ladder, links_line, links_cross]: - codes.append(ArcCircuit(links, 0)) - # then an RCC - codes.append(RepetitionCode(d, 1)) + # first an RCC + codes = [RepetitionCode(d, 1)] + # then a linear ARC + links = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 1)] + codes.append(ArcCircuit(links, 0)) + # then make a bunch of non-linear ARCs + # TODO: make these work for union find too + if Decoder is not UnionFindDecoder: + # crossed line + links_cross = [(2 * j, 2 * j + 1, 2 * (j + 1)) for j in range(d - 2)] + links_cross.append((2 * (d - 2), 2 * (d - 2) + 1, 2 * (int(d / 2)))) + links_cross.append(((2 * (int(d / 2))), 2 * (d - 1), 2 * (d - 1) + 1)) + # ladder (works for even d) + half_d = int(d / 2) + links_ladder = [] + for row in [0, 1]: + for j in range(half_d - 1): + delta = row * (2 * half_d - 1) + links_ladder.append((delta + 2 * j, delta + 2 * j + 1, delta + 2 * (j + 1))) + q = links_ladder[-1][2] + 1 + for j in range(half_d): + delta = 2 * half_d - 1 + links_ladder.append((2 * j, q, delta + 2 * j)) + q += 1 + # add them to the code list + for links in [links_ladder, links_cross]: + codes.append(ArcCircuit(links, 0)) # now run them all and check it works - for code in codes: - code = ArcCircuit(links, 0) + for c, code in enumerate(codes): decoding_graph = DecodingGraph(code) - decoder = BravyiHaahDecoder(code, decoding_graph=decoding_graph) + decoder = Decoder(code, decoding_graph=decoding_graph) errors = {z_logical[0]: 0 for z_logical in decoder.measured_logicals} min_error_num = code.d - for sample in range(N): + min_error_string = "" + for _ in range(N): # generate random string string = "".join([choices(["1", "0"], [1 - p, p])[0] for _ in range(d)]) for _ in range(code.T): @@ -549,20 +554,32 @@ def test_clustering_decoder(self): for j, z_logical in enumerate(decoder.measured_logicals): error = corrected_z_logicals[j] != 1 if error: - min_error_num = min(min_error_num, string.count("0")) + error_num = string.count("0") + if error_num < min_error_num: + min_error_num = error_num + min_error_string = string errors[z_logical[0]] += error - # check that error rates are at least
d/3 - for z_logical in decoder.measured_logicals: - self.assertTrue( - errors[z_logical[0]] / (sample + 1) < p**2, - "Logical error rate greater than p^2.", - ) + # check that min num errors to cause logical errors >d/3 self.assertTrue( min_error_num > d / 3, - str(min_error_num) + "errors cause logical error despite d=" + str(code.d), + str(min_error_num) + + " errors cause logical error despite d=" + + str(code.d) + + " for code " + + str(c) + + " with " + + min_error_string + + ".", ) + def test_bravyi_haah(self): + """Test decoding of ARCs and RCCs with Bravyi Haah""" + self.clustering_decoder_test(BravyiHaahDecoder) + + def test_union_find(self): + """Test decoding of ARCs and RCCs with Union Find""" + self.clustering_decoder_test(UnionFindDecoder) + if __name__ == "__main__": unittest.main() diff --git a/test/union_find/test_union_find.py b/test/union_find/test_union_find.py index 460a455d..08f2335a 100644 --- a/test/union_find/test_union_find.py +++ b/test/union_find/test_union_find.py @@ -12,168 +12,30 @@ """Tests for template.""" -from random import choices import unittest -import math from unittest import TestCase -from qiskit_qec.analysis.faultenumerator import FaultEnumerator from qiskit_qec.decoders import UnionFindDecoder -from qiskit_qec.circuits import SurfaceCodeCircuit, RepetitionCodeCircuit, ArcCircuit -from qiskit_qec.decoders.temp_code_util import temp_syndrome -from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel +from qiskit_qec.circuits import SurfaceCodeCircuit class UnionFindDecoderTest(TestCase): - """Tests will be here.""" - - def setUp(self) -> None: - # Bit-flip circuit noise model - p = 0.05 - noise_model = PauliNoiseModel() - noise_model.add_operation("cx", {"ix": 1, "xi": 1, "xx": 1}) - noise_model.add_operation("id", {"x": 1}) - noise_model.add_operation("reset", {"x": 1}) - noise_model.add_operation("measure", {"x": 1}) - noise_model.add_operation("x", {"x": 1, "y": 1, "z": 1}) - noise_model.set_error_probability("cx", p) - noise_model.set_error_probability("x", p) - noise_model.set_error_probability("id", p) - noise_model.set_error_probability("reset", p) - noise_model.set_error_probability("measure", p) - self.noise_model = noise_model - - self.fault_enumeration_method = "stabilizer" - - return super().setUp() + """Tests for the UnionFind decoder not covered elsewhere""" def test_surface_code_d3(self): """ Test the union find decoder on a surface code with d=3 and T=3 - with faults inserted by FaultEnumerator by checking if the syndromes - have even parity (if it's a valid code state) and if the logical value measured - is the one encoded by the circuit. + with faults inserted into the final readout. """ for logical in ["0", "1"]: - code = SurfaceCodeCircuit(d=3, T=3) - decoder = UnionFindDecoder(code) - fault_enumerator = FaultEnumerator( - code.circuit[logical], method=self.fault_enumeration_method, model=self.noise_model - ) - for fault in fault_enumerator.generate(): - outcome = "".join([str(x) for x in fault[3]]) - corrected_outcome = decoder.process(outcome) - stabilizers = temp_syndrome(corrected_outcome, code.css_z_stabilizer_ops) - for syndrome in stabilizers: - self.assertEqual(syndrome, 0) - logical_measurement = temp_syndrome(corrected_outcome, [code.css_z_logical])[0] - self.assertEqual(str(logical_measurement), logical) - - def test_repetition_code_d5(self): - """ - Test the union find decoder on a repetition code with d=3 and T=3 - with faults inserted by FaultEnumerator by checking if the syndromes - have even parity (if it's a valid code state) and if the logical value measured - is the one encoded by the circuit. - """ - for logical in ["0", "1"]: - code = RepetitionCodeCircuit(d=5, T=5) - decoder = UnionFindDecoder(code) - fault_enumerator = FaultEnumerator( - code.circuit[logical], method=self.fault_enumeration_method, model=self.noise_model - ) - for fault in fault_enumerator.generate(): - outcome = "".join([str(x) for x in fault[3]]) - corrected_outcome = decoder.process(outcome) - stabilizers = temp_syndrome(corrected_outcome, code.css_z_stabilizer_ops) - for syndrome in stabilizers: - self.assertEqual(syndrome, 0) - logical_measurement = temp_syndrome(corrected_outcome, code.css_z_logical)[0] - self.assertEqual(str(logical_measurement), logical) - - def test_circular_arc_code(self): - """ - Test the union find decoder on a circular ARC code with faults inserted - by FaultEnumerator by checking if the syndromes have even parity - (if it's a valid code state) and if the logical value measured - is the one encoded by the circuit (only logical 0 for the ARC circuit, - see issue #309). - """ - links = [(0, 1, 2), (2, 3, 4), (4, 5, 6), (6, 7, 0)] - code = ArcCircuit(links=links, T=len(links), resets=False) - decoder = UnionFindDecoder(code) - fault_enumerator = FaultEnumerator( - code.circuit[code.base], method=self.fault_enumeration_method, model=self.noise_model - ) - for fault in fault_enumerator.generate(): - outcome = "".join([str(x) for x in fault[3]]) - corrected_outcome = decoder.process(outcome) - logical_measurement = temp_syndrome( - corrected_outcome, [[int(q / 2) for q in code.z_logicals]] - )[0] - self.assertEqual(str(logical_measurement), "0") - - def test_error_rates(self): - """ - Test the error rates using some ARCs. - """ - d = 8 - p = 0.01 - samples = 1000 - - testcases = [] - testcases = [ - "".join([choices(["0", "1"], [1 - p, p])[0] for _ in range(d)]) for _ in range(samples) - ] - codes = self.construct_codes(d) - - # now run them all and check it works - for code in codes: + code = SurfaceCodeCircuit(d=3, T=1) decoder = UnionFindDecoder(code) - if isinstance(code, ArcCircuit): - z_logicals = code.z_logicals - else: - z_logicals = code.css_z_logical[0] - - logical_errors = 0 - min_flips_for_logical = code.d - for sample in range(samples): - # generate random string - string = "" - for _ in range(code.T): - string += "0" * (d - 1) + " " - string += testcases[sample] - # get and check corrected_z_logicals - outcome = decoder.process(string) - logical_outcome = sum([outcome[int(z_logical / 2)] for z_logical in z_logicals]) % 2 - if not logical_outcome == 0: - logical_errors += 1 - min_flips_for_logical = min(min_flips_for_logical, string.count("1")) - - # check that error rates are at least
d/3 - self.assertTrue( - logical_errors / samples - < (math.factorial(d)) / (math.factorial(int(d / 2)) ** 2) * p**4, - "Logical error rate shouldn't exceed d!/((d/2)!^2)*p^(d/2).", - ) - self.assertTrue( - min_flips_for_logical >= d / 2, - "Minimum amount of errors that also causes logical errors shouldn't be lower than d/2.", - ) - - def construct_codes(self, d): - """ - Construct codes for the logical error rate test. - """ - # parameters for test - codes = [] - # add them to the code list - # TODO: Add ARCs to tests as soon as a better general alternative to the peeling is found, - # instead of just looking at what logicals are affected by checking for boundary nodes. - # for links in [links_ladder, links_line, links_cross]: - # codes.append(ArcCircuit(links, 0)) - codes.append(RepetitionCodeCircuit(d=d, T=1)) - return codes + for j in range(code.n): + string = logical * (j) + str((1 + int(logical)) % 2) + logical * (code.n - j - 1) + string += " 0000 0000" + corrected_outcome = decoder.process(string) + self.assertTrue( + corrected_outcome[0] == int(logical), "Correction for surface code failed." + ) if __name__ == "__main__":