Skip to content

Commit

Permalink
Add sinter.FusionBlossomCompiledDecoder (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc authored Nov 28, 2023
1 parent fb69c2d commit 97a9cd6
Showing 1 changed file with 42 additions and 4 deletions.
46 changes: 42 additions & 4 deletions glue/sample/src/sinter/_decoding_fusion_blossom.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,58 @@
import math
import pathlib
from typing import Callable, List, TYPE_CHECKING
from typing import Tuple
from typing import Callable, List, TYPE_CHECKING, Tuple

import numpy as np
import stim

from sinter._decoding_decoder_class import Decoder
from sinter._decoding_decoder_class import Decoder, CompiledDecoder

if TYPE_CHECKING:
import fusion_blossom


class FusionBlossomCompiledDecoder(CompiledDecoder):
def __init__(self, solver: 'fusion_blossom.SolverSerial', fault_masks: 'np.ndarray', num_dets: int, num_obs: int):
self.solver = solver
self.fault_masks = fault_masks
self.num_dets = num_dets
self.num_obs = num_obs

def decode_shots_bit_packed(
self,
*,
bit_packed_detection_event_data: 'np.ndarray',
) -> 'np.ndarray':
num_shots = bit_packed_detection_event_data.shape[0]
predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8)
import fusion_blossom
for shot in range(num_shots):
dets_sparse = np.flatnonzero(np.unpackbits(bit_packed_detection_event_data[shot], count=self.num_dets, bitorder='little'))
syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
self.solver.solve(syndrome)
prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
predictions[shot] = np.packbits(prediction, bitorder='little')
self.solver.clear()
return predictions


class FusionBlossomDecoder(Decoder):
"""Use fusion blossom to predict observables from detection events."""

def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder:
try:
import fusion_blossom
except ImportError as ex:
raise ImportError(
"The decoder 'fusion_blossom' isn't installed\n"
"To fix this, install the python package 'fusion_blossom' into your environment.\n"
"For example, if you are using pip, run `pip install fusion_blossom`.\n"
) from ex

solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(dem)
return FusionBlossomCompiledDecoder(solver, fault_masks, dem.num_detectors, dem.num_observables)


def decode_via_files(self,
*,
num_shots: int,
Expand All @@ -36,7 +75,6 @@ def decode_via_files(self,
error_model = stim.DetectorErrorModel.from_file(dem_path)
solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(error_model)
num_det_bytes = math.ceil(num_dets / 8)

with open(dets_b8_in_path, 'rb') as dets_in_f:
with open(obs_predictions_b8_out_path, 'wb') as obs_out_f:
for _ in range(num_shots):
Expand Down

0 comments on commit 97a9cd6

Please sign in to comment.