From 603c677a204c73dd54368acd4c739f58b743e3fa Mon Sep 17 00:00:00 2001 From: James Wootton Date: Fri, 21 Apr 2023 20:16:43 +0200 Subject: [PATCH] Allow error graph to be constructed from list of nodes (#360) * Update decoding_graph.py * black * lint --------- Co-authored-by: grace-harper <119029214+grace-harper@users.noreply.github.com> --- src/qiskit_qec/decoders/decoding_graph.py | 26 +++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/qiskit_qec/decoders/decoding_graph.py b/src/qiskit_qec/decoders/decoding_graph.py index 4bf1392d..084c817d 100644 --- a/src/qiskit_qec/decoders/decoding_graph.py +++ b/src/qiskit_qec/decoders/decoding_graph.py @@ -57,6 +57,11 @@ def __init__(self, code, brute=False, graph=None): else: self._make_syndrome_graph() + self._logical_nodes = [] + for node in self.graph.nodes(): + if node.is_boundary: + self._logical_nodes.append(node) + def _make_syndrome_graph(self): if not self.brute and hasattr(self.code, "_make_syndrome_graph"): self.graph, self.hyperedges = self.code._make_syndrome_graph() @@ -288,11 +293,12 @@ def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ): edge_data.weight = w self.graph.update_edge(edge[0], edge[1], edge_data) - def make_error_graph(self, string: str): + def make_error_graph(self, data, all_logicals=True): """Returns error graph. Args: - string (str): A string describing the output from the code. + data: Either an ouput string of the code, or a list of + nodes for the code. Returns: The subgraph of graph which corresponds to the non-trivial @@ -300,7 +306,13 @@ def make_error_graph(self, string: str): """ E = rx.PyGraph(multigraph=False) - nodes = self.code.string2nodes(string, all_logicals=True) + if isinstance(data, str): + nodes = self.code.string2nodes(data, all_logicals=all_logicals) + else: + if all_logicals: + nodes = list(set(data).union(set(self._logical_nodes))) + else: + nodes = data for node in nodes: if node not in E.nodes(): E.add_node(node) @@ -317,11 +329,11 @@ def weight_fn(edge): source = E[source_index] target = E[target_index] if target != source: - distance = distance_matrix[self.graph.nodes().index(source)][ - self.graph.nodes().index(target) - ] - qubits = list(set(source.qubits).intersection(target.qubits)) + ns = self.graph.nodes().index(source) + nt = self.graph.nodes().index(target) + distance = distance_matrix[ns][nt] if np.isfinite(distance): + qubits = list(set(source.qubits).intersection(target.qubits)) distance = int(distance) E.add_edge(source_index, target_index, DecodingGraphEdge(qubits, distance)) return E