From cc232f60e5a1c7c484ec96da76cd2e240f028ce4 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 25 Aug 2023 15:00:12 -0700 Subject: [PATCH] Support cirq registers in `CirqGateAsBloq` conversion (#355) * Support cirq registers in CirqGateAsBloq conversion * Fix failing test * Address nits, add tests and fix a couple of bugs * Fix typo in test and add comment --- qualtran/cirq_interop/_cirq_interop.py | 165 ++++++++++++++------ qualtran/cirq_interop/_cirq_interop_test.py | 63 ++++++++ qualtran/cirq_interop/cirq_interop.ipynb | 61 +++++++- qualtran/drawing/graphviz.py | 4 - 4 files changed, 241 insertions(+), 52 deletions(-) diff --git a/qualtran/cirq_interop/_cirq_interop.py b/qualtran/cirq_interop/_cirq_interop.py index 0036aebc1..a42e2f129 100644 --- a/qualtran/cirq_interop/_cirq_interop.py +++ b/qualtran/cirq_interop/_cirq_interop.py @@ -14,6 +14,7 @@ """Functionality for the `Bloq.as_cirq_op(...)` protocol""" +import itertools from functools import cached_property from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union @@ -23,8 +24,6 @@ import numpy as np import quimb.tensor as qtn from attrs import frozen -from cirq_ft import Register as LegacyRegister -from cirq_ft import Registers as LegacyRegisters from numpy.typing import NDArray from qualtran import ( @@ -43,11 +42,16 @@ SoquetT, ) from qualtran._infra.composite_bloq import _binst_to_cxns +from qualtran.bloqs.util_bloqs import Allocate, Free CirqQuregT = NDArray[cirq.Qid] CirqQuregInT = Union[NDArray[cirq.Qid], Sequence[cirq.Qid]] +def signature_from_cirq_registers(registers: Iterable[cirq_ft.Register]) -> 'Signature': + return Signature([Register(reg.name, bitsize=1, shape=reg.shape) for reg in registers]) + + @frozen class CirqGateAsBloq(Bloq): """A Bloq wrapper around a `cirq.Gate`. @@ -67,11 +71,24 @@ def short_name(self) -> str: @cached_property def signature(self) -> 'Signature': - return Signature([Register('qubits', 1, shape=(self.n_qubits,))]) + return signature_from_cirq_registers(self.cirq_registers) @cached_property - def n_qubits(self): - return cirq.num_qubits(self.gate) + def cirq_registers(self) -> cirq_ft.Registers: + if isinstance(self.gate, cirq_ft.GateWithRegisters): + return self.gate.registers + else: + return cirq_ft.Registers.build(qubits=cirq.num_qubits(self.gate)) + + def decompose_bloq(self) -> 'CompositeBloq': + quregs = self.signature.get_cirq_quregs() + qubit_manager = cirq.ops.SimpleQubitManager() + cirq_op, quregs = self.as_cirq_op(qubit_manager, **quregs) + context = cirq.DecompositionContext(qubit_manager=qubit_manager) + decomposed_optree = cirq.decompose_once(cirq_op, context=context, default=None) + if decomposed_optree is None: + raise NotImplementedError(f"{self} does not support decomposition.") + return cirq_optree_to_cbloq(decomposed_optree, signature=self.signature, cirq_quregs=quregs) def add_my_tensors( self, @@ -81,28 +98,68 @@ def add_my_tensors( incoming: Dict[str, 'SoquetT'], outgoing: Dict[str, 'SoquetT'], ): - unitary = cirq.unitary(self.gate).reshape((2,) * 2 * self.n_qubits) + unitary = cirq.unitary(self.gate).reshape((2,) * 2 * self.cirq_registers.total_bits()) + incoming_list = [ + *itertools.chain.from_iterable( + [np.array(incoming[reg.name]).flatten() for reg in self.signature.lefts()] + ) + ] + outgoing_list = [ + *itertools.chain.from_iterable( + [np.array(outgoing[reg.name]).flatten() for reg in self.signature.rights()] + ) + ] tn.add( qtn.Tensor( - data=unitary, - inds=outgoing['qubits'].tolist() + incoming['qubits'].tolist(), - tags=[self.short_name(), tag], + data=unitary, inds=outgoing_list + incoming_list, tags=[self.short_name(), tag] ) ) def as_cirq_op( - self, qubit_manager: 'cirq.QubitManager', qubits: 'CirqQuregT' + self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: - assert qubits.shape == (self.n_qubits, 1) - return self.gate.on(*qubits[:, 0]), {'qubits': qubits} + merged_qubits = np.concatenate( + [cirq_quregs[reg.name].flatten() for reg in self.signature.lefts()] + ) + assert len(merged_qubits) == cirq.num_qubits(self.gate) + return self.gate.on(*merged_qubits), cirq_quregs def t_complexity(self) -> 'cirq_ft.TComplexity': return cirq_ft.t_complexity(self.gate) + def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': + from qualtran.drawing import directional_text_box + + wire_symbols = cirq.circuit_diagram_info(self.gate).wire_symbols + begin = 0 + symbol: str = soq.pretty() + for reg in self.signature: + finish = begin + np.product(reg.shape) + if reg == soq.reg: + symbol = np.array(wire_symbols[begin:finish]).reshape(reg.shape)[soq.idx] + begin = finish + return directional_text_box(text=symbol, side=soq.reg.side) + + +def _split_qvars_for_regs( + qvars: Sequence[Soquet], signature: Signature +) -> Dict[str, NDArray[Soquet]]: + """Split a flat list of soquets into a dictionary corresponding to `signature`.""" + qvars_regs = {} + base = 0 + for reg in signature: + assert reg.bitsize == 1 + qvars_regs[reg.name] = np.array(qvars[base : base + reg.total_bits()]).reshape(reg.shape) + base += reg.total_bits() + return qvars_regs + def cirq_optree_to_cbloq( - optree: cirq.OP_TREE, *, signature: Optional[Signature] = None + optree: cirq.OP_TREE, + *, + signature: Optional[Signature] = None, + cirq_quregs: Optional[Dict[str, 'NDArray[cirq.Qid]']] = None, ) -> CompositeBloq: """Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`. @@ -111,67 +168,87 @@ def cirq_optree_to_cbloq( If `signature` is not None, the signature of the resultant CompositeBloq is `signature`. For multi-dimensional registers and registers with > 1 bitsize, this function automatically splits the input soquets into a flat list and joins the output soquets into the correct shape - to ensure compatibility with the flat API expected by Cirq. + to ensure compatibility with the flat API expected by Cirq. When specifying a signature, users + must also specify the `cirq_quregs` argument, which is a mapping of cirq qubits used in the + OP-TREE corresponding to the `signature`. If `signature` has registers with entry + - `Register('x', bitsize=2, shape=(3, 4))` and + - `Register('y', bitsize=1, shape=(10, 20))` + then `cirq_quregs` should have one entry corresponding to each register as follows: + - key='x'; value=`np.array(cirq_qubits_used_in_optree, shape=(3, 4, 2))` and + - key='y'; value=`np.array(cirq_qubits_used_in_optree, shape=(10, 20, 1))`. If `signature` is None, the resultant composite bloq will have one thru-register named "qubits" of shape `(n_qubits,)`. """ - # "qubits" means cirq qubits | "qvars" means bloq Soquets circuit = cirq.Circuit(optree) - all_qubits = sorted(circuit.all_qubits()) + # "qubits" means cirq qubits | "qvars" means bloq Soquets if signature is None: + assert cirq_quregs is None + all_qubits = sorted(circuit.all_qubits()) signature = Signature([Register('qubits', 1, shape=(len(all_qubits),))]) + cirq_quregs = {'qubits': np.array(all_qubits).reshape(len(all_qubits), 1)} + + assert signature is not None and cirq_quregs is not None + bb, initial_soqs = BloqBuilder.from_signature(signature, add_registers_allowed=False) # Magic to make sure signature of the CompositeBloq matches `Signature`. - qvars = {} + qubit_to_qvar = {} for reg in signature.lefts(): + assert reg.name in cirq_quregs soqs = initial_soqs[reg.name] + if isinstance(soqs, Soquet): + soqs = np.asarray(soqs)[np.newaxis, ...] if reg.bitsize > 1: - # Need to split all soquets here. - if isinstance(soqs, Soquet): - qvars[reg.name] = bb.split(soqs) - else: - qvars[reg.name] = np.concatenate([bb.split(soq) for soq in soqs.reshape(-1)]) - else: - if isinstance(soqs, Soquet): - qvars[reg.name] = [soqs] - else: - qvars[reg.name] = soqs.reshape(-1) - - qubit_to_qvar = dict(zip(all_qubits, np.concatenate([*qvars.values()]))) + soqs = np.array([bb.split(soq) for soq in soqs.flatten()]) + soqs = soqs.reshape(reg.shape + (reg.bitsize,)) + assert cirq_quregs[reg.name].shape == soqs.shape + qubit_to_qvar |= zip(cirq_quregs[reg.name].flatten(), soqs.flatten()) + allocated_qubits = set() for op in circuit.all_operations(): if op.gate is None: raise ValueError(f"Only gate operations are supported, not {op}.") bloq = CirqGateAsBloq(op.gate) - qvars_for_op = np.array([qubit_to_qvar[qubit] for qubit in op.qubits]) - qvars_for_op_out = bb.add(bloq, qubits=qvars_for_op) - qubit_to_qvar |= zip(op.qubits, qvars_for_op_out) + for q in op.qubits: + if q not in qubit_to_qvar: + qubit_to_qvar[q] = bb.add(Allocate(1)) + allocated_qubits.add(q) + + qvars_in = [qubit_to_qvar[qubit] for qubit in op.qubits] + qvars_out = bb.add_t(bloq, **_split_qvars_for_regs(qvars_in, bloq.signature)) + qubit_to_qvar |= zip( + op.qubits, itertools.chain.from_iterable([arr.flatten() for arr in qvars_out]) + ) - qvar_vals_out = np.array([qubit_to_qvar[qubit] for qubit in all_qubits]) + for q in allocated_qubits: + bb.add(Free(1), free=qubit_to_qvar[q]) + qvars = np.array([*qubit_to_qvar.values()]) final_soqs = {} idx = 0 for reg in signature.rights(): name = reg.name - soqs = qvar_vals_out[idx : idx + len(qvars[name])] - idx = idx + len(qvars[name]) + assert name in cirq_quregs + soqs = qvars[idx : idx + np.product(cirq_quregs[name].shape)] + idx = idx + np.product(cirq_quregs[name].shape) if reg.bitsize > 1: # Need to combine the soquets here. if len(soqs) == reg.bitsize: final_soqs[name] = bb.join(soqs) else: final_soqs[name] = np.array( - bb.join(subsoqs) for subsoqs in soqs[:: reg.bitsize] + [ + bb.join(soqs[st : st + reg.bitsize]) + for st in range(0, len(soqs), reg.bitsize) + ] ).reshape(reg.shape) else: - if len(soqs) == 1: + if len(soqs) == 1 and reg.shape == (): final_soqs[name] = soqs[0] else: final_soqs[name] = soqs.reshape(reg.shape) - return bb.finalize(**final_soqs) @@ -280,7 +357,7 @@ def decompose_from_cirq_op(bloq: 'Bloq') -> 'CompositeBloq': ): raise NotImplementedError(f"{bloq} does not support decomposition.") - return cirq_optree_to_cbloq(cirq_op, signature=bloq.signature) + return cirq_optree_to_cbloq(cirq_op, signature=bloq.signature, cirq_quregs=cirq_quregs) def _cbloq_to_cirq_circuit( @@ -353,12 +430,12 @@ def bloq(self) -> Bloq: return self._bloq @property - def registers(self) -> LegacyRegisters: + def registers(self) -> cirq_ft.Registers: """`cirq_ft.GateWithRegisters` registers.""" return self._legacy_regs @staticmethod - def _init_legacy_regs(bloq: Bloq) -> Tuple[LegacyRegisters, Mapping[str, Register]]: + def _init_legacy_regs(bloq: Bloq) -> Tuple[cirq_ft.Registers, Mapping[str, Register]]: """Initialize legacy registers. We flatten multidimensional registers and annotate non-thru registers with @@ -369,15 +446,15 @@ def _init_legacy_regs(bloq: Bloq) -> Tuple[LegacyRegisters, Mapping[str, Registe compat_name_map: A mapping from the compatability-shim string names of the legacy registers back to the original (register, idx) pair. """ - legacy_regs: List[LegacyRegister] = [] + legacy_regs: List[cirq_ft.Register] = [] side_suffixes = {Side.LEFT: '_l', Side.RIGHT: '_r', Side.THRU: ''} compat_name_map = {} for reg in bloq.signature: compat_name = f'{reg.name}{side_suffixes[reg.side]}' compat_name_map[compat_name] = reg full_shape = reg.shape + (reg.bitsize,) - legacy_regs.append(LegacyRegister(name=compat_name, shape=full_shape)) - return LegacyRegisters(legacy_regs), compat_name_map + legacy_regs.append(cirq_ft.Register(name=compat_name, shape=full_shape)) + return cirq_ft.Registers(legacy_regs), compat_name_map @classmethod def bloq_on( diff --git a/qualtran/cirq_interop/_cirq_interop_test.py b/qualtran/cirq_interop/_cirq_interop_test.py index 45c1cf8c5..96681cd66 100644 --- a/qualtran/cirq_interop/_cirq_interop_test.py +++ b/qualtran/cirq_interop/_cirq_interop_test.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict, Tuple +import attr import cirq import cirq_ft import numpy as np @@ -20,10 +21,12 @@ import sympy from attrs import frozen +import qualtran from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature, Soquet, SoquetT from qualtran.bloqs.and_bloq import MultiAnd from qualtran.bloqs.basic_gates import XGate from qualtran.bloqs.swap_network import SwapWithZero +from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Split from qualtran.cirq_interop import ( BloqAsCirqGate, cirq_optree_to_cbloq, @@ -85,6 +88,66 @@ def test_cbloq_to_cirq_circuit(): assert circuit == circuit2 +def test_cirq_optree_to_cbloq(): + @attr.frozen + class CirqGateWithRegisters(cirq_ft.GateWithRegisters): + reg: cirq_ft.Register + + @property + def registers(self) -> cirq_ft.Registers: + return cirq_ft.Registers([self.reg]) + + reg1 = cirq_ft.Register('x', shape=(3, 4, 2)) + reg2 = cirq_ft.Register('y', shape=(12, 2)) + anc_reg = cirq_ft.Register('anc', shape=(2, 3)) + qubits = cirq.LineQubit.range(24) + anc_qubits = cirq.NamedQubit.range(3, prefix='anc') + circuit = cirq.Circuit( + CirqGateWithRegisters(reg1).on(*qubits), + CirqGateWithRegisters(anc_reg).on(*anc_qubits, *qubits[:3]), + CirqGateWithRegisters(reg2).on(*qubits), + ) + # Test-1: When no signature is specified, the method uses a default signature. Ancilla qubits + # are also included in the signature itself, so no allocations / deallocations are needed. + cbloq = cirq_optree_to_cbloq(circuit) + assert cbloq.signature == qualtran.Signature( + [qualtran.Register(name='qubits', bitsize=1, shape=(27,))] + ) + bloq_instances = [binst for binst, _, _ in cbloq.iter_bloqnections()] + assert bloq_instances[0].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg1)) + assert bloq_instances[0].bloq.signature == qualtran.Signature( + [qualtran.Register(name='x', bitsize=1, shape=(3, 4, 2))] + ) + assert bloq_instances[1].bloq == CirqGateAsBloq(CirqGateWithRegisters(anc_reg)) + assert bloq_instances[1].bloq.signature == qualtran.Signature( + [qualtran.Register(name='anc', bitsize=1, shape=(2, 3))] + ) + assert bloq_instances[2].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg2)) + assert bloq_instances[2].bloq.signature == qualtran.Signature( + [qualtran.Register(name='y', bitsize=1, shape=(12, 2))] + ) + # Test-2: If you provide an explicit signature, you must also provide a mapping of cirq qubits + # matching the signature. The additional ancilla allocations are automatically handled. + new_signature = qualtran.Signature( + [ + qualtran.Register('xx', bitsize=3, shape=(3, 2)), + qualtran.Register('yy', bitsize=1, shape=(2, 3)), + ] + ) + cirq_quregs = { + 'xx': np.asarray(qubits[:18]).reshape((3, 2, 3)), + 'yy': np.asarray(qubits[18:]).reshape((2, 3, 1)), + } + cbloq = cirq_optree_to_cbloq(circuit, signature=new_signature, cirq_quregs=cirq_quregs) + assert cbloq.signature == new_signature + # Splits, joins, Alloc, Free are automatically inserted. + bloqs_list = [binst.bloq for binst in cbloq.bloq_instances] + assert bloqs_list.count(Split(3)) == 6 + assert bloqs_list.count(Join(3)) == 6 + assert bloqs_list.count(Allocate(1)) == 3 + assert bloqs_list.count(Free(1)) == 3 + + @frozen class SwapTwoBitsTest(Bloq): @property diff --git a/qualtran/cirq_interop/cirq_interop.ipynb b/qualtran/cirq_interop/cirq_interop.ipynb index ca96d877e..5d56f3766 100644 --- a/qualtran/cirq_interop/cirq_interop.ipynb +++ b/qualtran/cirq_interop/cirq_interop.ipynb @@ -165,6 +165,61 @@ "cirq.Circuit(circuit.all_operations()) == circuit2" ] }, + { + "cell_type": "markdown", + "id": "115b1c2f-001c-4c03-aefa-150373800184", + "metadata": {}, + "source": [ + "# Importing Cirq-FT algorithms to Bloqs\n", + "`CirqGateAsBloq` also supports wrapping Cirq-FT's `GateWithRegisters` objects. As an example, we show how you can directly import `SELECT` and `PREPARE` primitives for the 2D Hubbard model from Cirq-FT into Bloqs. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee1f5334-da8c-40d8-9b6f-55b7f0183d30", + "metadata": {}, + "outputs": [], + "source": [ + "import cirq_ft\n", + "from cirq_ft.algos.hubbard_model import SelectHubbard, PrepareHubbard\n", + "import cirq_ft.infra.testing as cq_testing\n", + "x_dim, y_dim, t = 2, 2, 5\n", + "mu = 4 * t\n", + "# SELECT and PREPARE for 2D Hubbard Model\n", + "prepare = cq_testing.GateHelper(PrepareHubbard(x_dim=x_dim, y_dim=x_dim, t=t, mu=mu))\n", + "select = cq_testing.GateHelper(SelectHubbard(x_dim=x_dim, y_dim=y_dim, control_val=1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a591511-54dd-4040-8328-ff01eb62782d", + "metadata": {}, + "outputs": [], + "source": [ + "from qualtran.drawing.musical_score import get_musical_score_data, draw_musical_score\n", + "print(cirq.Circuit(cirq.decompose_once(select.operation)))\n", + "bloq = CirqGateAsBloq(select.gate)\n", + "fig, ax = draw_musical_score(get_musical_score_data(bloq.decompose_bloq()))\n", + "fig.set_size_inches(30, 12)\n", + "assert bloq.t_complexity() == cirq_ft.t_complexity(select.gate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6170e55-94a9-43f2-8503-e9a8e2d0f193", + "metadata": {}, + "outputs": [], + "source": [ + "print(cirq.Circuit(cirq.decompose_once(prepare.operation)))\n", + "bloq = CirqGateAsBloq(prepare.gate)\n", + "fig, ax = draw_musical_score(get_musical_score_data(bloq.decompose_bloq()))\n", + "fig.set_size_inches(30, 12)\n", + "assert bloq.t_complexity() == cirq_ft.t_complexity(prepare.gate)" + ] + }, { "cell_type": "markdown", "id": "03f03231", @@ -370,9 +425,7 @@ "cell_type": "code", "execution_count": null, "id": "3fd9ebd7", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "import cirq_ft.infra.testing as cq_testing\n", @@ -470,7 +523,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/qualtran/drawing/graphviz.py b/qualtran/drawing/graphviz.py index ac578626b..442b0a3f1 100644 --- a/qualtran/drawing/graphviz.py +++ b/qualtran/drawing/graphviz.py @@ -378,13 +378,9 @@ def get_binst_header_text(self, binst: BloqInstance): def soq_label(self, soq: Soquet): from qualtran.bloqs.util_bloqs import Join, Split - from qualtran.cirq_interop import CirqGateAsBloq if isinstance(soq.binst, BloqInstance) and isinstance(soq.binst.bloq, (Split, Join)): return '' - if isinstance(soq.binst, BloqInstance) and isinstance(soq.binst.bloq, CirqGateAsBloq): - (ii,) = soq.idx - return f'q{ii}' return soq.pretty() def get_default_text(self, reg: Register) -> str: