diff --git a/requirements.txt b/requirements.txt index 2a776083..c7960dce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ sympy>=1.9 numpy>=1.21.0 ipython ipywidgets>=8.0.5 +stim>=1.11.0 diff --git a/src/qiskit_qec/circuits/__init__.py b/src/qiskit_qec/circuits/__init__.py index 63447aaa..1a0db342 100644 --- a/src/qiskit_qec/circuits/__init__.py +++ b/src/qiskit_qec/circuits/__init__.py @@ -30,3 +30,4 @@ from .code_circuit import CodeCircuit from .repetition_code import RepetitionCodeCircuit, ArcCircuit from .surface_code import SurfaceCodeCircuit +from .css_code import CSSCodeCircuit diff --git a/src/qiskit_qec/circuits/css_code.py b/src/qiskit_qec/circuits/css_code.py new file mode 100644 index 00000000..bb5439ab --- /dev/null +++ b/src/qiskit_qec/circuits/css_code.py @@ -0,0 +1,490 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2019. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name, disable=no-name-in-module + +"""Generates circuits for CSS codes.""" +from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister +from qiskit_aer.noise import depolarizing_error, pauli_error + +from stim import Circuit as StimCircuit +from stim import target_rec as StimTarget_rec + +from qiskit_qec.utils import DecodingGraphNode +from qiskit_qec.circuits.code_circuit import CodeCircuit +from qiskit_qec.utils.stim_tools import ( + noisify_circuit, + get_stim_circuits, + detector_error_model_to_rx_graph, +) +from qiskit_qec.codes import StabSubSystemCode +from qiskit_qec.operators.pauli_list import PauliList +from qiskit_qec.linear.symplectic import normalizer +from qiskit_qec.exceptions import QiskitQECError + + +class CSSCodeCircuit(CodeCircuit): + """ + CodeCircuit class for generic CSS codes. + """ + + def __init__( + self, code, T: int, basis: str = "z", round_schedule: str = "zx", noise_model=None + ): + """ + Args: + code: A CSS code class which is either + a) StabSubSystemCode + b) a class with the following methods: + 'x_gauges' (as a list of list of qubit indices), + 'z_gauges', + 'x_stabilizers', + 'z_stabilizers', + 'logical_x', + 'logical_z', + 'n' (number of qubits), + T: Number of syndrome measurement rounds + basis: basis for encoding ('x' or 'z') + round_schedule: Order in which to measureme gauge operators ('zx' or 'xz') + noise_model: Pauli noise model used in the construction of noisy circuits. + If a tuple, a pnenomological noise model is used with the entries being + probabity of depolarizing noise on code qubits between rounds and + probability of measurement errors, respectively. + Examples: + The QuantumCircuit of a memory experiment for the distance-3 HeavyHEX code + >>> from qiskit_qec.codes.hhc import HHC + >>> from qiskit_qec.circuits.css_code import CSSCodeCircuit + >>> code = CSSCodeCircuit(HHC(3),T=3,basis='x',noise_model=(0.01,0.01),round_schedule='xz') + >>> code.circuit['0'] + """ + + super().__init__() + + self.code = code + self._get_code_properties() + self.T = T + self.basis = basis + self.base = "0" + self.round_schedule = round_schedule + self.noise_model = noise_model + self._phenom = isinstance(noise_model, tuple) + if self._phenom: + p_depol, p_meas = self.noise_model + self._depol_error = depolarizing_error(p_depol, 1) + self._meas_error = pauli_error([("X", p_meas), ("I", 1 - p_meas)]) + + circuit = {} + states = ["0", "1"] + if self._phenom: + states += ["0n", "1n"] + for state in states: + qc = QuantumCircuit() + qregs = [] + qregs.append(QuantumRegister(code.n, name="code qubits")) + qregs.append(QuantumRegister(len(self.z_gauges), name="z auxs")) + qregs.append(QuantumRegister(len(self.x_gauges), name="x auxs")) + for qreg in qregs: + qc.add_register(qreg) + self._prepare_initial_state(qc, qregs, state) + self._perform_syndrome_measurements(qc, qregs, state) + creg = ClassicalRegister(code.n, name="final_readout") + qc.add_register(creg) + self._final_readout(qc, qregs, creg, state) + circuit[state] = qc + + self.circuit = {} + self.noisy_circuit = {} + for state, qc in circuit.items(): + if state[-1] == "n" and self._phenom: + self.noisy_circuit[state[0]] = qc + else: + self.circuit[state] = qc + if noise_model and not self._phenom: + for state, qc in circuit.items(): + self.noisy_circuit[state] = noisify_circuit(qc, noise_model) + + self._gauges4stabilizers = [] + self._stabilizers = [self.x_stabilizers, self.z_stabilizers] + self._gauges = [self.x_gauges, self.z_gauges] + for j in range(2): + self._gauges4stabilizers.append([]) + for stabilizer in self._stabilizers[j]: + gauges = [] + for g, gauge in enumerate(self._gauges[j]): + if set(stabilizer).intersection(set(gauge)) == set(gauge): + gauges.append(g) + self._gauges4stabilizers[j].append(gauges) + + def _get_code_properties(self): + if isinstance(self.code, StabSubSystemCode): + is_css = True + + raw_gauges = self.code.gauge_group.generators + center, log, conj_log = normalizer(self.code.generators.matrix) + raw_stabilizers = PauliList(center) + raw_logicals = PauliList(log) + PauliList(conj_log) + + gauges = [[], []] + stabilizers = [[], []] + logicals = [[], []] + + for ( + raw_ops, + ops, + ) in zip([raw_gauges, raw_stabilizers, raw_logicals], [gauges, stabilizers, logicals]): + for op in raw_ops: + op = str(op) + for j, pauli in enumerate(["X", "Z"]): + if (op.count(pauli) + op.count("I")) == self.code.n: + ops[j].append([k for k, p in enumerate(op[::-1]) if p == pauli]) + is_css = is_css and (len(ops[0]) + len(ops[1])) == len(raw_ops) + + # extra stabilizers: the product of all others + for j in range(2): + combined = [] + for stabilizer in stabilizers[j]: + combined += stabilizer + stabilizers[j].append([]) + for q in combined: + if combined.count(q) % 2: + stabilizers[j][-1].append(q) + + if is_css: + self.x_gauges = gauges[0] + self.z_gauges = gauges[1] + self.x_stabilizers = stabilizers[0] + self.z_stabilizers = stabilizers[1] + self.logical_x = logicals[0] + self.logical_z = logicals[1] + else: + raise QiskitQECError("Code is not obviously CSS.") + + else: + # otherwise assume it has the info + self.x_gauges = self.code.x_gauges + self.z_gauges = self.code.z_gauges + self.x_stabilizers = self.code.x_stabilizers + self.z_stabilizers = self.code.z_stabilizers + self.logical_x = self.code.logical_x + self.logical_z = self.code.logical_z + + def _prepare_initial_state(self, qc, qregs, state): + if state[0] == "1": + if self.basis == "z": + qc.x(self.logical_x[0]) + else: + qc.x(self.logical_z[0]) + if self.basis == "x": + qc.h(qregs[0]) + + def _perform_syndrome_measurements(self, qc, qregs, state): + for t in range(self.T): + if state[-1] == "n" and self._phenom: + for q in qregs[0]: + qc.append(self._depol_error, [q]) + # gauge measurements + if self.round_schedule == "zx": + self._z_gauge_measurements(qc, t, state) + self._x_gauge_measurements(qc, t, state) + elif self.round_schedule == "xz": + self._x_gauge_measurements(qc, t, state) + self._z_gauge_measurements(qc, t, state) + else: + raise NotImplementedError( + "Round schedule " + self.round_schedule + " not supported." + ) + + def _final_readout(self, qc, qregs, creg, state): + if self.basis == "x": + qc.h(qregs[0]) + if state[-1] == "n" and self._phenom: + for q in qregs[0]: + qc.append(self._meas_error, [q]) + qc.measure(qregs[0], creg) + + def _z_gauge_measurements(self, qc, t, state): + creg = ClassicalRegister(len(self.z_gauges), name="round_" + str(t) + "_z_bits") + qc.add_register(creg) + for g, z_gauge in enumerate(self.z_gauges): + for q in z_gauge: + qc.cx(qc.qregs[0][q], qc.qregs[1][g]) + if state[-1] == "n" and self._phenom: + qc.append(self._meas_error, [qc.qregs[1][g]]) + qc.measure(qc.qregs[1][g], creg[g]) + qc.reset(qc.qregs[1][g]) + + def _x_gauge_measurements(self, qc, t, state): + creg = ClassicalRegister(len(self.x_gauges), name="round_" + str(t) + "_x_bits") + qc.add_register(creg) + for g, x_gauge in enumerate(self.x_gauges): + for q in x_gauge: + qc.h(qc.qregs[0][q]) + qc.cx(qc.qregs[0][q], qc.qregs[2][g]) + qc.h(qc.qregs[0][q]) + if state[-1] == "n" and self._phenom: + qc.append(self._meas_error, [qc.qregs[2][g]]) + qc.measure(qc.qregs[2][g], creg[g]) + qc.reset(qc.qregs[2][g]) + + def string2nodes(self, string, **kwargs): + """ + Convert output string from circuits into a set of nodes for + `DecodingGraph`. + Args: + string (string): Results string to convert. + kwargs (dict): Any additional keyword arguments. + logical (str): Logical value whose results are used ('0' as default). + all_logicals (bool): Whether to include logical nodes + irrespective of value. (False as default). + """ + + all_logicals = kwargs.get("all_logicals") + logical = kwargs.get("logical") + if logical is None: + logical = "0" + + output = string.split(" ")[::-1] + gauge_outs = [[], []] + for t in range(self.T): + gauge_outs[0].append( + [int(b) for b in output[2 * t + self.round_schedule.find("x")]][::-1] + ) + gauge_outs[1].append( + [int(b) for b in output[2 * t + self.round_schedule.find("z")]][::-1] + ) + final_outs = [int(b) for b in output[-1]] + + stabilizer_outs = [] + for j in range(2): + stabilizer_outs.append([]) + for t in range(self.T): + round_outs = [] + for gs in self._gauges4stabilizers[j]: + out = 0 + for g in gs: + out += gauge_outs[j][t][g] + out = out % 2 + round_outs.append(out) + stabilizer_outs[j].append(round_outs) + + bases = ["x", "z"] + j = bases.index(self.basis) + final_gauges = [] + for gauge in self._gauges[j]: + out = 0 + for q in gauge: + out += final_outs[-q - 1] + out = out % 2 + final_gauges.append(out) + final_stabilizers = [] + for gs in self._gauges4stabilizers[j]: + out = 0 + for g in gs: + out += final_gauges[g] + out = out % 2 + final_stabilizers.append(out) + stabilizer_outs[j].append(final_stabilizers) + + stabilizer_changes = [] + for j in range(2): + stabilizer_changes.append([]) + for t in range(self.T + (bases[j] == self.basis)): + stabilizer_changes[j].append([]) + for e in range(len(stabilizer_outs[j][t])): + if t == 0 and j == bases.index(self.basis): + stabilizer_changes[j][t].append(stabilizer_outs[j][t][e]) + else: + stabilizer_changes[j][t].append( + (stabilizer_outs[j][t][e] + stabilizer_outs[j][t - 1][e]) % 2 + ) + + nodes = [] + for j in range(2): + for t, round_changes in enumerate(stabilizer_changes[j]): + for e, change in enumerate(round_changes): + if change == 1: + node = DecodingGraphNode(time=t, qubits=self._stabilizers[j][e], index=e) + node.properties["basis"] = bases[j] + nodes.append(node) + + if self.basis == "x": + logicals = self.logical_x + else: + logicals = self.logical_z + + for index, logical_op in enumerate(logicals): + logical_out = 0 + for q in logical_op: + logical_out += final_outs[-q - 1] + logical_out = logical_out % 2 + + if all_logicals or str(logical_out) != logical: + node = DecodingGraphNode( + is_boundary=True, + qubits=logical, + index=index, + ) + node.properties["basis"] = self.basis + nodes.append(node) + + return nodes + + def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False): + raise NotImplementedError + + def is_cluster_neutral(self, atypical_nodes): + raise NotImplementedError + + def stim_circuit_with_detectors(self): + """Converts the qiskit circuits into stim ciruits and add detectors. + This is required for the stim-based construction of the DecodingGraph. + """ + stim_circuits, _ = get_stim_circuits(self.noisy_circuit) + measurements_per_cycle = len(self.x_gauges) + len(self.z_gauges) + + if self.round_schedule[0] == "x": + measurement_round_offset = [0, len(self.x_gauges)] + else: + measurement_round_offset = [len(self.z_gauges), 0] + + ## 0th round of measurements + if self.basis == "x": + for stabind, stabilizer in enumerate(self.x_stabilizers): + record_targets = [] + for gauge_ind in self._gauges4stabilizers[0][stabind]: + record_targets.append( + StimTarget_rec( + measurement_round_offset[0] + + gauge_ind + - (self.T * measurements_per_cycle + self.code.n) + ) + ) + qubits_and_time = stabilizer.copy() + qubits_and_time.extend([0]) + stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) + stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) + else: + for stabind, stabilizer in enumerate(self.z_stabilizers): + record_targets = [] + for gauge_ind in self._gauges4stabilizers[1][stabind]: + record_targets.append( + StimTarget_rec( + measurement_round_offset[1] + + gauge_ind + - (self.T * measurements_per_cycle + self.code.n) + ) + ) + qubits_and_time = stabilizer.copy() + qubits_and_time.extend([0]) + stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) + stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) + + # adding first x and then z stabilizer comparisons + for j in range(2): + circuit = StimCircuit() + for t in range( + 1, self.T + ): # compare stabilizer measurements with previous in each round + for gind, gs in enumerate(self._gauges4stabilizers[j]): + record_targets = [] + for gauge_ind in gs: + record_targets.append( + StimTarget_rec( + t * measurements_per_cycle + + measurement_round_offset[j] + + gauge_ind + - (self.T * measurements_per_cycle + self.code.n) + ) + ) + record_targets.append( + StimTarget_rec( + (t - 1) * measurements_per_cycle + + measurement_round_offset[j] + + gauge_ind + - (self.T * measurements_per_cycle + self.code.n) + ) + ) + qubits_and_time = self._stabilizers[j][gind].copy() + qubits_and_time.extend([t]) + circuit.append("DETECTOR", record_targets, qubits_and_time) + stim_circuits["0"] += circuit + stim_circuits["1"] += circuit + + ## final measurements + if self.basis == "x": + for stabind, stabilizer in enumerate(self.x_stabilizers): + record_targets = [] + for q in stabilizer: + record_targets.append(StimTarget_rec(q - self.code.n)) + for gauge_ind in self._gauges4stabilizers[0][stabind]: + record_targets.append( + StimTarget_rec( + measurement_round_offset[0] + + gauge_ind + - self.code.n + - measurements_per_cycle + ) + ) + qubits_and_time = stabilizer.copy() + qubits_and_time.extend([self.T]) + stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) + stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) + stim_circuits["0"].append( + "OBSERVABLE_INCLUDE", + [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_x[0])], + 0, + ) + stim_circuits["1"].append( + "OBSERVABLE_INCLUDE", + [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_x[0])], + 0, + ) + else: + for stabind, stabilizer in enumerate(self.z_stabilizers): + record_targets = [] + for q in stabilizer: + record_targets.append(StimTarget_rec(q - self.code.n)) + for gauge_ind in self._gauges4stabilizers[1][stabind]: + record_targets.append( + StimTarget_rec( + measurement_round_offset[1] + + gauge_ind + - self.code.n + - measurements_per_cycle + ) + ) + qubits_and_time = stabilizer.copy() + qubits_and_time.extend([self.T]) + stim_circuits["0"].append("DETECTOR", record_targets, qubits_and_time) + stim_circuits["1"].append("DETECTOR", record_targets, qubits_and_time) + stim_circuits["0"].append( + "OBSERVABLE_INCLUDE", + [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_z[0])], + 0, + ) + stim_circuits["1"].append( + "OBSERVABLE_INCLUDE", + [StimTarget_rec(q - self.code.n) for q in sorted(self.logical_z[0])], + 0, + ) + + return stim_circuits + + def _make_syndrome_graph(self): + stim_circuit = self.stim_circuit_with_detectors()["0"] + e = stim_circuit.detector_error_model( + decompose_errors=True, approximate_disjoint_errors=True + ) + graph, hyperedges = detector_error_model_to_rx_graph(e) + return graph, hyperedges diff --git a/src/qiskit_qec/utils/__init__.py b/src/qiskit_qec/utils/__init__.py index 8809eff3..11c4c162 100644 --- a/src/qiskit_qec/utils/__init__.py +++ b/src/qiskit_qec/utils/__init__.py @@ -31,4 +31,6 @@ """ from . import indexer, pauli_rep, visualizations + +from .stim_tools import get_counts_via_stim, get_stim_circuits, noisify_circuit from .decoding_graph_attributes import DecodingGraphNode, DecodingGraphEdge diff --git a/src/qiskit_qec/utils/decoding_graph_attributes.py b/src/qiskit_qec/utils/decoding_graph_attributes.py index 7903f10f..c650e72c 100644 --- a/src/qiskit_qec/utils/decoding_graph_attributes.py +++ b/src/qiskit_qec/utils/decoding_graph_attributes.py @@ -50,6 +50,26 @@ def __init__(self, index: int, qubits: List[int] = None, is_boundary=False, time self.index: int = index self.properties: Dict[str, Any] = {} + def __getitem__(self, key): + if key in self.__dict__: + return self.__dict__[key] + elif key in self.properties: + return self.properties[key] + else: + raise QiskitQECError( + "'" + str(key) + "'" + " is not an an attribute or property of the node." + ) + + def get(self, key, _): + """A dummy docstring.""" + return self.__getitem__(key) + + def __setitem__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self.properties[key] = value + def __eq__(self, rhs): if not isinstance(rhs, DecodingGraphNode): return NotImplemented @@ -88,9 +108,28 @@ class DecodingGraphEdge: qubits: List[int] weight: float - # TODO: Should code/decoder specific properties be accounted for when comparing edges properties: Dict[str, Any] = field(default_factory=dict) + def __getitem__(self, key): + if key in self.__dict__: + return self.__dict__[key] + elif key in self.properties: + return self.properties[key] + else: + raise QiskitQECError( + "'" + str(key) + "'" + " is not an an attribute or property of the edge." + ) + + def get(self, key, _): + """A dummy docstring.""" + return self.__getitem__(key) + + def __setitem__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self.properties[key] = value + def __eq__(self, rhs) -> bool: if not isinstance(rhs, DecodingGraphNode): return NotImplemented diff --git a/src/qiskit_qec/utils/stim_tools.py b/src/qiskit_qec/utils/stim_tools.py new file mode 100644 index 00000000..b5ca721c --- /dev/null +++ b/src/qiskit_qec/utils/stim_tools.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +# pylint: disable=invalid-name, disable=no-name-in-module + +"""Tools to use functionality from Stim.""" +from typing import Union, List, Dict +from math import log +from stim import Circuit as StimCircuit +from stim import DetectorErrorModel as StimDetectorErrorModel +from stim import DemInstruction as StimDemInstruction +from stim import DemTarget as StimDemTarget + +import numpy as np +import rustworkx as rx + +from qiskit import QuantumCircuit +from qiskit_aer.noise.errors.quantum_error import QuantumChannelInstruction +from qiskit_aer.noise import pauli_error +from qiskit_qec.utils.decoding_graph_attributes import DecodingGraphNode, DecodingGraphEdge +from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel + + +def get_stim_circuits(circuit_dict: Dict[int, QuantumCircuit]): + """Converts compatible qiskit circuits to stim circuits. + Dictionaries are not complete. For the stim definitions see: + https://github.com/quantumlib/Stim/blob/main/doc/gates.md + Args: + circuit_dict: Compatible gates are paulis, controlled paulis, h, s, + and sdg, swap, reset, measure and barrier. Compatible noise operators + correspond to a single or two qubit pauli channel. + + Returns: + stim_circuits, stim_measurement_data + """ + stim_circuits = {} + stim_measurement_data = {} + for circ_label, circuit in circuit_dict.items(): + stim_circuit = StimCircuit() + + qiskit_to_stim_dict = { + "id": "I", + "x": "X", + "y": "Y", + "z": "Z", + "h": "H", + "s": "S", + "sdg": "S_DAG", + "cx": "CNOT", + "cy": "CY", + "cz": "CZ", + "swap": "SWAP", + "reset": "R", + "measure": "M", + "barrier": "TICK", + } + pauli_error_1_stim_order = {"id": 0, "I": 0, "X": 1, "x": 1, "Y": 2, "y": 2, "Z": 3, "z": 3} + pauli_error_2_stim_order = { + "II": 0, + "IX": 1, + "IY": 2, + "IZ": 3, + "XI": 4, + "XX": 5, + "XY": 6, + "XZ": 7, + "YI": 8, + "YX": 9, + "YY": 10, + "YZ": 11, + "ZI": 12, + "ZX": 13, + "ZY": 14, + "ZZ": 15, + } + + measurement_data = [] + register_offset = {} + previous_offset = 0 + for inst, qargs, cargs in circuit.data: + for qubit in qargs: + if qubit._register.name not in register_offset: + register_offset[qubit._register.name] = previous_offset + previous_offset += qubit._register.size + + qubit_indices = [ + qargs[i]._index + register_offset[qargs[i]._register.name] + for i in range(len(qargs)) + ] + + if isinstance(inst, QuantumChannelInstruction): + qerror = inst._quantum_error + pauli_errors_types = qerror.to_dict()["instructions"] + pauli_probs = qerror.to_dict()["probabilities"] + if pauli_errors_types[0][0]["name"] in pauli_error_1_stim_order: + probs = 4 * [0.0] + for pind, ptype in enumerate(pauli_errors_types): + probs[pauli_error_1_stim_order[ptype[0]["name"]]] = pauli_probs[pind] + stim_circuit.append("PAULI_CHANNEL_1", qubit_indices, probs[1:]) + elif pauli_errors_types[0][0]["params"][0] in pauli_error_2_stim_order: + # here the name is always 'pauli' and the params gives the Pauli type + probs = 16 * [0.0] + for pind, ptype in enumerate(pauli_errors_types): + probs[pauli_error_2_stim_order[ptype[0]["params"][0]]] = pauli_probs[pind] + stim_circuit.append("PAULI_CHANNEL_2", qubit_indices, probs[1:]) + else: + raise Exception("Unexpected operations: " + str([inst, qargs, cargs])) + else: + # Gates and measurements + if inst.name in qiskit_to_stim_dict: + if len(cargs) > 0: # keeping track of measurement indices in stim + measurement_data.append( + [ + cargs[0]._index + register_offset[qargs[0]._register.name], + qargs[0]._register.name, + ] + ) + if qiskit_to_stim_dict[inst.name] == "TICK": # barrier + stim_circuit.append("TICK") + else: # gates/measurements acting on qubits + stim_circuit.append(qiskit_to_stim_dict[inst.name], qubit_indices) + else: + raise Exception("Unexpected operations: " + str([inst, qargs, cargs])) + + stim_circuits[circ_label] = stim_circuit + stim_measurement_data[circ_label] = measurement_data + return stim_circuits, stim_measurement_data + + +def get_counts_via_stim( + circuits: Union[List, QuantumCircuit], shots: int = 4000, noise_model: PauliNoiseModel = None +): + """Returns a qiskit compatible dictionary of measurement outcomes + + Args: + circuit: Qiskit circuit compatible with `get_stim_circuits` or list thereof. + shots: Number of samples to be generated. + noise_model: Pauli noise model for any additional noise to be applied. + + Returns: + counts: Counts dictionary in standard Qiskit form or list thereof. + """ + + if noise_model: + circuits = noisify_circuit(circuits, noise_model) + + single_circuit = isinstance(circuits, QuantumCircuit) + if single_circuit: + circuits = [circuits] + + counts = [] + for circuit in circuits: + stim_circuits, stim_measurement_data = get_stim_circuits({"": circuit}) + stim_circuit = stim_circuits[""] + measurement_data = stim_measurement_data[""] + + stim_samples = stim_circuit.compile_sampler().sample(shots=shots) + qiskit_counts = {} + for stim_sample in stim_samples: + prev_reg = measurement_data[-1][1] + qiskit_count = "" + for idx, meas in enumerate(measurement_data[::-1]): + _, reg = meas + if reg != prev_reg: + qiskit_count += " " + qiskit_count += str(int(stim_sample[-idx - 1])) + prev_reg = reg + if qiskit_count in qiskit_counts: + qiskit_counts[qiskit_count] += 1 + else: + qiskit_counts[qiskit_count] = 1 + counts.append(qiskit_counts) + + if single_circuit: + counts = counts[0] + + return counts + + +def detector_error_model_to_rx_graph(model: StimDetectorErrorModel) -> rx.PyGraph: + """Convert a stim error model into a RustworkX graph. + It assumes that the stim circuit does not contain repeat blocks. + Later on repeat blocks should be handled to make this function compatible with + user-defined stim circuits. + """ + + g = rx.PyGraph(multigraph=False) + + index_to_DecodingGraphNode = {} + + for instruction in model: + if instruction.type == "detector": + a = np.array(instruction.args_copy()) + time = a[-1] + qubits = [int(qubit_ind) for qubit_ind in a[:-1]] + for t in instruction.targets_copy(): + node = DecodingGraphNode(index=t.val, time=time, qubits=qubits) + index_to_DecodingGraphNode[t.val] = node + g.add_node(node) + + trivial_boundary_node = DecodingGraphNode(index=model.num_detectors, time=0, is_boundary=True) + g.add_node(trivial_boundary_node) + index_to_DecodingGraphNode[model.num_detectors] = trivial_boundary_node + + def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: Dict): + if p == 0: + return + if len(dets) == 0: + return + if len(dets) == 1: + dets = [dets[0], model.num_detectors] + if len(dets) > 2: + raise NotImplementedError( + f"Error with more than 2 symptoms can't become an edge or boundary edge: {dets!r}." + ) + if g.has_edge(dets[0], dets[1]): + edge_ind = list(g.edge_list()).index((dets[0], dets[1])) + edge_data = g.edges()[edge_ind].properties + old_p = edge_data["error_probability"] + old_frame_changes = edge_data["fault_ids"] + # If frame changes differ, the code has distance 2; just keep whichever was first. + if set(old_frame_changes) == set(frame_changes): + p = p * (1 - old_p) + old_p * (1 - p) + g.remove_edge(dets[0], dets[1]) + if p > 0.5: + p = 1 - p + if p > 0: + qubits = list( + set(index_to_DecodingGraphNode[dets[0]].qubits).intersection( + index_to_DecodingGraphNode[dets[1]].qubits + ) + ) + edge = DecodingGraphEdge( + qubits=qubits, + weight=log((1 - p) / p), + properties={"fault_ids": set(frame_changes), "error_probability": p}, + ) + g.add_edge(dets[0], dets[1], edge) + hyperedge[dets[0], dets[1]] = edge + + hyperedges = [] + + for instruction in model: + if isinstance(instruction, StimDemInstruction): + if instruction.type == "error": + dets: List[int] = [] + frames: List[int] = [] + t: StimDemTarget + p = instruction.args_copy()[0] + hyperedge = {} + for t in instruction.targets_copy(): + if t.is_relative_detector_id(): + dets.append(t.val) + elif t.is_logical_observable_id(): + frames.append(t.val) + elif t.is_separator(): + # Treat each component of a decomposed error as an independent error. + handle_error(p, dets, frames, hyperedge) + frames = [] + dets = [] + # Handle last component. + handle_error(p, dets, frames, hyperedge) + if len(hyperedge) > 1: + hyperedges.append(hyperedge) + elif instruction.type == "detector": + pass + elif instruction.type == "logical_observable": + pass + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + return g, hyperedges + + +def noisify_circuit(circuits: Union[List, QuantumCircuit], noise_model: PauliNoiseModel): + """ + Inserts error operations into a circuit according to a pauli noise model. + + Args: + circuits: Circuit or list thereof to which noise is added. + noise_model: Pauli noise model used to define types of errors to add to circuit. + + Returns: + noisy_circuits: Corresponding circuit or list thereof. + """ + + single_circuit = isinstance(circuits, QuantumCircuit) + if single_circuit: + circuits = [circuits] + + # create pauli errors for all errors in noise model + errors = {} + for g, noise in noise_model.to_dict().items(): + errors[g] = [] + for pauli, prob in noise["chan"].items(): + pauli = pauli.upper() + errors[g].append(pauli_error([(pauli, prob), ("I" * len(pauli), 1 - prob)])) + + noisy_circuits = [] + for qc in circuits: + noisy_qc = QuantumCircuit() + for qreg in qc.qregs: + noisy_qc.add_register(qreg) + for creg in qc.cregs: + noisy_qc.add_register(creg) + + for gate in qc: + g = gate[0].name + qubits = gate[1] + pre_error = g == "reset" + # add gate if it needs to go before the error + if pre_error: + noisy_qc.append(gate) + # then the error + if g in errors: + for error_op in errors[g]: + noisy_qc.append(error_op, qubits) + # add gate if it needs to go after the error + if not pre_error: + noisy_qc.append(gate) + + noisy_circuits.append(noisy_qc) + + if single_circuit: + noisy_circuits = noisy_circuits[0] + + return noisy_circuits diff --git a/test/code_circuits/test_css_codes_with_stim.py b/test/code_circuits/test_css_codes_with_stim.py new file mode 100644 index 00000000..2a6cb9b4 --- /dev/null +++ b/test/code_circuits/test_css_codes_with_stim.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Test for the CSSCodeCircuit class for Heavy-HEX code with pymatching""" +import unittest +import pymatching + +from qiskit_qec.codes.hhc import HHC +from qiskit_qec.circuits.css_code import CSSCodeCircuit +from qiskit_qec.decoders.decoding_graph import DecodingGraph + + +class TestCircuitMatcher(unittest.TestCase): + """Test for the CSSCodeCircuit class for Heavy-HEX code with pymatching""" + + def log_failure_dists(self, error_rate: float): + """Constructs the stim circuit and the decoding graph (via a stim DetectorErrorModel) + for the heavy-hex code and tests it on 10_000 samples for distance 3 and 5. + Returns the logical failure for distance 3 and 5 at the specified error rate. + Below ~100_000 shotss, the runtime is limited by the decoding graph construction, + not the sampling.""" + dist_list = [3, 5] + num_shots = 10_000 + log_fail_d = [] + for d in dist_list: + code = HHC(d) + css_code = CSSCodeCircuit(code, T=d, basis="x", noise_model=(error_rate, error_rate)) + graph = DecodingGraph(css_code).graph + m = pymatching.Matching(graph) + stim_circuit = css_code.stim_circuit_with_detectors()["0"] + stim_sampler = stim_circuit.compile_detector_sampler() + num_correct = 0 + stim_samples = stim_sampler.sample(num_shots, append_observables=True) + for sample in stim_samples: + actual_observable = sample[-1] + detectors_only = sample.copy() + detectors_only[-1] = 0 + predicted_observable = m.decode(detectors_only)[0] + num_correct += actual_observable == predicted_observable + log_fail_d.append((num_shots - num_correct) / num_shots) + return log_fail_d + + def test_HHC(self): + """Tests the order of logical failure rates for two distances. + One test is below threshold, one is above. (Threshold ~3.5% for the test code) + Runtime is approx 5 seconds.""" + error_rate = 0.01 # this should be below threshold + log_fail_dist = self.log_failure_dists(error_rate) + self.assertTrue(log_fail_dist[0] > 2 * log_fail_dist[1]) + + error_rate = 0.1 # this should be above threshold + log_fail_dist = self.log_failure_dists(error_rate) + self.assertTrue(log_fail_dist[0] < log_fail_dist[1]) + + +if __name__ == "__main__": + unittest.main()