From 33f25a3ff00064be9d3c5b066e3eac3f0d0c1fd8 Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Sun, 25 Aug 2024 20:55:59 -0400 Subject: [PATCH] remove timeout --- .../sinter/_decoding_all_built_in_decoders.py | 6 +- glue/sample/src/sinter/_decoding_mwpf.py | 74 +++++++++++-------- 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding_all_built_in_decoders.py index a27b3e9a..4f011cdf 100644 --- a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py +++ b/glue/sample/src/sinter/_decoding_all_built_in_decoders.py @@ -10,6 +10,8 @@ 'vacuous': VacuousDecoder(), 'pymatching': PyMatchingDecoder(), 'fusion_blossom': FusionBlossomDecoder(), - 'hyper_uf': HyperUFDecoder(), - 'mwpf': MwpfDecoder(), + # an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049) + 'hypergraph_union_find': HyperUFDecoder(), + # Minimum-Weight Parity Factor using similar primal-dual method the blossom algorithm (https://pypi.org/project/mwpf/) + 'mw_parity_factor': MwpfDecoder(), } diff --git a/glue/sample/src/sinter/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding_mwpf.py index 636de1f0..9841915d 100644 --- a/glue/sample/src/sinter/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding_mwpf.py @@ -1,6 +1,6 @@ import math import pathlib -from typing import Callable, List, TYPE_CHECKING, Tuple +from typing import Callable, List, TYPE_CHECKING, Tuple, Any import numpy as np import stim @@ -10,7 +10,13 @@ if TYPE_CHECKING: import mwpf -DEFAULT_TIMEOUT: float = 10.0 # decoder timeout in seconds + +def mwpf_import_error() -> ImportError: + return ImportError( + "The decoder 'MWPF' isn't installed\n" + "To fix this, install the python package 'MWPF' into your environment.\n" + "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" + ) class MwpfCompiledDecoder(CompiledDecoder): @@ -57,19 +63,18 @@ class MwpfDecoder(Decoder): """Use MWPF to predict observables from detection events.""" def compile_decoder_for_dem( - self, *, dem: "stim.DetectorErrorModel", timeout: float = DEFAULT_TIMEOUT + self, + *, + dem: "stim.DetectorErrorModel", + decoder_cls: Any = None, # decoder class used to construct the MWPF decoder. + # in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins` + # but just provide different plugins for optimizing the primal and/or dual solutions. + # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only + # grows the clusters until the first valid solution appears; some more optimized solvers uses + # one or more plugins to further optimize the solution, which requires longer decoding time. ) -> CompiledDecoder: - try: - import mwpf - except ImportError as ex: - raise ImportError( - "The decoder 'MWPF' isn't installed\n" - "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF`.\n" - ) from ex - solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - dem, timeout=timeout + dem, decoder_cls=decoder_cls ) return MwpfCompiledDecoder( solver, fault_masks, dem.num_detectors, dem.num_observables @@ -85,20 +90,11 @@ def decode_via_files( dets_b8_in_path: pathlib.Path, obs_predictions_b8_out_path: pathlib.Path, tmp_dir: pathlib.Path, - timeout: float = DEFAULT_TIMEOUT, + decoder_cls: Any = None, ) -> None: - try: - import mwpf - except ImportError as ex: - raise ImportError( - "The decoder 'MWPF' isn't installed\n" - "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" - ) from ex - error_model = stim.DetectorErrorModel.from_file(dem_path) solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - error_model, timeout=timeout + error_model, decoder_cls=decoder_cls ) num_det_bytes = math.ceil(num_dets / 8) with open(dets_b8_in_path, "rb") as dets_in_f: @@ -126,12 +122,17 @@ def decode_via_files( class HyperUFDecoder(MwpfDecoder): - """Setting timeout to 0 becomes effectively a hypergraph UF decoder""" - def compile_decoder_for_dem( self, *, dem: "stim.DetectorErrorModel" ) -> CompiledDecoder: - return super().compile_decoder_for_dem(dem=dem, timeout=0.0) + try: + import mwpf + except ImportError as ex: + raise mwpf_import_error() from ex + + return super().compile_decoder_for_dem( + dem=dem, decoder_cls=mwpf.SolverSerialUnionFind + ) def decode_via_files( self, @@ -144,6 +145,11 @@ def decode_via_files( obs_predictions_b8_out_path: pathlib.Path, tmp_dir: pathlib.Path, ) -> None: + try: + import mwpf + except ImportError as ex: + raise mwpf_import_error() from ex + return super().decode_via_files( num_shots=num_shots, num_dets=num_dets, @@ -152,7 +158,7 @@ def decode_via_files( dets_b8_in_path=dets_b8_in_path, obs_predictions_b8_out_path=obs_predictions_b8_out_path, tmp_dir=tmp_dir, - timeout=0.0, + decoder_cls=mwpf.SolverSerialUnionFind, ) @@ -204,11 +210,14 @@ def _helper(m: stim.DetectorErrorModel, reps: int): def detector_error_model_to_mwpf_solver_and_fault_masks( - model: stim.DetectorErrorModel, timeout: float = DEFAULT_TIMEOUT + model: stim.DetectorErrorModel, decoder_cls: Any = None ) -> Tuple["mwpf.SolverSerialJointSingleHair", np.ndarray]: """Convert a stim error model into a NetworkX graph.""" - import mwpf + try: + import mwpf + except ImportError as ex: + raise mwpf_import_error() from ex num_detectors = model.num_detectors is_detector_connected = np.full(num_detectors, False, dtype=bool) @@ -256,7 +265,10 @@ def handle_detector_coords(detector: int, coords: np.ndarray): rescaled_edges, # Weighted edges. ) + if decoder_cls is None: + # default to the solver with highest accuracy + decoder_cls = mwpf.SolverSerialJointSingleHair return ( - mwpf.SolverSerialJointSingleHair(initializer, {"primal": {"timeout": timeout}}), + decoder_cls(initializer), fault_masks, )