From d0c49b6601a7f70b540e3852fcf3be0bc373b90c Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 16 May 2024 17:50:30 -0700 Subject: [PATCH] First attempt at bugfix --- qualtran/_infra/adjoint.py | 15 +++++++----- qualtran/bloqs/mcmt/and_bloq.py | 1 + .../prepare_uniform_superposition_test.py | 13 ++++++++++ qualtran/bloqs/util_bloqs.py | 15 ++++++++++++ qualtran/cirq_interop/_bloq_to_cirq_test.py | 2 +- .../cirq_interop/_interop_qubit_manager.py | 24 +++++++++++++++++-- 6 files changed, 61 insertions(+), 9 deletions(-) diff --git a/qualtran/_infra/adjoint.py b/qualtran/_infra/adjoint.py index 16b9d2d3d..99a78b2fb 100644 --- a/qualtran/_infra/adjoint.py +++ b/qualtran/_infra/adjoint.py @@ -142,12 +142,15 @@ def decompose_bloq(self) -> 'CompositeBloq': """The decomposition is the adjoint of `subbloq`'s decomposition.""" return self.subbloq.decompose_bloq().adjoint() - def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var] - ) -> cirq.OP_TREE: - if isinstance(self.subbloq, GateWithRegisters): - return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs)) - return super().decompose_from_registers(context=context, **quregs) + # def decompose_from_registers( + # self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var] + # ) -> cirq.OP_TREE: + # if isinstance(self.subbloq, GateWithRegisters) or hasattr( + # self.subbloq, 'decompose_from_registers' + # ): + # yield cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs)) + # else: + # yield super().decompose_from_registers(context=context, **quregs) def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' diff --git a/qualtran/bloqs/mcmt/and_bloq.py b/qualtran/bloqs/mcmt/and_bloq.py index 95dd44b57..f12ca4f62 100644 --- a/qualtran/bloqs/mcmt/and_bloq.py +++ b/qualtran/bloqs/mcmt/and_bloq.py @@ -320,6 +320,7 @@ def _decompose_via_tree( def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> Iterator[cirq.OP_TREE]: + assert 'junk' in quregs control, ancilla, target = ( quregs['ctrl'].flatten(), quregs.get('junk', np.array([])).flatten(), diff --git a/qualtran/bloqs/state_preparation/prepare_uniform_superposition_test.py b/qualtran/bloqs/state_preparation/prepare_uniform_superposition_test.py index 912e65e24..36bfc14db 100644 --- a/qualtran/bloqs/state_preparation/prepare_uniform_superposition_test.py +++ b/qualtran/bloqs/state_preparation/prepare_uniform_superposition_test.py @@ -95,3 +95,16 @@ def test_prepare_uniform_superposition_consistent_protocols(): PrepareUniformSuperposition(5, cvs=()), PrepareUniformSuperposition(5, cvs=[]), ) + + +def test_prepare_uniform_superposition_adjoint(): + n = 3 + target = cirq.NamedQubit.range((n - 1).bit_length(), prefix='target') + control = [cirq.NamedQubit('control')] + op = PrepareUniformSuperposition(n, cvs=(0,)).on_registers(ctrl=control, target=target) + gqm = cirq.GreedyQubitManager(prefix="_ancilla", maximize_reuse=True) + context = cirq.DecompositionContext(gqm) + circuit = cirq.Circuit(op, cirq.decompose(cirq.inverse(op), context=context)) + identity = cirq.Circuit(cirq.identity_each(*circuit.all_qubits())).final_state_vector() + result = cirq.Simulator(dtype=np.complex128).simulate(circuit) + np.testing.assert_allclose(result.final_state_vector, identity, atol=1e-8) diff --git a/qualtran/bloqs/util_bloqs.py b/qualtran/bloqs/util_bloqs.py index 4137a0376..a803aa43a 100644 --- a/qualtran/bloqs/util_bloqs.py +++ b/qualtran/bloqs/util_bloqs.py @@ -365,6 +365,15 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym assert reg.name == 'reg' return directional_text_box('alloc', Side.RIGHT) + def as_cirq_op( + self, qubit_manager: 'cirq.QubitManager' + ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: + shape = (*self.signature[0].shape, self.signature[0].bitsize) + return ( + None, + {'reg': np.array(qubit_manager.qalloc(self.signature.n_qubits())).reshape(shape)}, + ) + @frozen class Free(Bloq): @@ -415,6 +424,12 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym assert reg.name == 'reg' return directional_text_box('free', Side.LEFT) + def as_cirq_op( + self, qubit_manager: 'cirq.QubitManager', reg: 'CirqQuregT' + ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: + qubit_manager.qfree(reg.flatten().tolist()) + return (None, {}) + @frozen class ArbitraryClifford(Bloq): diff --git a/qualtran/cirq_interop/_bloq_to_cirq_test.py b/qualtran/cirq_interop/_bloq_to_cirq_test.py index 9d9ab0f2e..7bb7b0042 100644 --- a/qualtran/cirq_interop/_bloq_to_cirq_test.py +++ b/qualtran/cirq_interop/_bloq_to_cirq_test.py @@ -187,7 +187,7 @@ def test_bloq_as_cirq_gate_left_register(): bb.free(q) cbloq = bb.finalize() circuit = cbloq.to_cirq_circuit() - cirq.testing.assert_has_diagram(circuit, """_c(0): ───alloc───X───free───""") + cirq.testing.assert_has_diagram(circuit, """_c(0): ───X───""") def test_bloq_as_cirq_gate_for_mod_exp(): diff --git a/qualtran/cirq_interop/_interop_qubit_manager.py b/qualtran/cirq_interop/_interop_qubit_manager.py index 0cb2cccd1..650290b44 100644 --- a/qualtran/cirq_interop/_interop_qubit_manager.py +++ b/qualtran/cirq_interop/_interop_qubit_manager.py @@ -28,10 +28,30 @@ def __init__(self, qm: Optional[cirq.QubitManager] = None): self._managed_qubits: Set[cirq.Qid] = set() def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']: - return self._qm.qalloc(n, dim) + ret = [] + qubits_to_free = [] + while len(ret) < n: + new_alloc = self._qm.qalloc(n - len(ret), dim) + for q in new_alloc: + if q in self._managed_qubits: + qubits_to_free.append(q) + else: + ret.append(q) + self._qm.qfree(qubits_to_free) + return ret def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']: - return self._qm.qborrow(n, dim) + ret = [] + qubits_to_free = [] + while len(ret) < n: + new_alloc = self._qm.qborrow(n - len(ret), dim) + for q in new_alloc: + if q in self._managed_qubits: + qubits_to_free.append(q) + else: + ret.append(q) + self._qm.qfree(qubits_to_free) + return ret def manage_qubits(self, qubits: Iterable[cirq.Qid]): self._managed_qubits |= set(qubits)