diff --git a/qualtran/_infra/composite_bloq.py b/qualtran/_infra/composite_bloq.py index eca989b81..a9a1899ac 100644 --- a/qualtran/_infra/composite_bloq.py +++ b/qualtran/_infra/composite_bloq.py @@ -907,8 +907,7 @@ def add(self, bloq: Bloq, **in_soqs: SoquetInT) -> Union[None, SoquetT, Tuple[So unpacking. In this final case, the ordering is according to `bloq.signature` and irrespective of the order of `**in_soqs`. """ - binst = BloqInstance(bloq, i=self._new_binst_i()) - outs = tuple(soq for _, soq in self._add_binst(binst, in_soqs=in_soqs)) + outs = self.add_t(bloq, **in_soqs) if len(outs) == 0: return None if len(outs) == 1: diff --git a/qualtran/bloqs/arithmetic.py b/qualtran/bloqs/arithmetic.py index 84abbe762..9882c4f6f 100644 --- a/qualtran/bloqs/arithmetic.py +++ b/qualtran/bloqs/arithmetic.py @@ -12,17 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import cached_property -from typing import Dict, Tuple, Union - -import cirq from attrs import frozen -from cirq_ft import LessThanEqualGate as CirqLessThanEqual -from cirq_ft import LessThanGate as CirqLessThanGate from cirq_ft import TComplexity -from qualtran import Bloq, CompositeBloq, Register, Signature -from qualtran.cirq_interop import CirqQuregT, decompose_from_cirq_op +from qualtran import Bloq, Register, Signature @frozen @@ -219,74 +212,3 @@ def t_complexity(self): # See: https://github.com/quantumlib/cirq-qubitization/issues/219 # See: https://github.com/quantumlib/cirq-qubitization/issues/217 return TComplexity(t=8 * self.bitsize) - - -@frozen -class LessThanEqual(Bloq): - r"""Implements $U|x,y,z\rangle = |x, y, z \oplus {x \le y}\rangle$. - - Args: - x_bitsize: bitsize of x register. - y_bitsize: bitsize of y register. - - Registers: - - x, y: Registers to compare against eachother. - - z: Register to hold result of comparison. - """ - - x_bitsize: int - y_bitsize: int - - @cached_property - def signature(self) -> Signature: - return Signature( - [ - Register("x", bitsize=self.x_bitsize), - Register("y", bitsize=self.y_bitsize), - Register("z", bitsize=1), - ] - ) - - def decompose_bloq(self) -> 'CompositeBloq': - return decompose_from_cirq_op(self) - - def as_cirq_op( - self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' - ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: - less_than = CirqLessThanEqual(x_bitsize=self.x_bitsize, y_bitsize=self.y_bitsize) - x = cirq_quregs['x'] - y = cirq_quregs['y'] - z = cirq_quregs['z'] - return (less_than.on(*x, *y, *z), cirq_quregs) - - -@frozen -class LessThanConstant(Bloq): - r"""Implements $U_a|x\rangle = U_a|x\rangle|z\rangle = |x\rangle |z ^ (x < a)\rangle" - - Args: - bitsize: bitsize of x register. - val: integer to compare x against (a above.) - - Registers: - - x: Registers to compare against val. - - z: Register to hold result of comparison. - """ - - bitsize: int - val: int - - @cached_property - def signature(self) -> Signature: - return Signature.build(x=self.bitsize, z=1) - - def decompose_bloq(self) -> 'CompositeBloq': - return decompose_from_cirq_op(self) - - def as_cirq_op( - self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' - ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: - less_than = CirqLessThanGate(bitsize=self.bitsize, less_than_val=self.val) - x = cirq_quregs['x'] - z = cirq_quregs['z'] - return (less_than.on(*x, *z), cirq_quregs) diff --git a/qualtran/bloqs/arithmetic_test.py b/qualtran/bloqs/arithmetic_test.py index bb60208cf..6afdaead6 100644 --- a/qualtran/bloqs/arithmetic_test.py +++ b/qualtran/bloqs/arithmetic_test.py @@ -12,21 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cirq_ft.algos import LessThanEqualGate as CirqLessThanEquals -from cirq_ft.algos import LessThanGate as CirqLessThanConstant -from cirq_ft.infra import t_complexity - -import qualtran.testing as qlt_testing from qualtran import BloqBuilder, Register -from qualtran.bloqs.arithmetic import ( - Add, - GreaterThan, - LessThanConstant, - LessThanEqual, - Product, - Square, - SumOfSquares, -) +from qualtran.bloqs.arithmetic import Add, GreaterThan, Product, Square, SumOfSquares from qualtran.testing import execute_notebook @@ -109,19 +96,5 @@ def test_greater_than(): cbloq = bb.finalize(a=q0, b=q1, result=anc) -def test_less_than_equal(): - lte = LessThanEqual(5, 5) - qlt_testing.assert_valid_bloq_decomposition(lte) - cirq_lte = CirqLessThanEquals(5, 5) - assert lte.decompose_bloq().t_complexity() == t_complexity(cirq_lte) - - -def test_less_than_constant(): - ltc = LessThanConstant(5, 7) - qlt_testing.assert_valid_bloq_decomposition(ltc) - cirq_ltc = CirqLessThanConstant(5, 7) - assert ltc.decompose_bloq().t_complexity() == t_complexity(cirq_ltc) - - def test_notebook(): execute_notebook('arithmetic') diff --git a/qualtran/bloqs/qrom.py b/qualtran/bloqs/qrom.py deleted file mode 100644 index d42f24598..000000000 --- a/qualtran/bloqs/qrom.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import cached_property -from typing import Dict, Sequence, Tuple - -import cirq -from attrs import frozen -from cirq_ft import QROM as CirqQROM -from numpy.typing import NDArray - -from qualtran import Bloq, CompositeBloq, Register, Signature -from qualtran.cirq_interop import CirqQuregT, decompose_from_cirq_op - - -@frozen -class QROM(Bloq): - """Gate to load data[l] in the target register when the selection stores an index l. - - In the case of multi-dimensional data[p,q,r,...] we use multiple named - selection registers [selection0, selection1, selection2, ...] to index and - load the data. - - Args: - data: List of numpy ndarrays specifying the data to load. If the length - of this list is greater than one then we use the same selection indices - to load each dataset (for example, to load alt and keep data for - state preparation). Each data set is required to have the same - shape and to be of integer type. - selection_bitsizes: The number of bits used to represent each selection register - corresponding to the size of each dimension of the array. Should be - the same length as the shape of each of the datasets. - data_bitsizes: The number of bits used to represent the data - registers. This can be deduced from the maximum element of each of the - datasets. Should be of length len(data), i.e. the number of datasets. - num_controls: The number of controls registers. - """ - - data: Sequence[NDArray] - selection_bitsizes: Tuple[int, ...] - data_bitsizes: Tuple[int, ...] - num_controls: int = 0 - - @cached_property - def signature(self) -> Signature: - regs = [ - Register(f"selection{i}", bitsize=bs) for i, bs in enumerate(self.selection_bitsizes) - ] - regs += [Register(f"target{i}", bitsize=bs) for i, bs in enumerate(self.data_bitsizes)] - if self.num_controls > 0: - regs += [Register("control", bitsize=self.num_controls)] - return Signature(regs) - - def decompose_bloq(self) -> 'CompositeBloq': - return decompose_from_cirq_op(self) - - def as_cirq_op( - self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' - ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: - qrom = CirqQROM( - data=self.data, - selection_bitsizes=self.selection_bitsizes, - target_bitsizes=self.data_bitsizes, - num_controls=self.num_controls, - ) - return (qrom.on_registers(**cirq_quregs), cirq_quregs) - - def __hash__(self): - # This is not a great hash. No guarantees. - # See: https://github.com/quantumlib/Qualtran/issues/339 - return hash(self.signature) - - def __eq__(self, other) -> bool: - return self.signature == other.signature - - def __ne__(self, other) -> bool: - return self.signature != other.signature diff --git a/qualtran/bloqs/qrom_test.py b/qualtran/bloqs/qrom_test.py deleted file mode 100644 index a28658d89..000000000 --- a/qualtran/bloqs/qrom_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -from cirq_ft import QROM as CirqQROM -from cirq_ft.infra import t_complexity - -import qualtran.testing as qlt_testing -from qualtran.bloqs.qrom import QROM - - -def test_qrom_decomp(): - data = np.ones((10, 10)) - sel_bitsizes = tuple((s - 1).bit_length() for s in data.shape) - qrom = QROM(data=[data], data_bitsizes=(4,), selection_bitsizes=sel_bitsizes) - qlt_testing.assert_valid_bloq_decomposition(qrom) - - -def test_qrom_decomp_with_control(): - data = np.ones((10, 10)) - sel_bitsizes = tuple((s - 1).bit_length() for s in data.shape) - qrom = QROM(data=[data], data_bitsizes=(4,), selection_bitsizes=sel_bitsizes, num_controls=1) - qlt_testing.assert_valid_bloq_decomposition(qrom) - - -def test_tcomplexity(): - data = np.ones((10, 10)) - sel_bitsizes = tuple((s - 1).bit_length() for s in data.shape) - qrom = QROM([data], selection_bitsizes=sel_bitsizes, data_bitsizes=(4,)) - cbloq = qrom.decompose_bloq() - cqrom = CirqQROM(data=[data], selection_bitsizes=sel_bitsizes, target_bitsizes=(4,)) - assert cbloq.t_complexity() == t_complexity(cqrom) - - -def test_hashing(): - data = np.ones((10, 10)) - sel_bitsizes = tuple((s - 1).bit_length() for s in data.shape) - qrom = QROM([data], selection_bitsizes=sel_bitsizes, data_bitsizes=(4,)) - assert hash(qrom) == hash(qrom) - qrom_2 = QROM([data], selection_bitsizes=sel_bitsizes, data_bitsizes=(5,)) - assert qrom == qrom - assert qrom_2 != qrom diff --git a/qualtran/bloqs/swap_network_cirq_test.py b/qualtran/bloqs/swap_network_cirq_test.py index f0dd4b3b0..390f0534e 100644 --- a/qualtran/bloqs/swap_network_cirq_test.py +++ b/qualtran/bloqs/swap_network_cirq_test.py @@ -39,12 +39,11 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe # Allocate selection and target qubits. all_qubits = cirq.LineQubit.range(cirq.num_qubits(gate)) selection = all_qubits[:selection_bitsize] - targets = { - f'targets_{i}': all_qubits[st : st + target_bitsize] - for i, st in enumerate(range(selection_bitsize, len(all_qubits), target_bitsize)) - } + targets = np.asarray(all_qubits[selection_bitsize:]).reshape( + (n_target_registers, target_bitsize) + ) # Create a circuit. - circuit = cirq.Circuit(gate.on_registers(selection=selection, **targets)) + circuit = cirq.Circuit(gate.on_registers(selection=selection, targets=targets)) # Load data[i] in i'th target register; where each register is of size target_bitsize data = [random.randint(0, 2**target_bitsize - 1) for _ in range(n_target_registers)] diff --git a/qualtran/cirq_interop/_cirq_interop.py b/qualtran/cirq_interop/_cirq_interop.py index 0d4dba77a..a42e2f129 100644 --- a/qualtran/cirq_interop/_cirq_interop.py +++ b/qualtran/cirq_interop/_cirq_interop.py @@ -14,6 +14,7 @@ """Functionality for the `Bloq.as_cirq_op(...)` protocol""" +import itertools from functools import cached_property from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union @@ -23,8 +24,6 @@ import numpy as np import quimb.tensor as qtn from attrs import frozen -from cirq_ft import Register as LegacyRegister -from cirq_ft import Registers as LegacyRegisters from numpy.typing import NDArray from qualtran import ( @@ -43,11 +42,16 @@ SoquetT, ) from qualtran._infra.composite_bloq import _binst_to_cxns +from qualtran.bloqs.util_bloqs import Allocate, Free CirqQuregT = NDArray[cirq.Qid] CirqQuregInT = Union[NDArray[cirq.Qid], Sequence[cirq.Qid]] +def signature_from_cirq_registers(registers: Iterable[cirq_ft.Register]) -> 'Signature': + return Signature([Register(reg.name, bitsize=1, shape=reg.shape) for reg in registers]) + + @frozen class CirqGateAsBloq(Bloq): """A Bloq wrapper around a `cirq.Gate`. @@ -67,11 +71,24 @@ def short_name(self) -> str: @cached_property def signature(self) -> 'Signature': - return Signature([Register('qubits', 1, shape=(self.n_qubits,))]) + return signature_from_cirq_registers(self.cirq_registers) @cached_property - def n_qubits(self): - return cirq.num_qubits(self.gate) + def cirq_registers(self) -> cirq_ft.Registers: + if isinstance(self.gate, cirq_ft.GateWithRegisters): + return self.gate.registers + else: + return cirq_ft.Registers.build(qubits=cirq.num_qubits(self.gate)) + + def decompose_bloq(self) -> 'CompositeBloq': + quregs = self.signature.get_cirq_quregs() + qubit_manager = cirq.ops.SimpleQubitManager() + cirq_op, quregs = self.as_cirq_op(qubit_manager, **quregs) + context = cirq.DecompositionContext(qubit_manager=qubit_manager) + decomposed_optree = cirq.decompose_once(cirq_op, context=context, default=None) + if decomposed_optree is None: + raise NotImplementedError(f"{self} does not support decomposition.") + return cirq_optree_to_cbloq(decomposed_optree, signature=self.signature, cirq_quregs=quregs) def add_my_tensors( self, @@ -81,28 +98,68 @@ def add_my_tensors( incoming: Dict[str, 'SoquetT'], outgoing: Dict[str, 'SoquetT'], ): - unitary = cirq.unitary(self.gate).reshape((2,) * 2 * self.n_qubits) + unitary = cirq.unitary(self.gate).reshape((2,) * 2 * self.cirq_registers.total_bits()) + incoming_list = [ + *itertools.chain.from_iterable( + [np.array(incoming[reg.name]).flatten() for reg in self.signature.lefts()] + ) + ] + outgoing_list = [ + *itertools.chain.from_iterable( + [np.array(outgoing[reg.name]).flatten() for reg in self.signature.rights()] + ) + ] tn.add( qtn.Tensor( - data=unitary, - inds=outgoing['qubits'].tolist() + incoming['qubits'].tolist(), - tags=[self.short_name(), tag], + data=unitary, inds=outgoing_list + incoming_list, tags=[self.short_name(), tag] ) ) def as_cirq_op( - self, qubit_manager: 'cirq.QubitManager', qubits: 'CirqQuregT' + self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' ) -> Tuple['cirq.Operation', Dict[str, 'CirqQuregT']]: - assert qubits.shape == (self.n_qubits, 1) - return self.gate.on(*qubits[:, 0]), {'qubits': qubits} + merged_qubits = np.concatenate( + [cirq_quregs[reg.name].flatten() for reg in self.signature.lefts()] + ) + assert len(merged_qubits) == cirq.num_qubits(self.gate) + return self.gate.on(*merged_qubits), cirq_quregs def t_complexity(self) -> 'cirq_ft.TComplexity': return cirq_ft.t_complexity(self.gate) + def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': + from qualtran.drawing import directional_text_box + + wire_symbols = cirq.circuit_diagram_info(self.gate).wire_symbols + begin = 0 + symbol: str = soq.pretty() + for reg in self.signature: + finish = begin + np.product(reg.shape) + if reg == soq.reg: + symbol = np.array(wire_symbols[begin:finish]).reshape(reg.shape)[soq.idx] + begin = finish + return directional_text_box(text=symbol, side=soq.reg.side) + + +def _split_qvars_for_regs( + qvars: Sequence[Soquet], signature: Signature +) -> Dict[str, NDArray[Soquet]]: + """Split a flat list of soquets into a dictionary corresponding to `signature`.""" + qvars_regs = {} + base = 0 + for reg in signature: + assert reg.bitsize == 1 + qvars_regs[reg.name] = np.array(qvars[base : base + reg.total_bits()]).reshape(reg.shape) + base += reg.total_bits() + return qvars_regs + def cirq_optree_to_cbloq( - optree: cirq.OP_TREE, *, signature: Optional[Signature] = None + optree: cirq.OP_TREE, + *, + signature: Optional[Signature] = None, + cirq_quregs: Optional[Dict[str, 'NDArray[cirq.Qid]']] = None, ) -> CompositeBloq: """Convert a Cirq OP-TREE into a `CompositeBloq` with signature `signature`. @@ -111,67 +168,87 @@ def cirq_optree_to_cbloq( If `signature` is not None, the signature of the resultant CompositeBloq is `signature`. For multi-dimensional registers and registers with > 1 bitsize, this function automatically splits the input soquets into a flat list and joins the output soquets into the correct shape - to ensure compatibility with the flat API expected by Cirq. + to ensure compatibility with the flat API expected by Cirq. When specifying a signature, users + must also specify the `cirq_quregs` argument, which is a mapping of cirq qubits used in the + OP-TREE corresponding to the `signature`. If `signature` has registers with entry + - `Register('x', bitsize=2, shape=(3, 4))` and + - `Register('y', bitsize=1, shape=(10, 20))` + then `cirq_quregs` should have one entry corresponding to each register as follows: + - key='x'; value=`np.array(cirq_qubits_used_in_optree, shape=(3, 4, 2))` and + - key='y'; value=`np.array(cirq_qubits_used_in_optree, shape=(10, 20, 1))`. If `signature` is None, the resultant composite bloq will have one thru-register named "qubits" of shape `(n_qubits,)`. """ - # "qubits" means cirq qubits | "qvars" means bloq Soquets circuit = cirq.Circuit(optree) - all_qubits = sorted(circuit.all_qubits()) + # "qubits" means cirq qubits | "qvars" means bloq Soquets if signature is None: + assert cirq_quregs is None + all_qubits = sorted(circuit.all_qubits()) signature = Signature([Register('qubits', 1, shape=(len(all_qubits),))]) + cirq_quregs = {'qubits': np.array(all_qubits).reshape(len(all_qubits), 1)} + + assert signature is not None and cirq_quregs is not None + bb, initial_soqs = BloqBuilder.from_signature(signature, add_registers_allowed=False) # Magic to make sure signature of the CompositeBloq matches `Signature`. - qvars = {} + qubit_to_qvar = {} for reg in signature.lefts(): + assert reg.name in cirq_quregs soqs = initial_soqs[reg.name] + if isinstance(soqs, Soquet): + soqs = np.asarray(soqs)[np.newaxis, ...] if reg.bitsize > 1: - # Need to split all soquets here. - if isinstance(soqs, Soquet): - qvars[reg.name] = bb.split(soqs) - else: - qvars[reg.name] = np.concatenate([bb.split(soq) for soq in soqs.reshape(-1)]) - else: - if isinstance(soqs, Soquet): - qvars[reg.name] = [soqs] - else: - qvars[reg.name] = soqs.reshape(-1) - - qubit_to_qvar = dict(zip(all_qubits, np.concatenate([*qvars.values()]))) + soqs = np.array([bb.split(soq) for soq in soqs.flatten()]) + soqs = soqs.reshape(reg.shape + (reg.bitsize,)) + assert cirq_quregs[reg.name].shape == soqs.shape + qubit_to_qvar |= zip(cirq_quregs[reg.name].flatten(), soqs.flatten()) + allocated_qubits = set() for op in circuit.all_operations(): if op.gate is None: raise ValueError(f"Only gate operations are supported, not {op}.") bloq = CirqGateAsBloq(op.gate) - qvars_for_op = np.array([qubit_to_qvar[qubit] for qubit in op.qubits]) - qvars_for_op_out = bb.add(bloq, qubits=qvars_for_op) - qubit_to_qvar |= zip(op.qubits, qvars_for_op_out) + for q in op.qubits: + if q not in qubit_to_qvar: + qubit_to_qvar[q] = bb.add(Allocate(1)) + allocated_qubits.add(q) + + qvars_in = [qubit_to_qvar[qubit] for qubit in op.qubits] + qvars_out = bb.add_t(bloq, **_split_qvars_for_regs(qvars_in, bloq.signature)) + qubit_to_qvar |= zip( + op.qubits, itertools.chain.from_iterable([arr.flatten() for arr in qvars_out]) + ) - qvar_vals_out = np.array([qubit_to_qvar[qubit] for qubit in all_qubits]) + for q in allocated_qubits: + bb.add(Free(1), free=qubit_to_qvar[q]) + qvars = np.array([*qubit_to_qvar.values()]) final_soqs = {} idx = 0 for reg in signature.rights(): name = reg.name - soqs = qvar_vals_out[idx : idx + len(qvars[name])] - idx = idx + len(qvars[name]) + assert name in cirq_quregs + soqs = qvars[idx : idx + np.product(cirq_quregs[name].shape)] + idx = idx + np.product(cirq_quregs[name].shape) if reg.bitsize > 1: # Need to combine the soquets here. if len(soqs) == reg.bitsize: final_soqs[name] = bb.join(soqs) else: final_soqs[name] = np.array( - bb.join(subsoqs) for subsoqs in soqs[:: reg.bitsize] + [ + bb.join(soqs[st : st + reg.bitsize]) + for st in range(0, len(soqs), reg.bitsize) + ] ).reshape(reg.shape) else: - if len(soqs) == 1: + if len(soqs) == 1 and reg.shape == (): final_soqs[name] = soqs[0] else: final_soqs[name] = soqs.reshape(reg.shape) - return bb.finalize(**final_soqs) @@ -211,7 +288,9 @@ def _update_assign_from_cirq_quregs( arr = np.asarray(arr) full_shape = reg.shape + (reg.bitsize,) if arr.shape != full_shape: - raise ValueError(f"Incorrect shape {arr.shape} received for {binst}.{reg.name}") + raise ValueError( + f"Incorrect shape {arr.shape} received for {binst}.{reg.name}. Expected {full_shape}." + ) for idx in reg.all_idxs(): soq = Soquet(binst, reg, idx=idx) @@ -278,7 +357,7 @@ def decompose_from_cirq_op(bloq: 'Bloq') -> 'CompositeBloq': ): raise NotImplementedError(f"{bloq} does not support decomposition.") - return cirq_optree_to_cbloq(cirq_op, signature=bloq.signature) + return cirq_optree_to_cbloq(cirq_op, signature=bloq.signature, cirq_quregs=cirq_quregs) def _cbloq_to_cirq_circuit( @@ -351,14 +430,12 @@ def bloq(self) -> Bloq: return self._bloq @property - def registers(self) -> LegacyRegisters: + def registers(self) -> cirq_ft.Registers: """`cirq_ft.GateWithRegisters` registers.""" return self._legacy_regs @staticmethod - def _init_legacy_regs( - bloq: Bloq, - ) -> Tuple[LegacyRegisters, Mapping[str, Tuple[Register, Tuple[int, ...]]]]: + def _init_legacy_regs(bloq: Bloq) -> Tuple[cirq_ft.Registers, Mapping[str, Register]]: """Initialize legacy registers. We flatten multidimensional registers and annotate non-thru registers with @@ -369,23 +446,15 @@ def _init_legacy_regs( compat_name_map: A mapping from the compatability-shim string names of the legacy registers back to the original (register, idx) pair. """ - legacy_regs: List[LegacyRegister] = [] + legacy_regs: List[cirq_ft.Register] = [] side_suffixes = {Side.LEFT: '_l', Side.RIGHT: '_r', Side.THRU: ''} compat_name_map = {} for reg in bloq.signature: - if not reg.shape: - compat_name = f'{reg.name}{side_suffixes[reg.side]}' - compat_name_map[compat_name] = (reg, ()) - legacy_regs.append(LegacyRegister(name=compat_name, shape=reg.bitsize)) - continue - - for idx in reg.all_idxs(): - idx_str = '_'.join(str(i) for i in idx) - compat_name = f'{reg.name}{side_suffixes[reg.side]}_{idx_str}' - compat_name_map[compat_name] = (reg, idx) - legacy_regs.append(LegacyRegister(name=compat_name, shape=reg.bitsize)) - - return LegacyRegisters(legacy_regs), compat_name_map + compat_name = f'{reg.name}{side_suffixes[reg.side]}' + compat_name_map[compat_name] = reg + full_shape = reg.shape + (reg.bitsize,) + legacy_regs.append(cirq_ft.Register(name=compat_name, shape=full_shape)) + return cirq_ft.Registers(legacy_regs), compat_name_map @classmethod def bloq_on( @@ -405,27 +474,25 @@ def bloq_on( op: A cirq operation whose gate is the `BloqAsCirqGate`-wrapped version of `bloq`. cirq_quregs: The output cirq qubit registers. """ - flat_qubits: List[cirq.Qid] = [] + bloq_quregs: Dict[str, 'CirqQuregT'] = {} out_quregs: Dict[str, 'CirqQuregT'] = {} for reg in bloq.signature: if reg.side is Side.THRU: - for i, q in enumerate(cirq_quregs[reg.name].reshape(-1)): - flat_qubits.append(q) + bloq_quregs[reg.name] = cirq_quregs[reg.name] out_quregs[reg.name] = cirq_quregs[reg.name] elif reg.side is Side.LEFT: - for i, q in enumerate(cirq_quregs[reg.name].reshape(-1)): - flat_qubits.append(q) + bloq_quregs[f'{reg.name}_l'] = cirq_quregs[reg.name] qubit_manager.qfree(cirq_quregs[reg.name].reshape(-1)) del cirq_quregs[reg.name] elif reg.side is Side.RIGHT: new_qubits = qubit_manager.qalloc(reg.total_bits()) - flat_qubits.extend(new_qubits) - out_quregs[reg.name] = np.array(new_qubits).reshape(reg.shape + (reg.bitsize,)) - - return BloqAsCirqGate(bloq=bloq).on(*flat_qubits), out_quregs + full_shape = reg.shape + (reg.bitsize,) + out_quregs[reg.name] = np.array(new_qubits).reshape(full_shape) + bloq_quregs[f'{reg.name}_r'] = out_quregs[reg.name] + return BloqAsCirqGate(bloq=bloq).on_registers(**bloq_quregs), out_quregs def decompose_from_registers( - self, context: cirq.DecompositionContext, **qubit_regs: Sequence[cirq.Qid] + self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: """Implementation of the GatesWithRegisters decompose method. @@ -434,7 +501,7 @@ def decompose_from_registers( Args: context: `cirq.DecompositionContext` stores options for decomposing gates (eg: cirq.QubitManager). - **qubit_regs: Sequences of cirq qubits as expected for the legacy register shims + **quregs: Sequences of cirq qubits as expected for the legacy register shims of the bloq's registers. Returns: @@ -442,20 +509,10 @@ def decompose_from_registers( """ cbloq = self._bloq.decompose_bloq() - # Initialize shapely qubit registers to pass to bloqs infrastructure cirq_quregs: Dict[str, CirqQuregT] = {} - for reg in self._bloq.signature: - if reg.shape: - shape = reg.shape + (reg.bitsize,) - cirq_quregs[reg.name] = np.empty(shape, dtype=object) - - # Shapefy the provided cirq qubits - for compat_name, qubits in qubit_regs.items(): - reg, idx = self._compat_name_map[compat_name] - if idx == (): - cirq_quregs[reg.name] = np.asarray(qubits) - else: - cirq_quregs[reg.name][idx] = np.asarray(qubits) + for compat_name, qubits in quregs.items(): + reg = self._compat_name_map[compat_name] + cirq_quregs[reg.name] = np.asarray(qubits) circuit, _ = cbloq.to_cirq_circuit(qubit_manager=context.qubit_manager, **cirq_quregs) return circuit @@ -481,7 +538,8 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ symbs = reg_to_wires(reg) assert len(symbs) == reg.total_bits() wire_symbols.extend(symbs) - + if self._reg_to_wires is None: + wire_symbols[0] = self._bloq.pretty_name() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def __eq__(self, other): diff --git a/qualtran/cirq_interop/_cirq_interop_test.py b/qualtran/cirq_interop/_cirq_interop_test.py index 2365180e7..96681cd66 100644 --- a/qualtran/cirq_interop/_cirq_interop_test.py +++ b/qualtran/cirq_interop/_cirq_interop_test.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict, Tuple +import attr import cirq import cirq_ft import numpy as np @@ -20,10 +21,14 @@ import sympy from attrs import frozen +import qualtran from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature, Soquet, SoquetT from qualtran.bloqs.and_bloq import MultiAnd from qualtran.bloqs.basic_gates import XGate +from qualtran.bloqs.swap_network import SwapWithZero +from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Split from qualtran.cirq_interop import ( + BloqAsCirqGate, cirq_optree_to_cbloq, CirqGateAsBloq, CirqQuregT, @@ -83,6 +88,66 @@ def test_cbloq_to_cirq_circuit(): assert circuit == circuit2 +def test_cirq_optree_to_cbloq(): + @attr.frozen + class CirqGateWithRegisters(cirq_ft.GateWithRegisters): + reg: cirq_ft.Register + + @property + def registers(self) -> cirq_ft.Registers: + return cirq_ft.Registers([self.reg]) + + reg1 = cirq_ft.Register('x', shape=(3, 4, 2)) + reg2 = cirq_ft.Register('y', shape=(12, 2)) + anc_reg = cirq_ft.Register('anc', shape=(2, 3)) + qubits = cirq.LineQubit.range(24) + anc_qubits = cirq.NamedQubit.range(3, prefix='anc') + circuit = cirq.Circuit( + CirqGateWithRegisters(reg1).on(*qubits), + CirqGateWithRegisters(anc_reg).on(*anc_qubits, *qubits[:3]), + CirqGateWithRegisters(reg2).on(*qubits), + ) + # Test-1: When no signature is specified, the method uses a default signature. Ancilla qubits + # are also included in the signature itself, so no allocations / deallocations are needed. + cbloq = cirq_optree_to_cbloq(circuit) + assert cbloq.signature == qualtran.Signature( + [qualtran.Register(name='qubits', bitsize=1, shape=(27,))] + ) + bloq_instances = [binst for binst, _, _ in cbloq.iter_bloqnections()] + assert bloq_instances[0].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg1)) + assert bloq_instances[0].bloq.signature == qualtran.Signature( + [qualtran.Register(name='x', bitsize=1, shape=(3, 4, 2))] + ) + assert bloq_instances[1].bloq == CirqGateAsBloq(CirqGateWithRegisters(anc_reg)) + assert bloq_instances[1].bloq.signature == qualtran.Signature( + [qualtran.Register(name='anc', bitsize=1, shape=(2, 3))] + ) + assert bloq_instances[2].bloq == CirqGateAsBloq(CirqGateWithRegisters(reg2)) + assert bloq_instances[2].bloq.signature == qualtran.Signature( + [qualtran.Register(name='y', bitsize=1, shape=(12, 2))] + ) + # Test-2: If you provide an explicit signature, you must also provide a mapping of cirq qubits + # matching the signature. The additional ancilla allocations are automatically handled. + new_signature = qualtran.Signature( + [ + qualtran.Register('xx', bitsize=3, shape=(3, 2)), + qualtran.Register('yy', bitsize=1, shape=(2, 3)), + ] + ) + cirq_quregs = { + 'xx': np.asarray(qubits[:18]).reshape((3, 2, 3)), + 'yy': np.asarray(qubits[18:]).reshape((2, 3, 1)), + } + cbloq = cirq_optree_to_cbloq(circuit, signature=new_signature, cirq_quregs=cirq_quregs) + assert cbloq.signature == new_signature + # Splits, joins, Alloc, Free are automatically inserted. + bloqs_list = [binst.bloq for binst in cbloq.bloq_instances] + assert bloqs_list.count(Split(3)) == 6 + assert bloqs_list.count(Join(3)) == 6 + assert bloqs_list.count(Allocate(1)) == 3 + assert bloqs_list.count(Free(1)) == 3 + + @frozen class SwapTwoBitsTest(Bloq): @property @@ -178,7 +243,7 @@ def test_bloq_as_cirq_gate_left_register(): bb.free(q) cbloq = bb.finalize() circuit, _ = cbloq.to_cirq_circuit() - cirq.testing.assert_has_diagram(circuit, """_c(0): ───alloc───X───free───""") + cirq.testing.assert_has_diagram(circuit, """_c(0): ───Allocate───X───Free───""") @frozen @@ -223,5 +288,76 @@ def test_bloq_decompose_from_cirq_op(): TestCNOTSymbolic().decompose_bloq() +def test_bloq_as_cirq_gate_multi_dimensional_signature(): + bloq = SwapWithZero(2, 3, 4) + cirq_quregs = bloq.signature.get_cirq_quregs() + op = BloqAsCirqGate(bloq).on_registers(**cirq_quregs) + cirq.testing.assert_has_diagram( + cirq.Circuit(op), + ''' +selection0: ──────SwapWithZero─── + │ +selection1: ──────selection────── + │ +targets[0, 0]: ───targets──────── + │ +targets[0, 1]: ───targets──────── + │ +targets[0, 2]: ───targets──────── + │ +targets[1, 0]: ───targets──────── + │ +targets[1, 1]: ───targets──────── + │ +targets[1, 2]: ───targets──────── + │ +targets[2, 0]: ───targets──────── + │ +targets[2, 1]: ───targets──────── + │ +targets[2, 2]: ───targets──────── + │ +targets[3, 0]: ───targets──────── + │ +targets[3, 1]: ───targets──────── + │ +targets[3, 2]: ───targets──────── +''', + ) + cbloq = bloq.decompose_bloq() + cirq.testing.assert_has_diagram( + cbloq.to_cirq_circuit(**cirq_quregs)[0], + ''' +selection0: ──────────────────────────────@(approx)─── + │ +selection1: ──────@(approx)───@(approx)───┼─────────── + │ │ │ +targets[0, 0]: ───×(x)────────┼───────────×(x)──────── + │ │ │ +targets[0, 1]: ───×(x)────────┼───────────×(x)──────── + │ │ │ +targets[0, 2]: ───×(x)────────┼───────────×(x)──────── + │ │ │ +targets[1, 0]: ───×(y)────────┼───────────┼─────────── + │ │ │ +targets[1, 1]: ───×(y)────────┼───────────┼─────────── + │ │ │ +targets[1, 2]: ───×(y)────────┼───────────┼─────────── + │ │ +targets[2, 0]: ───────────────×(x)────────×(y)──────── + │ │ +targets[2, 1]: ───────────────×(x)────────×(y)──────── + │ │ +targets[2, 2]: ───────────────×(x)────────×(y)──────── + │ +targets[3, 0]: ───────────────×(y)──────────────────── + │ +targets[3, 1]: ───────────────×(y)──────────────────── + │ +targets[3, 2]: ───────────────×(y)──────────────────── +''', + ) + + def test_notebook(): execute_notebook('cirq_interop') diff --git a/qualtran/cirq_interop/cirq_interop.ipynb b/qualtran/cirq_interop/cirq_interop.ipynb index ca96d877e..5d56f3766 100644 --- a/qualtran/cirq_interop/cirq_interop.ipynb +++ b/qualtran/cirq_interop/cirq_interop.ipynb @@ -165,6 +165,61 @@ "cirq.Circuit(circuit.all_operations()) == circuit2" ] }, + { + "cell_type": "markdown", + "id": "115b1c2f-001c-4c03-aefa-150373800184", + "metadata": {}, + "source": [ + "# Importing Cirq-FT algorithms to Bloqs\n", + "`CirqGateAsBloq` also supports wrapping Cirq-FT's `GateWithRegisters` objects. As an example, we show how you can directly import `SELECT` and `PREPARE` primitives for the 2D Hubbard model from Cirq-FT into Bloqs. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee1f5334-da8c-40d8-9b6f-55b7f0183d30", + "metadata": {}, + "outputs": [], + "source": [ + "import cirq_ft\n", + "from cirq_ft.algos.hubbard_model import SelectHubbard, PrepareHubbard\n", + "import cirq_ft.infra.testing as cq_testing\n", + "x_dim, y_dim, t = 2, 2, 5\n", + "mu = 4 * t\n", + "# SELECT and PREPARE for 2D Hubbard Model\n", + "prepare = cq_testing.GateHelper(PrepareHubbard(x_dim=x_dim, y_dim=x_dim, t=t, mu=mu))\n", + "select = cq_testing.GateHelper(SelectHubbard(x_dim=x_dim, y_dim=y_dim, control_val=1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a591511-54dd-4040-8328-ff01eb62782d", + "metadata": {}, + "outputs": [], + "source": [ + "from qualtran.drawing.musical_score import get_musical_score_data, draw_musical_score\n", + "print(cirq.Circuit(cirq.decompose_once(select.operation)))\n", + "bloq = CirqGateAsBloq(select.gate)\n", + "fig, ax = draw_musical_score(get_musical_score_data(bloq.decompose_bloq()))\n", + "fig.set_size_inches(30, 12)\n", + "assert bloq.t_complexity() == cirq_ft.t_complexity(select.gate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6170e55-94a9-43f2-8503-e9a8e2d0f193", + "metadata": {}, + "outputs": [], + "source": [ + "print(cirq.Circuit(cirq.decompose_once(prepare.operation)))\n", + "bloq = CirqGateAsBloq(prepare.gate)\n", + "fig, ax = draw_musical_score(get_musical_score_data(bloq.decompose_bloq()))\n", + "fig.set_size_inches(30, 12)\n", + "assert bloq.t_complexity() == cirq_ft.t_complexity(prepare.gate)" + ] + }, { "cell_type": "markdown", "id": "03f03231", @@ -370,9 +425,7 @@ "cell_type": "code", "execution_count": null, "id": "3fd9ebd7", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "import cirq_ft.infra.testing as cq_testing\n", @@ -470,7 +523,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/qualtran/drawing/graphviz.py b/qualtran/drawing/graphviz.py index ac578626b..442b0a3f1 100644 --- a/qualtran/drawing/graphviz.py +++ b/qualtran/drawing/graphviz.py @@ -378,13 +378,9 @@ def get_binst_header_text(self, binst: BloqInstance): def soq_label(self, soq: Soquet): from qualtran.bloqs.util_bloqs import Join, Split - from qualtran.cirq_interop import CirqGateAsBloq if isinstance(soq.binst, BloqInstance) and isinstance(soq.binst.bloq, (Split, Join)): return '' - if isinstance(soq.binst, BloqInstance) and isinstance(soq.binst.bloq, CirqGateAsBloq): - (ii,) = soq.idx - return f'q{ii}' return soq.pretty() def get_default_text(self, reg: Register) -> str: