Skip to content

Commit

Permalink
add cluster_node_limit for mwpf decoder to better tune decoding tim…
Browse files Browse the repository at this point in the history
…e and accuracy
  • Loading branch information
yuewuo committed Nov 20, 2024
1 parent 7985ebb commit 2d907bd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ jobs:
- run: bazel build :stim_dev_wheel
- run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl
- run: pip install -e glue/sample
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.5
- run: pytest glue/sample
- run: dev/doctest_proper.py --module sinter
- run: sinter help
Expand Down
37 changes: 24 additions & 13 deletions glue/sample/src/sinter/_decoding_mwpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError:
return ImportError(
"The decoder 'MWPF' isn't installed\n"
"To fix this, install the python package 'MWPF' into your environment.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n"
"For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n"
)


Expand Down Expand Up @@ -75,12 +75,18 @@ def compile_decoder_for_dem(
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
# grows the clusters until the first valid solution appears; some more optimized solvers uses
# one or more plugins to further optimize the solution, which requires longer decoding time.
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster.
) -> CompiledDecoder:
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
dem, decoder_cls=decoder_cls
dem,
decoder_cls=decoder_cls,
cluster_node_limit=cluster_node_limit,
)
return MwpfCompiledDecoder(
solver, fault_masks, dem.num_detectors, dem.num_observables
solver,
fault_masks,
dem.num_detectors,
dem.num_observables,
)

def decode_via_files(
Expand Down Expand Up @@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
def deduplicate_hyperedges(
hyperedges: List[Tuple[List[int], float, int]]
) -> List[Tuple[List[int], float, int]]:
indices: dict[frozenset[int], int] = dict()
indices: dict[frozenset[int], Tuple[int, float]] = dict()
result: List[Tuple[List[int], float, int]] = []
for dets, weight, mask in hyperedges:
dets_set = frozenset(dets)
if dets_set in indices:
idx = indices[dets_set]
idx, min_weight = indices[dets_set]
p1 = 1 / (1 + math.exp(weight))
p2 = 1 / (1 + math.exp(result[idx][1]))
p = p1 * (1 - p2) + p2 * (1 - p1)
# not sure why would this fail? two hyperedges with different masks?
# assert mask == result[idx][2], (result[idx], (dets, weight, mask))
result[idx] = (dets, math.log((1 - p) / p), result[idx][2])
# choosing the mask from the most likely error
new_mask = result[idx][2]
if weight < min_weight:
indices[dets_set] = (idx, weight)
new_mask = mask
result[idx] = (dets, math.log((1 - p) / p), new_mask)
else:
indices[dets_set] = len(result)
indices[dets_set] = (len(result), weight)
result.append((dets, weight, mask))
return result


def detector_error_model_to_mwpf_solver_and_fault_masks(
model: stim.DetectorErrorModel, decoder_cls: Any = None
model: stim.DetectorErrorModel,
decoder_cls: Any = None,
cluster_node_limit: int = 50,
) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]:
"""Convert a stim error model into a NetworkX graph."""

Expand All @@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]):
# Accept it and keep going, though of course decoding will probably perform terribly.
return
if p > 0.5:
# mwpf doesn't support negative edge weights.
# mwpf doesn't support negative edge weights (yet, will be supported in the next version).
# approximate them as weight 0.
p = 0.5
weight = math.log((1 - p) / p)
Expand All @@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
# mwpf package panic on duplicate edges, thus we need to handle them here
hyperedges = deduplicate_hyperedges(hyperedges)

# fix the input by connecting an edge to all isolated vertices
# fix the input by connecting an edge to all isolated vertices; will be supported in the next version
for idx in range(num_detectors):
if not is_detector_connected[idx]:
hyperedges.append(([idx], 0, 0))
Expand All @@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
decoder_cls = mwpf.SolverSerialJointSingleHair
return (
(
decoder_cls(initializer)
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
if num_detectors > 0 and len(rescaled_edges) > 0
else None
),
Expand Down

0 comments on commit 2d907bd

Please sign in to comment.