From 72248b01203f1660e468b40f82531f040a138efa Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 17 Aug 2023 13:48:09 -0700 Subject: [PATCH] Use Cirq-FT's multi-dimensional registers directly in BloqAsCirqGate (#353) * Use Cirq-FT's multi-dimensional registers directly in BloqAsCirqGate * Fix failing tests --------- Co-authored-by: Fionn Malone --- qualtran/_infra/composite_bloq.py | 3 +- qualtran/bloqs/swap_network_cirq_test.py | 9 ++- qualtran/cirq_interop/_cirq_interop.py | 63 ++++++----------- qualtran/cirq_interop/_cirq_interop_test.py | 75 ++++++++++++++++++++- 4 files changed, 101 insertions(+), 49 deletions(-) diff --git a/qualtran/_infra/composite_bloq.py b/qualtran/_infra/composite_bloq.py index eca989b81..a9a1899ac 100644 --- a/qualtran/_infra/composite_bloq.py +++ b/qualtran/_infra/composite_bloq.py @@ -907,8 +907,7 @@ def add(self, bloq: Bloq, **in_soqs: SoquetInT) -> Union[None, SoquetT, Tuple[So unpacking. In this final case, the ordering is according to `bloq.signature` and irrespective of the order of `**in_soqs`. """ - binst = BloqInstance(bloq, i=self._new_binst_i()) - outs = tuple(soq for _, soq in self._add_binst(binst, in_soqs=in_soqs)) + outs = self.add_t(bloq, **in_soqs) if len(outs) == 0: return None if len(outs) == 1: diff --git a/qualtran/bloqs/swap_network_cirq_test.py b/qualtran/bloqs/swap_network_cirq_test.py index f0dd4b3b0..390f0534e 100644 --- a/qualtran/bloqs/swap_network_cirq_test.py +++ b/qualtran/bloqs/swap_network_cirq_test.py @@ -39,12 +39,11 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe # Allocate selection and target qubits. all_qubits = cirq.LineQubit.range(cirq.num_qubits(gate)) selection = all_qubits[:selection_bitsize] - targets = { - f'targets_{i}': all_qubits[st : st + target_bitsize] - for i, st in enumerate(range(selection_bitsize, len(all_qubits), target_bitsize)) - } + targets = np.asarray(all_qubits[selection_bitsize:]).reshape( + (n_target_registers, target_bitsize) + ) # Create a circuit. - circuit = cirq.Circuit(gate.on_registers(selection=selection, **targets)) + circuit = cirq.Circuit(gate.on_registers(selection=selection, targets=targets)) # Load data[i] in i'th target register; where each register is of size target_bitsize data = [random.randint(0, 2**target_bitsize - 1) for _ in range(n_target_registers)] diff --git a/qualtran/cirq_interop/_cirq_interop.py b/qualtran/cirq_interop/_cirq_interop.py index 0d4dba77a..0036aebc1 100644 --- a/qualtran/cirq_interop/_cirq_interop.py +++ b/qualtran/cirq_interop/_cirq_interop.py @@ -211,7 +211,9 @@ def _update_assign_from_cirq_quregs( arr = np.asarray(arr) full_shape = reg.shape + (reg.bitsize,) if arr.shape != full_shape: - raise ValueError(f"Incorrect shape {arr.shape} received for {binst}.{reg.name}") + raise ValueError( + f"Incorrect shape {arr.shape} received for {binst}.{reg.name}. Expected {full_shape}." + ) for idx in reg.all_idxs(): soq = Soquet(binst, reg, idx=idx) @@ -356,9 +358,7 @@ def registers(self) -> LegacyRegisters: return self._legacy_regs @staticmethod - def _init_legacy_regs( - bloq: Bloq, - ) -> Tuple[LegacyRegisters, Mapping[str, Tuple[Register, Tuple[int, ...]]]]: + def _init_legacy_regs(bloq: Bloq) -> Tuple[LegacyRegisters, Mapping[str, Register]]: """Initialize legacy registers. We flatten multidimensional registers and annotate non-thru registers with @@ -373,18 +373,10 @@ def _init_legacy_regs( side_suffixes = {Side.LEFT: '_l', Side.RIGHT: '_r', Side.THRU: ''} compat_name_map = {} for reg in bloq.signature: - if not reg.shape: - compat_name = f'{reg.name}{side_suffixes[reg.side]}' - compat_name_map[compat_name] = (reg, ()) - legacy_regs.append(LegacyRegister(name=compat_name, shape=reg.bitsize)) - continue - - for idx in reg.all_idxs(): - idx_str = '_'.join(str(i) for i in idx) - compat_name = f'{reg.name}{side_suffixes[reg.side]}_{idx_str}' - compat_name_map[compat_name] = (reg, idx) - legacy_regs.append(LegacyRegister(name=compat_name, shape=reg.bitsize)) - + compat_name = f'{reg.name}{side_suffixes[reg.side]}' + compat_name_map[compat_name] = reg + full_shape = reg.shape + (reg.bitsize,) + legacy_regs.append(LegacyRegister(name=compat_name, shape=full_shape)) return LegacyRegisters(legacy_regs), compat_name_map @classmethod @@ -405,27 +397,25 @@ def bloq_on( op: A cirq operation whose gate is the `BloqAsCirqGate`-wrapped version of `bloq`. cirq_quregs: The output cirq qubit registers. """ - flat_qubits: List[cirq.Qid] = [] + bloq_quregs: Dict[str, 'CirqQuregT'] = {} out_quregs: Dict[str, 'CirqQuregT'] = {} for reg in bloq.signature: if reg.side is Side.THRU: - for i, q in enumerate(cirq_quregs[reg.name].reshape(-1)): - flat_qubits.append(q) + bloq_quregs[reg.name] = cirq_quregs[reg.name] out_quregs[reg.name] = cirq_quregs[reg.name] elif reg.side is Side.LEFT: - for i, q in enumerate(cirq_quregs[reg.name].reshape(-1)): - flat_qubits.append(q) + bloq_quregs[f'{reg.name}_l'] = cirq_quregs[reg.name] qubit_manager.qfree(cirq_quregs[reg.name].reshape(-1)) del cirq_quregs[reg.name] elif reg.side is Side.RIGHT: new_qubits = qubit_manager.qalloc(reg.total_bits()) - flat_qubits.extend(new_qubits) - out_quregs[reg.name] = np.array(new_qubits).reshape(reg.shape + (reg.bitsize,)) - - return BloqAsCirqGate(bloq=bloq).on(*flat_qubits), out_quregs + full_shape = reg.shape + (reg.bitsize,) + out_quregs[reg.name] = np.array(new_qubits).reshape(full_shape) + bloq_quregs[f'{reg.name}_r'] = out_quregs[reg.name] + return BloqAsCirqGate(bloq=bloq).on_registers(**bloq_quregs), out_quregs def decompose_from_registers( - self, context: cirq.DecompositionContext, **qubit_regs: Sequence[cirq.Qid] + self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: """Implementation of the GatesWithRegisters decompose method. @@ -434,7 +424,7 @@ def decompose_from_registers( Args: context: `cirq.DecompositionContext` stores options for decomposing gates (eg: cirq.QubitManager). - **qubit_regs: Sequences of cirq qubits as expected for the legacy register shims + **quregs: Sequences of cirq qubits as expected for the legacy register shims of the bloq's registers. Returns: @@ -442,20 +432,10 @@ def decompose_from_registers( """ cbloq = self._bloq.decompose_bloq() - # Initialize shapely qubit registers to pass to bloqs infrastructure cirq_quregs: Dict[str, CirqQuregT] = {} - for reg in self._bloq.signature: - if reg.shape: - shape = reg.shape + (reg.bitsize,) - cirq_quregs[reg.name] = np.empty(shape, dtype=object) - - # Shapefy the provided cirq qubits - for compat_name, qubits in qubit_regs.items(): - reg, idx = self._compat_name_map[compat_name] - if idx == (): - cirq_quregs[reg.name] = np.asarray(qubits) - else: - cirq_quregs[reg.name][idx] = np.asarray(qubits) + for compat_name, qubits in quregs.items(): + reg = self._compat_name_map[compat_name] + cirq_quregs[reg.name] = np.asarray(qubits) circuit, _ = cbloq.to_cirq_circuit(qubit_manager=context.qubit_manager, **cirq_quregs) return circuit @@ -481,7 +461,8 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ symbs = reg_to_wires(reg) assert len(symbs) == reg.total_bits() wire_symbols.extend(symbs) - + if self._reg_to_wires is None: + wire_symbols[0] = self._bloq.pretty_name() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def __eq__(self, other): diff --git a/qualtran/cirq_interop/_cirq_interop_test.py b/qualtran/cirq_interop/_cirq_interop_test.py index 2365180e7..45c1cf8c5 100644 --- a/qualtran/cirq_interop/_cirq_interop_test.py +++ b/qualtran/cirq_interop/_cirq_interop_test.py @@ -23,7 +23,9 @@ from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature, Soquet, SoquetT from qualtran.bloqs.and_bloq import MultiAnd from qualtran.bloqs.basic_gates import XGate +from qualtran.bloqs.swap_network import SwapWithZero from qualtran.cirq_interop import ( + BloqAsCirqGate, cirq_optree_to_cbloq, CirqGateAsBloq, CirqQuregT, @@ -178,7 +180,7 @@ def test_bloq_as_cirq_gate_left_register(): bb.free(q) cbloq = bb.finalize() circuit, _ = cbloq.to_cirq_circuit() - cirq.testing.assert_has_diagram(circuit, """_c(0): ───alloc───X───free───""") + cirq.testing.assert_has_diagram(circuit, """_c(0): ───Allocate───X───Free───""") @frozen @@ -223,5 +225,76 @@ def test_bloq_decompose_from_cirq_op(): TestCNOTSymbolic().decompose_bloq() +def test_bloq_as_cirq_gate_multi_dimensional_signature(): + bloq = SwapWithZero(2, 3, 4) + cirq_quregs = bloq.signature.get_cirq_quregs() + op = BloqAsCirqGate(bloq).on_registers(**cirq_quregs) + cirq.testing.assert_has_diagram( + cirq.Circuit(op), + ''' +selection0: ──────SwapWithZero─── + │ +selection1: ──────selection────── + │ +targets[0, 0]: ───targets──────── + │ +targets[0, 1]: ───targets──────── + │ +targets[0, 2]: ───targets──────── + │ +targets[1, 0]: ───targets──────── + │ +targets[1, 1]: ───targets──────── + │ +targets[1, 2]: ───targets──────── + │ +targets[2, 0]: ───targets──────── + │ +targets[2, 1]: ───targets──────── + │ +targets[2, 2]: ───targets──────── + │ +targets[3, 0]: ───targets──────── + │ +targets[3, 1]: ───targets──────── + │ +targets[3, 2]: ───targets──────── +''', + ) + cbloq = bloq.decompose_bloq() + cirq.testing.assert_has_diagram( + cbloq.to_cirq_circuit(**cirq_quregs)[0], + ''' +selection0: ──────────────────────────────@(approx)─── + │ +selection1: ──────@(approx)───@(approx)───┼─────────── + │ │ │ +targets[0, 0]: ───×(x)────────┼───────────×(x)──────── + │ │ │ +targets[0, 1]: ───×(x)────────┼───────────×(x)──────── + │ │ │ +targets[0, 2]: ───×(x)────────┼───────────×(x)──────── + │ │ │ +targets[1, 0]: ───×(y)────────┼───────────┼─────────── + │ │ │ +targets[1, 1]: ───×(y)────────┼───────────┼─────────── + │ │ │ +targets[1, 2]: ───×(y)────────┼───────────┼─────────── + │ │ +targets[2, 0]: ───────────────×(x)────────×(y)──────── + │ │ +targets[2, 1]: ───────────────×(x)────────×(y)──────── + │ │ +targets[2, 2]: ───────────────×(x)────────×(y)──────── + │ +targets[3, 0]: ───────────────×(y)──────────────────── + │ +targets[3, 1]: ───────────────×(y)──────────────────── + │ +targets[3, 2]: ───────────────×(y)──────────────────── +''', + ) + + def test_notebook(): execute_notebook('cirq_interop')