diff --git a/dev_tools/requirements/deps/runtime.txt b/dev_tools/requirements/deps/runtime.txt index 63b03d5a7..bf7fe82ed 100644 --- a/dev_tools/requirements/deps/runtime.txt +++ b/dev_tools/requirements/deps/runtime.txt @@ -1,5 +1,5 @@ -cirq-core==1.3.0.dev20230925183640 -cirq-ft==1.3.0.dev20230925183640 +cirq-core==1.3.0.dev20230928195458 +cirq-ft==1.3.0.dev20230928195458 # drawing pydot diff --git a/dev_tools/requirements/envs/dev.env.txt b/dev_tools/requirements/envs/dev.env.txt index 8e1f034aa..40a04732c 100644 --- a/dev_tools/requirements/envs/dev.env.txt +++ b/dev_tools/requirements/envs/dev.env.txt @@ -48,15 +48,15 @@ cachetools==5.3.1 # via cirq-ft certifi==2023.7.22 # via requests -cffi==1.15.1 +cffi==1.16.0 # via cryptography charset-normalizer==3.2.0 # via requests -cirq-core==1.3.0.dev20230925183640 +cirq-core==1.3.0.dev20230928195458 # via # -r deps/runtime.txt # cirq-ft -cirq-ft==1.3.0.dev20230925183640 +cirq-ft==1.3.0.dev20230928195458 # via -r deps/runtime.txt click==8.1.7 # via @@ -169,7 +169,7 @@ jupyter-client==8.3.1 # via # ipykernel # nbclient -jupyter-core==5.3.1 +jupyter-core==5.3.2 # via # ipykernel # jupyter-client @@ -392,7 +392,7 @@ rpds-py==0.10.3 # via # jsonschema # referencing -scipy==1.11.2 +scipy==1.11.3 # via # cirq-core # quimb @@ -467,7 +467,7 @@ tqdm==4.66.1 # via # cirq-core # quimb -traitlets==5.10.0 +traitlets==5.10.1 # via # comm # ipykernel diff --git a/dev_tools/requirements/envs/docs.env.txt b/dev_tools/requirements/envs/docs.env.txt index 485bb7987..e4f3b1ca0 100644 --- a/dev_tools/requirements/envs/docs.env.txt +++ b/dev_tools/requirements/envs/docs.env.txt @@ -66,12 +66,12 @@ charset-normalizer==3.2.0 # via # -c envs/dev.env.txt # requests -cirq-core==1.3.0.dev20230925183640 +cirq-core==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt # cirq-ft -cirq-ft==1.3.0.dev20230925183640 +cirq-ft==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -194,7 +194,7 @@ jupyter-client==8.3.1 # -c envs/dev.env.txt # ipykernel # nbclient -jupyter-core==5.3.1 +jupyter-core==5.3.2 # via # -c envs/dev.env.txt # ipykernel @@ -426,7 +426,7 @@ rpds-py==0.10.3 # -c envs/dev.env.txt # jsonschema # referencing -scipy==1.11.2 +scipy==1.11.3 # via # -c envs/dev.env.txt # cirq-core @@ -523,7 +523,7 @@ tqdm==4.66.1 # -c envs/dev.env.txt # cirq-core # quimb -traitlets==5.10.0 +traitlets==5.10.1 # via # -c envs/dev.env.txt # comm diff --git a/dev_tools/requirements/envs/format.env.txt b/dev_tools/requirements/envs/format.env.txt index 8f24ee37f..2aaf435bc 100644 --- a/dev_tools/requirements/envs/format.env.txt +++ b/dev_tools/requirements/envs/format.env.txt @@ -43,12 +43,12 @@ cachetools==5.3.1 # via # -c envs/dev.env.txt # cirq-ft -cirq-core==1.3.0.dev20230925183640 +cirq-core==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt # cirq-ft -cirq-ft==1.3.0.dev20230925183640 +cirq-ft==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -137,7 +137,7 @@ jupyter-client==8.3.1 # via # -c envs/dev.env.txt # nbclient -jupyter-core==5.3.1 +jupyter-core==5.3.2 # via # -c envs/dev.env.txt # jupyter-client @@ -319,7 +319,7 @@ rpds-py==0.10.3 # -c envs/dev.env.txt # jsonschema # referencing -scipy==1.11.2 +scipy==1.11.3 # via # -c envs/dev.env.txt # cirq-core @@ -368,7 +368,7 @@ tqdm==4.66.1 # -c envs/dev.env.txt # cirq-core # quimb -traitlets==5.10.0 +traitlets==5.10.1 # via # -c envs/dev.env.txt # comm diff --git a/dev_tools/requirements/envs/pylint.env.txt b/dev_tools/requirements/envs/pylint.env.txt index 695926bc4..e039002cc 100644 --- a/dev_tools/requirements/envs/pylint.env.txt +++ b/dev_tools/requirements/envs/pylint.env.txt @@ -39,12 +39,12 @@ cachetools==5.3.1 # via # -c envs/dev.env.txt # cirq-ft -cirq-core==1.3.0.dev20230925183640 +cirq-core==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt # cirq-ft -cirq-ft==1.3.0.dev20230925183640 +cirq-ft==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -129,7 +129,7 @@ jupyter-client==8.3.1 # via # -c envs/dev.env.txt # nbclient -jupyter-core==5.3.1 +jupyter-core==5.3.2 # via # -c envs/dev.env.txt # jupyter-client @@ -315,7 +315,7 @@ rpds-py==0.10.3 # -c envs/dev.env.txt # jsonschema # referencing -scipy==1.11.2 +scipy==1.11.3 # via # -c envs/dev.env.txt # cirq-core @@ -367,7 +367,7 @@ tqdm==4.66.1 # -c envs/dev.env.txt # cirq-core # quimb -traitlets==5.10.0 +traitlets==5.10.1 # via # -c envs/dev.env.txt # comm diff --git a/dev_tools/requirements/envs/pytest.env.txt b/dev_tools/requirements/envs/pytest.env.txt index f91182f07..f12a86dc9 100644 --- a/dev_tools/requirements/envs/pytest.env.txt +++ b/dev_tools/requirements/envs/pytest.env.txt @@ -35,12 +35,12 @@ cachetools==5.3.1 # via # -c envs/dev.env.txt # cirq-ft -cirq-core==1.3.0.dev20230925183640 +cirq-core==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt # cirq-ft -cirq-ft==1.3.0.dev20230925183640 +cirq-ft==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -141,7 +141,7 @@ jupyter-client==8.3.1 # -c envs/dev.env.txt # ipykernel # nbclient -jupyter-core==5.3.1 +jupyter-core==5.3.2 # via # -c envs/dev.env.txt # ipykernel @@ -349,7 +349,7 @@ rpds-py==0.10.3 # -c envs/dev.env.txt # jsonschema # referencing -scipy==1.11.2 +scipy==1.11.3 # via # -c envs/dev.env.txt # cirq-core @@ -399,7 +399,7 @@ tqdm==4.66.1 # -c envs/dev.env.txt # cirq-core # quimb -traitlets==5.10.0 +traitlets==5.10.1 # via # -c envs/dev.env.txt # comm diff --git a/dev_tools/requirements/envs/runtime.env.txt b/dev_tools/requirements/envs/runtime.env.txt index 6ad09f06b..be5e141a2 100644 --- a/dev_tools/requirements/envs/runtime.env.txt +++ b/dev_tools/requirements/envs/runtime.env.txt @@ -35,12 +35,12 @@ cachetools==5.3.1 # via # -c envs/dev.env.txt # cirq-ft -cirq-core==1.3.0.dev20230925183640 +cirq-core==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt # cirq-ft -cirq-ft==1.3.0.dev20230925183640 +cirq-ft==1.3.0.dev20230928195458 # via # -c envs/dev.env.txt # -r deps/runtime.txt @@ -117,7 +117,7 @@ jupyter-client==8.3.1 # via # -c envs/dev.env.txt # nbclient -jupyter-core==5.3.1 +jupyter-core==5.3.2 # via # -c envs/dev.env.txt # jupyter-client @@ -290,7 +290,7 @@ rpds-py==0.10.3 # -c envs/dev.env.txt # jsonschema # referencing -scipy==1.11.2 +scipy==1.11.3 # via # -c envs/dev.env.txt # cirq-core @@ -334,7 +334,7 @@ tqdm==4.66.1 # -c envs/dev.env.txt # cirq-core # quimb -traitlets==5.10.0 +traitlets==5.10.1 # via # -c envs/dev.env.txt # comm diff --git a/qualtran/_infra/composite_bloq.py b/qualtran/_infra/composite_bloq.py index a9a1899ac..59db334ba 100644 --- a/qualtran/_infra/composite_bloq.py +++ b/qualtran/_infra/composite_bloq.py @@ -140,7 +140,7 @@ def to_cirq_circuit( """ import cirq - from qualtran.cirq_interop._cirq_interop import _cbloq_to_cirq_circuit + from qualtran.cirq_interop._bloq_to_cirq import _cbloq_to_cirq_circuit if qubit_manager is None: qubit_manager = cirq.ops.SimpleQubitManager() diff --git a/qualtran/bloqs/swap_network_cirq_test.py b/qualtran/bloqs/swap_network_cirq_test.py index 8f8198c1f..f1cebc0e2 100644 --- a/qualtran/bloqs/swap_network_cirq_test.py +++ b/qualtran/bloqs/swap_network_cirq_test.py @@ -74,7 +74,7 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe def test_swap_with_zero_gate_diagram(): gate = SwapWithZeroGate(3, 2, 4) q = cirq.LineQubit.range(cirq.num_qubits(gate)) - circuit = cirq.Circuit(gate.on_registers(**cirq_ft.infra.split_qubits(gate.registers, q))) + circuit = cirq.Circuit(gate.on_registers(**cirq_ft.infra.split_qubits(gate.signature, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/qualtran/cirq_interop/__init__.py b/qualtran/cirq_interop/__init__.py index 170c99962..4c84e8569 100644 --- a/qualtran/cirq_interop/__init__.py +++ b/qualtran/cirq_interop/__init__.py @@ -12,15 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Interoperability with Cirq +"""Bi-directional interop between Qualtran & Cirq using Cirq-FT. isort:skip_file """ -from ._cirq_interop import ( - CirqQuregT, - CirqGateAsBloq, - BloqAsCirqGate, - cirq_optree_to_cbloq, - decompose_from_cirq_op, -) +from ._cirq_to_bloq import CirqQuregT, CirqGateAsBloq, cirq_optree_to_cbloq, decompose_from_cirq_op + +from ._bloq_to_cirq import BloqAsCirqGate diff --git a/qualtran/cirq_interop/_bloq_to_cirq.py b/qualtran/cirq_interop/_bloq_to_cirq.py new file mode 100644 index 000000000..dd545901e --- /dev/null +++ b/qualtran/cirq_interop/_bloq_to_cirq.py @@ -0,0 +1,295 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qualtran Bloqs to Cirq gates/circuits conversion.""" + +from functools import cached_property +from typing import Callable, Dict, Iterable, List, Optional, Tuple + +import cirq +import cirq_ft +import networkx as nx +import numpy as np + +from qualtran import Bloq, Connection, LeftDangle, Register, RightDangle, Side, Signature, Soquet +from qualtran._infra.composite_bloq import _binst_to_cxns +from qualtran.cirq_interop._cirq_to_bloq import _QReg, CirqQuregInT, CirqQuregT + + +class BloqAsCirqGate(cirq_ft.GateWithRegisters): + """A shim for using bloqs in a Cirq circuit. + + Args: + bloq: The bloq to wrap. + reg_to_wires: an optional callable to produce a list of wire symbols for each register + to match Cirq diagrams. + """ + + def __init__(self, bloq: Bloq, reg_to_wires: Optional[Callable[[Register], List[str]]] = None): + for _, regs in bloq.signature.groups(): + if len(regs) > 1: + raise ValueError( + f"Automated cirq conversion doesn't support multiple registers with same name." + f" Found {regs}\n. Please override `bloq.as_cirq_op` for `{bloq=}` instead." + ) + self._bloq = bloq + self._reg_to_wires = reg_to_wires + + @property + def bloq(self) -> Bloq: + """The bloq we're wrapping.""" + return self._bloq + + @cached_property + def signature(self) -> cirq_ft.Signature: + """`cirq_ft.GateWithRegisters` registers.""" + legacy_regs: List[cirq_ft.Register] = [] + for reg in self.bloq.signature: + legacy_regs.append( + cirq_ft.Register( + name=reg.name, + shape=reg.shape, + bitsize=reg.bitsize, + side=cirq_ft.infra.Side(reg.side.value), + ) + ) + return cirq_ft.Signature(legacy_regs) + + @classmethod + def bloq_on( + cls, bloq: Bloq, cirq_quregs: Dict[str, 'CirqQuregT'], qubit_manager: cirq.QubitManager + ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: + """Shim `bloq` into a cirq gate and call it on `cirq_quregs`. + + This is used as a default implementation for `Bloq.as_cirq_op` if a native + cirq conversion is not specified. + + Args: + bloq: The bloq to be wrapped with `BloqAsCirqGate` + cirq_quregs: The cirq qubit registers on which we call the gate. Should correspond to + registers in `self.bloq.signature.lefts()`. + qubit_manager: A `cirq.QubitManager` to allocate new qubits. + + Returns: + op: A cirq operation whose gate is the `BloqAsCirqGate`-wrapped version of `bloq`. + cirq_quregs: The output cirq qubit registers. + """ + return _construct_op_from_gate( + BloqAsCirqGate(bloq=bloq), in_quregs=cirq_quregs, qubit_manager=qubit_manager + ) + + def decompose_from_registers( + self, context: cirq.DecompositionContext, **quregs: CirqQuregT + ) -> cirq.OP_TREE: + """Implementation of the GatesWithRegisters decompose method. + + This delegates to `self.bloq.decompose_bloq()` and converts the result to a cirq circuit. + + Args: + context: `cirq.DecompositionContext` stores options for decomposing gates (eg: + cirq.QubitManager). + **quregs: Appropriately shaped qubit arrays corresponding to Cirq-FT registers defined + as per `self.signature`. + + Returns: + A cirq circuit containing the cirq-exported version of the bloq decomposition. + """ + cbloq = self._bloq.decompose_bloq() + circuit, out_quregs = cbloq.to_cirq_circuit(qubit_manager=context.qubit_manager, **quregs) + qubit_map = {q: q for q in circuit.all_qubits()} + for reg in self.bloq.signature.rights(): + if reg.side == Side.RIGHT: + # Right only registers can get mapped to newly allocated output qubits in `out_regs`. + # Map them back to the original system qubits and deallocate newly allocated qubits. + assert reg.name in quregs and reg.name in out_quregs + assert quregs[reg.name].shape == out_quregs[reg.name].shape + context.qubit_manager.qfree([q for q in out_quregs[reg.name].flatten()]) + qubit_map |= zip(out_quregs[reg.name].flatten(), quregs[reg.name].flatten()) + return circuit.unfreeze(copy=False).transform_qubits(qubit_map) + + def _t_complexity_(self): + """Delegate to the bloq's t complexity.""" + return self._bloq.t_complexity() + + def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: + """Draw cirq diagrams. + + By default, we label each qubit with its register name. If `reg_to_wires` was provided + in the class constructor, we use that to get a list of wire symbols for each register. + """ + + if self._reg_to_wires is not None: + reg_to_wires = self._reg_to_wires + else: + reg_to_wires = lambda reg: [reg.name] * reg.total_bits() + + wire_symbols = [] + for reg in self._bloq.signature: + symbs = reg_to_wires(reg) + assert len(symbs) == reg.total_bits() + wire_symbols.extend(symbs) + if self._reg_to_wires is None: + wire_symbols[0] = self._bloq.pretty_name() + return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) + + def __eq__(self, other): + if not isinstance(other, BloqAsCirqGate): + return False + return self.bloq == other.bloq + + def __hash__(self): + return hash(self.bloq) + + def __str__(self) -> str: + return f'bloq.{self.bloq}' + + def __repr__(self) -> str: + return f'BloqAsCirqGate({self.bloq})' + + +def _track_soq_name_changes(cxns: Iterable[Connection], qvar_to_qreg: Dict[Soquet, _QReg]): + """Track inter-Bloq name changes across the two ends of a connection.""" + for cxn in cxns: + qvar_to_qreg[cxn.right] = qvar_to_qreg[cxn.left] + del qvar_to_qreg[cxn.left] + + +def _bloq_to_cirq_op( + bloq: Bloq, + pred_cxns: Iterable[Connection], + succ_cxns: Iterable[Connection], + qvar_to_qreg: Dict[Soquet, _QReg], + qubit_manager: cirq.QubitManager, +) -> cirq.Operation: + _track_soq_name_changes(pred_cxns, qvar_to_qreg) + in_quregs: Dict[str, CirqQuregT] = { + reg.name: np.empty((*reg.shape, reg.bitsize), dtype=object) + for reg in bloq.signature.lefts() + } + # Construct the cirq qubit registers using input / output connections. + # 1. All input Soquets should already have the correct mapping in `qvar_to_qreg`. + for cxn in pred_cxns: + soq = cxn.right + assert soq in qvar_to_qreg, f"{soq=} should exist in {qvar_to_qreg=}." + in_quregs[soq.reg.name][soq.idx] = qvar_to_qreg[soq].qubits + if soq.reg.side == Side.LEFT: + # Remove soquets for LEFT registers from qvar_to_qreg mapping. + del qvar_to_qreg[soq] + + op, out_quregs = bloq.as_cirq_op(qubit_manager=qubit_manager, **in_quregs) + + # 2. Update the mappings based on output soquets and `out_quregs`. + for cxn in succ_cxns: + soq = cxn.left + assert soq.reg.name in out_quregs, f"{soq=} should exist in {out_quregs=}." + if soq.reg.side == Side.RIGHT: + qvar_to_qreg[soq] = _QReg(out_quregs[soq.reg.name][soq.idx]) + return op + + +def _cbloq_to_cirq_circuit( + signature: Signature, + cirq_quregs: Dict[str, 'CirqQuregInT'], + binst_graph: nx.DiGraph, + qubit_manager: cirq.QubitManager, +) -> Tuple[cirq.FrozenCircuit, Dict[str, 'CirqQuregT']]: + """Propagate `as_cirq_op` calls through a composite bloq's contents to export a `cirq.Circuit`. + + Args: + signature: The cbloq's signature for validating inputs and outputs. + cirq_quregs: Mapping from left register name to Cirq qubit arrays. + binst_graph: The cbloq's binst graph. This is read only. + qubit_manager: A `cirq.QubitManager` to allocate new qubits. + + Returns: + circuit: The cirq.FrozenCircuit version of this composite bloq. + cirq_quregs: The output mapping from right register names to Cirq qubit arrays. + """ + cirq_quregs = {k: np.apply_along_axis(_QReg, -1, v) for k, v in cirq_quregs.items()} + qvar_to_qreg: Dict[Soquet, _QReg] = { + Soquet(LeftDangle, idx=idx, reg=reg): cirq_quregs[reg.name][idx] + for reg in signature.lefts() + for idx in reg.all_idxs() + } + moments: List[cirq.Moment] = [] + for binsts in nx.topological_generations(binst_graph): + moment: List[cirq.Operation] = [] + + for binst in binsts: + if binst is LeftDangle: + continue + pred_cxns, succ_cxns = _binst_to_cxns(binst, binst_graph=binst_graph) + if binst is RightDangle: + _track_soq_name_changes(pred_cxns, qvar_to_qreg) + continue + + op = _bloq_to_cirq_op(binst.bloq, pred_cxns, succ_cxns, qvar_to_qreg, qubit_manager) + if op is not None: + moment.append(op) + if moment: + moments.append(cirq.Moment(moment)) + + # Find output Cirq quregs using `qvar_to_qreg` mapping for registers in `signature.rights()`. + def _f_quregs(reg: Register) -> CirqQuregT: + ret = np.empty(reg.shape + (reg.bitsize,), dtype=object) + for idx in reg.all_idxs(): + soq = Soquet(RightDangle, idx=idx, reg=reg) + ret[idx] = qvar_to_qreg[soq].qubits + return ret + + out_quregs = {reg.name: _f_quregs(reg) for reg in signature.rights()} + + return cirq.FrozenCircuit(moments), out_quregs + + +def _construct_op_from_gate( + gate: cirq_ft.GateWithRegisters, + in_quregs: Dict[str, 'CirqQuregT'], + qubit_manager: cirq.QubitManager, +) -> Tuple[cirq.Operation, Dict[str, 'CirqQuregT']]: + """Allocates / Deallocates qubits for RIGHT / LEFT only registers to construct a Cirq operation + + Args: + gate: A `cirq_ft.GateWithRegisters` which specifies a signature. + in_quregs: Mapping from LEFT register names of `gate` and corresponding cirq qubits. + qubit_manager: For allocating / deallocating qubits for RIGHT / LEFT only registers. + + Returns: + A cirq operation constructed using `gate` and a mapping from RIGHT register names to + corresponding Cirq qubits. + """ + all_quregs: Dict[str, 'CirqQuregT'] = {} + out_quregs: Dict[str, 'CirqQuregT'] = {} + for reg in gate.signature: + full_shape = reg.shape + (reg.bitsize,) + if reg.side & cirq_ft.infra.Side.LEFT: + if reg.name not in in_quregs or in_quregs[reg.name].shape != full_shape: + # Left registers should exist as input to `as_cirq_op`. + raise ValueError(f'Compatible {reg=} must exist in {in_quregs=}') + all_quregs[reg.name] = in_quregs[reg.name] + if reg.side == cirq_ft.infra.Side.RIGHT: + # Right only registers will get allocated as part of `as_cirq_op`. + if reg.name in in_quregs: + raise ValueError(f"RIGHT register {reg=} shouldn't exist in {in_quregs=}.") + all_quregs[reg.name] = np.array(qubit_manager.qalloc(reg.total_bits())).reshape( + full_shape + ) + if reg.side == cirq_ft.infra.Side.LEFT: + # LEFT only registers should be de-allocated and not be part of output. + qubit_manager.qfree(in_quregs[reg.name].flatten()) + + if reg.side & cirq_ft.infra.Side.RIGHT: + # Right registers should be part of the output. + out_quregs[reg.name] = all_quregs[reg.name] + return gate.on_registers(**all_quregs), out_quregs diff --git a/qualtran/cirq_interop/_cirq_interop_test.py b/qualtran/cirq_interop/_bloq_to_cirq_test.py similarity index 53% rename from qualtran/cirq_interop/_cirq_interop_test.py rename to qualtran/cirq_interop/_bloq_to_cirq_test.py index 05ae4ac03..b6cb15b2d 100644 --- a/qualtran/cirq_interop/_cirq_interop_test.py +++ b/qualtran/cirq_interop/_bloq_to_cirq_test.py @@ -11,145 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Dict, Tuple -import attr import cirq import cirq_ft import numpy as np import pytest -import sympy from attrs import frozen -import qualtran -from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature, Soquet, SoquetT +from qualtran import Bloq, BloqBuilder, Signature, Soquet, SoquetT from qualtran.bloqs.and_bloq import MultiAnd from qualtran.bloqs.basic_gates import XGate +from qualtran.bloqs.factoring import ModExp from qualtran.bloqs.swap_network import SwapWithZero -from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Split -from qualtran.cirq_interop import ( - BloqAsCirqGate, - cirq_optree_to_cbloq, - CirqGateAsBloq, - CirqQuregT, - decompose_from_cirq_op, -) +from qualtran.cirq_interop._bloq_to_cirq import _construct_op_from_gate, BloqAsCirqGate, CirqQuregT from qualtran.testing import execute_notebook -def test_cirq_gate(): - x = CirqGateAsBloq(cirq.X) - rx = CirqGateAsBloq(cirq.Rx(rads=0.123 * np.pi)) - toffoli = CirqGateAsBloq(cirq.TOFFOLI) - - for b in [x, rx, toffoli]: - assert len(b.signature) == 1 - assert b.signature[0].side == Side.THRU - - assert x.signature[0].shape == (1,) - assert toffoli.signature[0].shape == (3,) - - assert str(x) == 'CirqGateAsBloq(gate=cirq.X)' - assert x.pretty_name() == 'cirq.X' - assert x.short_name() == 'cirq.X' - - assert rx.pretty_name() == 'cirq.Rx(0.123π)' - assert rx.short_name() == 'cirq.Rx' - - assert toffoli.pretty_name() == 'cirq.TOFFOLI' - assert toffoli.short_name() == 'cirq.TOFFOLI' - - -def test_cirq_circuit_to_cbloq(): - qubits = cirq.LineQubit.range(6) - circuit = cirq.testing.random_circuit(qubits, n_moments=7, op_density=1.0, random_state=52) - cbloq = cirq_optree_to_cbloq(circuit) - - bloq_unitary = cbloq.tensor_contract() - cirq_unitary = circuit.unitary(qubits) - np.testing.assert_allclose(cirq_unitary, bloq_unitary, atol=1e-8) - - -def test_cbloq_to_cirq_circuit(): - qubits = cirq.LineQubit.range(6) - circuit = cirq.testing.random_circuit(qubits, n_moments=7, op_density=1.0, random_state=52) - cbloq = cirq_optree_to_cbloq(circuit) - - # important! we lose moment structure - circuit = cirq.Circuit(circuit.all_operations()) - - # Note: a 1d `shape` bloq register is actually two-dimensional in cirq-world - # because of the implicit `bitsize` dimension (which must be explicit in cirq-world). - # CirqGate has registers of bitsize=1 and shape=(n,); hence the list transpose below. - circuit2, _ = cbloq.to_cirq_circuit( - **{'qubits': [[q] for q in qubits]}, qubit_manager=cirq.ops.SimpleQubitManager() - ) - - assert circuit == circuit2 - - -def test_cirq_optree_to_cbloq(): - @attr.frozen - class CirqGateWithRegisters(cirq_ft.GateWithRegisters): - reg: cirq_ft.Register - - @property - def registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers([self.reg]) - - reg1 = cirq_ft.Register('x', shape=(3, 4), bitsize=2) - reg2 = cirq_ft.Register('y', shape=12, bitsize=2) - anc_reg = cirq_ft.Register('anc', shape=4, bitsize=2) - qubits = cirq.LineQubit.range(24) - anc_qubits = cirq.NamedQubit.range(4, prefix='anc') - circuit = cirq.Circuit( - CirqGateWithRegisters(reg1).on(*qubits), - CirqGateWithRegisters(anc_reg).on(*anc_qubits, *qubits[:4]), - CirqGateWithRegisters(reg2).on(*qubits), - ) - # Test-1: When no signature is specified, the method uses a default signature. Ancilla qubits - # are also included in the signature itself, so no allocations / deallocations are needed. - cbloq = cirq_optree_to_cbloq(circuit) - assert cbloq.signature == qualtran.Signature( - [qualtran.Register(name='qubits', bitsize=1, shape=(28,))] - ) - bloq_instances = [binst for binst, _, _ in cbloq.iter_bloqnections()] - assert all(bloq_instances[i].bloq == Join(2) for i in range(14)) - assert bloq_instances[14].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg1)) - assert bloq_instances[14].bloq.signature == qualtran.Signature( - [qualtran.Register(name='x', bitsize=2, shape=(3, 4))] - ) - assert bloq_instances[15].bloq == CirqGateAsBloq(CirqGateWithRegisters(anc_reg)) - assert bloq_instances[15].bloq.signature == qualtran.Signature( - [qualtran.Register(name='anc', bitsize=2, shape=(4,))] - ) - assert bloq_instances[16].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg2)) - assert bloq_instances[16].bloq.signature == qualtran.Signature( - [qualtran.Register(name='y', bitsize=2, shape=(12,))] - ) - assert all(bloq_instances[-i].bloq == Split(2) for i in range(1, 15)) - # Test-2: If you provide an explicit signature, you must also provide a mapping of cirq qubits - # matching the signature. The additional ancilla allocations are automatically handled. - new_signature = qualtran.Signature( - [ - qualtran.Register('xx', bitsize=3, shape=(3, 2)), - qualtran.Register('yy', bitsize=1, shape=(2, 3)), - ] - ) - cirq_quregs = { - 'xx': np.asarray(qubits[:18]).reshape((3, 2, 3)), - 'yy': np.asarray(qubits[18:]).reshape((2, 3, 1)), - } - cbloq = cirq_optree_to_cbloq(circuit, signature=new_signature, cirq_quregs=cirq_quregs) - assert cbloq.signature == new_signature - # Splits, joins, Alloc, Free are automatically inserted. - bloqs_list = [binst.bloq for binst in cbloq.bloq_instances] - assert bloqs_list.count(Split(3)) == 6 - assert bloqs_list.count(Join(3)) == 6 - assert bloqs_list.count(Allocate(4)) == 1 - assert bloqs_list.count(Free(4)) == 1 - - @frozen class SwapTwoBitsTest(Bloq): @property @@ -238,6 +117,38 @@ def test_multi_and_allocates(): assert sorted(out_quregs.keys()) == ['ctrl', 'junk', 'target'] +def test_contruct_op_from_gate(): + and_gate = cirq_ft.And() + in_quregs = {'ctrl': np.array([*cirq.LineQubit.range(2)]).reshape(2, 1)} + qm = cirq.ops.SimpleQubitManager() + # Allocates new qubits for RIGHT only registers. + op, out_quregs = _construct_op_from_gate(and_gate, in_quregs, qm) + assert len(out_quregs['target']) == 1 + assert op == and_gate.on_registers(**out_quregs) + # Deallocates qubits for LEFT only registers. + and_inv = cirq_ft.And(adjoint=True) + op, inv_out_quregs = _construct_op_from_gate(and_inv, out_quregs, qm) + assert inv_out_quregs == in_quregs + assert op == and_inv.on_registers(**out_quregs) + + +def test_construct_op_from_gate_raises(): + and_gate = cirq_ft.And() + qm = cirq.ops.SimpleQubitManager() + q = [*cirq.LineQubit.range(2)] + in_quregs = {} + with pytest.raises(ValueError, match='Compatible reg.*must exist'): + _ = _construct_op_from_gate(and_gate, in_quregs, qm) + + in_quregs = {'ctrl': np.array(q)} + with pytest.raises(ValueError, match='Compatible reg.*must exist'): + _ = _construct_op_from_gate(and_gate, in_quregs, qm) + + in_quregs = {'ctrl': np.array(q).reshape(2, 1), 'target': np.array([cirq.q('t')])} + with pytest.raises(ValueError, match='RIGHT register.*shouldn\'t exist in'): + _ = _construct_op_from_gate(and_gate, in_quregs, qm) + + def test_bloq_as_cirq_gate_left_register(): bb = BloqBuilder() q = bb.allocate(1) @@ -248,48 +159,6 @@ def test_bloq_as_cirq_gate_left_register(): cirq.testing.assert_has_diagram(circuit, """_c(0): ───Allocate───X───Free───""") -@frozen -class TestCNOT(Bloq): - @property - def signature(self) -> Signature: - return Signature.build(control=1, target=1) - - def decompose_bloq(self) -> 'CompositeBloq': - return decompose_from_cirq_op(self) - - def as_cirq_op( - self, qubit_manager: cirq.QubitManager, **cirq_quregs: 'CirqQuregT' - ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: - (control,) = cirq_quregs['control'] - (target,) = cirq_quregs['target'] - return cirq.CNOT(control, target), cirq_quregs - - -@frozen -class TestCNOTSymbolic(TestCNOT): - @property - def signature(self) -> Signature: - c, t = sympy.Symbol('c'), sympy.Symbol('t') - return Signature.build(control=c, target=t) - - -def test_bloq_decompose_from_cirq_op(): - tb = TestCNOT() - assert len(tb.signature) == 2 - ctrl, trg = tb.signature - assert ctrl.bitsize == 1 - assert ctrl.side == Side.THRU - assert tb.pretty_name() == 'TestCNOT' - - cirq_quregs = tb.signature.get_cirq_quregs() - circuit, _ = tb.decompose_bloq().to_cirq_circuit(**cirq_quregs) - assert circuit == cirq.Circuit(cirq.CNOT(*cirq_quregs['control'], *cirq_quregs['target'])) - assert tb.t_complexity() == cirq_ft.TComplexity(clifford=1) - - with pytest.raises(NotImplementedError): - TestCNOTSymbolic().decompose_bloq() - - def test_bloq_as_cirq_gate_multi_dimensional_signature(): bloq = SwapWithZero(2, 3, 4) cirq_quregs = bloq.signature.get_cirq_quregs() @@ -361,5 +230,63 @@ def test_bloq_as_cirq_gate_multi_dimensional_signature(): ) +def test_bloq_as_cirq_gate_for_mod_exp(): + # ModExp is a good test because, similar to And gate, it has a RIGHT only register. + # but also has a decomposition specified. + mod_exp = ModExp.make_for_shor(4, 3) + gate = BloqAsCirqGate(mod_exp) + # Use Cirq's infrastructure to construct an operation and corresponding decomposition. + quregs = cirq_ft.infra.get_named_qubits(gate.signature) + op = gate.on_registers(**quregs) + # cirq.decompose_once(op) delegates to underlying Bloq's decomposition specified in + # `bloq.decompose_bloq()` and wraps resulting composite bloq in a Cirq op-tree. Note + # how `BloqAsCirqGate.decompose_with_registers()` automatically takes care of mapping + # newly allocated RIGHT registers in the decomposition to the one's specified by the user + # when constructing the original operation (in this case, register `x`). + circuit = cirq.Circuit(op, cirq.decompose_once(op)) + assert cirq_ft.t_complexity(circuit) == 2 * mod_exp.t_complexity() + cirq.testing.assert_has_diagram( + circuit, + ''' +exponent0: ───ModExp──────────────────────────────────────────────────CtrlModMul─── + │ │ +exponent1: ───exponent───────────────────────────────────CtrlModMul───┼──────────── + │ │ │ +exponent2: ───exponent──────────────────────CtrlModMul───┼────────────┼──────────── + │ │ │ │ +exponent3: ───exponent─────────CtrlModMul───┼────────────┼────────────┼──────────── + │ │ │ │ │ +x0: ──────────x──────────|1>───x────────────x────────────x────────────x──────────── + │ │ │ │ │ │ +x1: ──────────x──────────val───x────────────x────────────x────────────x──────────── +''', + ) + # Alternatively, decompose the Bloq and then convert the composite Bloq to a Cirq circuit. + cbloq = mod_exp.decompose_bloq() + # When converting a composite Bloq to a Cirq circuit, we only need to specify the input + # registers. + decomposed_circuit, out_regs = cbloq.to_cirq_circuit(exponent=quregs['exponent']) + # Whereas when directly applying a cirq gate on qubits to get an operations, we need to + # specify both input and output registers. + circuit = cirq.Circuit(gate.on_registers(**out_regs), decomposed_circuit) + assert cirq_ft.t_complexity(circuit) == 2 * mod_exp.t_complexity() + # Notice the newly allocated qubits _C(0) and _C(1) for output register x. + cirq.testing.assert_has_diagram( + circuit, + ''' +_c(0): ───────x──────────|1>───x────────────x────────────x────────────x──────────── + │ │ │ │ │ │ +_c(1): ───────x──────────val───x────────────x────────────x────────────x──────────── + │ │ │ │ │ +exponent0: ───ModExp───────────┼────────────┼────────────┼────────────CtrlModMul─── + │ │ │ │ +exponent1: ───exponent─────────┼────────────┼────────────CtrlModMul──────────────── + │ │ │ +exponent2: ───exponent─────────┼────────────CtrlModMul───────────────────────────── + │ │ +exponent3: ───exponent─────────CtrlModMul──────────────────────────────────────────''', + ) + + def test_notebook(): execute_notebook('cirq_interop') diff --git a/qualtran/cirq_interop/_cirq_interop.py b/qualtran/cirq_interop/_cirq_interop.py deleted file mode 100644 index c4f5978cc..000000000 --- a/qualtran/cirq_interop/_cirq_interop.py +++ /dev/null @@ -1,606 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functionality for the `Bloq.as_cirq_op(...)` protocol""" -import itertools -from functools import cached_property -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union - -import cirq -import cirq_ft -import networkx as nx -import numpy as np -import quimb.tensor as qtn -from attrs import field, frozen -from numpy.typing import NDArray - -from qualtran import ( - Bloq, - BloqBuilder, - BloqInstance, - CompositeBloq, - Connection, - DanglingT, - LeftDangle, - Register, - RightDangle, - Side, - Signature, - Soquet, - SoquetT, -) -from qualtran._infra.composite_bloq import _binst_to_cxns - -CirqQuregT = NDArray[cirq.Qid] -CirqQuregInT = Union[NDArray[cirq.Qid], Sequence[cirq.Qid]] - - -def signature_from_cirq_registers(registers: Iterable[cirq_ft.Register]) -> 'Signature': - return Signature( - [Register(reg.name, bitsize=reg.bitsize, shape=reg.shape) for reg in registers] - ) - - -@frozen -class CirqGateAsBloq(Bloq): - """A Bloq wrapper around a `cirq.Gate`. - - This bloq has one thru-register named "qubits", which is a 1D array of soquets - representing individual qubits. - """ - - gate: cirq.Gate - - def pretty_name(self) -> str: - return f'cirq.{self.gate}' - - def short_name(self) -> str: - g = min(self.gate.__class__.__name__, str(self.gate), key=len) - return f'cirq.{g}' - - @cached_property - def signature(self) -> 'Signature': - return signature_from_cirq_registers(self.cirq_registers) - - @cached_property - def cirq_registers(self) -> cirq_ft.Registers: - if isinstance(self.gate, cirq_ft.GateWithRegisters): - return self.gate.registers - else: - return cirq_ft.Registers( - [cirq_ft.Register('qubits', shape=(cirq.num_qubits(self.gate),), bitsize=1)] - ) - - def decompose_bloq(self) -> 'CompositeBloq': - quregs = self.signature.get_cirq_quregs() - qubit_manager = cirq.ops.SimpleQubitManager() - cirq_op, quregs = self.as_cirq_op(qubit_manager, **quregs) - context = cirq.DecompositionContext(qubit_manager=qubit_manager) - decomposed_optree = cirq.decompose_once(cirq_op, context=context, default=None) - if decomposed_optree is None: - raise NotImplementedError(f"{self} does not support decomposition.") - return cirq_optree_to_cbloq(decomposed_optree, signature=self.signature, cirq_quregs=quregs) - - def add_my_tensors( - self, - tn: qtn.TensorNetwork, - tag: Any, - *, - incoming: Dict[str, 'SoquetT'], - outgoing: Dict[str, 'SoquetT'], - ): - new_shape = [ - *itertools.chain.from_iterable( - (2**reg.bitsize,) * int(np.prod(reg.shape)) - for reg in [*self.signature.rights(), *self.signature.lefts()] - ) - ] - unitary = cirq.unitary(self.gate).reshape(new_shape) - incoming_list = [ - *itertools.chain.from_iterable( - [np.array(incoming[reg.name]).flatten() for reg in self.signature.lefts()] - ) - ] - outgoing_list = [ - *itertools.chain.from_iterable( - [np.array(outgoing[reg.name]).flatten() for reg in self.signature.rights()] - ) - ] - - tn.add( - qtn.Tensor( - data=unitary, inds=outgoing_list + incoming_list, tags=[self.short_name(), tag] - ) - ) - - def as_cirq_op( - self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' - ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: - merged_qubits = np.concatenate( - [cirq_quregs[reg.name].flatten() for reg in self.signature.lefts()] - ) - assert len(merged_qubits) == cirq.num_qubits(self.gate) - return self.gate.on(*merged_qubits), cirq_quregs - - def t_complexity(self) -> 'cirq_ft.TComplexity': - return cirq_ft.t_complexity(self.gate) - - def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': - from qualtran.drawing import directional_text_box - - wire_symbols = cirq.circuit_diagram_info(self.gate).wire_symbols - begin = 0 - symbol: str = soq.pretty() - for reg in self.signature: - finish = begin + int(np.prod(reg.shape)) - if reg == soq.reg: - symbol = np.array(wire_symbols[begin:finish]).reshape(reg.shape)[soq.idx] - begin = finish - return directional_text_box(text=symbol, side=soq.reg.side) - - -@frozen -class _QReg: - """Used as a container for qubits that form a `cirq_ft.Register` of a given bitsize. - - Each instance of `_QReg` would correspond to a `Soquet` in Bloqs and represents and opaque collection - of qubits that together form a quantum register. - """ - - qubits: Tuple[cirq.Qid, ...] = field( - converter=lambda v: (v,) if isinstance(v, cirq.Qid) else tuple(v) - ) - - -def _ensure_in_reg_exists( - bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QReg, Soquet] -) -> None: - """Takes care of splits and joins to make sure `qreg_to_qvar[in_reg]` exists.""" - - if in_reg in qreg_to_qvar: - # This is the easy case when no split / joins are needed. - return - - # a. Split all registers containing at-least one qubit corresponding to `in_reg`. - in_reg_qubits = set(in_reg.qubits) - - new_qreg_to_qvar: Dict[_QReg, Soquet] = {} - for qreg, soq in qreg_to_qvar.items(): - if len(qreg.qubits) > 1 and any(q in qreg.qubits for q in in_reg_qubits): - new_qreg_to_qvar |= {_QReg(q): s for q, s in zip(qreg.qubits, bb.split(soq=soq))} - else: - new_qreg_to_qvar[qreg] = soq - qreg_to_qvar.clear() - - # b. Join all 1-bit registers, corresponding to individual qubits, that make up `in_reg`. - soqs_to_join = [] - for qreg, soq in new_qreg_to_qvar.items(): - if len(in_reg_qubits) > 1 and qreg.qubits and qreg.qubits[0] in in_reg_qubits: - assert len(qreg.qubits) == 1, "Individual qubits should have been split by now." - soqs_to_join.append(soq) - else: - qreg_to_qvar[qreg] = soq - if soqs_to_join: - qreg_to_qvar[in_reg] = bb.join(np.array(soqs_to_join)) - - -def _gather_input_soqs( - bb: BloqBuilder, op_quregs: Dict[str, NDArray[_QReg]], qreg_to_qvar: Dict[_QReg, Soquet] -) -> Dict[str, NDArray[Soquet]]: - qvars_in: Dict[str, NDArray[Soquet]] = {} - for reg_name, quregs in op_quregs.items(): - flat_soqs: List[Soquet] = [] - for qureg in quregs.flatten(): - _ensure_in_reg_exists(bb, qureg, qreg_to_qvar) - flat_soqs.append(qreg_to_qvar[qureg]) - qvars_in[reg_name] = np.array(flat_soqs).reshape(quregs.shape) - return qvars_in - - -def cirq_optree_to_cbloq( - optree: cirq.OP_TREE, - *, - signature: Optional[Signature] = None, - cirq_quregs: Optional[Dict[str, 'NDArray[cirq.Qid]']] = None, -) -> CompositeBloq: - """Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`. - - Each `cirq.Operation` will be wrapped into a `CirqGateAsBloq` wrapper. - The signature of the resultant CompositeBloq is `signature`, if provided. Otherwise, use - one thru-register named "qubits" of shape `(n_qubits,)`. - - For multi-dimensional registers and registers with bitsize>1, this function automatically - splits the input soquets and joins the output soquets to ensure compatibility with the - flat-list-of-qubits expected by Cirq. - - When specifying a signature, users must also specify the `cirq_quregs` argument, which is a - mapping of cirq qubits used in the OP-TREE corresponding to the `signature`. If `signature` - has registers with entry - - - `Register('x', bitsize=2, shape=(3, 4))` and - - `Register('y', bitsize=1, shape=(10, 20))` - - then `cirq_quregs` should have one entry corresponding to each register as follows: - - - key='x'; value=`np.array(cirq_qubits_used_for_x, shape=(3, 4, 2))` and - - key='y'; value=`np.array(cirq_qubits_used_for_y, shape=(10, 20, 1))`. - """ - circuit = cirq.Circuit(optree) - if signature is None: - if cirq_quregs is not None: - raise ValueError("`cirq_quregs` requires specifying `signature`.") - all_qubits = sorted(circuit.all_qubits()) - signature = Signature([Register('qubits', 1, shape=(len(all_qubits),))]) - cirq_quregs = {'qubits': np.array(all_qubits).reshape(len(all_qubits), 1)} - elif cirq_quregs is None: - raise ValueError("`signature` requires specifying `cirq_quregs`.") - - cirq_quregs = {k: np.apply_along_axis(_QReg, -1, v) for k, v in cirq_quregs.items()} - - bb, initial_soqs = BloqBuilder.from_signature(signature, add_registers_allowed=False) - - # 1. Compute qreg_to_qvar. - qreg_to_qvar: Dict[_QReg, Soquet] = {} - for reg in signature.lefts(): - if reg.name not in cirq_quregs: - raise ValueError(f"Register {reg.name} from signature must be present in cirq_quregs.") - soqs = initial_soqs[reg.name] - if isinstance(soqs, Soquet): - soqs = np.array(soqs) - if cirq_quregs[reg.name].shape != soqs.shape: - raise ValueError( - f"Shape {cirq_quregs[reg.name].shape} of cirq register " - f"{reg.name} should be {soqs.shape}." - ) - qreg_to_qvar |= zip(cirq_quregs[reg.name].flatten(), soqs.flatten()) - - # 2. Add allocated qubits to qreg_to_qvar - all_qubits = set(q for qreg in qreg_to_qvar for q in qreg.qubits) - allocated_qubits = _QReg(sorted(circuit.all_qubits() - all_qubits)) - if allocated_qubits.qubits: - qreg_to_qvar |= {allocated_qubits: bb.allocate(len(allocated_qubits.qubits))} - - # 3. Add each operation to the bloq. - for op in circuit.all_operations(): - if op.gate is None: - raise ValueError(f"Only gate operations are supported, not {op}.") - - bloq = CirqGateAsBloq(op.gate) - # 3.1 Find input soquets. - op_quregs: Dict[str, NDArray[_QReg]] = { - k: np.apply_along_axis(_QReg, -1, v) - for k, v in cirq_ft.infra.split_qubits(bloq.cirq_registers, op.qubits).items() - } - qvars_in = _gather_input_soqs(bb, op_quregs, qreg_to_qvar) - # 3.2 Add Bloq - qvars_out = bb.add_d(bloq, **qvars_in) - - # 3.3 Update qubit registers to soquets mapping using output soquets. - for reg_name, quregs in op_quregs.items(): - qreg_to_qvar |= zip(quregs.flatten(), np.array(qvars_out[reg_name]).flatten()) - - # 4. Deallocated newly allocated qubits. - if allocated_qubits.qubits: - _ensure_in_reg_exists(bb, allocated_qubits, qreg_to_qvar) - bb.free(qreg_to_qvar.pop(allocated_qubits)) - - # 5. Combine Soquets to match the right signature. - final_soqs = _gather_input_soqs(bb, cirq_quregs, qreg_to_qvar) - return bb.finalize(**final_soqs) - - -def _get_in_cirq_quregs( - binst: BloqInstance, reg: Register, soq_assign: Dict[Soquet, 'NDArray[cirq.Qid]'] -) -> 'NDArray[cirq.Qid]': - """Pluck out the correct values from `soq_assign` for `reg` on `binst`.""" - full_shape = reg.shape + (reg.bitsize,) - arg = np.empty(full_shape, dtype=object) - - for idx in reg.all_idxs(): - soq = Soquet(binst, reg, idx=idx) - arg[idx] = soq_assign[soq] - - return arg - - -def _update_assign_from_cirq_quregs( - regs: Iterable[Register], - binst: BloqInstance, - cirq_quregs: Dict[str, CirqQuregInT], - soq_assign: Dict[Soquet, CirqQuregT], -) -> None: - """Update `soq_assign` using `cirq_quregs`. - - This helper function is responsible for error checking. We use `regs` to make sure all the - keys are present in the vals dictionary. We check the quregs shapes. - """ - unprocessed_reg_names = set(cirq_quregs.keys()) - for reg in regs: - try: - arr = cirq_quregs[reg.name] - except KeyError: - raise ValueError(f"{binst} requires an input register named {reg.name}") - unprocessed_reg_names.remove(reg.name) - - arr = np.asarray(arr) - full_shape = reg.shape + (reg.bitsize,) - if arr.shape != full_shape: - raise ValueError( - f"Incorrect shape {arr.shape} received for {binst}.{reg.name}. Expected {full_shape}." - ) - - for idx in reg.all_idxs(): - soq = Soquet(binst, reg, idx=idx) - soq_assign[soq] = arr[idx] - - if unprocessed_reg_names: - raise ValueError(f"{binst} had extra cirq_quregs: {unprocessed_reg_names}") - - -def _binst_as_cirq_op( - binst: BloqInstance, - pred_cxns: Iterable[Connection], - soq_assign: Dict[Soquet, NDArray[cirq.Qid]], - qubit_manager: cirq.QubitManager, -) -> Union[cirq.Operation, None]: - """Helper function used in `_cbloq_to_cirq_circuit`. - - Args: - binst: The current BloqInstance on which we wish to call `as_cirq_op`. - pred_cxns: Predecessor connections for the bloq instance. - soq_assign: The current assignment from soquets to cirq qubit arrays. This mapping - is mutated by this function. - qubit_manager: A `cirq.QubitManager` for allocating `cirq.Qid`s. - - Returns: - The operation resulting from `binst.bloq.as_cirq_op(...)`. - """ - # Track inter-Bloq name changes - for cxn in pred_cxns: - soq_assign[cxn.right] = soq_assign[cxn.left] - del soq_assign[cxn.left] - - def _in_vals(reg: Register) -> CirqQuregT: - # close over `binst` and `soq_assign`. - return _get_in_cirq_quregs(binst, reg, soq_assign=soq_assign) - - bloq = binst.bloq - cirq_quregs = {reg.name: _in_vals(reg) for reg in bloq.signature.lefts()} - - op, out_quregs = bloq.as_cirq_op(qubit_manager=qubit_manager, **cirq_quregs) - _update_assign_from_cirq_quregs(bloq.signature.rights(), binst, out_quregs, soq_assign) - return op - - -def decompose_from_cirq_op(bloq: 'Bloq') -> 'CompositeBloq': - """Returns a CompositeBloq constructed using Cirq operations obtained via `bloq.as_cirq_op`. - - This method first checks whether `bloq.signature` is parameterized. If yes, it raises a - NotImplementedError. If not, it uses `cirq_optree_to_cbloq` to wrap the operations obtained - from `bloq.as_cirq_op` into a `CompositeBloq` which has the same signature as `bloq` and returns - the corresponding `CompositeBloq`. - """ - - if any( - cirq.is_parameterized(reg.bitsize) or cirq.is_parameterized(reg.side) - for reg in bloq.signature - ): - raise NotImplementedError(f"{bloq} does not support decomposition.") - - cirq_quregs = bloq.signature.get_cirq_quregs() - cirq_op, cirq_quregs = bloq.as_cirq_op(cirq.ops.SimpleQubitManager(), **cirq_quregs) - if cirq_op is None or ( - isinstance(cirq_op, cirq.Operation) and isinstance(cirq_op.gate, BloqAsCirqGate) - ): - raise NotImplementedError(f"{bloq} does not support decomposition.") - - return cirq_optree_to_cbloq(cirq_op, signature=bloq.signature, cirq_quregs=cirq_quregs) - - -def _cbloq_to_cirq_circuit( - signature: Signature, - cirq_quregs: Dict[str, 'CirqQuregInT'], - binst_graph: nx.DiGraph, - qubit_manager: cirq.QubitManager, -) -> Tuple[cirq.FrozenCircuit, Dict[str, 'CirqQuregT']]: - """Propagate `as_cirq_op` calls through a composite bloq's contents to export a `cirq.Circuit`. - - Args: - signature: The cbloq's signature for validating inputs and outputs. - cirq_quregs: Mapping from left register name to Cirq qubit arrays. - binst_graph: The cbloq's binst graph. This is read only. - qubit_manager: A `cirq.QubitManager` to allocate new qubits. - - Returns: - circuit: The cirq.FrozenCircuit version of this composite bloq. - cirq_quregs: The output mapping from right register names to Cirq qubit arrays. - """ - soq_assign: Dict[Soquet, CirqQuregT] = {} - _update_assign_from_cirq_quregs(signature.lefts(), LeftDangle, cirq_quregs, soq_assign) - moments: List[cirq.Moment] = [] - for binsts in nx.topological_generations(binst_graph): - moment: List[cirq.Operation] = [] - - for binst in binsts: - if isinstance(binst, DanglingT): - continue - - pred_cxns, succ_cxns = _binst_to_cxns(binst, binst_graph=binst_graph) - op = _binst_as_cirq_op(binst, pred_cxns, soq_assign, qubit_manager=qubit_manager) - if op is not None: - moment.append(op) - if moment: - moments.append(cirq.Moment(moment)) - - # Track bloq-to-dangle name changes - if len(list(signature.rights())) > 0: - final_preds, _ = _binst_to_cxns(RightDangle, binst_graph=binst_graph) - for cxn in final_preds: - soq_assign[cxn.right] = soq_assign[cxn.left] - - # Formulate output with expected API - def _f_quregs(reg: Register): - return _get_in_cirq_quregs(RightDangle, reg, soq_assign) - - out_quregs = {reg.name: _f_quregs(reg) for reg in signature.rights()} - - return cirq.FrozenCircuit(moments), out_quregs - - -class BloqAsCirqGate(cirq_ft.GateWithRegisters): - """A shim for using bloqs in a Cirq circuit. - - Args: - bloq: The bloq to wrap. - reg_to_wires: an optional callable to produce a list of wire symbols for each register - to match Cirq diagrams. - """ - - def __init__(self, bloq: Bloq, reg_to_wires: Optional[Callable[[Register], List[str]]] = None): - self._bloq = bloq - self._legacy_regs, self._compat_name_map = self._init_legacy_regs(bloq) - self._reg_to_wires = reg_to_wires - - @property - def bloq(self) -> Bloq: - """The bloq we're wrapping.""" - return self._bloq - - @property - def registers(self) -> cirq_ft.Registers: - """`cirq_ft.GateWithRegisters` registers.""" - return self._legacy_regs - - @staticmethod - def _init_legacy_regs(bloq: Bloq) -> Tuple[cirq_ft.Registers, Mapping[str, Register]]: - """Initialize legacy registers. - - We flatten multidimensional registers and annotate non-thru registers with - modifications to their string name. - - Returns: - legacy_registers: The flattened, cirq GateWithRegisters-style registers - compat_name_map: A mapping from the compatability-shim string names of the legacy - registers back to the original (register, idx) pair. - """ - legacy_regs: List[cirq_ft.Register] = [] - side_suffixes = {Side.LEFT: '_l', Side.RIGHT: '_r', Side.THRU: ''} - compat_name_map = {} - for reg in bloq.signature: - compat_name = f'{reg.name}{side_suffixes[reg.side]}' - compat_name_map[compat_name] = reg - legacy_regs.append( - cirq_ft.Register(name=compat_name, shape=reg.shape, bitsize=reg.bitsize) - ) - return cirq_ft.Registers(legacy_regs), compat_name_map - - @classmethod - def bloq_on( - cls, bloq: Bloq, cirq_quregs: Dict[str, 'CirqQuregT'], qubit_manager: cirq.QubitManager - ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: - """Shim `bloq` into a cirq gate and call it on `cirq_quregs`. - - This is used as a default implementation for `Bloq.as_cirq_op` if a native - cirq conversion is not specified. - - Args: - bloq: The bloq to be wrapped with `BloqAsCirqGate` - cirq_quregs: The cirq qubit registers on which we call the gate. - qubit_manager: A `cirq.QubitManager` to allocate new qubits. - - Returns: - op: A cirq operation whose gate is the `BloqAsCirqGate`-wrapped version of `bloq`. - cirq_quregs: The output cirq qubit registers. - """ - bloq_quregs: Dict[str, 'CirqQuregT'] = {} - out_quregs: Dict[str, 'CirqQuregT'] = {} - for reg in bloq.signature: - if reg.side is Side.THRU: - bloq_quregs[reg.name] = cirq_quregs[reg.name] - out_quregs[reg.name] = cirq_quregs[reg.name] - elif reg.side is Side.LEFT: - bloq_quregs[f'{reg.name}_l'] = cirq_quregs[reg.name] - qubit_manager.qfree(cirq_quregs[reg.name].reshape(-1)) - del cirq_quregs[reg.name] - elif reg.side is Side.RIGHT: - new_qubits = qubit_manager.qalloc(reg.total_bits()) - full_shape = reg.shape + (reg.bitsize,) - out_quregs[reg.name] = np.array(new_qubits).reshape(full_shape) - bloq_quregs[f'{reg.name}_r'] = out_quregs[reg.name] - return BloqAsCirqGate(bloq=bloq).on_registers(**bloq_quregs), out_quregs - - def decompose_from_registers( - self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] - ) -> cirq.OP_TREE: - """Implementation of the GatesWithRegisters decompose method. - - This delegates to `self.bloq.decompose_bloq()` and converts the result to a cirq circuit. - - Args: - context: `cirq.DecompositionContext` stores options for decomposing gates (eg: - cirq.QubitManager). - **quregs: Sequences of cirq qubits as expected for the legacy register shims - of the bloq's registers. - - Returns: - A cirq circuit containing the cirq-exported version of the bloq decomposition. - """ - cbloq = self._bloq.decompose_bloq() - - cirq_quregs: Dict[str, CirqQuregT] = {} - for compat_name, qubits in quregs.items(): - reg = self._compat_name_map[compat_name] - cirq_quregs[reg.name] = np.asarray(qubits) - - circuit, _ = cbloq.to_cirq_circuit(qubit_manager=context.qubit_manager, **cirq_quregs) - return circuit - - def _t_complexity_(self): - """Delegate to the bloq's t complexity.""" - return self._bloq.t_complexity() - - def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - """Draw cirq diagrams. - - By default, we label each qubit with its register name. If `reg_to_wires` was provided - in the class constructor, we use that to get a list of wire symbols for each register. - """ - - if self._reg_to_wires is not None: - reg_to_wires = self._reg_to_wires - else: - reg_to_wires = lambda reg: [reg.name] * reg.total_bits() - - wire_symbols = [] - for reg in self._bloq.signature: - symbs = reg_to_wires(reg) - assert len(symbs) == reg.total_bits() - wire_symbols.extend(symbs) - if self._reg_to_wires is None: - wire_symbols[0] = self._bloq.pretty_name() - return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) - - def __eq__(self, other): - if not isinstance(other, BloqAsCirqGate): - return False - return self.bloq == other.bloq - - def __hash__(self): - return hash(self.bloq) - - def __str__(self) -> str: - return f'bloq.{self.bloq}' - - def __repr__(self) -> str: - return f'BloqAsCirqGate({self.bloq})' diff --git a/qualtran/cirq_interop/_cirq_to_bloq.py b/qualtran/cirq_interop/_cirq_to_bloq.py new file mode 100644 index 000000000..956e0088e --- /dev/null +++ b/qualtran/cirq_interop/_cirq_to_bloq.py @@ -0,0 +1,372 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cirq gates/circuits to Qualtran Bloqs conversion.""" +import itertools +from collections import defaultdict +from functools import cached_property +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union + +import cirq +import cirq_ft +import numpy as np +import quimb.tensor as qtn +from attrs import field, frozen +from numpy.typing import NDArray + +from qualtran import Bloq, BloqBuilder, CompositeBloq, Register, Side, Signature, Soquet, SoquetT + +if TYPE_CHECKING: + from qualtran.drawing import WireSymbol + +CirqQuregT = NDArray[cirq.Qid] +CirqQuregInT = Union[NDArray[cirq.Qid], Sequence[cirq.Qid]] + + +@frozen +class CirqGateAsBloq(Bloq): + """A Bloq wrapper around a `cirq.Gate`, preserving signature if gate is a `GateWithRegisters`.""" + + gate: cirq.Gate + + def pretty_name(self) -> str: + return f'cirq.{self.gate}' + + def short_name(self) -> str: + g = min(self.gate.__class__.__name__, str(self.gate), key=len) + return f'cirq.{g}' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register(reg.name, bitsize=reg.bitsize, shape=reg.shape, side=Side(reg.side.value)) + for reg in self.cirq_registers + ] + ) + + @cached_property + def cirq_registers(self) -> cirq_ft.Signature: + if isinstance(self.gate, cirq_ft.GateWithRegisters): + return self.gate.signature + else: + return cirq_ft.Signature( + [cirq_ft.Register('qubits', shape=(cirq.num_qubits(self.gate),), bitsize=1)] + ) + + def decompose_bloq(self) -> 'CompositeBloq': + in_quregs = self.signature.get_cirq_quregs() + qubit_manager = cirq.ops.SimpleQubitManager() + cirq_op, out_quregs = self.as_cirq_op(qubit_manager, **in_quregs) + context = cirq.DecompositionContext(qubit_manager=qubit_manager) + decomposed_optree = cirq.decompose_once(cirq_op, context=context, default=None) + if decomposed_optree is None: + raise NotImplementedError(f"{self} does not support decomposition.") + return cirq_optree_to_cbloq( + decomposed_optree, signature=self.signature, in_quregs=in_quregs, out_quregs=out_quregs + ) + + def add_my_tensors( + self, + tn: qtn.TensorNetwork, + tag: Any, + *, + incoming: Dict[str, 'SoquetT'], + outgoing: Dict[str, 'SoquetT'], + ): + if not cirq.has_unitary(self.gate): + raise NotImplementedError( + f"CirqGateAsBloq.add_my_tensors is currently supported only for unitary gates. " + f"Found {self.gate}." + ) + unitary_shape = [] + reg_to_idx = defaultdict(list) + for reg in self.cirq_registers: + start = len(unitary_shape) + for i in range(int(np.prod(reg.shape))): + reg_to_idx[reg.name].append(start + i) + unitary_shape.append(2**reg.bitsize) + + unitary_shape = (*unitary_shape, *unitary_shape) + unitary = cirq.unitary(self.gate).reshape(unitary_shape) + idx: List[Union[int, slice]] = [slice(x) for x in unitary_shape] + n = len(unitary_shape) // 2 + for reg in self.signature: + if reg.side == Side.LEFT: + for i in reg_to_idx[reg.name]: + # LEFT register ends, extract right subspace that's equivalent to 0. + idx[i] = 0 + if reg.side == Side.RIGHT: + for i in reg_to_idx[reg.name]: + # Right register begins, extract the left subspace that's equivalent to 0. + idx[i + n] = 0 + unitary = unitary[tuple(idx)] + new_shape = tuple( + [ + *itertools.chain.from_iterable( + (2**reg.bitsize,) * int(np.prod(reg.shape)) + for reg in [*self.signature.rights(), *self.signature.lefts()] + ) + ] + ) + assert unitary.shape == new_shape + incoming_list = [ + *itertools.chain.from_iterable( + [np.array(incoming[reg.name]).flatten() for reg in self.signature.lefts()] + ) + ] + outgoing_list = [ + *itertools.chain.from_iterable( + [np.array(outgoing[reg.name]).flatten() for reg in self.signature.rights()] + ) + ] + + tn.add( + qtn.Tensor( + data=unitary, inds=outgoing_list + incoming_list, tags=[self.short_name(), tag] + ) + ) + + def as_cirq_op( + self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' + ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: + from qualtran.cirq_interop._bloq_to_cirq import _construct_op_from_gate + + if not isinstance(self.gate, cirq_ft.GateWithRegisters): + return self.gate.on(*cirq_quregs['qubits'].flatten()), cirq_quregs + return _construct_op_from_gate( + self.gate, + in_quregs={k: np.array(v) for k, v in cirq_quregs.items()}, + qubit_manager=qubit_manager, + ) + + def t_complexity(self) -> 'cirq_ft.TComplexity': + return cirq_ft.t_complexity(self.gate) + + def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': + from qualtran.drawing import directional_text_box + + wire_symbols = cirq.circuit_diagram_info(self.gate).wire_symbols + begin = 0 + symbol: str = soq.pretty() + for reg in self.signature: + finish = begin + int(np.prod(reg.shape)) + if reg == soq.reg: + symbol = np.array(wire_symbols[begin:finish]).reshape(reg.shape)[soq.idx] + begin = finish + return directional_text_box(text=symbol, side=soq.reg.side) + + +@frozen +class _QReg: + """Used as a container for qubits that form a `cirq_ft.Register` of a given bitsize. + + Each instance of `_QReg` would correspond to a `Soquet` in Bloqs and represents an opaque collection + of qubits that together form a quantum register. + """ + + qubits: Tuple[cirq.Qid, ...] = field( + converter=lambda v: (v,) if isinstance(v, cirq.Qid) else tuple(v) + ) + + +def _ensure_in_reg_exists( + bb: BloqBuilder, in_reg: _QReg, qreg_to_qvar: Dict[_QReg, Soquet] +) -> None: + """Takes care of qubit allocations, split and joins to ensure `qreg_to_qvar[in_reg]` exists.""" + all_mapped_qubits = {q for qreg in qreg_to_qvar for q in qreg.qubits} + qubits_to_allocate: List[cirq.Qid] = [q for q in in_reg.qubits if q not in all_mapped_qubits] + if qubits_to_allocate: + qreg_to_qvar[_QReg(qubits_to_allocate)] = bb.allocate(len(qubits_to_allocate)) + + if in_reg in qreg_to_qvar: + # This is the easy case when no split / joins are needed. + return + + # a. Split all registers containing at-least one qubit corresponding to `in_reg`. + in_reg_qubits = set(in_reg.qubits) + + new_qreg_to_qvar: Dict[_QReg, Soquet] = {} + for qreg, soq in qreg_to_qvar.items(): + if len(qreg.qubits) > 1 and any(q in qreg.qubits for q in in_reg_qubits): + new_qreg_to_qvar |= {_QReg(q): s for q, s in zip(qreg.qubits, bb.split(soq=soq))} + else: + new_qreg_to_qvar[qreg] = soq + qreg_to_qvar.clear() + + # b. Join all 1-bit registers, corresponding to individual qubits, that make up `in_reg`. + soqs_to_join = [] + for qreg, soq in new_qreg_to_qvar.items(): + if len(in_reg_qubits) > 1 and qreg.qubits and qreg.qubits[0] in in_reg_qubits: + assert len(qreg.qubits) == 1, "Individual qubits should have been split by now." + soqs_to_join.append(soq) + else: + qreg_to_qvar[qreg] = soq + if soqs_to_join: + qreg_to_qvar[in_reg] = bb.join(np.array(soqs_to_join)) + + +def _gather_input_soqs( + bb: BloqBuilder, op_quregs: Dict[str, NDArray[_QReg]], qreg_to_qvar: Dict[_QReg, Soquet] +) -> Dict[str, NDArray[Soquet]]: + qvars_in: Dict[str, NDArray[Soquet]] = {} + for reg_name, quregs in op_quregs.items(): + flat_soqs: List[Soquet] = [] + for qureg in quregs.flatten(): + _ensure_in_reg_exists(bb, qureg, qreg_to_qvar) + flat_soqs.append(qreg_to_qvar[qureg]) + qvars_in[reg_name] = np.array(flat_soqs).reshape(quregs.shape) + return qvars_in + + +def cirq_optree_to_cbloq( + optree: cirq.OP_TREE, + *, + signature: Optional[Signature] = None, + in_quregs: Optional[Dict[str, 'CirqQuregT']] = None, + out_quregs: Optional[Dict[str, 'CirqQuregT']] = None, +) -> CompositeBloq: + """Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`. + + Each `cirq.Operation` will be wrapped into a `CirqGateAsBloq` wrapper. + The signature of the resultant CompositeBloq is `signature`, if provided. Otherwise, use + one thru-register named "qubits" of shape `(n_qubits,)`. + + For multi-dimensional registers and registers with bitsize>1, this function automatically + splits the input soquets and joins the output soquets to ensure compatibility with the + flat-list-of-qubits expected by Cirq. + + When specifying a signature, users must also specify the `in_quregs` & `out_quregs` arguments, + which are mappings of cirq qubits used in the OP-TREE corresponding to the `LEFT` & `RIGHT` + registers in `signature`. If `signature` has registers with entry + + - `Register('x', bitsize=2, shape=(3, 4), side=Side.THRU)` + - `Register('y', bitsize=1, shape=(10, 20), side=Side.LEFT)` + - `Register('z', bitsize=1, shape=(10, 20), side=Side.RIGHT)` + + then `in_quregs` should have one entry corresponding to registers `x` and `y` as follows: + + - key='x'; value=`np.array(cirq_qubits_used_for_x, shape=(3, 4, 2))` and + - key='y'; value=`np.array(cirq_qubits_used_for_y, shape=(10, 20, 1))`. + and `out_quregs` should have one entry corresponding to registers `x` and `z` as follows: + + - key='x'; value=`np.array(cirq_qubits_used_for_x, shape=(3, 4, 2))` and + - key='z'; value=`np.array(cirq_qubits_used_for_z, shape=(10, 20, 1))`. + + Any qubit in `optree` which is not part of `in_quregs` and `out_quregs` is considered to be + allocated & deallocated inside the CompositeBloq and does not show up in it's signature. + """ + circuit = cirq.Circuit(optree) + if signature is None: + if in_quregs is not None or out_quregs is not None: + raise ValueError("`in_quregs` / `out_quregs` requires specifying `signature`.") + all_qubits = sorted(circuit.all_qubits()) + signature = Signature([Register('qubits', 1, shape=(len(all_qubits),))]) + in_quregs = out_quregs = {'qubits': np.array(all_qubits).reshape(len(all_qubits), 1)} + elif in_quregs is None or out_quregs is None: + raise ValueError("`signature` requires specifying both `in_quregs` and `out_quregs`.") + + in_quregs = {k: np.apply_along_axis(_QReg, -1, v) for k, v in in_quregs.items()} + out_quregs = {k: np.apply_along_axis(_QReg, -1, v) for k, v in out_quregs.items()} + + bb, initial_soqs = BloqBuilder.from_signature(signature, add_registers_allowed=False) + + # 1. Compute qreg_to_qvar for input qubits in the LEFT signature. + qreg_to_qvar: Dict[_QReg, Soquet] = {} + for reg in signature.lefts(): + if reg.name not in in_quregs: + raise ValueError(f"Register {reg.name} from signature must be present in in_quregs.") + soqs = initial_soqs[reg.name] + if isinstance(soqs, Soquet): + soqs = np.array(soqs) + if in_quregs[reg.name].shape != soqs.shape: + raise ValueError( + f"Shape {in_quregs[reg.name].shape} of cirq register " + f"{reg.name} should be {soqs.shape}." + ) + qreg_to_qvar |= zip(in_quregs[reg.name].flatten(), soqs.flatten()) + + # 2. Add each operation to the composite Bloq. + for op in circuit.all_operations(): + if op.gate is None: + raise ValueError(f"Only gate operations are supported, not {op}.") + + bloq = CirqGateAsBloq(op.gate) + # 3.1 Find input / output registers. + all_op_quregs: Dict[str, NDArray[_QReg]] = { + k: np.apply_along_axis(_QReg, -1, v) + for k, v in cirq_ft.infra.split_qubits(bloq.cirq_registers, op.qubits).items() + } + in_op_quregs: Dict[str, NDArray[_QReg]] = { + reg.name: all_op_quregs[reg.name] for reg in bloq.signature.lefts() + } + # 3.2 Find input Soquets, by potentially allocating new Bloq registers corresponding to + # input Cirq `in_quregs` and updating the `qreg_to_qvar` mapping. + qvars_in = _gather_input_soqs(bb, in_op_quregs, qreg_to_qvar) + + # 3.3 Add Bloq to the `CompositeBloq` compute graph and get corresponding output Soquets. + qvars_out = bb.add_d(bloq, **qvars_in) + + # 3.4 Update `qreg_to_qvar` mapping using output soquets `qvars_out`. + for reg in bloq.signature: + # all_op_quregs should exist for both LEFT & RIGHT registers. + assert reg.name in all_op_quregs + quregs = all_op_quregs[reg.name] + if reg.side == Side.LEFT: + # This register got de-allocated, update the `qreg_to_qvar` mapping. + for q in quregs.flatten(): + _ = qreg_to_qvar.pop(q) + else: + assert quregs.shape == np.array(qvars_out[reg.name]).shape + qreg_to_qvar |= zip(quregs.flatten(), np.array(qvars_out[reg.name]).flatten()) + + # 4. Combine Soquets to match the right signature. + final_soqs_dict = _gather_input_soqs( + bb, {reg.name: out_quregs[reg.name] for reg in signature.rights()}, qreg_to_qvar + ) + final_soqs_set = set(soq for soqs in final_soqs_dict.values() for soq in soqs.flatten()) + # 5. Free all dangling Soquets which are not part of the final soquets set. + for qvar in qreg_to_qvar.values(): + if qvar not in final_soqs_set: + bb.free(qvar) + return bb.finalize(**final_soqs_dict) + + +def decompose_from_cirq_op(bloq: 'Bloq') -> 'CompositeBloq': + """Returns a CompositeBloq constructed using Cirq operations obtained via `bloq.as_cirq_op`. + + This method first checks whether `bloq.signature` is parameterized. If yes, it raises a + NotImplementedError. If not, it uses `cirq_optree_to_cbloq` to wrap the operations obtained + from `bloq.as_cirq_op` into a `CompositeBloq` which has the same signature as `bloq` and returns + the corresponding `CompositeBloq`. + """ + + if any( + cirq.is_parameterized(reg.bitsize) or cirq.is_parameterized(reg.side) + for reg in bloq.signature + ): + raise NotImplementedError(f"{bloq} does not support decomposition.") + + in_quregs = bloq.signature.get_cirq_quregs() + cirq_op, out_quregs = bloq.as_cirq_op(cirq.ops.SimpleQubitManager(), **in_quregs) + from qualtran.cirq_interop._bloq_to_cirq import BloqAsCirqGate + + if cirq_op is None or ( + isinstance(cirq_op, cirq.Operation) and isinstance(cirq_op.gate, BloqAsCirqGate) + ): + raise NotImplementedError(f"{bloq} does not support decomposition.") + + return cirq_optree_to_cbloq( + cirq_op, signature=bloq.signature, in_quregs=in_quregs, out_quregs=out_quregs + ) diff --git a/qualtran/cirq_interop/_cirq_to_bloq_test.py b/qualtran/cirq_interop/_cirq_to_bloq_test.py new file mode 100644 index 000000000..52c1eae58 --- /dev/null +++ b/qualtran/cirq_interop/_cirq_to_bloq_test.py @@ -0,0 +1,203 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Tuple + +import attr +import cirq +import cirq_ft +import numpy as np +import pytest +import sympy +from attrs import frozen + +import qualtran +from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature +from qualtran.bloqs.basic_gates import OneState +from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Split +from qualtran.cirq_interop import ( + cirq_optree_to_cbloq, + CirqGateAsBloq, + CirqQuregT, + decompose_from_cirq_op, +) + + +@frozen +class TestCNOT(Bloq): + @property + def signature(self) -> Signature: + return Signature.build(control=1, target=1) + + def decompose_bloq(self) -> 'CompositeBloq': + return decompose_from_cirq_op(self) + + def as_cirq_op( + self, qubit_manager: cirq.QubitManager, **cirq_quregs: 'CirqQuregT' + ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: + (control,) = cirq_quregs['control'] + (target,) = cirq_quregs['target'] + return cirq.CNOT(control, target), cirq_quregs + + +@frozen +class TestCNOTSymbolic(TestCNOT): + @property + def signature(self) -> Signature: + c, t = sympy.Symbol('c'), sympy.Symbol('t') + return Signature.build(control=c, target=t) + + +def test_cirq_gate_as_bloq_for_trivial_gates(): + x = CirqGateAsBloq(cirq.X) + rx = CirqGateAsBloq(cirq.Rx(rads=0.123 * np.pi)) + toffoli = CirqGateAsBloq(cirq.TOFFOLI) + + for b in [x, rx, toffoli]: + assert len(b.signature) == 1 + assert b.signature[0].side == Side.THRU + + assert x.signature[0].shape == (1,) + assert toffoli.signature[0].shape == (3,) + + assert str(x) == 'CirqGateAsBloq(gate=cirq.X)' + assert x.pretty_name() == 'cirq.X' + assert x.short_name() == 'cirq.X' + + assert rx.pretty_name() == 'cirq.Rx(0.123π)' + assert rx.short_name() == 'cirq.Rx' + + assert toffoli.pretty_name() == 'cirq.TOFFOLI' + assert toffoli.short_name() == 'cirq.TOFFOLI' + + +def test_cirq_gate_as_bloq_tensor_contract_for_and_gate(): + and_gate = cirq_ft.And() + bb = BloqBuilder() + ctrl = [bb.add(OneState()) for _ in range(2)] + ctrl, target = bb.add(CirqGateAsBloq(and_gate), ctrl=ctrl) + cbloq = bb.finalize(ctrl=ctrl, target=target) + state_vector = cbloq.tensor_contract() + assert np.isclose(state_vector[7], 1) + + with pytest.raises(NotImplementedError, match="supported only for unitary gates"): + _ = CirqGateAsBloq(cirq_ft.And(adjoint=True)).as_composite_bloq().tensor_contract() + + +def test_bloq_decompose_from_cirq_op(): + tb = TestCNOT() + assert len(tb.signature) == 2 + ctrl, trg = tb.signature + assert ctrl.bitsize == 1 + assert ctrl.side == Side.THRU + assert tb.pretty_name() == 'TestCNOT' + + cirq_quregs = tb.signature.get_cirq_quregs() + circuit, _ = tb.decompose_bloq().to_cirq_circuit(**cirq_quregs) + assert circuit == cirq.Circuit(cirq.CNOT(*cirq_quregs['control'], *cirq_quregs['target'])) + assert tb.t_complexity() == cirq_ft.TComplexity(clifford=1) + + with pytest.raises(NotImplementedError): + TestCNOTSymbolic().decompose_bloq() + + +def test_cirq_circuit_to_cbloq(): + qubits = cirq.LineQubit.range(6) + circuit = cirq.testing.random_circuit(qubits, n_moments=7, op_density=1.0, random_state=52) + cbloq = cirq_optree_to_cbloq(circuit) + + bloq_unitary = cbloq.tensor_contract() + cirq_unitary = circuit.unitary(qubits) + np.testing.assert_allclose(cirq_unitary, bloq_unitary, atol=1e-8) + + +def test_cbloq_to_cirq_circuit(): + qubits = cirq.LineQubit.range(6) + circuit = cirq.testing.random_circuit(qubits, n_moments=7, op_density=1.0, random_state=52) + cbloq = cirq_optree_to_cbloq(circuit) + + # important! we lose moment structure + circuit = cirq.Circuit(circuit.all_operations()) + + # Note: a 1d `shape` bloq register is actually two-dimensional in cirq-world + # because of the implicit `bitsize` dimension (which must be explicit in cirq-world). + # CirqGate has registers of bitsize=1 and shape=(n,); hence the list transpose below. + circuit2, _ = cbloq.to_cirq_circuit( + **{'qubits': [[q] for q in qubits]}, qubit_manager=cirq.ops.SimpleQubitManager() + ) + + assert circuit == circuit2 + + +def test_cirq_optree_to_cbloq(): + @attr.frozen + class CirqGateWithRegisters(cirq_ft.GateWithRegisters): + reg: cirq_ft.Register + + @property + def signature(self) -> cirq_ft.Signature: + return cirq_ft.Signature([self.reg]) + + reg1 = cirq_ft.Register('x', shape=(3, 4), bitsize=2) + reg2 = cirq_ft.Register('y', shape=12, bitsize=2) + anc_reg = cirq_ft.Register('anc', shape=4, bitsize=2) + qubits = cirq.LineQubit.range(24) + anc_qubits = cirq.NamedQubit.range(4, prefix='anc') + circuit = cirq.Circuit( + CirqGateWithRegisters(reg1).on(*qubits), + CirqGateWithRegisters(anc_reg).on(*anc_qubits, *qubits[:4]), + CirqGateWithRegisters(reg2).on(*qubits), + ) + # Test-1: When no signature is specified, the method uses a default signature. Ancilla qubits + # are also included in the signature itself, so no allocations / deallocations are needed. + cbloq = cirq_optree_to_cbloq(circuit) + assert cbloq.signature == qualtran.Signature( + [qualtran.Register(name='qubits', bitsize=1, shape=(28,))] + ) + bloq_instances = [binst for binst, _, _ in cbloq.iter_bloqnections()] + assert all(bloq_instances[i].bloq == Join(2) for i in range(14)) + assert bloq_instances[14].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg1)) + assert bloq_instances[14].bloq.signature == qualtran.Signature( + [qualtran.Register(name='x', bitsize=2, shape=(3, 4))] + ) + assert bloq_instances[15].bloq == CirqGateAsBloq(CirqGateWithRegisters(anc_reg)) + assert bloq_instances[15].bloq.signature == qualtran.Signature( + [qualtran.Register(name='anc', bitsize=2, shape=(4,))] + ) + assert bloq_instances[16].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg2)) + assert bloq_instances[16].bloq.signature == qualtran.Signature( + [qualtran.Register(name='y', bitsize=2, shape=(12,))] + ) + assert all(bloq_instances[-i].bloq == Split(2) for i in range(1, 15)) + # Test-2: If you provide an explicit signature, you must also provide a mapping of cirq qubits + # matching the signature. The additional ancilla allocations are automatically handled. + new_signature = qualtran.Signature( + [ + qualtran.Register('xx', bitsize=3, shape=(3, 2)), + qualtran.Register('yy', bitsize=1, shape=(2, 3)), + ] + ) + cirq_quregs = { + 'xx': np.asarray(qubits[:18]).reshape((3, 2, 3)), + 'yy': np.asarray(qubits[18:]).reshape((2, 3, 1)), + } + cbloq = cirq_optree_to_cbloq( + circuit, signature=new_signature, in_quregs=cirq_quregs, out_quregs=cirq_quregs + ) + assert cbloq.signature == new_signature + # Splits, joins, Alloc, Free are automatically inserted. + bloqs_list = [binst.bloq for binst in cbloq.bloq_instances] + assert bloqs_list.count(Split(3)) == 6 + assert bloqs_list.count(Join(3)) == 6 + assert bloqs_list.count(Allocate(2)) == 2 + assert bloqs_list.count(Free(2)) == 2 diff --git a/qualtran/cirq_interop/cirq_interop.ipynb b/qualtran/cirq_interop/cirq_interop.ipynb index fa68e993e..dff67c4aa 100644 --- a/qualtran/cirq_interop/cirq_interop.ipynb +++ b/qualtran/cirq_interop/cirq_interop.ipynb @@ -220,29 +220,6 @@ "assert bloq.t_complexity() == cirq_ft.t_complexity(prepare.gate)" ] }, - { - "cell_type": "markdown", - "id": "5065f197-f9f9-4fd2-a3a5-abfd82e3e387", - "metadata": {}, - "source": [ - "Another example as follows is to import the data loading oracle QROM from Cirq-FT into Qualtran by wrapping it into a `CirqGateAsBloq`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba083d2c-c0b2-4fec-ab61-2cae21676169", - "metadata": {}, - "outputs": [], - "source": [ - "cirq_qrom = cirq_ft.QROM.build([10, 20, 30], num_controls=1)\n", - "bloq = CirqGateAsBloq(cirq_qrom)\n", - "cbloq = bloq.decompose_bloq()\n", - "show_bloq(cbloq)\n", - "fig, ax = draw_musical_score(get_musical_score_data(cbloq))\n", - "fig.set_size_inches(16, 5)" - ] - }, { "cell_type": "markdown", "id": "03f03231", @@ -374,6 +351,8 @@ "metadata": {}, "outputs": [], "source": [ + "from qualtran.cirq_interop import BloqAsCirqGate, cirq_optree_to_cbloq\n", + "\n", "@attrs.frozen\n", "class Swap(Bloq):\n", " n: int\n", @@ -400,7 +379,8 @@ "outputs": [], "source": [ "swap = Swap(n=5)\n", - "show_bloq(swap)" + "show_bloq(swap)\n", + "show_bloq(swap.decompose_bloq())" ] }, { @@ -421,8 +401,7 @@ "circuit, _ = swap.as_composite_bloq().to_cirq_circuit(\n", " x=cirq.LineQubit.range(5), y=cirq.LineQubit.range(100,105))\n", "\n", - "op = next(circuit.all_operations())\n", - "op.gate" + "op = next(circuit.all_operations())" ] }, { @@ -528,6 +507,113 @@ "# Note the new precense of `junk` and `target` entries.\n", "out_quregs" ] + }, + { + "cell_type": "markdown", + "id": "896e0392-3f1f-4965-bb13-0a541581ec25", + "metadata": {}, + "source": [ + "## Test `Cirq-FT -> Bloqs -> Cirq-FT` roundtrip using `QROM`\n", + "\n", + "Another example is to import the data loading oracle QROM from Cirq-FT into Qualtran by wrapping it into a `CirqGateAsBloq`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8284a856-f0fa-4589-8dd5-627a5b74af1e", + "metadata": {}, + "outputs": [], + "source": [ + "# Ensure no information is lost in Cirq-FT -> Bloqs -> Cirq-FT conversion.\n", + "cirq_qrom = cirq_ft.QROM.build([10, 20, 30], num_controls=1)\n", + "quregs = cirq_ft.infra.get_named_qubits(cirq_qrom.signature)\n", + "circuit = cirq.Circuit(cirq.decompose_once(cirq_qrom.on_registers(**quregs)))\n", + "print(circuit)\n", + "qrom_gate_via_bloq = BloqAsCirqGate(CirqGateAsBloq(cirq_qrom))\n", + "circuit_roundrip = cirq.Circuit(cirq.decompose_once(qrom_gate_via_bloq.on_registers(**quregs)))\n", + "print(circuit_roundrip)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52765449-acc2-4673-b7f4-f2add2b83c66", + "metadata": {}, + "outputs": [], + "source": [ + "# The same decomposition can be obtained by directly decomposing the corresponding Bloq.\n", + "# 1. Decompose wrapped Bloq directly to obtain the Composite Bloq for QROM.\n", + "bloq = CirqGateAsBloq(cirq_qrom)\n", + "cbloq = bloq.decompose_bloq()\n", + "fig, ax = draw_musical_score(get_musical_score_data(cbloq))\n", + "fig.set_size_inches(16, 5)\n", + "# 2. Convert Cirq decomposed circuit to composite Bloq.\n", + "cbloq = cirq_optree_to_cbloq(circuit, signature=bloq.signature, in_quregs=quregs, out_quregs=quregs)\n", + "fig, ax = draw_musical_score(get_musical_score_data(cbloq))\n", + "fig.set_size_inches(16, 5)" + ] + }, + { + "cell_type": "markdown", + "id": "8ec19639-ae2d-4a66-91eb-8cd4d10068ef", + "metadata": {}, + "source": [ + "## Test `Bloqs -> Cirq-FT -> Bloqs` roundtrip using `ModExp`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f05d340-f3ca-45d2-bde0-cc4bf0bd0c5d", + "metadata": {}, + "outputs": [], + "source": [ + "from qualtran.bloqs.factoring.mod_exp import ModExp\n", + "from qualtran.drawing import show_bloq\n", + "N = 13*17\n", + "n = int(np.ceil(np.log2(N)))\n", + "g = 8\n", + "mod_exp = ModExp(base=g, mod=N, exp_bitsize=32, x_bitsize=32)\n", + "show_bloq(mod_exp)\n", + "cbloq = mod_exp.decompose_bloq()\n", + "fig, ax = draw_musical_score(get_musical_score_data(cbloq))\n", + "fig.set_size_inches(24, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f60ad59-b3fe-4745-b0d0-e2f0feb1fbaf", + "metadata": {}, + "outputs": [], + "source": [ + "in_quregs = {'exponent': np.array(cirq.LineQubit.range(32))}\n", + "\n", + "op, out_quregs = BloqAsCirqGate.bloq_on(mod_exp, cirq_quregs=in_quregs, qubit_manager=cirq.ops.SimpleQubitManager())\n", + "\n", + "# 1. Decompose using cirq.decompose_once(op) and then convert back into a CompositeBloq.\n", + "decomposed_circuit = cirq.Circuit(cirq.decompose_once(op))\n", + "cbloq = cirq_optree_to_cbloq(decomposed_circuit, signature=mod_exp.signature, in_quregs=in_quregs, out_quregs=out_quregs)\n", + "fig, ax = draw_musical_score(get_musical_score_data(cbloq))\n", + "fig.set_size_inches(24, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d51c58ae-89f6-4f29-b7d7-da0f437f5af3", + "metadata": {}, + "outputs": [], + "source": [ + "# 2. Ensure that Bloq -> BloqAsCirqGate -> CirqGateAsBloq.decompose_bloq() roundtrip works as expected.\n", + "# This makes sure no information is lost when converting from Bloqs -> Cirq-FT -> Bloqs.\n", + "bloq = CirqGateAsBloq(BloqAsCirqGate(mod_exp))\n", + "show_bloq(bloq)\n", + "cbloq = CirqGateAsBloq(op.gate).decompose_bloq()\n", + "fig, ax = draw_musical_score(get_musical_score_data(cbloq))\n", + "fig.set_size_inches(24, 15)" + ] } ], "metadata": { diff --git a/qualtran/serialization/bloq_test.py b/qualtran/serialization/bloq_test.py index 523e83615..c62c1422f 100644 --- a/qualtran/serialization/bloq_test.py +++ b/qualtran/serialization/bloq_test.py @@ -26,7 +26,7 @@ from qualtran.bloqs.controlled_bloq import ControlledBloq from qualtran.bloqs.factoring.mod_exp import ModExp from qualtran.cirq_interop import CirqGateAsBloq -from qualtran.cirq_interop._cirq_interop_test import TestCNOT as TestCNOTCirq +from qualtran.cirq_interop._cirq_to_bloq_test import TestCNOT as TestCNOTCirq from qualtran.protos import registers_pb2 from qualtran.serialization import bloq as bloq_serialization