From 7906302c7d09f2cbe7dc9595b2b65c3bebcb1ecd Mon Sep 17 00:00:00 2001
From: Emma Rosenfeld <emmarosenfeld@google.com>
Date: Wed, 27 Nov 2024 05:11:51 +0000
Subject: [PATCH] persist cirq tags thru to stim

---
 glue/cirq/stimcirq/_cirq_to_stim.py      | 106 +++++++++++++----------
 glue/cirq/stimcirq/_cirq_to_stim_test.py |  26 +++---
 2 files changed, 74 insertions(+), 58 deletions(-)

diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py
index e08abea6..abf762f4 100644
--- a/glue/cirq/stimcirq/_cirq_to_stim.py
+++ b/glue/cirq/stimcirq/_cirq_to_stim.py
@@ -1,7 +1,8 @@
 import functools
 import itertools
 import math
-from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type
+from collections.abc import Callable
+from typing import cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type
 
 import cirq
 import stim
@@ -11,7 +12,7 @@ def cirq_circuit_to_stim_circuit(
     circuit: cirq.AbstractCircuit,
     *,
     qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None,
-    custom_op_conversion_func: Callable | None = None
+    custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None
 ) -> stim.Circuit:
     """Converts a cirq circuit into an equivalent stim circuit.
 
@@ -36,6 +37,9 @@ def cirq_circuit_to_stim_circuit(
         circuit: The circuit to convert.
         qubit_to_index_dict: Optional. Which integer each qubit should get mapped to. If not specified, defaults to
             indexing qubits in the circuit in sorted order.
+        custom_op_conversion_func: Optional. A function which will transform cirq operators into other cirq operators, to be then 
+            converted to STIM. Useful in e.g. the case of non-Clifford operations in a cirq circuit, which are to be replaced
+            by Clifford operations in STIM.
 
     Returns:
         The converted circuit.
@@ -101,7 +105,7 @@ def cirq_circuit_to_stim_data(
     *,
     q2i: Optional[Dict[cirq.Qid, int]] = None,
     flatten: bool = False,
-    custom_op_conversion_func: Callable | None = None,
+    custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None,
 ) -> Tuple[stim.Circuit, List[Tuple[str, int]]]:
     """Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go."""
     if q2i is None:
@@ -145,21 +149,21 @@ def use(
     ) -> Callable[[stim.Circuit, List[int]], None]:
         if len(gates) == 1 and not individuals:
             (g,) = gates
-            return lambda c, t: c.append_operation(g, t)
+            return lambda c, t, tag: c.append(stim.CircuitInstruction(g, t, tag=tag))
 
         if not individuals:
 
-            def do(c, t):
+            def do(c, t, tag):
                 for g in gates:
-                    c.append_operation(g, t)
+                    c.append(stim.CircuitInstruction(g, t, tag=tag))
 
         else:
 
-            def do(c, t):
+            def do(c, t, tag):
                 for g in gates:
-                    c.append_operation(g, t)
+                    c.append(stim.CircuitInstruction(g, t, tag=tag))
                 for g, k in individuals:
-                    c.append_operation(g, [t[k]])
+                    c.append(stim.CircuitInstruction(g, [t[k]], tag))
 
         return do
 
@@ -251,16 +255,17 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]:
         cirq.AsymmetricDepolarizingChannel: cast(
             StimTypeHandler, _stim_append_asymmetric_depolarizing_channel
         ),
-        cirq.BitFlipChannel: lambda c, g, t: c.append_operation(
-            "X_ERROR", t, cast(cirq.BitFlipChannel, g).p
+        cirq.BitFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction(
+            "X_ERROR", t, cast(cirq.BitFlipChannel, g).p, tag=tag)
         ),
-        cirq.PhaseFlipChannel: lambda c, g, t: c.append_operation(
-            "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p
+        cirq.PhaseFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction(
+            "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p, tag=tag)
         ),
-        cirq.PhaseDampingChannel: lambda c, g, t: c.append_operation(
+        cirq.PhaseDampingChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction(
             "Z_ERROR",
             t,
             0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2,
+            tag=tag)
         ),
         cirq.RandomGateChannel: cast(StimTypeHandler, _stim_append_random_gate_channel),
         cirq.DepolarizingChannel: cast(
@@ -270,16 +275,16 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]:
 
 
 def _stim_append_measurement_gate(
-    circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int]
+    circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str
 ):
     for i, b in enumerate(gate.invert_mask):
         if b:
             targets[i] = stim.target_inv(targets[i])
-    circuit.append_operation("M", targets)
+    circuit.append(stim.CircuitInstruction("M", targets, tag=tag))
 
 
 def _stim_append_pauli_measurement_gate(
-    circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int]
+    circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int], tag: str
 ):
     obs: cirq.DensePauliString = gate.observable()
 
@@ -304,11 +309,11 @@ def _stim_append_pauli_measurement_gate(
     if obs.coefficient != 1 and obs.coefficient != -1:
         raise NotImplementedError(f"obs.coefficient={obs.coefficient!r} not in [1, -1]")
 
-    circuit.append_operation("MPP", new_targets)
+    circuit.append(stim.CircuitInstruction("MPP", new_targets, tag=tag))
 
 
 def _stim_append_spp_gate(
-    circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int]
+    circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int], tag: str
 ):
     obs: cirq.DensePauliString = gate.dense_pauli_string
     a = gate.exponent_neg
@@ -329,26 +334,26 @@ def _stim_append_spp_gate(
         return False
     new_targets.pop()
 
-    circuit.append_operation("SPP" if d == 0.5 else "SPP_DAG", new_targets)
+    circuit.append(stim.CircuitInstruction("SPP" if d == 0.5 else "SPP_DAG", new_targets, tag=tag))
     return True
 
 
 def _stim_append_dense_pauli_string_gate(
-    c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int]
+    c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int], tag: str
 ):
     gates = [None, "X", "Y", "Z"]
     for p, k in zip(g.pauli_mask, t):
         if p:
-            c.append_operation(gates[p], [k])
+            c.append(stim.CircuitInstruction(gates[p], [k], tag=tag))
 
 
 def _stim_append_asymmetric_depolarizing_channel(
-    c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int]
+    c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int], tag: str
 ):
     if cirq.num_qubits(g) == 1:
-        c.append_operation("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z])
+        c.append(stim.CircuitInstruction("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z], tag=tag))
     elif cirq.num_qubits(g) == 2:
-        c.append_operation(
+        c.append(stim.CircuitInstruction(
             "PAULI_CHANNEL_2",
             t,
             [
@@ -368,34 +373,35 @@ def _stim_append_asymmetric_depolarizing_channel(
                 g.error_probabilities.get("ZY", 0),
                 g.error_probabilities.get("ZZ", 0),
             ],
+            tag=tag)
         )
     else:
         raise NotImplementedError(f"cirq-to-stim gate {g!r}")
 
 
 def _stim_append_depolarizing_channel(
-    c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int]
+    c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int], tag: str
 ):
     if g.num_qubits() == 1:
-        c.append_operation("DEPOLARIZE1", t, g.p)
+        c.append(stim.CircuitInstruction("DEPOLARIZE1", t, g.p, tag=tag))
     elif g.num_qubits() == 2:
-        c.append_operation("DEPOLARIZE2", t, g.p)
+        c.append(stim.CircuitInstruction("DEPOLARIZE2", t, g.p, tag=tag))
     else:
         raise TypeError(f"Don't know how to turn {g!r} into Stim operations.")
 
 
-def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int]):
+def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int], tag: str):
     if isinstance(g.sub_gate, cirq.BaseDensePauliString) and g.num_controls() == 1:
         gates = [None, "CX", "CY", "CZ"]
         for p, k in zip(g.sub_gate.pauli_mask, t[1:]):
             if p:
-                c.append_operation(gates[p], [t[0], k])
+                c.append(stim.CircuitInstruction(gates[p], [t[0], k], tag=tag))
         if g.sub_gate.coefficient == 1j:
-            c.append_operation("S", t[:1])
+            c.append(stim.CircuitInstruction("S", t[:1], tag=tag))
         elif g.sub_gate.coefficient == -1:
-            c.append_operation("Z", t[:1])
+            c.append(stim.CircuitInstruction("Z", t[:1], tag=tag))
         elif g.sub_gate.coefficient == -1j:
-            c.append_operation("S_DAG", t[:1])
+            c.append(stim.CircuitInstruction("S_DAG", t[:1], tag=tag))
         elif g.sub_gate.coefficient == 1:
             pass
         else:
@@ -408,14 +414,14 @@ def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: Lis
 
 
 def _stim_append_random_gate_channel(
-    c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int]
+    c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int], tag: str
 ):
     if g.sub_gate in [cirq.X, cirq.Y, cirq.Z]:
-        c.append_operation(f"{g.sub_gate}_ERROR", t, g.probability)
+        c.append(stim.CircuitInstruction(f"{g.sub_gate}_ERROR", t, g.probability, tag=tag))
     elif isinstance(g.sub_gate, cirq.DensePauliString):
         target_p = [None, stim.target_x, stim.target_y, stim.target_z]
         pauli_targets = [target_p[p](t) for t, p in zip(t, g.sub_gate.pauli_mask) if p]
-        c.append_operation(f"CORRELATED_ERROR", pauli_targets, g.probability)
+        c.append(stim.CircuitInstruction(f"CORRELATED_ERROR", pauli_targets, g.probability, tag=tag))
     else:
         raise NotImplementedError(
             f"Don't know how to turn probabilistic {g!r} into Stim operations."
@@ -431,7 +437,7 @@ def __init__(self):
         self.flatten = False
 
     def process_circuit_operation_into_repeat_block(
-        self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable | None
+        self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None
     ) -> None:
         if self.flatten or op.repetitions == 1:
             moments = cirq.unroll_circuit_op(
@@ -451,7 +457,7 @@ def process_circuit_operation_into_repeat_block(
         )
         self.out += child.out * op.repetitions
 
-    def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable | None) -> None:
+    def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None) -> None:
         g2f = gate_to_stim_append_func()
         t2f = gate_type_to_stim_append_func()
         for op in operations:
@@ -459,6 +465,16 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con
             op = op.untagged if custom_op_conversion_func is None else custom_op_conversion_func(op)
             gate = op.gate
             targets = [self.q2i[q] for q in op.qubits]
+            if isinstance(op, cirq.TaggedOperation):
+                assert all([isinstance(tag, str) for tag in op.tags]), "I only understand str tags"
+                tag = ""
+                i = 0
+                while i < len(op.tags):
+                    tag += op.tags[i]
+                    tag += ", "
+                    i += 1
+            else:
+                tag = ""
 
             custom_method = getattr(
                 op, "_stim_conversion_", getattr(gate, "_stim_conversion_", None)
@@ -479,27 +495,27 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con
 
             # Special case measurement, because of its metadata.
             if isinstance(gate, cirq.PauliStringPhasorGate):
-                if _stim_append_spp_gate(self.out, gate, targets):
+                if _stim_append_spp_gate(self.out, gate, targets, tag):
                     continue
             if isinstance(gate, cirq.PauliMeasurementGate):
                 self.key_out.append((gate.key, len(targets)))
-                _stim_append_pauli_measurement_gate(self.out, gate, targets)
+                _stim_append_pauli_measurement_gate(self.out, gate, targets, tag)
                 continue
             if isinstance(gate, cirq.MeasurementGate):
                 self.key_out.append((gate.key, len(targets)))
-                _stim_append_measurement_gate(self.out, gate, targets)
+                _stim_append_measurement_gate(self.out, gate, targets, tag)
                 continue
 
             # Look for recognized gate values like cirq.H.
             val_append_func = g2f.get(gate)
             if val_append_func is not None:
-                val_append_func(self.out, targets)
+                val_append_func(self.out, targets, tag)
                 continue
 
             # Look for recognized gate types like cirq.DepolarizingChannel.
             type_append_func = t2f.get(type(gate))
             if type_append_func is not None:
-                type_append_func(self.out, gate, targets)
+                type_append_func(self.out, gate, targets, tag)
                 continue
 
             # Ask unrecognized operations to decompose themselves into simpler operations.
@@ -512,7 +528,7 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con
                     f"- It doesn't have a _stim_conversion_ method.\n"
                 ) from ex
 
-    def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable | None):
+    def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None):
         length_before = len(self.out)
         self.process_operations(moment, custom_op_conversion_func=custom_op_conversion_func)
 
@@ -522,6 +538,6 @@ def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callabl
         ):
             self.out.append_operation("TICK", [])
 
-    def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable | None):
+    def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None):
         for moment in moments:
             self.process_moment(moment, custom_op_conversion_func=custom_op_conversion_func)
diff --git a/glue/cirq/stimcirq/_cirq_to_stim_test.py b/glue/cirq/stimcirq/_cirq_to_stim_test.py
index fb46008e..9d32a1ff 100644
--- a/glue/cirq/stimcirq/_cirq_to_stim_test.py
+++ b/glue/cirq/stimcirq/_cirq_to_stim_test.py
@@ -411,13 +411,13 @@ def test_random_gate_channel():
 
 
 def test_stimcirq_custom_conversion():
+    """ Checks the custom operation conversion functionality. In this test, we specifically
+    convert cirq TaggedOperations with particular tag values to a given STIM operation, 
+    according to the lookup `_tag_lookup`. """
 
     _tag_lookup = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z}
 
     def _op_conversion(op: cirq.Operation) -> cirq.Operation:
-        """" For converting particular tagged cirq.Operator's to a value described by the tag content. 
-        Useful when treating non-Clifford gates in cirq and converting to STIM. 
-        """
         if isinstance(op, cirq.TaggedOperation):
             tag_checks = [tag for tag in op.tags if tag in list(_tag_lookup.keys())]
             if len(tag_checks) == 1:
@@ -437,7 +437,7 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation:
 
     stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
     assert stim_circuit == stim.Circuit(
-        """
+    """
     H 0
     X 1
     TICK
@@ -480,13 +480,13 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation:
 
     stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
     assert stim_circuit == stim.Circuit(
-        """
-        REPEAT 3 {
-                H 0
-                X 1
-                TICK
-                M 0 1
-                TICK
-            }
-        """
+    """
+    REPEAT 3 {
+            H 0
+            X 1
+            TICK
+            M 0 1
+            TICK
+        }
+    """
     )