From 2cf9930b828c2301accd465fa0f86a48fbe11c6a Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 29 Sep 2023 17:11:00 -0700 Subject: [PATCH] Extend functionality for CirqGateAsBloq.tensor_contract and add more tests --- qualtran/cirq_interop/_bloq_to_cirq_test.py | 5 +++ qualtran/cirq_interop/_cirq_to_bloq.py | 42 ++++++++++++++++++--- qualtran/cirq_interop/_cirq_to_bloq_test.py | 18 ++++++++- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/qualtran/cirq_interop/_bloq_to_cirq_test.py b/qualtran/cirq_interop/_bloq_to_cirq_test.py index ae70329f5..6eae91e40 100644 --- a/qualtran/cirq_interop/_bloq_to_cirq_test.py +++ b/qualtran/cirq_interop/_bloq_to_cirq_test.py @@ -238,6 +238,11 @@ def test_bloq_as_cirq_gate_for_mod_exp(): # Use Cirq's infrastructure to construct an operation and corresponding decomposition. quregs = cirq_ft.infra.get_named_qubits(gate.signature) op = gate.on_registers(**quregs) + # cirq.decompose_once(op) delegates to underlying Bloq's decomposition specified in + # `bloq.decompose_bloq()` and wraps resulting composite bloq in a Cirq op-tree. Note + # how `BloqAsCirqGate.decompose_with_registers()` automatically takes care of mapping + # newly allocated RIGHT registers in the decomposition to the one's specified by the user + # when constructing the original operation (in this case, register `x`). circuit = cirq.Circuit(op, cirq.decompose_once(op)) assert cirq_ft.t_complexity(circuit) == 2 * mod_exp.t_complexity() cirq.testing.assert_has_diagram( diff --git a/qualtran/cirq_interop/_cirq_to_bloq.py b/qualtran/cirq_interop/_cirq_to_bloq.py index 357b051dc..956e0088e 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq.py +++ b/qualtran/cirq_interop/_cirq_to_bloq.py @@ -14,6 +14,7 @@ """Cirq gates/circuits to Qualtran Bloqs conversion.""" import itertools +from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union @@ -84,13 +85,42 @@ def add_my_tensors( incoming: Dict[str, 'SoquetT'], outgoing: Dict[str, 'SoquetT'], ): - new_shape = [ - *itertools.chain.from_iterable( - (2**reg.bitsize,) * int(np.prod(reg.shape)) - for reg in [*self.signature.rights(), *self.signature.lefts()] + if not cirq.has_unitary(self.gate): + raise NotImplementedError( + f"CirqGateAsBloq.add_my_tensors is currently supported only for unitary gates. " + f"Found {self.gate}." ) - ] - unitary = cirq.unitary(self.gate).reshape(new_shape) + unitary_shape = [] + reg_to_idx = defaultdict(list) + for reg in self.cirq_registers: + start = len(unitary_shape) + for i in range(int(np.prod(reg.shape))): + reg_to_idx[reg.name].append(start + i) + unitary_shape.append(2**reg.bitsize) + + unitary_shape = (*unitary_shape, *unitary_shape) + unitary = cirq.unitary(self.gate).reshape(unitary_shape) + idx: List[Union[int, slice]] = [slice(x) for x in unitary_shape] + n = len(unitary_shape) // 2 + for reg in self.signature: + if reg.side == Side.LEFT: + for i in reg_to_idx[reg.name]: + # LEFT register ends, extract right subspace that's equivalent to 0. + idx[i] = 0 + if reg.side == Side.RIGHT: + for i in reg_to_idx[reg.name]: + # Right register begins, extract the left subspace that's equivalent to 0. + idx[i + n] = 0 + unitary = unitary[tuple(idx)] + new_shape = tuple( + [ + *itertools.chain.from_iterable( + (2**reg.bitsize,) * int(np.prod(reg.shape)) + for reg in [*self.signature.rights(), *self.signature.lefts()] + ) + ] + ) + assert unitary.shape == new_shape incoming_list = [ *itertools.chain.from_iterable( [np.array(incoming[reg.name]).flatten() for reg in self.signature.lefts()] diff --git a/qualtran/cirq_interop/_cirq_to_bloq_test.py b/qualtran/cirq_interop/_cirq_to_bloq_test.py index 488f8e497..52c1eae58 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq_test.py +++ b/qualtran/cirq_interop/_cirq_to_bloq_test.py @@ -22,7 +22,8 @@ from attrs import frozen import qualtran -from qualtran import Bloq, CompositeBloq, Side, Signature +from qualtran import Bloq, BloqBuilder, CompositeBloq, Side, Signature +from qualtran.bloqs.basic_gates import OneState from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Split from qualtran.cirq_interop import ( cirq_optree_to_cbloq, @@ -57,7 +58,7 @@ def signature(self) -> Signature: return Signature.build(control=c, target=t) -def test_cirq_gate(): +def test_cirq_gate_as_bloq_for_trivial_gates(): x = CirqGateAsBloq(cirq.X) rx = CirqGateAsBloq(cirq.Rx(rads=0.123 * np.pi)) toffoli = CirqGateAsBloq(cirq.TOFFOLI) @@ -80,6 +81,19 @@ def test_cirq_gate(): assert toffoli.short_name() == 'cirq.TOFFOLI' +def test_cirq_gate_as_bloq_tensor_contract_for_and_gate(): + and_gate = cirq_ft.And() + bb = BloqBuilder() + ctrl = [bb.add(OneState()) for _ in range(2)] + ctrl, target = bb.add(CirqGateAsBloq(and_gate), ctrl=ctrl) + cbloq = bb.finalize(ctrl=ctrl, target=target) + state_vector = cbloq.tensor_contract() + assert np.isclose(state_vector[7], 1) + + with pytest.raises(NotImplementedError, match="supported only for unitary gates"): + _ = CirqGateAsBloq(cirq_ft.And(adjoint=True)).as_composite_bloq().tensor_contract() + + def test_bloq_decompose_from_cirq_op(): tb = TestCNOT() assert len(tb.signature) == 2