From 01be362a2356c0c605cb9aa946d7617adbb52996 Mon Sep 17 00:00:00 2001 From: Matthew Harrigan Date: Wed, 16 Oct 2024 16:53:42 -0700 Subject: [PATCH] Hone Bloq<->Cirq adapters CirqGateAsBloq only has bloq API. BloqAsCirqGate only has cirq API --- qualtran/_infra/gate_with_registers.py | 2 +- qualtran/bloqs/basic_gates/global_phase.py | 4 +- qualtran/bloqs/basic_gates/rotation.py | 16 ++--- qualtran/cirq_interop/__init__.py | 2 +- qualtran/cirq_interop/_bloq_to_cirq.py | 33 +++------ qualtran/cirq_interop/_bloq_to_cirq_test.py | 8 +-- qualtran/cirq_interop/_cirq_to_bloq.py | 77 +++++++++++++-------- qualtran/cirq_interop/_cirq_to_bloq_test.py | 31 ++------- 8 files changed, 81 insertions(+), 92 deletions(-) diff --git a/qualtran/_infra/gate_with_registers.py b/qualtran/_infra/gate_with_registers.py index e59bd4a20..99026c66a 100644 --- a/qualtran/_infra/gate_with_registers.py +++ b/qualtran/_infra/gate_with_registers.py @@ -319,7 +319,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - return _wire_symbol_from_gate(self, self.signature, reg, idx) - # Part-2: Cirq-FT style interface can be used to implemented algorithms by Bloq authors. + # Part-2: Cirq-FT style interface can be used to implement algorithms by Bloq authors. def _num_qubits_(self) -> int: return total_bits(self.signature) diff --git a/qualtran/bloqs/basic_gates/global_phase.py b/qualtran/bloqs/basic_gates/global_phase.py index cae67a2fa..3782f38ea 100644 --- a/qualtran/bloqs/basic_gates/global_phase.py +++ b/qualtran/bloqs/basic_gates/global_phase.py @@ -31,7 +31,7 @@ SoquetT, ) from qualtran.bloqs.basic_gates.rotation import ZPowGate -from qualtran.cirq_interop import CirqGateAsBloqBase +from qualtran.cirq_interop import CirqGateAsBloqMixin from qualtran.symbolics import pi, sarg, sexp, SymbolicComplex, SymbolicFloat if TYPE_CHECKING: @@ -39,7 +39,7 @@ @frozen -class GlobalPhase(CirqGateAsBloqBase): +class GlobalPhase(CirqGateAsBloqMixin): r"""Applies a global phase to the circuit as a whole. The unitary effect is to multiply the state vector by the complex scalar diff --git a/qualtran/bloqs/basic_gates/rotation.py b/qualtran/bloqs/basic_gates/rotation.py index 9adb439e1..372a636e0 100644 --- a/qualtran/bloqs/basic_gates/rotation.py +++ b/qualtran/bloqs/basic_gates/rotation.py @@ -21,13 +21,13 @@ from attrs import frozen from qualtran import bloq_example, BloqDocSpec, CompositeBloq, DecomposeTypeError, Register -from qualtran.cirq_interop import CirqGateAsBloqBase +from qualtran.cirq_interop import CirqGateAsBloqMixin from qualtran.drawing import Text, TextBox, WireSymbol from qualtran.symbolics import SymbolicFloat @frozen -class ZPowGate(CirqGateAsBloqBase): +class ZPowGate(CirqGateAsBloqMixin): r"""A gate that rotates around the Z axis of the Bloch sphere. The unitary matrix of `ZPowGate(exponent=t, global_shift=s)` is: @@ -107,7 +107,7 @@ def _z_pow() -> ZPowGate: @frozen -class CZPowGate(CirqGateAsBloqBase): +class CZPowGate(CirqGateAsBloqMixin): exponent: float = 1.0 global_shift: float = 0.0 eps: SymbolicFloat = 1e-11 @@ -131,7 +131,7 @@ def __str__(self): @frozen -class XPowGate(CirqGateAsBloqBase): +class XPowGate(CirqGateAsBloqMixin): r"""A gate that rotates around the X axis of the Bloch sphere. The unitary matrix of `XPowGate(exponent=t, global_shift=s)` is: @@ -205,7 +205,7 @@ def _x_pow() -> XPowGate: @frozen -class YPowGate(CirqGateAsBloqBase): +class YPowGate(CirqGateAsBloqMixin): r"""A gate that rotates around the Y axis of the Bloch sphere. The unitary matrix of `YPowGate(exponent=t)` is: @@ -279,7 +279,7 @@ def _y_pow() -> YPowGate: @frozen -class Rz(CirqGateAsBloqBase): +class Rz(CirqGateAsBloqMixin): """Single-qubit Rz gate. Args: @@ -320,7 +320,7 @@ def __str__(self): @frozen -class Rx(CirqGateAsBloqBase): +class Rx(CirqGateAsBloqMixin): angle: Union[sympy.Expr, float] eps: SymbolicFloat = 1e-11 @@ -344,7 +344,7 @@ def __str__(self): @frozen -class Ry(CirqGateAsBloqBase): +class Ry(CirqGateAsBloqMixin): angle: Union[sympy.Expr, float] eps: SymbolicFloat = 1e-11 diff --git a/qualtran/cirq_interop/__init__.py b/qualtran/cirq_interop/__init__.py index f292f268c..fe0121a67 100644 --- a/qualtran/cirq_interop/__init__.py +++ b/qualtran/cirq_interop/__init__.py @@ -20,7 +20,7 @@ from ._cirq_to_bloq import ( CirqQuregT, CirqGateAsBloq, - CirqGateAsBloqBase, + CirqGateAsBloqMixin, cirq_optree_to_cbloq, cirq_gate_to_bloq, decompose_from_cirq_style_method, diff --git a/qualtran/cirq_interop/_bloq_to_cirq.py b/qualtran/cirq_interop/_bloq_to_cirq.py index 70f6d122c..79f90a3aa 100644 --- a/qualtran/cirq_interop/_bloq_to_cirq.py +++ b/qualtran/cirq_interop/_bloq_to_cirq.py @@ -14,13 +14,11 @@ """Qualtran Bloqs to Cirq gates/circuits conversion.""" -from functools import cached_property -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Dict, Iterable, List, Optional, Sequence, Tuple import cirq import networkx as nx import numpy as np -from numpy.typing import NDArray from qualtran import ( Bloq, @@ -40,7 +38,6 @@ _get_all_and_output_quregs_from_input, merge_qubits, split_qubits, - total_bits, ) from qualtran.cirq_interop._cirq_to_bloq import _QReg, CirqQuregInT, CirqQuregT from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager @@ -58,7 +55,9 @@ def _cirq_style_decompose_from_decompose_bloq( # Input qubits can get de-allocated by cbloq.to_cirq_circuit_and_quregs, thus mark them as managed. qm = InteropQubitManager(context.qubit_manager) qm.manage_qubits(merge_qubits(bloq.signature.lefts(), **in_quregs)) - circuit, out_quregs = cbloq.to_cirq_circuit_and_quregs(qubit_manager=qm, **in_quregs) + circuit, out_quregs = _cbloq_to_cirq_circuit( + cbloq.signature, in_quregs, cbloq._binst_graph, qubit_manager=qm + ) qubit_map = {q: q for q in circuit.all_qubits()} for reg in bloq.signature.rights(): if reg.side == Side.RIGHT: @@ -93,11 +92,6 @@ def bloq(self) -> Bloq: """The bloq we're wrapping.""" return self._bloq - @cached_property - def signature(self) -> Signature: - """`GateWithRegisters` registers.""" - return self.bloq.signature - @classmethod def bloq_on( cls, bloq: Bloq, cirq_quregs: Dict[str, 'CirqQuregT'], qubit_manager: cirq.QubitManager # type: ignore[type-var] @@ -120,15 +114,16 @@ def bloq_on( all_quregs, out_quregs = _get_all_and_output_quregs_from_input( bloq.signature, qubit_manager, in_quregs=cirq_quregs ) - return BloqAsCirqGate(bloq=bloq).on_registers(**all_quregs), out_quregs + cirq_op = BloqAsCirqGate(bloq=bloq).on(*merge_qubits(bloq.signature, **all_quregs)) + return cirq_op, out_quregs def _num_qubits_(self) -> int: - return total_bits(self.signature) + return self.bloq.signature.n_qubits() def _decompose_with_context_( self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None ) -> cirq.OP_TREE: - quregs = split_qubits(self.signature, qubits) + quregs = split_qubits(self.bloq.signature, qubits) if context is None: context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) try: @@ -143,22 +138,16 @@ def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE: return self._decompose_with_context_(qubits) def _unitary_(self): - if all(reg.side == Side.THRU for reg in self.signature): + if all(reg.side == Side.THRU for reg in self.bloq.signature): try: - _ = self.bloq.decompose_bloq() # check for decomposability - return NotImplemented - except (DecomposeNotImplementedError, DecomposeTypeError): tensor = self.bloq.tensor_contract() if tensor.ndim != 2: return NotImplemented return tensor + except NotImplementedError: + return NotImplemented return NotImplemented - def on_registers( - self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] # type: ignore[type-var] - ) -> cirq.Operation: - return self.on(*merge_qubits(self.signature, **qubit_regs)) - def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: """Draw cirq diagrams. diff --git a/qualtran/cirq_interop/_bloq_to_cirq_test.py b/qualtran/cirq_interop/_bloq_to_cirq_test.py index 95dae1cef..ade620df0 100644 --- a/qualtran/cirq_interop/_bloq_to_cirq_test.py +++ b/qualtran/cirq_interop/_bloq_to_cirq_test.py @@ -19,7 +19,7 @@ from attrs import frozen from qualtran import Bloq, BloqBuilder, ConnectionT, Signature, Soquet, SoquetT -from qualtran._infra.gate_with_registers import get_named_qubits +from qualtran._infra.gate_with_registers import get_named_qubits, merge_qubits from qualtran.bloqs.basic_gates import Toffoli, XGate, YGate from qualtran.bloqs.factoring import ModExp from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd @@ -226,8 +226,8 @@ def test_bloq_as_cirq_gate_for_mod_exp(): mod_exp = ModExp.make_for_shor(4, 3) gate = BloqAsCirqGate(mod_exp) # Use Cirq's infrastructure to construct an operation and corresponding decomposition. - quregs = get_named_qubits(gate.signature) - op = gate.on_registers(**quregs) + quregs = get_named_qubits(mod_exp.signature) + op = gate.on(*merge_qubits(mod_exp.signature, **quregs)) # cirq.decompose_once(op) delegates to underlying Bloq's decomposition specified in # `bloq.decompose_bloq()` and wraps resulting composite bloq in a Cirq op-tree. Note # how `BloqAsCirqGate.decompose_with_registers()` automatically takes care of mapping @@ -257,7 +257,7 @@ def test_bloq_as_cirq_gate_for_mod_exp(): decomposed_circuit, out_regs = cbloq.to_cirq_circuit_and_quregs(exponent=quregs['exponent']) # Whereas when directly applying a cirq gate on qubits to get an operations, we need to # specify both input and output registers. - circuit = cirq.Circuit(gate.on_registers(**out_regs), decomposed_circuit) + circuit = cirq.Circuit(gate.on(*merge_qubits(cbloq.signature, **out_regs)), decomposed_circuit) # Notice the newly allocated qubits _C(0) and _C(1) for output register x. cirq.testing.assert_has_diagram( circuit, diff --git a/qualtran/cirq_interop/_cirq_to_bloq.py b/qualtran/cirq_interop/_cirq_to_bloq.py index f71810df9..88713537d 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq.py +++ b/qualtran/cirq_interop/_cirq_to_bloq.py @@ -70,7 +70,7 @@ def _get_cirq_quregs(signature: Signature, qm: InteropQubitManager): return ret -class CirqGateAsBloqBase(GateWithRegisters, metaclass=abc.ABCMeta): +class _CirqGateAsBloqBase(Bloq, metaclass=abc.ABCMeta): """A Bloq wrapper around a `cirq.Gate`""" @property @@ -79,8 +79,6 @@ def cirq_gate(self) -> cirq.Gate: ... @cached_property def signature(self) -> 'Signature': - if isinstance(self.cirq_gate, Bloq): - return self.cirq_gate.signature nqubits = cirq.num_qubits(self.cirq_gate) return ( Signature([Register('q', QBit(), shape=nqubits)]) @@ -88,14 +86,13 @@ def signature(self) -> 'Signature': else Signature.build(q=nqubits) ) + def decompose_bloq(self) -> 'CompositeBloq': + return decompose_from_cirq_style_method(self) + def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: CirqQuregT ) -> cirq.OP_TREE: - op = ( - self.cirq_gate.on_registers(**quregs) - if isinstance(self.cirq_gate, GateWithRegisters) - else self.cirq_gate.on(*quregs.get('q', np.array(())).flatten()) - ) + op = self.cirq_gate.on(*quregs.get('q', np.array(())).flatten()) try: return cirq.decompose_once(op) except TypeError as e: @@ -111,38 +108,20 @@ def my_tensors( def as_cirq_op( self, qubit_manager: 'cirq.QubitManager', **in_quregs: 'CirqQuregT' ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: - if isinstance(self.cirq_gate, GateWithRegisters): - return self.cirq_gate.as_cirq_op(qubit_manager, **in_quregs) qubits = in_quregs.get('q', np.array([])).flatten() return self.cirq_gate.on(*qubits), in_quregs - # Delegate all cirq-style protocols to underlying gate - def _unitary_(self): - return cirq.unitary(self.cirq_gate, default=None) - - def _circuit_diagram_info_( - self, args: cirq.CircuitDiagramInfoArgs - ) -> Optional[cirq.CircuitDiagramInfo]: - return cirq.circuit_diagram_info(self.cirq_gate, default=None) - - def __str__(self): - return str(self.cirq_gate) - - def __pow__(self, power): - return CirqGateAsBloq(gate=cirq.pow(self.cirq_gate, power)) - def adjoint(self) -> 'Bloq': return CirqGateAsBloq(gate=cirq.inverse(self.cirq_gate)) + def __str__(self): + return f'cirq.{self.cirq_gate}' + @frozen -class CirqGateAsBloq(CirqGateAsBloqBase): +class CirqGateAsBloq(_CirqGateAsBloqBase): gate: cirq.Gate - def __str__(self) -> str: - g = min(self.cirq_gate.__class__.__name__, str(self.cirq_gate), key=len) - return f'cirq.{g}' - @property def cirq_gate(self) -> cirq.Gate: return self.gate @@ -153,6 +132,44 @@ def my_static_costs(self, cost_key: 'CostKey'): if t_count is None: raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}") return GateCounts(t=t_count.t, rotation=t_count.rotations, clifford=t_count.clifford) + return NotImplemented + + +class CirqGateAsBloqMixin(_CirqGateAsBloqBase, GateWithRegisters, metaclass=abc.ABCMeta): + """A mixin to bootstrap a bloq from a Cirq gate. + + Bloq authors can inherit from this abstract class and override the `cirq_gate` property + to get a bloq adapted from the cirq gate. Authors can continue to customize the bloq + by overriding methods (like costs, string representations, ...). + + This uses the same machinery as `CirqGateAsBloq`. That adapter will always be of type + `CirqGateAsBloq` and additional information cannot be annotated on the Cirq gate. + + This interface can be used to bootstrap a bloq if you have a complicated operation + implemented as a Cirq gate that you cannot change to be a `GateWithRegisters` and for + which writing a native Bloq would be prohibitive. + """ + + @property + @abc.abstractmethod + def cirq_gate(self) -> cirq.Gate: ... + + def _unitary_(self): + return cirq.unitary(self.cirq_gate, default=None) + + def _circuit_diagram_info_( + self, args: cirq.CircuitDiagramInfoArgs + ) -> Optional[cirq.CircuitDiagramInfo]: + return cirq.circuit_diagram_info(self.cirq_gate, default=None) + + def __str__(self): + return str(self.cirq_gate) + + def __pow__(self, power): + return CirqGateAsBloq(gate=cirq.pow(self.cirq_gate, power)) + + def adjoint(self) -> 'Bloq': + return CirqGateAsBloq(gate=cirq.inverse(self.cirq_gate)) def _cirq_wire_symbol_to_qualtran_wire_symbol(symbol: str, side: Side) -> 'WireSymbol': diff --git a/qualtran/cirq_interop/_cirq_to_bloq_test.py b/qualtran/cirq_interop/_cirq_to_bloq_test.py index a016d85a6..030afefc4 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq_test.py +++ b/qualtran/cirq_interop/_cirq_to_bloq_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterator, Tuple +from typing import Dict, Tuple import attr import cirq @@ -34,7 +34,7 @@ SoquetT, ) from qualtran._infra.gate_with_registers import get_named_qubits -from qualtran.bloqs.basic_gates import CNOT, GlobalPhase, OneState +from qualtran.bloqs.basic_gates import GlobalPhase, OneState, ZeroState from qualtran.bloqs.bookkeeping import Allocate, Free, Join, Split from qualtran.bloqs.mcmt.and_bloq import And from qualtran.cirq_interop import cirq_optree_to_cbloq, CirqGateAsBloq, CirqQuregT @@ -75,7 +75,7 @@ def test_cirq_gate_as_bloq_for_trivial_gates(): assert toffoli.signature[0].shape == (3,) assert str(x) == 'cirq.X' - assert str(rx) == 'cirq.Rx' + assert str(rx) == 'cirq.Rx(0.123π)' assert str(toffoli) == 'cirq.TOFFOLI' @@ -83,8 +83,10 @@ def test_cirq_gate_as_bloq_tensor_contract_for_and_gate(): and_gate = And() bb = BloqBuilder() ctrl = [bb.add(OneState()) for _ in range(2)] - ctrl, target = bb.add(CirqGateAsBloq(and_gate), ctrl=ctrl) - cbloq = bb.finalize(ctrl=ctrl, target=target) + target = bb.add(ZeroState()) + q = [*ctrl, target] + c0, c1, target = bb.add(CirqGateAsBloq(and_gate), q=q) + cbloq = bb.finalize(ctrl=np.array([c0, c1]), target=target) state_vector = cbloq.tensor_contract() assert np.isclose(state_vector[7], 1) @@ -201,25 +203,6 @@ def signature(self) -> Signature: assert bloqs_list.count(Free(QAny(2))) == 2 -def test_cirq_gate_as_bloq_for_left_only_gates(): - class LeftOnlyGate(GateWithRegisters): - @property - def signature(self): - return Signature([Register('junk', QAny(2), side=Side.LEFT)]) - - def decompose_from_registers(self, *, context, junk) -> Iterator[cirq.OP_TREE]: - yield cirq.CNOT(*junk) - yield cirq.reset_each(*junk) - - # Using InteropQubitManager enables support for LeftOnlyGate's in CirqGateAsBloq. - cbloq = CirqGateAsBloq(gate=LeftOnlyGate()).decompose_bloq() - bloqs_list = [binst.bloq for binst in cbloq.bloq_instances] - assert bloqs_list.count(Split(QAny(2))) == 1 - assert bloqs_list.count(Free(QBit())) == 2 - assert bloqs_list.count(CNOT()) == 1 - assert bloqs_list.count(CirqGateAsBloq(cirq.ResetChannel())) == 2 - - def test_cirq_gate_as_bloq_decompose_raises(): bloq = CirqGateAsBloq(cirq.X) with pytest.raises(DecomposeNotImplementedError, match="does not declare a decomposition"):