From 48a664ef35e9bcbd5c2d5539b0667e1a270d5593 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Mon, 27 Nov 2023 17:11:42 -0800 Subject: [PATCH] Add `sinter.FusionBlossomCompiledDecoder` --- .../src/sinter/_decoding_fusion_blossom.py | 46 +++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/glue/sample/src/sinter/_decoding_fusion_blossom.py b/glue/sample/src/sinter/_decoding_fusion_blossom.py index bcf45cf09..966b69a4f 100644 --- a/glue/sample/src/sinter/_decoding_fusion_blossom.py +++ b/glue/sample/src/sinter/_decoding_fusion_blossom.py @@ -1,19 +1,58 @@ import math import pathlib -from typing import Callable, List, TYPE_CHECKING -from typing import Tuple +from typing import Callable, List, TYPE_CHECKING, Tuple import numpy as np import stim -from sinter._decoding_decoder_class import Decoder +from sinter._decoding_decoder_class import Decoder, CompiledDecoder if TYPE_CHECKING: import fusion_blossom +class FusionBlossomCompiledDecoder(CompiledDecoder): + def __init__(self, solver: 'fusion_blossom.SolverSerial', fault_masks: 'np.ndarray', num_dets: int, num_obs: int): + self.solver = solver + self.fault_masks = fault_masks + self.num_dets = num_dets + self.num_obs = num_obs + + def decode_shots_bit_packed( + self, + *, + bit_packed_detection_event_data: 'np.ndarray', + ) -> 'np.ndarray': + num_shots = bit_packed_detection_event_data.shape[0] + predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8) + import fusion_blossom + for shot in range(num_shots): + dets_sparse = np.flatnonzero(np.unpackbits(bit_packed_detection_event_data[shot], count=self.num_dets, bitorder='little')) + syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse) + self.solver.solve(syndrome) + prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])) + predictions[shot] = np.packbits(prediction, bitorder='little') + self.solver.clear() + return predictions + + class FusionBlossomDecoder(Decoder): """Use fusion blossom to predict observables from detection events.""" + + def compile_decoder_for_dem(self, *, dem: 'stim.DetectorErrorModel') -> CompiledDecoder: + try: + import fusion_blossom + except ImportError as ex: + raise ImportError( + "The decoder 'fusion_blossom' isn't installed\n" + "To fix this, install the python package 'fusion_blossom' into your environment.\n" + "For example, if you are using pip, run `pip install fusion_blossom`.\n" + ) from ex + + solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(dem) + return FusionBlossomCompiledDecoder(solver, fault_masks, dem.num_detectors, dem.num_observables) + + def decode_via_files(self, *, num_shots: int, @@ -36,7 +75,6 @@ def decode_via_files(self, error_model = stim.DetectorErrorModel.from_file(dem_path) solver, fault_masks = detector_error_model_to_fusion_blossom_solver_and_fault_masks(error_model) num_det_bytes = math.ceil(num_dets / 8) - with open(dets_b8_in_path, 'rb') as dets_in_f: with open(obs_predictions_b8_out_path, 'wb') as obs_out_f: for _ in range(num_shots):