Skip to content

Commit

Permalink
Allow error graph to be constructed from list of nodes (#360)
Browse files Browse the repository at this point in the history
* Update decoding_graph.py

* black

* lint

---------

Co-authored-by: grace-harper <[email protected]>
  • Loading branch information
quantumjim and grace-harper authored Apr 21, 2023
1 parent 9ff5b25 commit 603c677
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/qiskit_qec/decoders/decoding_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, code, brute=False, graph=None):
else:
self._make_syndrome_graph()

self._logical_nodes = []
for node in self.graph.nodes():
if node.is_boundary:
self._logical_nodes.append(node)

def _make_syndrome_graph(self):
if not self.brute and hasattr(self.code, "_make_syndrome_graph"):
self.graph, self.hyperedges = self.code._make_syndrome_graph()
Expand Down Expand Up @@ -288,19 +293,26 @@ def weight_syndrome_graph(self, counts, method: str = METHOD_SPITZ):
edge_data.weight = w
self.graph.update_edge(edge[0], edge[1], edge_data)

def make_error_graph(self, string: str):
def make_error_graph(self, data, all_logicals=True):
"""Returns error graph.
Args:
string (str): A string describing the output from the code.
data: Either an ouput string of the code, or a list of
nodes for the code.
Returns:
The subgraph of graph which corresponds to the non-trivial
syndrome elements in the given string.
"""

E = rx.PyGraph(multigraph=False)
nodes = self.code.string2nodes(string, all_logicals=True)
if isinstance(data, str):
nodes = self.code.string2nodes(data, all_logicals=all_logicals)
else:
if all_logicals:
nodes = list(set(data).union(set(self._logical_nodes)))
else:
nodes = data
for node in nodes:
if node not in E.nodes():
E.add_node(node)
Expand All @@ -317,11 +329,11 @@ def weight_fn(edge):
source = E[source_index]
target = E[target_index]
if target != source:
distance = distance_matrix[self.graph.nodes().index(source)][
self.graph.nodes().index(target)
]
qubits = list(set(source.qubits).intersection(target.qubits))
ns = self.graph.nodes().index(source)
nt = self.graph.nodes().index(target)
distance = distance_matrix[ns][nt]
if np.isfinite(distance):
qubits = list(set(source.qubits).intersection(target.qubits))
distance = int(distance)
E.add_edge(source_index, target_index, DecodingGraphEdge(qubits, distance))
return E
Expand Down

0 comments on commit 603c677

Please sign in to comment.