Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hone Bloq<->Cirq adapters #1472

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -

return _wire_symbol_from_gate(self, self.signature, reg, idx)

# Part-2: Cirq-FT style interface can be used to implemented algorithms by Bloq authors.
# Part-2: Cirq-FT style interface can be used to implement algorithms by Bloq authors.

def _num_qubits_(self) -> int:
return total_bits(self.signature)
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/basic_gates/global_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
SoquetT,
)
from qualtran.bloqs.basic_gates.rotation import ZPowGate
from qualtran.cirq_interop import CirqGateAsBloqBase
from qualtran.cirq_interop import CirqGateAsBloqMixin
from qualtran.symbolics import pi, sarg, sexp, SymbolicComplex, SymbolicFloat

if TYPE_CHECKING:
import quimb.tensor as qtn


@frozen
class GlobalPhase(CirqGateAsBloqBase):
class GlobalPhase(CirqGateAsBloqMixin):
r"""Applies a global phase to the circuit as a whole.

The unitary effect is to multiply the state vector by the complex scalar
Expand Down
16 changes: 8 additions & 8 deletions qualtran/bloqs/basic_gates/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from attrs import frozen

from qualtran import bloq_example, BloqDocSpec, CompositeBloq, DecomposeTypeError, Register
from qualtran.cirq_interop import CirqGateAsBloqBase
from qualtran.cirq_interop import CirqGateAsBloqMixin
from qualtran.drawing import Text, TextBox, WireSymbol
from qualtran.symbolics import SymbolicFloat


@frozen
class ZPowGate(CirqGateAsBloqBase):
class ZPowGate(CirqGateAsBloqMixin):
r"""A gate that rotates around the Z axis of the Bloch sphere.

The unitary matrix of `ZPowGate(exponent=t, global_shift=s)` is:
Expand Down Expand Up @@ -107,7 +107,7 @@ def _z_pow() -> ZPowGate:


@frozen
class CZPowGate(CirqGateAsBloqBase):
class CZPowGate(CirqGateAsBloqMixin):
exponent: float = 1.0
global_shift: float = 0.0
eps: SymbolicFloat = 1e-11
Expand All @@ -131,7 +131,7 @@ def __str__(self):


@frozen
class XPowGate(CirqGateAsBloqBase):
class XPowGate(CirqGateAsBloqMixin):
r"""A gate that rotates around the X axis of the Bloch sphere.

The unitary matrix of `XPowGate(exponent=t, global_shift=s)` is:
Expand Down Expand Up @@ -205,7 +205,7 @@ def _x_pow() -> XPowGate:


@frozen
class YPowGate(CirqGateAsBloqBase):
class YPowGate(CirqGateAsBloqMixin):
r"""A gate that rotates around the Y axis of the Bloch sphere.

The unitary matrix of `YPowGate(exponent=t)` is:
Expand Down Expand Up @@ -279,7 +279,7 @@ def _y_pow() -> YPowGate:


@frozen
class Rz(CirqGateAsBloqBase):
class Rz(CirqGateAsBloqMixin):
"""Single-qubit Rz gate.

Args:
Expand Down Expand Up @@ -320,7 +320,7 @@ def __str__(self):


@frozen
class Rx(CirqGateAsBloqBase):
class Rx(CirqGateAsBloqMixin):
angle: Union[sympy.Expr, float]
eps: SymbolicFloat = 1e-11

Expand All @@ -344,7 +344,7 @@ def __str__(self):


@frozen
class Ry(CirqGateAsBloqBase):
class Ry(CirqGateAsBloqMixin):
angle: Union[sympy.Expr, float]
eps: SymbolicFloat = 1e-11

Expand Down
2 changes: 1 addition & 1 deletion qualtran/cirq_interop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._cirq_to_bloq import (
CirqQuregT,
CirqGateAsBloq,
CirqGateAsBloqBase,
CirqGateAsBloqMixin,
cirq_optree_to_cbloq,
cirq_gate_to_bloq,
decompose_from_cirq_style_method,
Expand Down
33 changes: 11 additions & 22 deletions qualtran/cirq_interop/_bloq_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@

"""Qualtran Bloqs to Cirq gates/circuits conversion."""

from functools import cached_property
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import cirq
import networkx as nx
import numpy as np
from numpy.typing import NDArray

from qualtran import (
Bloq,
Expand All @@ -40,7 +38,6 @@
_get_all_and_output_quregs_from_input,
merge_qubits,
split_qubits,
total_bits,
)
from qualtran.cirq_interop._cirq_to_bloq import _QReg, CirqQuregInT, CirqQuregT
from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager
Expand All @@ -58,7 +55,9 @@ def _cirq_style_decompose_from_decompose_bloq(
# Input qubits can get de-allocated by cbloq.to_cirq_circuit_and_quregs, thus mark them as managed.
qm = InteropQubitManager(context.qubit_manager)
qm.manage_qubits(merge_qubits(bloq.signature.lefts(), **in_quregs))
circuit, out_quregs = cbloq.to_cirq_circuit_and_quregs(qubit_manager=qm, **in_quregs)
circuit, out_quregs = _cbloq_to_cirq_circuit(
cbloq.signature, in_quregs, cbloq._binst_graph, qubit_manager=qm
)
qubit_map = {q: q for q in circuit.all_qubits()}
for reg in bloq.signature.rights():
if reg.side == Side.RIGHT:
Expand Down Expand Up @@ -93,11 +92,6 @@ def bloq(self) -> Bloq:
"""The bloq we're wrapping."""
return self._bloq

@cached_property
def signature(self) -> Signature:
"""`GateWithRegisters` registers."""
return self.bloq.signature

@classmethod
def bloq_on(
cls, bloq: Bloq, cirq_quregs: Dict[str, 'CirqQuregT'], qubit_manager: cirq.QubitManager # type: ignore[type-var]
Expand All @@ -120,15 +114,16 @@ def bloq_on(
all_quregs, out_quregs = _get_all_and_output_quregs_from_input(
bloq.signature, qubit_manager, in_quregs=cirq_quregs
)
return BloqAsCirqGate(bloq=bloq).on_registers(**all_quregs), out_quregs
cirq_op = BloqAsCirqGate(bloq=bloq).on(*merge_qubits(bloq.signature, **all_quregs))
return cirq_op, out_quregs

def _num_qubits_(self) -> int:
return total_bits(self.signature)
return self.bloq.signature.n_qubits()

def _decompose_with_context_(
self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None
) -> cirq.OP_TREE:
quregs = split_qubits(self.signature, qubits)
quregs = split_qubits(self.bloq.signature, qubits)
if context is None:
context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager())
try:
Expand All @@ -143,22 +138,16 @@ def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE:
return self._decompose_with_context_(qubits)

def _unitary_(self):
if all(reg.side == Side.THRU for reg in self.signature):
if all(reg.side == Side.THRU for reg in self.bloq.signature):
try:
_ = self.bloq.decompose_bloq() # check for decomposability
return NotImplemented
except (DecomposeNotImplementedError, DecomposeTypeError):
tensor = self.bloq.tensor_contract()
if tensor.ndim != 2:
return NotImplemented
return tensor
except NotImplementedError:
return NotImplemented
return NotImplemented

def on_registers(
self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] # type: ignore[type-var]
) -> cirq.Operation:
return self.on(*merge_qubits(self.signature, **qubit_regs))

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
"""Draw cirq diagrams.

Expand Down
8 changes: 4 additions & 4 deletions qualtran/cirq_interop/_bloq_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from attrs import frozen

from qualtran import Bloq, BloqBuilder, ConnectionT, Signature, Soquet, SoquetT
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran._infra.gate_with_registers import get_named_qubits, merge_qubits
from qualtran.bloqs.basic_gates import Toffoli, XGate, YGate
from qualtran.bloqs.factoring import ModExp
from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd
Expand Down Expand Up @@ -226,8 +226,8 @@ def test_bloq_as_cirq_gate_for_mod_exp():
mod_exp = ModExp.make_for_shor(4, 3)
gate = BloqAsCirqGate(mod_exp)
# Use Cirq's infrastructure to construct an operation and corresponding decomposition.
quregs = get_named_qubits(gate.signature)
op = gate.on_registers(**quregs)
quregs = get_named_qubits(mod_exp.signature)
op = gate.on(*merge_qubits(mod_exp.signature, **quregs))
# cirq.decompose_once(op) delegates to underlying Bloq's decomposition specified in
# `bloq.decompose_bloq()` and wraps resulting composite bloq in a Cirq op-tree. Note
# how `BloqAsCirqGate.decompose_with_registers()` automatically takes care of mapping
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_bloq_as_cirq_gate_for_mod_exp():
decomposed_circuit, out_regs = cbloq.to_cirq_circuit_and_quregs(exponent=quregs['exponent'])
# Whereas when directly applying a cirq gate on qubits to get an operations, we need to
# specify both input and output registers.
circuit = cirq.Circuit(gate.on_registers(**out_regs), decomposed_circuit)
circuit = cirq.Circuit(gate.on(*merge_qubits(cbloq.signature, **out_regs)), decomposed_circuit)
# Notice the newly allocated qubits _C(0) and _C(1) for output register x.
cirq.testing.assert_has_diagram(
circuit,
Expand Down
77 changes: 47 additions & 30 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _get_cirq_quregs(signature: Signature, qm: InteropQubitManager):
return ret


class CirqGateAsBloqBase(GateWithRegisters, metaclass=abc.ABCMeta):
class _CirqGateAsBloqBase(Bloq, metaclass=abc.ABCMeta):
"""A Bloq wrapper around a `cirq.Gate`"""

@property
Expand All @@ -79,23 +79,20 @@ def cirq_gate(self) -> cirq.Gate: ...

@cached_property
def signature(self) -> 'Signature':
if isinstance(self.cirq_gate, Bloq):
return self.cirq_gate.signature
nqubits = cirq.num_qubits(self.cirq_gate)
return (
Signature([Register('q', QBit(), shape=nqubits)])
if nqubits > 1
else Signature.build(q=nqubits)
)

def decompose_bloq(self) -> 'CompositeBloq':
return decompose_from_cirq_style_method(self)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: CirqQuregT
) -> cirq.OP_TREE:
op = (
self.cirq_gate.on_registers(**quregs)
if isinstance(self.cirq_gate, GateWithRegisters)
else self.cirq_gate.on(*quregs.get('q', np.array(())).flatten())
)
op = self.cirq_gate.on(*quregs.get('q', np.array(())).flatten())
try:
return cirq.decompose_once(op)
except TypeError as e:
Expand All @@ -111,38 +108,20 @@ def my_tensors(
def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', **in_quregs: 'CirqQuregT'
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
if isinstance(self.cirq_gate, GateWithRegisters):
return self.cirq_gate.as_cirq_op(qubit_manager, **in_quregs)
qubits = in_quregs.get('q', np.array([])).flatten()
return self.cirq_gate.on(*qubits), in_quregs

# Delegate all cirq-style protocols to underlying gate
def _unitary_(self):
return cirq.unitary(self.cirq_gate, default=None)

def _circuit_diagram_info_(
self, args: cirq.CircuitDiagramInfoArgs
) -> Optional[cirq.CircuitDiagramInfo]:
return cirq.circuit_diagram_info(self.cirq_gate, default=None)

def __str__(self):
return str(self.cirq_gate)

def __pow__(self, power):
return CirqGateAsBloq(gate=cirq.pow(self.cirq_gate, power))

def adjoint(self) -> 'Bloq':
return CirqGateAsBloq(gate=cirq.inverse(self.cirq_gate))

def __str__(self):
return f'cirq.{self.cirq_gate}'


@frozen
class CirqGateAsBloq(CirqGateAsBloqBase):
class CirqGateAsBloq(_CirqGateAsBloqBase):
gate: cirq.Gate

def __str__(self) -> str:
g = min(self.cirq_gate.__class__.__name__, str(self.cirq_gate), key=len)
return f'cirq.{g}'

@property
def cirq_gate(self) -> cirq.Gate:
return self.gate
Expand All @@ -153,6 +132,44 @@ def my_static_costs(self, cost_key: 'CostKey'):
if t_count is None:
raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}")
return GateCounts(t=t_count.t, rotation=t_count.rotations, clifford=t_count.clifford)
return NotImplemented


class CirqGateAsBloqMixin(_CirqGateAsBloqBase, GateWithRegisters, metaclass=abc.ABCMeta):
"""A mixin to bootstrap a bloq from a Cirq gate.

Bloq authors can inherit from this abstract class and override the `cirq_gate` property
to get a bloq adapted from the cirq gate. Authors can continue to customize the bloq
by overriding methods (like costs, string representations, ...).

This uses the same machinery as `CirqGateAsBloq`. That adapter will always be of type
`CirqGateAsBloq` and additional information cannot be annotated on the Cirq gate.

This interface can be used to bootstrap a bloq if you have a complicated operation
implemented as a Cirq gate that you cannot change to be a `GateWithRegisters` and for
which writing a native Bloq would be prohibitive.
"""

@property
@abc.abstractmethod
def cirq_gate(self) -> cirq.Gate: ...

def _unitary_(self):
return cirq.unitary(self.cirq_gate, default=None)

def _circuit_diagram_info_(
self, args: cirq.CircuitDiagramInfoArgs
) -> Optional[cirq.CircuitDiagramInfo]:
return cirq.circuit_diagram_info(self.cirq_gate, default=None)

def __str__(self):
return str(self.cirq_gate)

def __pow__(self, power):
return CirqGateAsBloq(gate=cirq.pow(self.cirq_gate, power))

def adjoint(self) -> 'Bloq':
return CirqGateAsBloq(gate=cirq.inverse(self.cirq_gate))


def _cirq_wire_symbol_to_qualtran_wire_symbol(symbol: str, side: Side) -> 'WireSymbol':
Expand Down
Loading
Loading