Skip to content

Commit

Permalink
Create pymatching_decoder.py (#415)
Browse files Browse the repository at this point in the history
Creates a matching object which can be used to decode a counts string via the method process.
  • Loading branch information
hetenyib authored Dec 15, 2023
1 parent 1549e02 commit 9badddb
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions src/qiskit_qec/decoders/pymatching_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=invalid-name, disable=no-name-in-module, disable=no-member

"""PyMatching"""
from typing import List, Union
from pymatching import Matching
from qiskit_qec.decoders.decoding_graph import (
DecodingGraphNode,
DecodingGraphEdge,
DecodingGraph,
)


class PyMatching:
"""
Matching decoder using PyMatching.
"""

def __init__(self, decoding_graph: DecodingGraph):
"""Setting up the matching object"""
self.decoding_graph = decoding_graph
self.graph = decoding_graph.graph
self.pymatching = self.matching()
self.indexer = None
super().__init__()

def matching(self) -> Matching:
return Matching(self.graph)

def logical_flips(
self, syndrome: Union[List[DecodingGraphNode], List[int]]
) -> List[int]:
"""
Args:
syndromes: a) list of DecodingGraphNode objects returnes by string2nodes, or
b) list of binaries indicating which node is highlighted, e.g., the output of a stim detector sampler
Returns: list of binaries indicating which logical is flipped
"""
if isinstance(syndrome[0], DecodingGraphNode):
syndrome = self.nodes_to_detections(syndrome)
return self.pymatching.decode(syndrome)

def process(self, string: str) -> List[int]:
"""
Converts qiskit counts string into a list of flipped logicals
Args: counts string
Returns: list of corrected logicals (0 or 1)
"""
nodes = self.decoding_graph.code.string2nodes(string)
raw_logicals = self.decoding_graph.code.string2raw_logicals(string)

logical_flips = self.logical_flips(nodes)

corrected_logicals = [
(raw + flip) % 2 for raw, flip in zip(raw_logicals, logical_flips)
]

return corrected_logicals

def matched_edges(
self, syndrome: Union[List[DecodingGraphNode], List[int]]
) -> List[DecodingGraphEdge]:
"""
Args:
syndromes: a) list of DecodingGraphNode objects returnes by string2nodes, or
b) list of binaries indicating which node is highlighted
Returns: list of DecodingGraphEdge-s included in the matching
"""
if isinstance(syndrome[0], DecodingGraphNode):
syndrome = self.nodes_to_detections(syndrome)
edge_dets = list(self.graph.edge_list())
edges = self.graph.edges()
matched_det_pairs = self.pymatching.decode_to_edges_array(syndrome)
det_pairs = []
for pair in matched_det_pairs:
if pair[1] == -1:
pair[-1] = pair[-1] + len(self.graph.nodes())
pair.sort()
det_pairs.append(tuple(pair))
mached_edges = [edges[edge_dets.index(det_pair)] for det_pair in det_pairs]
return mached_edges

def nodes_to_detections(self, syndrome_nodes: List[DecodingGraphNode]) -> List[int]:
"""Converts nodes to detector indices to be used by pymatching.Matching.decode"""
graph_nodes = self.graph.nodes()
detections = [0] * len(graph_nodes)
for i, node in enumerate(graph_nodes):
if node in syndrome_nodes:
detections[i] = 1
return detections

0 comments on commit 9badddb

Please sign in to comment.