Skip to content

Commit

Permalink
remove timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
yuewuo committed Aug 26, 2024
1 parent bad41d7 commit 33f25a3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 33 deletions.
6 changes: 4 additions & 2 deletions glue/sample/src/sinter/_decoding_all_built_in_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
'vacuous': VacuousDecoder(),
'pymatching': PyMatchingDecoder(),
'fusion_blossom': FusionBlossomDecoder(),
'hyper_uf': HyperUFDecoder(),
'mwpf': MwpfDecoder(),
# an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049)
'hypergraph_union_find': HyperUFDecoder(),
# Minimum-Weight Parity Factor using similar primal-dual method the blossom algorithm (https://pypi.org/project/mwpf/)
'mw_parity_factor': MwpfDecoder(),
}
74 changes: 43 additions & 31 deletions glue/sample/src/sinter/_decoding_mwpf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import pathlib
from typing import Callable, List, TYPE_CHECKING, Tuple
from typing import Callable, List, TYPE_CHECKING, Tuple, Any

import numpy as np
import stim
Expand All @@ -10,7 +10,13 @@
if TYPE_CHECKING:
import mwpf

DEFAULT_TIMEOUT: float = 10.0 # decoder timeout in seconds

def mwpf_import_error() -> ImportError:
return ImportError(
"The decoder 'MWPF' isn't installed\n"
"To fix this, install the python package 'MWPF' into your environment.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n"
)


class MwpfCompiledDecoder(CompiledDecoder):
Expand Down Expand Up @@ -57,19 +63,18 @@ class MwpfDecoder(Decoder):
"""Use MWPF to predict observables from detection events."""

def compile_decoder_for_dem(
self, *, dem: "stim.DetectorErrorModel", timeout: float = DEFAULT_TIMEOUT
self,
*,
dem: "stim.DetectorErrorModel",
decoder_cls: Any = None, # decoder class used to construct the MWPF decoder.
# in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins`
# but just provide different plugins for optimizing the primal and/or dual solutions.
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
# grows the clusters until the first valid solution appears; some more optimized solvers uses
# one or more plugins to further optimize the solution, which requires longer decoding time.
) -> CompiledDecoder:
try:
import mwpf
except ImportError as ex:
raise ImportError(
"The decoder 'MWPF' isn't installed\n"
"To fix this, install the python package 'MWPF' into your environment.\n"
"For example, if you are using pip, run `pip install MWPF`.\n"
) from ex

solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
dem, timeout=timeout
dem, decoder_cls=decoder_cls
)
return MwpfCompiledDecoder(
solver, fault_masks, dem.num_detectors, dem.num_observables
Expand All @@ -85,20 +90,11 @@ def decode_via_files(
dets_b8_in_path: pathlib.Path,
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
timeout: float = DEFAULT_TIMEOUT,
decoder_cls: Any = None,
) -> None:
try:
import mwpf
except ImportError as ex:
raise ImportError(
"The decoder 'MWPF' isn't installed\n"
"To fix this, install the python package 'MWPF' into your environment.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n"
) from ex

error_model = stim.DetectorErrorModel.from_file(dem_path)
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
error_model, timeout=timeout
error_model, decoder_cls=decoder_cls
)
num_det_bytes = math.ceil(num_dets / 8)
with open(dets_b8_in_path, "rb") as dets_in_f:
Expand Down Expand Up @@ -126,12 +122,17 @@ def decode_via_files(


class HyperUFDecoder(MwpfDecoder):
"""Setting timeout to 0 becomes effectively a hypergraph UF decoder"""

def compile_decoder_for_dem(
self, *, dem: "stim.DetectorErrorModel"
) -> CompiledDecoder:
return super().compile_decoder_for_dem(dem=dem, timeout=0.0)
try:
import mwpf
except ImportError as ex:
raise mwpf_import_error() from ex

return super().compile_decoder_for_dem(
dem=dem, decoder_cls=mwpf.SolverSerialUnionFind
)

def decode_via_files(
self,
Expand All @@ -144,6 +145,11 @@ def decode_via_files(
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
) -> None:
try:
import mwpf
except ImportError as ex:
raise mwpf_import_error() from ex

return super().decode_via_files(
num_shots=num_shots,
num_dets=num_dets,
Expand All @@ -152,7 +158,7 @@ def decode_via_files(
dets_b8_in_path=dets_b8_in_path,
obs_predictions_b8_out_path=obs_predictions_b8_out_path,
tmp_dir=tmp_dir,
timeout=0.0,
decoder_cls=mwpf.SolverSerialUnionFind,
)


Expand Down Expand Up @@ -204,11 +210,14 @@ def _helper(m: stim.DetectorErrorModel, reps: int):


def detector_error_model_to_mwpf_solver_and_fault_masks(
model: stim.DetectorErrorModel, timeout: float = DEFAULT_TIMEOUT
model: stim.DetectorErrorModel, decoder_cls: Any = None
) -> Tuple["mwpf.SolverSerialJointSingleHair", np.ndarray]:
"""Convert a stim error model into a NetworkX graph."""

import mwpf
try:
import mwpf
except ImportError as ex:
raise mwpf_import_error() from ex

num_detectors = model.num_detectors
is_detector_connected = np.full(num_detectors, False, dtype=bool)
Expand Down Expand Up @@ -256,7 +265,10 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
rescaled_edges, # Weighted edges.
)

if decoder_cls is None:
# default to the solver with highest accuracy
decoder_cls = mwpf.SolverSerialJointSingleHair
return (
mwpf.SolverSerialJointSingleHair(initializer, {"primal": {"timeout": timeout}}),
decoder_cls(initializer),
fault_masks,
)

0 comments on commit 33f25a3

Please sign in to comment.