Skip to content

Commit

Permalink
stimcirq
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Jan 24, 2024
1 parent 8a5d855 commit dd5c43c
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 0 deletions.
2 changes: 2 additions & 0 deletions glue/cirq/stimcirq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__version__ = '1.13.dev0'
from ._cirq_to_stim import cirq_circuit_to_stim_circuit
from ._cx_swap_gate import CXSwapGate
from ._cz_swap_gate import CZSwapGate
from ._det_annotation import DetAnnotation
from ._obs_annotation import CumulativeObservableAnnotation
from ._shift_coords_annotation import ShiftCoordsAnnotation
Expand All @@ -20,5 +21,6 @@
"SweepPauli": SweepPauli,
"TwoQubitAsymmetricDepolarizingChannel": TwoQubitAsymmetricDepolarizingChannel,
"CXSwapGate": CXSwapGate,
"CZSwapGate": CZSwapGate,
}
JSON_RESOLVER = JSON_RESOLVERS_DICT.get
46 changes: 46 additions & 0 deletions glue/cirq/stimcirq/_cz_swap_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any, Dict, List

import cirq
import stim


@cirq.value_equality
class CZSwapGate(cirq.Gate):
"""Handles explaining stim's CZSWAP gates to cirq."""

def _num_qubits_(self) -> int:
return 2

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> List[str]:
return ['ZSWAP', 'ZSWAP']

def _value_equality_values_(self):
return ()

def _decompose_(self, qubits):
a, b = qubits
yield cirq.SWAP(a, b)
yield cirq.CZ(a, b)

def _stim_conversion_(self, edit_circuit: stim.Circuit, targets: List[int], **kwargs):
edit_circuit.append_operation('CZSWAP', targets)

def __pow__(self, power: int) -> 'CZSwapGate':
if power == +1:
return self
if power == -1:
return self
return NotImplemented

def __str__(self) -> str:
return 'CZSWAP'

def __repr__(self):
return f'stimcirq.CZSwapGate()'

@staticmethod
def _json_namespace_() -> str:
return ''

def _json_dict_(self) -> Dict[str, Any]:
return {}
72 changes: 72 additions & 0 deletions glue/cirq/stimcirq/_cz_swap_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import cirq
import stim
import stimcirq


def test_stim_conversion():
a, b, c = cirq.LineQubit.range(3)

cirq_circuit = cirq.Circuit(
stimcirq.CZSwapGate().on(a, b),
stimcirq.CZSwapGate().on(b, c),
)
stim_circuit = stim.Circuit(
"""
CZSWAP 0 1
TICK
CZSWAP 1 2
TICK
"""
)
assert stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) == stim_circuit
assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq_circuit


def test_diagram():
a, b = cirq.LineQubit.range(2)
cirq.testing.assert_has_diagram(
cirq.Circuit(
stimcirq.CZSwapGate()(a, b),
stimcirq.CZSwapGate()(a, b),
),
"""
0: ---ZSWAP---ZSWAP---
| |
1: ---ZSWAP---ZSWAP---
""",
use_unicode_characters=False,
)


def test_inverse():
a = stimcirq.CZSwapGate()
assert a**+1 == a
assert a**-1 == a


def test_repr():
val = stimcirq.CZSwapGate()
assert eval(repr(val), {"stimcirq": stimcirq}) == val


def test_equality():
eq = cirq.testing.EqualsTester()
eq.add_equality_group(stimcirq.CZSwapGate(), stimcirq.CZSwapGate())


def test_json_serialization():
a, b, d = cirq.LineQubit.range(3)
c = cirq.Circuit(
stimcirq.CZSwapGate()(a, b),
stimcirq.CZSwapGate()(b, d),
)
json = cirq.to_json(c)
c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER])
assert c == c2


def test_json_backwards_compat_exact():
raw = stimcirq.CZSwapGate()
packed = '{\n "cirq_type": "CZSwapGate"\n}'
assert cirq.to_json(raw) == packed
assert cirq.read_json(json_text=packed, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw
2 changes: 2 additions & 0 deletions glue/cirq/stimcirq/_stim_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import stim

from ._cx_swap_gate import CXSwapGate
from ._cz_swap_gate import CZSwapGate
from ._det_annotation import DetAnnotation
from ._measure_and_or_reset_gate import MeasureAndOrResetGate
from ._obs_annotation import CumulativeObservableAnnotation
Expand Down Expand Up @@ -424,6 +425,7 @@ def handler(
measure=False, reset=True, basis='X', invert_measure=False, key=''
)
),
"CZSWAP": gate(CZSwapGate()),
"CXSWAP": gate(CXSwapGate(inverted=False)),
"SWAPCX": gate(CXSwapGate(inverted=True)),
"RY": gate(
Expand Down

0 comments on commit dd5c43c

Please sign in to comment.