From 7906302c7d09f2cbe7dc9595b2b65c3bebcb1ecd Mon Sep 17 00:00:00 2001 From: Emma Rosenfeld <emmarosenfeld@google.com> Date: Wed, 27 Nov 2024 05:11:51 +0000 Subject: [PATCH] persist cirq tags thru to stim --- glue/cirq/stimcirq/_cirq_to_stim.py | 106 +++++++++++++---------- glue/cirq/stimcirq/_cirq_to_stim_test.py | 26 +++--- 2 files changed, 74 insertions(+), 58 deletions(-) diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py index e08abea6..abf762f4 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim.py +++ b/glue/cirq/stimcirq/_cirq_to_stim.py @@ -1,7 +1,8 @@ import functools import itertools import math -from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type +from collections.abc import Callable +from typing import cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type import cirq import stim @@ -11,7 +12,7 @@ def cirq_circuit_to_stim_circuit( circuit: cirq.AbstractCircuit, *, qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None, - custom_op_conversion_func: Callable | None = None + custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None ) -> stim.Circuit: """Converts a cirq circuit into an equivalent stim circuit. @@ -36,6 +37,9 @@ def cirq_circuit_to_stim_circuit( circuit: The circuit to convert. qubit_to_index_dict: Optional. Which integer each qubit should get mapped to. If not specified, defaults to indexing qubits in the circuit in sorted order. + custom_op_conversion_func: Optional. A function which will transform cirq operators into other cirq operators, to be then + converted to STIM. Useful in e.g. the case of non-Clifford operations in a cirq circuit, which are to be replaced + by Clifford operations in STIM. Returns: The converted circuit. @@ -101,7 +105,7 @@ def cirq_circuit_to_stim_data( *, q2i: Optional[Dict[cirq.Qid, int]] = None, flatten: bool = False, - custom_op_conversion_func: Callable | None = None, + custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None, ) -> Tuple[stim.Circuit, List[Tuple[str, int]]]: """Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go.""" if q2i is None: @@ -145,21 +149,21 @@ def use( ) -> Callable[[stim.Circuit, List[int]], None]: if len(gates) == 1 and not individuals: (g,) = gates - return lambda c, t: c.append_operation(g, t) + return lambda c, t, tag: c.append(stim.CircuitInstruction(g, t, tag=tag)) if not individuals: - def do(c, t): + def do(c, t, tag): for g in gates: - c.append_operation(g, t) + c.append(stim.CircuitInstruction(g, t, tag=tag)) else: - def do(c, t): + def do(c, t, tag): for g in gates: - c.append_operation(g, t) + c.append(stim.CircuitInstruction(g, t, tag=tag)) for g, k in individuals: - c.append_operation(g, [t[k]]) + c.append(stim.CircuitInstruction(g, [t[k]], tag)) return do @@ -251,16 +255,17 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: cirq.AsymmetricDepolarizingChannel: cast( StimTypeHandler, _stim_append_asymmetric_depolarizing_channel ), - cirq.BitFlipChannel: lambda c, g, t: c.append_operation( - "X_ERROR", t, cast(cirq.BitFlipChannel, g).p + cirq.BitFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction( + "X_ERROR", t, cast(cirq.BitFlipChannel, g).p, tag=tag) ), - cirq.PhaseFlipChannel: lambda c, g, t: c.append_operation( - "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p + cirq.PhaseFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction( + "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p, tag=tag) ), - cirq.PhaseDampingChannel: lambda c, g, t: c.append_operation( + cirq.PhaseDampingChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction( "Z_ERROR", t, 0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2, + tag=tag) ), cirq.RandomGateChannel: cast(StimTypeHandler, _stim_append_random_gate_channel), cirq.DepolarizingChannel: cast( @@ -270,16 +275,16 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: def _stim_append_measurement_gate( - circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str ): for i, b in enumerate(gate.invert_mask): if b: targets[i] = stim.target_inv(targets[i]) - circuit.append_operation("M", targets) + circuit.append(stim.CircuitInstruction("M", targets, tag=tag)) def _stim_append_pauli_measurement_gate( - circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int], tag: str ): obs: cirq.DensePauliString = gate.observable() @@ -304,11 +309,11 @@ def _stim_append_pauli_measurement_gate( if obs.coefficient != 1 and obs.coefficient != -1: raise NotImplementedError(f"obs.coefficient={obs.coefficient!r} not in [1, -1]") - circuit.append_operation("MPP", new_targets) + circuit.append(stim.CircuitInstruction("MPP", new_targets, tag=tag)) def _stim_append_spp_gate( - circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int], tag: str ): obs: cirq.DensePauliString = gate.dense_pauli_string a = gate.exponent_neg @@ -329,26 +334,26 @@ def _stim_append_spp_gate( return False new_targets.pop() - circuit.append_operation("SPP" if d == 0.5 else "SPP_DAG", new_targets) + circuit.append(stim.CircuitInstruction("SPP" if d == 0.5 else "SPP_DAG", new_targets, tag=tag)) return True def _stim_append_dense_pauli_string_gate( - c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int] + c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int], tag: str ): gates = [None, "X", "Y", "Z"] for p, k in zip(g.pauli_mask, t): if p: - c.append_operation(gates[p], [k]) + c.append(stim.CircuitInstruction(gates[p], [k], tag=tag)) def _stim_append_asymmetric_depolarizing_channel( - c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int] + c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int], tag: str ): if cirq.num_qubits(g) == 1: - c.append_operation("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z]) + c.append(stim.CircuitInstruction("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z], tag=tag)) elif cirq.num_qubits(g) == 2: - c.append_operation( + c.append(stim.CircuitInstruction( "PAULI_CHANNEL_2", t, [ @@ -368,34 +373,35 @@ def _stim_append_asymmetric_depolarizing_channel( g.error_probabilities.get("ZY", 0), g.error_probabilities.get("ZZ", 0), ], + tag=tag) ) else: raise NotImplementedError(f"cirq-to-stim gate {g!r}") def _stim_append_depolarizing_channel( - c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int] + c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int], tag: str ): if g.num_qubits() == 1: - c.append_operation("DEPOLARIZE1", t, g.p) + c.append(stim.CircuitInstruction("DEPOLARIZE1", t, g.p, tag=tag)) elif g.num_qubits() == 2: - c.append_operation("DEPOLARIZE2", t, g.p) + c.append(stim.CircuitInstruction("DEPOLARIZE2", t, g.p, tag=tag)) else: raise TypeError(f"Don't know how to turn {g!r} into Stim operations.") -def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int]): +def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int], tag: str): if isinstance(g.sub_gate, cirq.BaseDensePauliString) and g.num_controls() == 1: gates = [None, "CX", "CY", "CZ"] for p, k in zip(g.sub_gate.pauli_mask, t[1:]): if p: - c.append_operation(gates[p], [t[0], k]) + c.append(stim.CircuitInstruction(gates[p], [t[0], k], tag=tag)) if g.sub_gate.coefficient == 1j: - c.append_operation("S", t[:1]) + c.append(stim.CircuitInstruction("S", t[:1], tag=tag)) elif g.sub_gate.coefficient == -1: - c.append_operation("Z", t[:1]) + c.append(stim.CircuitInstruction("Z", t[:1], tag=tag)) elif g.sub_gate.coefficient == -1j: - c.append_operation("S_DAG", t[:1]) + c.append(stim.CircuitInstruction("S_DAG", t[:1], tag=tag)) elif g.sub_gate.coefficient == 1: pass else: @@ -408,14 +414,14 @@ def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: Lis def _stim_append_random_gate_channel( - c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int] + c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int], tag: str ): if g.sub_gate in [cirq.X, cirq.Y, cirq.Z]: - c.append_operation(f"{g.sub_gate}_ERROR", t, g.probability) + c.append(stim.CircuitInstruction(f"{g.sub_gate}_ERROR", t, g.probability, tag=tag)) elif isinstance(g.sub_gate, cirq.DensePauliString): target_p = [None, stim.target_x, stim.target_y, stim.target_z] pauli_targets = [target_p[p](t) for t, p in zip(t, g.sub_gate.pauli_mask) if p] - c.append_operation(f"CORRELATED_ERROR", pauli_targets, g.probability) + c.append(stim.CircuitInstruction(f"CORRELATED_ERROR", pauli_targets, g.probability, tag=tag)) else: raise NotImplementedError( f"Don't know how to turn probabilistic {g!r} into Stim operations." @@ -431,7 +437,7 @@ def __init__(self): self.flatten = False def process_circuit_operation_into_repeat_block( - self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable | None + self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None ) -> None: if self.flatten or op.repetitions == 1: moments = cirq.unroll_circuit_op( @@ -451,7 +457,7 @@ def process_circuit_operation_into_repeat_block( ) self.out += child.out * op.repetitions - def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable | None) -> None: + def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None) -> None: g2f = gate_to_stim_append_func() t2f = gate_type_to_stim_append_func() for op in operations: @@ -459,6 +465,16 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con op = op.untagged if custom_op_conversion_func is None else custom_op_conversion_func(op) gate = op.gate targets = [self.q2i[q] for q in op.qubits] + if isinstance(op, cirq.TaggedOperation): + assert all([isinstance(tag, str) for tag in op.tags]), "I only understand str tags" + tag = "" + i = 0 + while i < len(op.tags): + tag += op.tags[i] + tag += ", " + i += 1 + else: + tag = "" custom_method = getattr( op, "_stim_conversion_", getattr(gate, "_stim_conversion_", None) @@ -479,27 +495,27 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con # Special case measurement, because of its metadata. if isinstance(gate, cirq.PauliStringPhasorGate): - if _stim_append_spp_gate(self.out, gate, targets): + if _stim_append_spp_gate(self.out, gate, targets, tag): continue if isinstance(gate, cirq.PauliMeasurementGate): self.key_out.append((gate.key, len(targets))) - _stim_append_pauli_measurement_gate(self.out, gate, targets) + _stim_append_pauli_measurement_gate(self.out, gate, targets, tag) continue if isinstance(gate, cirq.MeasurementGate): self.key_out.append((gate.key, len(targets))) - _stim_append_measurement_gate(self.out, gate, targets) + _stim_append_measurement_gate(self.out, gate, targets, tag) continue # Look for recognized gate values like cirq.H. val_append_func = g2f.get(gate) if val_append_func is not None: - val_append_func(self.out, targets) + val_append_func(self.out, targets, tag) continue # Look for recognized gate types like cirq.DepolarizingChannel. type_append_func = t2f.get(type(gate)) if type_append_func is not None: - type_append_func(self.out, gate, targets) + type_append_func(self.out, gate, targets, tag) continue # Ask unrecognized operations to decompose themselves into simpler operations. @@ -512,7 +528,7 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con f"- It doesn't have a _stim_conversion_ method.\n" ) from ex - def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable | None): + def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None): length_before = len(self.out) self.process_operations(moment, custom_op_conversion_func=custom_op_conversion_func) @@ -522,6 +538,6 @@ def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callabl ): self.out.append_operation("TICK", []) - def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable | None): + def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None): for moment in moments: self.process_moment(moment, custom_op_conversion_func=custom_op_conversion_func) diff --git a/glue/cirq/stimcirq/_cirq_to_stim_test.py b/glue/cirq/stimcirq/_cirq_to_stim_test.py index fb46008e..9d32a1ff 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim_test.py +++ b/glue/cirq/stimcirq/_cirq_to_stim_test.py @@ -411,13 +411,13 @@ def test_random_gate_channel(): def test_stimcirq_custom_conversion(): + """ Checks the custom operation conversion functionality. In this test, we specifically + convert cirq TaggedOperations with particular tag values to a given STIM operation, + according to the lookup `_tag_lookup`. """ _tag_lookup = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z} def _op_conversion(op: cirq.Operation) -> cirq.Operation: - """" For converting particular tagged cirq.Operator's to a value described by the tag content. - Useful when treating non-Clifford gates in cirq and converting to STIM. - """ if isinstance(op, cirq.TaggedOperation): tag_checks = [tag for tag in op.tags if tag in list(_tag_lookup.keys())] if len(tag_checks) == 1: @@ -437,7 +437,7 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation: stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) assert stim_circuit == stim.Circuit( - """ + """ H 0 X 1 TICK @@ -480,13 +480,13 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation: stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) assert stim_circuit == stim.Circuit( - """ - REPEAT 3 { - H 0 - X 1 - TICK - M 0 1 - TICK - } - """ + """ + REPEAT 3 { + H 0 + X 1 + TICK + M 0 1 + TICK + } + """ )