diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index d33e4840..9b104329 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -831,6 +831,8 @@ def append( name: str, targets: Union[int, stim.GateTarget, Iterable[Union[int, stim.GateTarget]]], arg: Union[float, Iterable[float]], + *, + tag: str = "", ) -> None: pass @overload @@ -844,6 +846,8 @@ def append( name: object, targets: object = (), arg: object = None, + *, + tag: str = '', ) -> None: """Appends an operation into the circuit. @@ -894,6 +898,7 @@ def append( compatibility reasons, `cirq.append_operation` (but not `cirq.append`) will default to a single 0.0 argument for gates that take exactly one argument. + tag: A customizable string attached to the instruction. """ ``` diff --git a/doc/stim.pyi b/doc/stim.pyi index b9e1c185..059bc23e 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -303,6 +303,8 @@ class Circuit: name: str, targets: Union[int, stim.GateTarget, Iterable[Union[int, stim.GateTarget]]], arg: Union[float, Iterable[float]], + *, + tag: str = "", ) -> None: pass @overload @@ -316,6 +318,8 @@ class Circuit: name: object, targets: object = (), arg: object = None, + *, + tag: str = '', ) -> None: """Appends an operation into the circuit. @@ -366,6 +370,7 @@ class Circuit: compatibility reasons, `cirq.append_operation` (but not `cirq.append`) will default to a single 0.0 argument for gates that take exactly one argument. + tag: A customizable string attached to the instruction. """ def append_from_stim_program_text( self, @@ -398,6 +403,8 @@ class Circuit: name: object, targets: object = (), arg: object = None, + *, + tag: str = '', ) -> None: """[DEPRECATED] use stim.Circuit.append instead """ diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py index 248dcd53..8b7df855 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim.py +++ b/glue/cirq/stimcirq/_cirq_to_stim.py @@ -7,8 +7,18 @@ import stim +def _forward_single_str_tag(op: cirq.CircuitOperation) -> str: + tags = [tag for tag in op.tags if isinstance(tag, str)] + if len(tags) == 1: + return tags[0] + return "" + + def cirq_circuit_to_stim_circuit( - circuit: cirq.AbstractCircuit, *, qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None + circuit: cirq.AbstractCircuit, + *, + qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None, + tag_func: Callable[[cirq.Operation], str] = _forward_single_str_tag, ) -> stim.Circuit: """Converts a cirq circuit into an equivalent stim circuit. @@ -33,6 +43,10 @@ def cirq_circuit_to_stim_circuit( circuit: The circuit to convert. qubit_to_index_dict: Optional. Which integer each qubit should get mapped to. If not specified, defaults to indexing qubits in the circuit in sorted order. + tag_func: Controls the tag attached to the stim instructions the cirq operation turns + into. If not specified, defaults to checking for string tags on the circuit operation + and if there is exactly one string tag then using that tag (otherwise not specifying a + tag). Returns: The converted circuit. @@ -85,21 +99,29 @@ def _stim_conversion_( # The indices of qubits the gate is operating on. targets: List[int], + # A custom string associated with the operation, which can be tagged + # onto any operations appended to the stim circuit. + tag: str, + # Forward compatibility with future arguments. **kwargs): edit_circuit.append_operation("H", targets) """ - return cirq_circuit_to_stim_data(circuit, q2i=qubit_to_index_dict, flatten=False)[0] + return cirq_circuit_to_stim_data(circuit, q2i=qubit_to_index_dict, flatten=False, tag_func=tag_func)[0] def cirq_circuit_to_stim_data( - circuit: cirq.AbstractCircuit, *, q2i: Optional[Dict[cirq.Qid, int]] = None, flatten: bool = False, + circuit: cirq.AbstractCircuit, + *, + q2i: Optional[Dict[cirq.Qid, int]] = None, + flatten: bool = False, + tag_func: Callable[[cirq.Operation], str] = _forward_single_str_tag, ) -> Tuple[stim.Circuit, List[Tuple[str, int]]]: """Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go.""" if q2i is None: q2i = {q: i for i, q in enumerate(sorted(circuit.all_qubits()))} - helper = CirqToStimHelper() + helper = CirqToStimHelper(tag_func=tag_func) helper.q2i = q2i helper.flatten = flatten @@ -115,11 +137,11 @@ def cirq_circuit_to_stim_data( return helper.out, helper.key_out -StimTypeHandler = Callable[[stim.Circuit, cirq.Gate, List[int]], None] +StimTypeHandler = Callable[[stim.Circuit, cirq.Gate, List[int], str], None] @functools.lru_cache(maxsize=1) -def gate_to_stim_append_func() -> Dict[cirq.Gate, Callable[[stim.Circuit, List[int]], None]]: +def gate_to_stim_append_func() -> Dict[cirq.Gate, Callable[[stim.Circuit, List[int], str], None]]: """A dictionary mapping specific gate instances to stim circuit appending functions.""" x = (cirq.X, False) y = (cirq.Y, False) @@ -128,29 +150,29 @@ def gate_to_stim_append_func() -> Dict[cirq.Gate, Callable[[stim.Circuit, List[i ny = (cirq.Y, True) nz = (cirq.Z, True) - def do_nothing(c, t): + def do_nothing(_gates, _targets, tag): pass def use( *gates: str, individuals: Sequence[Tuple[str, int]] = () - ) -> Callable[[stim.Circuit, List[int]], None]: + ) -> Callable[[stim.Circuit, List[int], str], None]: if len(gates) == 1 and not individuals: (g,) = gates - return lambda c, t: c.append_operation(g, t) + return lambda c, t, tag: c.append(g, t, tag=tag) if not individuals: - def do(c, t): + def do(c, t, tag: str): for g in gates: - c.append_operation(g, t) + c.append(g, t, tag=tag) else: - def do(c, t): + def do(c, t, tag: str): for g in gates: - c.append_operation(g, t) + c.append(g, t, tag=tag) for g, k in individuals: - c.append_operation(g, [t[k]]) + c.append(g, [t[k]], tag=tag) return do @@ -238,14 +260,14 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: cirq.AsymmetricDepolarizingChannel: cast( StimTypeHandler, _stim_append_asymmetric_depolarizing_channel ), - cirq.BitFlipChannel: lambda c, g, t: c.append_operation( - "X_ERROR", t, cast(cirq.BitFlipChannel, g).p + cirq.BitFlipChannel: lambda c, g, t, tag: c.append( + "X_ERROR", t, cast(cirq.BitFlipChannel, g).p, tag=tag ), - cirq.PhaseFlipChannel: lambda c, g, t: c.append_operation( - "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p + cirq.PhaseFlipChannel: lambda c, g, t, tag: c.append( + "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p, tag=tag ), - cirq.PhaseDampingChannel: lambda c, g, t: c.append_operation( - "Z_ERROR", t, 0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2 + cirq.PhaseDampingChannel: lambda c, g, t, tag: c.append( + "Z_ERROR", t, 0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2, tag=tag ), cirq.RandomGateChannel: cast(StimTypeHandler, _stim_append_random_gate_channel), cirq.DepolarizingChannel: cast(StimTypeHandler, _stim_append_depolarizing_channel), @@ -253,16 +275,16 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: def _stim_append_measurement_gate( - circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str ): for i, b in enumerate(gate.invert_mask): if b: targets[i] = stim.target_inv(targets[i]) - circuit.append_operation("M", targets) + circuit.append("M", targets, tag=tag) def _stim_append_pauli_measurement_gate( - circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int], tag: str ): obs: cirq.DensePauliString = gate.observable() @@ -287,11 +309,11 @@ def _stim_append_pauli_measurement_gate( if obs.coefficient != 1 and obs.coefficient != -1: raise NotImplementedError(f"obs.coefficient={obs.coefficient!r} not in [1, -1]") - circuit.append_operation("MPP", new_targets) + circuit.append("MPP", new_targets, tag=tag) def _stim_append_spp_gate( - circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int], tag: str ): obs: cirq.DensePauliString = gate.dense_pauli_string a = gate.exponent_neg @@ -312,26 +334,26 @@ def _stim_append_spp_gate( return False new_targets.pop() - circuit.append_operation("SPP" if d == 0.5 else "SPP_DAG", new_targets) + circuit.append("SPP" if d == 0.5 else "SPP_DAG", new_targets, tag=tag) return True def _stim_append_dense_pauli_string_gate( - c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int] + c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int], tag: str ): gates = [None, "X", "Y", "Z"] for p, k in zip(g.pauli_mask, t): if p: - c.append_operation(gates[p], [k]) + c.append(gates[p], [k], tag=tag) def _stim_append_asymmetric_depolarizing_channel( - c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int] + c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int], tag: str ): if cirq.num_qubits(g) == 1: - c.append_operation("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z]) + c.append("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z], tag=tag) elif cirq.num_qubits(g) == 2: - c.append_operation( + c.append( "PAULI_CHANNEL_2", t, [ @@ -350,33 +372,34 @@ def _stim_append_asymmetric_depolarizing_channel( g.error_probabilities.get('ZX', 0), g.error_probabilities.get('ZY', 0), g.error_probabilities.get('ZZ', 0), - ] + ], + tag=tag, ) else: raise NotImplementedError(f'cirq-to-stim gate {g!r}') -def _stim_append_depolarizing_channel(c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int]): +def _stim_append_depolarizing_channel(c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int], tag: str): if g.num_qubits() == 1: - c.append_operation("DEPOLARIZE1", t, g.p) + c.append("DEPOLARIZE1", t, g.p, tag=tag) elif g.num_qubits() == 2: - c.append_operation("DEPOLARIZE2", t, g.p) + c.append("DEPOLARIZE2", t, g.p, tag=tag) else: raise TypeError(f"Don't know how to turn {g!r} into Stim operations.") -def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int]): +def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int], tag: str): if isinstance(g.sub_gate, cirq.BaseDensePauliString) and g.num_controls() == 1: gates = [None, "CX", "CY", "CZ"] for p, k in zip(g.sub_gate.pauli_mask, t[1:]): if p: - c.append_operation(gates[p], [t[0], k]) + c.append(gates[p], [t[0], k], tag=tag) if g.sub_gate.coefficient == 1j: - c.append_operation("S", t[:1]) + c.append("S", t[:1], tag=tag) elif g.sub_gate.coefficient == -1: - c.append_operation("Z", t[:1]) + c.append("Z", t[:1], tag=tag) elif g.sub_gate.coefficient == -1j: - c.append_operation("S_DAG", t[:1]) + c.append("S_DAG", t[:1], tag=tag) elif g.sub_gate.coefficient == 1: pass else: @@ -386,13 +409,13 @@ def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: Lis raise TypeError(f"Don't know how to turn controlled gate {g!r} into Stim operations.") -def _stim_append_random_gate_channel(c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int]): +def _stim_append_random_gate_channel(c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int], tag: str): if g.sub_gate in [cirq.X, cirq.Y, cirq.Z]: - c.append_operation(f"{g.sub_gate}_ERROR", t, g.probability) + c.append(f"{g.sub_gate}_ERROR", t, g.probability, tag=tag) elif isinstance(g.sub_gate, cirq.DensePauliString): target_p = [None, stim.target_x, stim.target_y, stim.target_z] pauli_targets = [target_p[p](t) for t, p in zip(t, g.sub_gate.pauli_mask) if p] - c.append_operation(f"CORRELATED_ERROR", pauli_targets, g.probability) + c.append(f"CORRELATED_ERROR", pauli_targets, g.probability, tag=tag) else: raise NotImplementedError( f"Don't know how to turn probabilistic {g!r} into Stim operations." @@ -400,33 +423,35 @@ def _stim_append_random_gate_channel(c: stim.Circuit, g: cirq.RandomGateChannel, class CirqToStimHelper: - def __init__(self): + def __init__(self, tag_func: Callable[[cirq.Operation], str]): self.key_out: List[Tuple[str, int]] = [] self.out = stim.Circuit() self.q2i = {} self.have_seen_loop = False self.flatten = False + self.tag_func = tag_func - def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation) -> None: + def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation, tag: str) -> None: if self.flatten or op.repetitions == 1: moments = cirq.unroll_circuit_op(cirq.Circuit(op), deep=False, tags_to_check=None).moments self.process_moments(moments) self.out = self.out[:-1] # Remove a trailing TICK (to avoid double TICK) return - child = CirqToStimHelper() + child = CirqToStimHelper(tag_func=self.tag_func) child.key_out = self.key_out child.q2i = self.q2i child.have_seen_loop = True self.have_seen_loop = True child.process_moments(op.transform_qubits(lambda q: op.qubit_map.get(q, q)).circuit) - self.out += child.out * op.repetitions + self.out.append(stim.CircuitRepeatBlock(op.repetitions, child.out, tag=tag)) def process_operations(self, operations: Iterable[cirq.Operation]) -> None: g2f = gate_to_stim_append_func() t2f = gate_type_to_stim_append_func() for op in operations: assert isinstance(op, cirq.Operation) + tag = self.tag_func(op) op = op.untagged gate = op.gate targets = [self.q2i[q] for q in op.qubits] @@ -441,36 +466,37 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None: edit_measurement_key_lengths=self.key_out, targets=targets, have_seen_loop=self.have_seen_loop, + tag=tag, ) continue if isinstance(op, cirq.CircuitOperation): - self.process_circuit_operation_into_repeat_block(op) + self.process_circuit_operation_into_repeat_block(op, tag=tag) continue # Special case measurement, because of its metadata. if isinstance(gate, cirq.PauliStringPhasorGate): - if _stim_append_spp_gate(self.out, gate, targets): + if _stim_append_spp_gate(self.out, gate, targets, tag=tag): continue if isinstance(gate, cirq.PauliMeasurementGate): self.key_out.append((gate.key, len(targets))) - _stim_append_pauli_measurement_gate(self.out, gate, targets) + _stim_append_pauli_measurement_gate(self.out, gate, targets, tag=tag) continue if isinstance(gate, cirq.MeasurementGate): self.key_out.append((gate.key, len(targets))) - _stim_append_measurement_gate(self.out, gate, targets) + _stim_append_measurement_gate(self.out, gate, targets, tag=tag) continue # Look for recognized gate values like cirq.H. val_append_func = g2f.get(gate) if val_append_func is not None: - val_append_func(self.out, targets) + val_append_func(self.out, targets, tag=tag) continue # Look for recognized gate types like cirq.DepolarizingChannel. type_append_func = t2f.get(type(gate)) if type_append_func is not None: - type_append_func(self.out, gate, targets) + type_append_func(self.out, gate, targets, tag=tag) continue # Ask unrecognized operations to decompose themselves into simpler operations. @@ -489,7 +515,7 @@ def process_moment(self, moment: cirq.Moment): # Append a TICK, unless it was already handled by an internal REPEAT block. if length_before == len(self.out) or not isinstance(self.out[-1], stim.CircuitRepeatBlock): - self.out.append_operation("TICK", []) + self.out.append("TICK", []) def process_moments(self, moments: Iterable[cirq.Moment]): for moment in moments: diff --git a/glue/cirq/stimcirq/_cirq_to_stim_test.py b/glue/cirq/stimcirq/_cirq_to_stim_test.py index 65f3f1d2..977d12d9 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim_test.py +++ b/glue/cirq/stimcirq/_cirq_to_stim_test.py @@ -74,21 +74,21 @@ def assert_unitary_gate_converts_correctly(gate: cirq.Gate): # If the gate is translated correctly, the measurement will always be zero. c = stim.Circuit() - c.append_operation("H", range(n)) + c.append("H", range(n)) for i in range(n): - c.append_operation("CNOT", [i, i + n]) - c.append_operation("H", [2 * n]) + c.append("CNOT", [i, i + n]) + c.append("H", [2 * n]) for q, p in pre.items(): - c.append_operation(f"C{p}", [2 * n, q.x]) + c.append(f"C{p}", [2 * n, q.x]) qs = cirq.LineQubit.range(n) conv_gate, _ = cirq_circuit_to_stim_data(cirq.Circuit(gate(*qs)), q2i={q: q.x for q in qs}) c += conv_gate for q, p in post.items(): - c.append_operation(f"C{p}", [2 * n, q.x]) + c.append(f"C{p}", [2 * n, q.x]) if post.coefficient == -1: - c.append_operation("Z", [2 * n]) - c.append_operation("H", [2 * n]) - c.append_operation("M", [2 * n]) + c.append("Z", [2 * n]) + c.append("H", [2 * n]) + c.append("M", [2 * n]) correct = np.count_nonzero(c.compile_sampler().sample_bit_packed(10)) == 0 assert correct, f"{gate!r} failed to turn {pre} into {post}.\nConverted to:\n{conv_gate}\n" @@ -254,13 +254,13 @@ def _stim_conversion_( **kwargs, ): edit_measurement_key_lengths.append(("custom", 2)) - edit_circuit.append_operation("M", [stim.target_inv(targets[0])]) - edit_circuit.append_operation("M", [targets[0]]) - edit_circuit.append_operation("DETECTOR", [stim.target_rec(-1)]) + edit_circuit.append("M", [stim.target_inv(targets[0])]) + edit_circuit.append("M", [targets[0]]) + edit_circuit.append("DETECTOR", [stim.target_rec(-1)]) class SecondLastMeasurementWasDeterministicOperation(cirq.Operation): - def _stim_conversion_(self, edit_circuit: stim.Circuit, **kwargs): - edit_circuit.append_operation("DETECTOR", [stim.target_rec(-2)]) + def _stim_conversion_(self, edit_circuit: stim.Circuit, tag: str, **kwargs): + edit_circuit.append("DETECTOR", [stim.target_rec(-2)], tag=tag) def with_qubits(self, *new_qubits): raise NotImplementedError() @@ -392,3 +392,21 @@ def test_random_gate_channel(): E(0.25) X1 TICK """) + + +def test_custom_tagging(): + assert stimcirq.cirq_circuit_to_stim_circuit( + cirq.Circuit( + cirq.X(cirq.LineQubit(0)).with_tags('test'), + cirq.X(cirq.LineQubit(0)).with_tags((2, 3, 4)), + cirq.H(cirq.LineQubit(0)).with_tags('a', 'b'), + ), + tag_func=lambda op: "PAIR" if len(op.tags) == 2 else repr(op.tags), + ) == stim.Circuit(""" + X[('test',)] 0 + TICK + X[((2, 3, 4),)] 0 + TICK + H[PAIR] 0 + TICK + """) diff --git a/glue/cirq/stimcirq/_cx_swap_gate.py b/glue/cirq/stimcirq/_cx_swap_gate.py index 1b523b6f..947c4313 100644 --- a/glue/cirq/stimcirq/_cx_swap_gate.py +++ b/glue/cirq/stimcirq/_cx_swap_gate.py @@ -35,8 +35,8 @@ def _decompose_(self, qubits): yield cirq.CNOT(a, b) yield cirq.SWAP(a, b) - def _stim_conversion_(self, edit_circuit: stim.Circuit, targets: List[int], **kwargs): - edit_circuit.append_operation('SWAPCX' if self.inverted else 'CXSWAP', targets) + def _stim_conversion_(self, edit_circuit: stim.Circuit, targets: List[int], tag: str, **kwargs): + edit_circuit.append('SWAPCX' if self.inverted else 'CXSWAP', targets, tag=tag) def __pow__(self, power: int) -> 'CXSwapGate': if power == +1: diff --git a/glue/cirq/stimcirq/_cz_swap_gate.py b/glue/cirq/stimcirq/_cz_swap_gate.py index 6f9ded68..b8c8eca4 100644 --- a/glue/cirq/stimcirq/_cz_swap_gate.py +++ b/glue/cirq/stimcirq/_cz_swap_gate.py @@ -22,8 +22,8 @@ def _decompose_(self, qubits): yield cirq.SWAP(a, b) yield cirq.CZ(a, b) - def _stim_conversion_(self, edit_circuit: stim.Circuit, targets: List[int], **kwargs): - edit_circuit.append_operation('CZSWAP', targets) + def _stim_conversion_(self, edit_circuit: stim.Circuit, targets: List[int], tag: str, **kwargs): + edit_circuit.append('CZSWAP', targets, tag=tag) def __pow__(self, power: int) -> 'CZSwapGate': if power == +1: diff --git a/glue/cirq/stimcirq/_det_annotation.py b/glue/cirq/stimcirq/_det_annotation.py index b03dec27..5de0c760 100644 --- a/glue/cirq/stimcirq/_det_annotation.py +++ b/glue/cirq/stimcirq/_det_annotation.py @@ -77,8 +77,10 @@ def _is_comment_(self) -> bool: def _stim_conversion_( self, + *, edit_circuit: stim.Circuit, edit_measurement_key_lengths: List[Tuple[str, int]], + tag: str, have_seen_loop: bool = False, **kwargs, ): @@ -111,4 +113,4 @@ def _stim_conversion_( f" in an earlier moment (or earlier in the same moment's operation order)." ) - edit_circuit.append_operation("DETECTOR", rec_targets, self.coordinate_metadata) + edit_circuit.append("DETECTOR", rec_targets, self.coordinate_metadata, tag=tag) diff --git a/glue/cirq/stimcirq/_measure_and_or_reset_gate.py b/glue/cirq/stimcirq/_measure_and_or_reset_gate.py index 8960168a..4e4aaf59 100644 --- a/glue/cirq/stimcirq/_measure_and_or_reset_gate.py +++ b/glue/cirq/stimcirq/_measure_and_or_reset_gate.py @@ -91,15 +91,15 @@ def _stim_op_name(self) -> str: result += self.basis return result - def _stim_conversion_(self, edit_circuit: stim.Circuit, targets: List[int], **kwargs): + def _stim_conversion_(self, *, edit_circuit: stim.Circuit, targets: List[int], tag: str, **kwargs): if self.invert_measure: targets[0] = stim.target_inv(targets[0]) if self.measure_flip_probability: edit_circuit.append_operation( - self._stim_op_name(), targets, self.measure_flip_probability + self._stim_op_name(), targets, self.measure_flip_probability, tag=tag ) else: - edit_circuit.append_operation(self._stim_op_name(), targets) + edit_circuit.append_operation(self._stim_op_name(), targets, tag=tag) def __str__(self) -> str: result = self._stim_op_name() diff --git a/glue/cirq/stimcirq/_obs_annotation.py b/glue/cirq/stimcirq/_obs_annotation.py index fe068455..c0fa246d 100644 --- a/glue/cirq/stimcirq/_obs_annotation.py +++ b/glue/cirq/stimcirq/_obs_annotation.py @@ -75,9 +75,11 @@ def _is_comment_(self) -> bool: def _stim_conversion_( self, + *, edit_circuit: stim.Circuit, edit_measurement_key_lengths: List[Tuple[str, int]], have_seen_loop: bool = False, + tag: str, **kwargs, ): # Ideally these references would all be resolved ahead of time, to avoid the redundant @@ -109,4 +111,4 @@ def _stim_conversion_( f" in an earlier moment (or earlier in the same moment's operation order)." ) - edit_circuit.append_operation("OBSERVABLE_INCLUDE", rec_targets, self.observable_index) + edit_circuit.append("OBSERVABLE_INCLUDE", rec_targets, self.observable_index, tag=tag) diff --git a/glue/cirq/stimcirq/_shift_coords_annotation.py b/glue/cirq/stimcirq/_shift_coords_annotation.py index 4fc886ae..06ba4b68 100644 --- a/glue/cirq/stimcirq/_shift_coords_annotation.py +++ b/glue/cirq/stimcirq/_shift_coords_annotation.py @@ -49,5 +49,5 @@ def _decompose_(self): def _is_comment_(self) -> bool: return True - def _stim_conversion_(self, edit_circuit: stim.Circuit, **kwargs): - edit_circuit.append_operation("SHIFT_COORDS", [], self.shift) + def _stim_conversion_(self, *, edit_circuit: stim.Circuit, tag: str, **kwargs): + edit_circuit.append_operation("SHIFT_COORDS", [], self.shift, tag=tag) diff --git a/glue/cirq/stimcirq/_stim_to_cirq.py b/glue/cirq/stimcirq/_stim_to_cirq.py index d97d1aa5..1be3ba7c 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq.py +++ b/glue/cirq/stimcirq/_stim_to_cirq.py @@ -84,8 +84,12 @@ def process_gate_instruction( m = cirq.num_qubits(gate) if not all(t.is_qubit_target for t in targets) or len(targets) % m != 0: raise NotImplementedError(f"instruction={instruction!r}") + if instruction.tag: + tags = [instruction.tag] + else: + tags = () for k in range(0, len(targets), m): - self.append_operation(gate(*[cirq.LineQubit(t.value) for t in targets[k : k + m]])) + self.append_operation(gate(*[cirq.LineQubit(t.value) for t in targets[k : k + m]]).with_tags(*tags)) def process_tick(self, instruction: stim.CircuitInstruction) -> None: self.full_circuit += self.tick_circuit or cirq.Moment() @@ -128,7 +132,7 @@ def process_pauli_channel_2(self, instruction: stim.CircuitInstruction) -> None: self.process_gate_instruction(gate, instruction) def process_repeat_block(self, block: stim.CircuitRepeatBlock): - if self.flatten or block.repeat_count == 1: + if self.flatten or (block.repeat_count == 1 and block.tag == ""): self.process_circuit(block.repeat_count, block.body_copy()) return @@ -141,6 +145,10 @@ def process_repeat_block(self, block: stim.CircuitRepeatBlock): child.process_circuit(1, block.body_copy()) # Circuit operation will always be in their own cirq.Moment + if block.tag == "": + tags = () + else: + tags = (block.tag,) if len(self.tick_circuit): self.full_circuit += self.tick_circuit self.full_circuit += cirq.Moment( @@ -148,7 +156,7 @@ def process_repeat_block(self, block: stim.CircuitRepeatBlock): cirq.FrozenCircuit(child.full_circuit + child.tick_circuit), repetitions=block.repeat_count, use_repetition_ids=False, - ) + ).with_tags(*tags) ) self.tick_circuit = cirq.Circuit() @@ -168,6 +176,10 @@ def process_measurement_instruction( flip_probability = args[0] targets: List[stim.GateTarget] = instruction.targets_copy() + if instruction.tag: + tags = [instruction.tag] + else: + tags = () for t in targets: if not t.is_qubit_target: raise NotImplementedError(f"instruction={instruction!r}") @@ -180,7 +192,7 @@ def process_measurement_instruction( invert_measure=t.is_inverted_result_target, key=key, measure_flip_probability=flip_probability, - ).resolve(cirq.LineQubit(t.value)) + ).resolve(cirq.LineQubit(t.value)).with_tags(*tags) ) def process_circuit(self, repetitions: int, circuit: stim.Circuit) -> None: @@ -219,6 +231,10 @@ def process_mpp(self, instruction: stim.CircuitInstruction) -> None: raise NotImplementedError("Noisy MPP") targets: List[stim.GateTarget] = instruction.targets_copy() + if instruction.tag: + tags = [instruction.tag] + else: + tags = () start = 0 while start < len(targets): next_start = start + 1 @@ -230,13 +246,17 @@ def process_mpp(self, instruction: stim.CircuitInstruction) -> None: obs = _stim_targets_to_dense_pauli_string(group) qubits = [cirq.LineQubit(t.value) for t in group] key = str(self.get_next_measure_id()) - self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits)) + self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags)) def process_spp_dag(self, instruction: stim.CircuitInstruction) -> None: self.process_spp(instruction, dag=True) def process_spp(self, instruction: stim.CircuitInstruction, dag: bool = False) -> None: targets: List[stim.GateTarget] = instruction.targets_copy() + if instruction.tag: + tags = [instruction.tag] + else: + tags = () start = 0 while start < len(targets): next_start = start + 1 @@ -250,13 +270,17 @@ def process_spp(self, instruction: stim.CircuitInstruction, dag: bool = False) - self.append_operation(cirq.PauliStringPhasorGate( obs, exponent_neg=-0.5 if dag else 0.5, - ).on(*qubits)) + ).on(*qubits).with_tags(*tags)) def process_m_pair(self, instruction: stim.CircuitInstruction, basis: str) -> None: args = instruction.gate_args_copy() if args and args[0]: raise NotImplementedError("Noisy M" + basis*2) + if instruction.tag: + tags = [instruction.tag] + else: + tags = () targets: List[stim.GateTarget] = instruction.targets_copy() for k in range(0, len(targets), 2): obs = cirq.DensePauliString(basis * 2) @@ -264,7 +288,7 @@ def process_m_pair(self, instruction: stim.CircuitInstruction, basis: str) -> No obs *= -1 qubits = [cirq.LineQubit(targets[0].value), cirq.LineQubit(targets[1].value)] key = str(self.get_next_measure_id()) - self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits)) + self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits).with_tags(*tags)) def process_mxx(self, instruction: stim.CircuitInstruction) -> None: self.process_m_pair(instruction, "X") @@ -286,12 +310,16 @@ def process_mpad(self, instruction: stim.CircuitInstruction) -> None: self.append_operation(cirq.PauliMeasurementGate(obs, key=key).on(*qubits)) def process_correlated_error(self, instruction: stim.CircuitInstruction) -> None: + if instruction.tag: + tags = [instruction.tag] + else: + tags = () args = instruction.gate_args_copy() probability = args[0] if args else 0 targets = instruction.targets_copy() qubits = [cirq.LineQubit(t.value) for t in targets] self.append_operation( - _stim_targets_to_dense_pauli_string(targets).on(*qubits).with_probability(probability) + _stim_targets_to_dense_pauli_string(targets).on(*qubits).with_probability(probability).with_tags(*tags) ) def coords_after_offset( @@ -316,20 +344,28 @@ def resolve_measurement_record_keys( return [str(self.num_measurements_seen + t.value) for t in targets], [] def process_detector(self, instruction: stim.CircuitInstruction) -> None: + if instruction.tag: + tags = [instruction.tag] + else: + tags = () coords = self.coords_after_offset(instruction.gate_args_copy()) keys, rels = self.resolve_measurement_record_keys(instruction.targets_copy()) self.append_operation( - DetAnnotation(parity_keys=keys, relative_keys=rels, coordinate_metadata=coords) + DetAnnotation(parity_keys=keys, relative_keys=rels, coordinate_metadata=coords).with_tags(*tags) ) def process_observable_include(self, instruction: stim.CircuitInstruction) -> None: + if instruction.tag: + tags = [instruction.tag] + else: + tags = () args = instruction.gate_args_copy() index = 0 if not args else int(args[0]) keys, rels = self.resolve_measurement_record_keys(instruction.targets_copy()) self.append_operation( CumulativeObservableAnnotation( parity_keys=keys, relative_keys=rels, observable_index=index - ) + ).with_tags(*tags) ) def process_qubit_coords(self, instruction: stim.CircuitInstruction) -> None: @@ -341,9 +377,13 @@ def process_qubit_coords(self, instruction: stim.CircuitInstruction) -> None: self.qubit_coords[t.value] = cirq.GridQubit(*coords) def process_shift_coords(self, instruction: stim.CircuitInstruction) -> None: + if instruction.tag: + tags = [instruction.tag] + else: + tags = () args = instruction.gate_args_copy() if not self.flatten: - self.append_operation(ShiftCoordsAnnotation(args)) + self.append_operation(ShiftCoordsAnnotation(args).with_tags(*tags)) for k, a in enumerate(args): self.origin[k] += a @@ -364,6 +404,10 @@ def __init__(self, pauli_gate: cirq.Pauli, gate: cirq.Gate): def __call__( self, tracker: 'CircuitTranslationTracker', instruction: stim.CircuitInstruction ) -> None: + if instruction.tag: + tags = [instruction.tag] + else: + tags = () targets: List[stim.GateTarget] = instruction.targets_copy() for k in range(0, len(targets), 2): a = targets[k] @@ -379,13 +423,13 @@ def __call__( stim_sweep_bit_index=a.value, cirq_sweep_symbol=f'sweep[{a.value}]', pauli=self.pauli_gate, - ).on(cirq.LineQubit(b.value)) + ).on(cirq.LineQubit(b.value)).with_tags(*tags) ) else: if not a.is_qubit_target or not b.is_qubit_target: raise NotImplementedError(f"instruction={instruction!r}") tracker.append_operation( - self.gate(cirq.LineQubit(a.value), cirq.LineQubit(b.value)) + self.gate(cirq.LineQubit(a.value), cirq.LineQubit(b.value)).with_tags(*tags) ) class OneToOneMeasurementHandler: @@ -422,7 +466,7 @@ def get_handler_table() -> Dict[ noise = CircuitTranslationTracker.OneToOneNoisyGateHandler sweep_gate = CircuitTranslationTracker.SweepableGateHandler - def not_impl(message) -> Callable[[Any], None]: + def not_impl(message) -> Callable[[Any, Any], None]: def handler( tracker: CircuitTranslationTracker, instruction: stim.CircuitInstruction ) -> None: diff --git a/glue/cirq/stimcirq/_stim_to_cirq_test.py b/glue/cirq/stimcirq/_stim_to_cirq_test.py index 42fb3230..30bfdcf8 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq_test.py +++ b/glue/cirq/stimcirq/_stim_to_cirq_test.py @@ -682,3 +682,72 @@ def test_stim_circuit_to_cirq_circuit_spp(): SPP Z0 TICK """) + + +def test_tags_convert(): + assert stimcirq.stim_circuit_to_cirq_circuit(stim.Circuit(""" + H[my_tag] 0 + """)) == cirq.Circuit( + cirq.H(cirq.LineQubit(0)).with_tags('my_tag'), + ) + + +@pytest.mark.parametrize('gate', sorted(stim.gate_data().keys())) +def test_every_operation_converts_tags(gate: str): + if gate in [ + "ELSE_CORRELATED_ERROR", + "HERALDED_ERASE", + "HERALDED_PAULI_CHANNEL_1", + "TICK", + "REPEAT", + "MPAD", + "QUBIT_COORDS", + ]: + pytest.skip() + + data = stim.gate_data(gate) + stim_circuit = stim.Circuit() + arg = None + targets = [0, 1] + if data.num_parens_arguments_range.start: + arg = [2**-6] * data.num_parens_arguments_range.start + if data.takes_pauli_targets: + targets = [stim.target_x(0), stim.target_y(1)] + if data.takes_measurement_record_targets and not data.is_unitary: + stim_circuit.append("M", [0], tag='custom_tag') + targets = [stim.target_rec(-1)] + if gate == 'SHIFT_COORDS': + targets = [] + if gate == 'OBSERVABLE_INCLUDE': + arg = [1] + stim_circuit.append(gate, targets, arg, tag='custom_tag') + cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) + assert any(cirq_circuit.all_operations()) + for op in cirq_circuit.all_operations(): + assert op.tags == ('custom_tag',) + restored_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) + assert restored_circuit.pop() == stim.CircuitInstruction("TICK") + assert all(instruction.tag == 'custom_tag' for instruction in restored_circuit) + if gate not in ['MXX', 'MYY', 'MZZ']: + assert restored_circuit == stim_circuit + + +def test_loop_tagging(): + stim_circuit = stim.Circuit(""" + REPEAT[custom-tag] 5 { + H[tag2] 0 + TICK + } + """) + cirq_circuit = stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) + assert cirq_circuit == cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.H(cirq.LineQubit(0)).with_tags('tag2'), + ), + repetitions=5, + use_repetition_ids=False, + ).with_tags('custom-tag') + ) + restored_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) + assert restored_circuit == stim_circuit diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index b9e1c185..059bc23e 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -303,6 +303,8 @@ class Circuit: name: str, targets: Union[int, stim.GateTarget, Iterable[Union[int, stim.GateTarget]]], arg: Union[float, Iterable[float]], + *, + tag: str = "", ) -> None: pass @overload @@ -316,6 +318,8 @@ class Circuit: name: object, targets: object = (), arg: object = None, + *, + tag: str = '', ) -> None: """Appends an operation into the circuit. @@ -366,6 +370,7 @@ class Circuit: compatibility reasons, `cirq.append_operation` (but not `cirq.append`) will default to a single 0.0 argument for gates that take exactly one argument. + tag: A customizable string attached to the instruction. """ def append_from_stim_program_text( self, @@ -398,6 +403,8 @@ class Circuit: name: object, targets: object = (), arg: object = None, + *, + tag: str = '', ) -> None: """[DEPRECATED] use stim.Circuit.append instead """ diff --git a/src/stim/circuit/circuit.cc b/src/stim/circuit/circuit.cc index c7f8e9b7..b158069e 100644 --- a/src/stim/circuit/circuit.cc +++ b/src/stim/circuit/circuit.cc @@ -331,7 +331,7 @@ void Circuit::safe_append(CircuitInstruction operation, bool block_fusion) { } } -void Circuit::safe_append_ua(std::string_view gate_name, const std::vector &targets, double singleton_arg) { +void Circuit::safe_append_ua(std::string_view gate_name, const std::vector &targets, double singleton_arg, std::string_view tag) { const auto &gate = GATE_DATA.at(gate_name); std::vector converted; @@ -340,11 +340,11 @@ void Circuit::safe_append_ua(std::string_view gate_name, const std::vector &targets, const std::vector &args) { + std::string_view gate_name, const std::vector &targets, const std::vector &args, std::string_view tag) { const auto &gate = GATE_DATA.at(gate_name); std::vector converted; @@ -353,7 +353,7 @@ void Circuit::safe_append_u( converted.push_back({e}); } - safe_append(CircuitInstruction(gate.id, args, converted, "")); + safe_append(CircuitInstruction(gate.id, args, converted, tag)); } void Circuit::safe_insert(size_t index, const CircuitInstruction &instruction) { diff --git a/src/stim/circuit/circuit.h b/src/stim/circuit/circuit.h index 51f5596f..69eae879 100644 --- a/src/stim/circuit/circuit.h +++ b/src/stim/circuit/circuit.h @@ -108,10 +108,10 @@ struct Circuit { /// Safely adds an operation at the end of the circuit, copying its data into the circuit's jagged data as needed. void safe_append(CircuitInstruction operation, bool block_fusion = false); /// Safely adds an operation at the end of the circuit, copying its data into the circuit's jagged data as needed. - void safe_append_ua(std::string_view gate_name, const std::vector &targets, double singleton_arg); + void safe_append_ua(std::string_view gate_name, const std::vector &targets, double singleton_arg, std::string_view tag = ""); /// Safely adds an operation at the end of the circuit, copying its data into the circuit's jagged data as needed. void safe_append_u( - std::string_view gate_name, const std::vector &targets, const std::vector &args = {}); + std::string_view gate_name, const std::vector &targets, const std::vector &args = {}, std::string_view tag = ""); /// Safely copies a repeat block to the end of the circuit. void append_repeat_block(uint64_t repeat_count, const Circuit &body, std::string_view tag); /// Safely moves a repeat block to the end of the circuit. diff --git a/src/stim/circuit/circuit.pybind.cc b/src/stim/circuit/circuit.pybind.cc index 6d6acd97..f6f133c2 100644 --- a/src/stim/circuit/circuit.pybind.cc +++ b/src/stim/circuit/circuit.pybind.cc @@ -249,6 +249,7 @@ void circuit_append( const pybind11::object &obj, const pybind11::object &targets, const pybind11::object &arg, + std::string_view tag, bool backwards_compat) { // Extract single target or list of targets. std::vector raw_targets; @@ -276,20 +277,20 @@ void circuit_append( // Extract single argument or list of arguments. try { auto d = pybind11::cast(used_arg); - self.safe_append_ua(gate_name, raw_targets, d); + self.safe_append_ua(gate_name, raw_targets, d, tag); return; } catch (const pybind11::cast_error &ex) { } try { auto args = pybind11::cast>(used_arg); - self.safe_append_u(gate_name, raw_targets, args); + self.safe_append_u(gate_name, raw_targets, args, tag); return; } catch (const pybind11::cast_error &ex) { } throw std::invalid_argument("Arg must be a double or sequence of doubles."); } else if (pybind11::isinstance(obj)) { - if (!raw_targets.empty() || !arg.is_none()) { - throw std::invalid_argument("Can't specify `targets` or `arg` when appending a stim.CircuitInstruction."); + if (!raw_targets.empty() || !arg.is_none() || !tag.empty()) { + throw std::invalid_argument("Can't specify `targets` or `arg` or `tag` when appending a stim.CircuitInstruction."); } const PyCircuitInstruction &instruction = pybind11::cast(obj); @@ -301,8 +302,8 @@ void circuit_append( pybind11::cast(instruction.tag), }); } else if (pybind11::isinstance(obj)) { - if (!raw_targets.empty() || !arg.is_none()) { - throw std::invalid_argument("Can't specify `targets` or `arg` when appending a stim.CircuitRepeatBlock."); + if (!raw_targets.empty() || !arg.is_none() || !tag.empty()) { + throw std::invalid_argument("Can't specify `targets` or `arg` or `tag` when appending a stim.CircuitRepeatBlock."); } const CircuitRepeatBlock &block = pybind11::cast(obj); @@ -315,12 +316,12 @@ void circuit_append( } } void circuit_append_backwards_compat( - Circuit &self, const pybind11::object &obj, const pybind11::object &targets, const pybind11::object &arg) { - circuit_append(self, obj, targets, arg, true); + Circuit &self, const pybind11::object &obj, const pybind11::object &targets, const pybind11::object &arg, std::string_view tag) { + circuit_append(self, obj, targets, arg, tag, true); } void circuit_append_strict( - Circuit &self, const pybind11::object &obj, const pybind11::object &targets, const pybind11::object &arg) { - circuit_append(self, obj, targets, arg, false); + Circuit &self, const pybind11::object &obj, const pybind11::object &targets, const pybind11::object &arg, std::string_view tag) { + circuit_append(self, obj, targets, arg, tag, false); } pybind11::class_ stim_pybind::pybind_circuit(pybind11::module &m) { @@ -1107,10 +1108,12 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_ None: + @overload def append(self, name: str, targets: Union[int, stim.GateTarget, Iterable[Union[int, stim.GateTarget]]], arg: Union[float, Iterable[float]], *, tag: str = "") -> None: @overload def append(self, name: Union[stim.CircuitOperation, stim.CircuitRepeatBlock]) -> None: Note: `stim.Circuit.append_operation` is an alias of `stim.Circuit.append`. @@ -1160,6 +1163,7 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_