From c98612fa275ea9d4506f8969800bb9c163034e17 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 9 Sep 2024 11:55:09 -0700 Subject: [PATCH] Bugfix in `QROAMCleanAdjointWrapper` to correctly pass `num_controls` to `QROAMCleanAdjoint` (#1394) * Bugfix in QROAMCleanAdjointWrapper * More improvements to QROM bloqs * Revert SelectSwapQROM.with_log_block_sizes and handle use_dirty_ancilla properly --- qualtran/bloqs/data_loading/qroam_clean.py | 5 +++ .../bloqs/data_loading/qroam_clean_test.py | 20 ++++-------- qualtran/bloqs/data_loading/qrom.py | 9 +++++- .../bloqs/data_loading/select_swap_qrom.py | 32 +++++++++++++------ .../data_loading/select_swap_qrom_test.py | 11 +++++++ 5 files changed, 54 insertions(+), 23 deletions(-) diff --git a/qualtran/bloqs/data_loading/qroam_clean.py b/qualtran/bloqs/data_loading/qroam_clean.py index 30e2ec25d..e7dc9c78e 100644 --- a/qualtran/bloqs/data_loading/qroam_clean.py +++ b/qualtran/bloqs/data_loading/qroam_clean.py @@ -71,6 +71,8 @@ def get_optimal_log_block_size_clean_ancilla( k = log2(qroam_block_size) if is_symbolic(k): return k + if k < 0: + return 0 k_int = np.array([np.floor(k), np.ceil(k)]) return int(k_int[np.argmin(qroam_cost(2**k_int, data_size, bitsize, adjoint))]) @@ -233,9 +235,11 @@ def qroam_clean_adjoint_bloq(self) -> 'QROAMCleanAdjoint': if self.qroam_clean.has_data(): return QROAMCleanAdjoint.build_from_data( *self.qroam_clean.batched_data_permuted, + target_bitsizes=self.qroam_clean.target_bitsizes, target_shapes=(self.qroam_clean.block_sizes,) * len(self.qroam_clean.batched_data_permuted), log_block_sizes=self.log_block_sizes, + num_controls=self.qroam_clean.num_controls, ) else: return QROAMCleanAdjoint.build_from_bitsize( @@ -244,6 +248,7 @@ def qroam_clean_adjoint_bloq(self) -> 'QROAMCleanAdjoint': target_shapes=(self.qroam_clean.block_sizes,) * len(self.qroam_clean.target_bitsizes), log_block_sizes=self.log_block_sizes, + num_controls=self.qroam_clean.num_controls, ) def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']: diff --git a/qualtran/bloqs/data_loading/qroam_clean_test.py b/qualtran/bloqs/data_loading/qroam_clean_test.py index 8e8782677..96f30af53 100644 --- a/qualtran/bloqs/data_loading/qroam_clean_test.py +++ b/qualtran/bloqs/data_loading/qroam_clean_test.py @@ -80,19 +80,18 @@ def test_qroam_clean_classical_sim(): # 1D data, 1 dataset N, max_N, log_block_sizes = 25, 2**10, 3 data = rng.integers(max_N, size=N) - bloq = QROAMClean.build_from_data(data, log_block_sizes=log_block_sizes) + bloq = QROAMClean.build_from_data(data, log_block_sizes=log_block_sizes, num_controls=1) cbloq = bloq.decompose_bloq() bloq_inv = bloq.adjoint() assert isinstance(bloq_inv, QROAMCleanAdjointWrapper) for x in range(N): - vals = bloq.call_classically(selection=x) - cvals = cbloq.call_classically(selection=x) - assert vals[0:2] == cvals[0:2] == (x, data[x]) - assert np.array_equal(vals[2], cvals[2]) - target_with_junk = np.array([vals[1], *vals[2]]) # type: ignore[misc] + vals = bloq.call_classically(selection=x, control=1) + cvals = cbloq.call_classically(selection=x, control=1) + assert vals[0:3] == cvals[0:3] == (1, x, data[x]) + assert np.array_equal(vals[3], cvals[3]) assert bloq_inv.call_classically( - selection=vals[0], target0_=vals[1], junk_target0_=vals[2] - ) == (x,) + control=vals[0], selection=vals[1], target0_=vals[2], junk_target0_=vals[3] + ) == (1, x) # 2D data, 1 datasets N, M, max_N, log_block_sizes = 7, 11, 2**5, (2, 3) @@ -107,7 +106,6 @@ def test_qroam_clean_classical_sim(): cvals = cbloq.call_classically(selection0=x, selection1=y) assert vals[0:3] == cvals[0:3] == (x, y, data[x][y]) assert np.array_equal(vals[3], cvals[3]) - # target_with_junk = np.array([vals[2], *vals[3]]).reshape(2 ** np.array(log_block_sizes)) # type: ignore[misc] assert bloq_inv.call_classically( selection0=x, selection1=y, target0_=vals[2], junk_target0_=vals[3] ) == (x, y) @@ -128,8 +126,6 @@ def test_qroam_clean_classical_sim_multi_dataset(): cvals = cbloq.call_classically(selection=x) assert vals[0:3] == cvals[0:3] == (x, data[0][x], data[1][x]) assert np.array_equal(vals[3], cvals[3]) and np.array_equal(vals[4], cvals[4]) - targets_with_junk0 = np.array([vals[1], *vals[3]]) # type: ignore[misc] - targets_with_junk1 = np.array([vals[2], *vals[4]]) # type: ignore[misc] assert bloq_inv.call_classically( selection=vals[0], target0_=vals[1], @@ -154,8 +150,6 @@ def test_qroam_clean_classical_sim_multi_dataset(): cvals = cbloq.call_classically(selection0=x, selection1=y) assert vals[0:4] == cvals[0:4] == (x, y, data[0][x][y], data[1][x][y]) assert np.array_equal(vals[4], cvals[4]) and np.array_equal(vals[5], cvals[5]) - targets_with_junk0 = np.array([vals[2], *vals[4]]).reshape(2**log_block_sizes) # type: ignore[misc] - targets_with_junk1 = np.array([vals[3], *vals[5]]).reshape(2**log_block_sizes) # type: ignore[misc] assert bloq_inv.call_classically( selection0=x, selection1=y, diff --git a/qualtran/bloqs/data_loading/qrom.py b/qualtran/bloqs/data_loading/qrom.py index 61ac5a7b3..59ace8544 100644 --- a/qualtran/bloqs/data_loading/qrom.py +++ b/qualtran/bloqs/data_loading/qrom.py @@ -22,7 +22,7 @@ import sympy from numpy.typing import ArrayLike, NDArray -from qualtran import bloq_example, BloqDocSpec, QUInt, Register +from qualtran import bloq_example, BloqDocSpec, DecomposeTypeError, QUInt, Register from qualtran._infra.gate_with_registers import merge_qubits from qualtran.bloqs.arithmetic import XorK from qualtran.bloqs.basic_gates import CNOT @@ -153,6 +153,13 @@ def decompose_zero_selection( yield cirq.inverse(multi_controlled_and) context.qubit_manager.qfree(list(junk.flatten()) + [and_target]) + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + if self.has_data(): + return super().decompose_from_registers(context=context, **quregs) + raise DecomposeTypeError(f"Cannot decompose symbolic {self} with no data.") + def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int): if not self.has_data(): return False diff --git a/qualtran/bloqs/data_loading/select_swap_qrom.py b/qualtran/bloqs/data_loading/select_swap_qrom.py index 52952719e..74b86c269 100644 --- a/qualtran/bloqs/data_loading/select_swap_qrom.py +++ b/qualtran/bloqs/data_loading/select_swap_qrom.py @@ -39,15 +39,19 @@ def find_optimal_log_block_size( - iteration_length: SymbolicInt, target_bitsize: SymbolicInt + iteration_length: SymbolicInt, target_bitsize: SymbolicInt, use_dirty_ancilla: bool = False ) -> SymbolicInt: """Find optimal block size, which is a power of 2, for SelectSwapQROM. This functions returns the optimal `k` s.t. * k is in an integer and k >= 0. - * iteration_length/2^k + target_bitsize*(2^k - 1) is minimized. + * iteration_length/2^k + target_bitsize*(2^k - 1) is minimized if use_dirty_ancilla is False + * iteration_length/2^k + 2*target_bitsize*(2^k - 1) is minimized if use_dirty_ancilla is True + The corresponding block size for SelectSwapQROM would be 2^k. """ + if not use_dirty_ancilla: + target_bitsize = 2 * target_bitsize k: SymbolicFloat = 0.5 * log2(iteration_length / target_bitsize) if is_symbolic(k): return ceil(k) @@ -62,6 +66,14 @@ def value(kk: List[int]): return int(k_int[np.argmin(value(k_int))]) # obtain optimal k +def _find_optimal_log_block_size_helper(qrom: 'SelectSwapQROM') -> Tuple[SymbolicInt, ...]: + target_bitsize = sum(qrom.target_bitsizes) * sum(prod(shape) for shape in qrom.target_shapes) + return tuple( + find_optimal_log_block_size(ilen, target_bitsize, qrom.use_dirty_ancilla) + for ilen in qrom.data_shape + ) + + def _alloc_anc_for_reg( bb: 'BloqBuilder', dtype: 'QDType', shape: Tuple[int, ...], dirty: bool ) -> 'SoquetT': @@ -120,6 +132,9 @@ class SelectSwapQROM(QROMBase, GateWithRegisters): # type: ignore[misc] log_block_sizes: Tuple[SymbolicInt, ...] = attrs.field( converter=lambda x: tuple(x.tolist() if isinstance(x, np.ndarray) else x) + if x is not None + else x, + default=None, ) use_dirty_ancilla: bool = True @@ -129,6 +144,8 @@ def __attrs_post_init__(self): raise ValueError( f"{type(self)} currently only supports target registers of shape (). Found {self.target_shapes}" ) + if self.log_block_sizes is None: + object.__setattr__(self, "log_block_sizes", _find_optimal_log_block_size_helper(self)) @cached_property def signature(self) -> Signature: @@ -137,13 +154,6 @@ def signature(self) -> Signature: ) # Builder methods and helpers. - @log_block_sizes.default - def _default_log_block_sizes(self) -> Tuple[SymbolicInt, ...]: - target_bitsize = sum(self.target_bitsizes) * sum( - prod(shape) for shape in self.target_shapes - ) - return tuple(find_optimal_log_block_size(ilen, target_bitsize) for ilen in self.data_shape) - def is_symbolic(self) -> bool: return super().is_symbolic() or is_symbolic(*self.log_block_sizes) @@ -160,6 +170,8 @@ def build_from_data( *data, target_bitsizes=target_bitsizes, num_controls=num_controls ) qroam = attrs.evolve(qroam, use_dirty_ancilla=use_dirty_ancilla) + if log_block_sizes is None: + log_block_sizes = _find_optimal_log_block_size_helper(qroam) return qroam.with_log_block_sizes(log_block_sizes=log_block_sizes) @classmethod @@ -180,6 +192,8 @@ def build_from_bitsize( num_controls=num_controls, ) qroam = attrs.evolve(qroam, use_dirty_ancilla=use_dirty_ancilla) + if log_block_sizes is None: + log_block_sizes = _find_optimal_log_block_size_helper(qroam) return qroam.with_log_block_sizes(log_block_sizes=log_block_sizes) def with_log_block_sizes( diff --git a/qualtran/bloqs/data_loading/select_swap_qrom_test.py b/qualtran/bloqs/data_loading/select_swap_qrom_test.py index 49a0bab72..975120bbf 100644 --- a/qualtran/bloqs/data_loading/select_swap_qrom_test.py +++ b/qualtran/bloqs/data_loading/select_swap_qrom_test.py @@ -232,3 +232,14 @@ def test_tensor_contraction(use_dirty_ancilla: bool): ) qrom = QROM.build_from_data(data) np.testing.assert_allclose(qrom.tensor_contract(), qroam.tensor_contract(), atol=1e-8) + + +def test_select_swap_block_sizes(): + data = [*range(1600)] + qroam_opt = SelectSwapQROM.build_from_data(data) + qroam_subopt = SelectSwapQROM.build_from_data(data, log_block_sizes=(8,)) + assert qroam_opt.block_sizes == (16,) + assert qroam_opt.t_complexity().t < qroam_subopt.t_complexity().t + + qroam = SelectSwapQROM.build_from_data(data, use_dirty_ancilla=False) + assert qroam.block_sizes == (8,)