Skip to content

Commit

Permalink
Bugfix in QROAMCleanAdjointWrapper to correctly pass num_controls
Browse files Browse the repository at this point in the history
… to `QROAMCleanAdjoint` (#1394)

* Bugfix in QROAMCleanAdjointWrapper

* More improvements to QROM bloqs

* Revert SelectSwapQROM.with_log_block_sizes and handle use_dirty_ancilla properly
  • Loading branch information
tanujkhattar authored Sep 9, 2024
1 parent b5722c0 commit c98612f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 23 deletions.
5 changes: 5 additions & 0 deletions qualtran/bloqs/data_loading/qroam_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_optimal_log_block_size_clean_ancilla(
k = log2(qroam_block_size)
if is_symbolic(k):
return k
if k < 0:
return 0
k_int = np.array([np.floor(k), np.ceil(k)])
return int(k_int[np.argmin(qroam_cost(2**k_int, data_size, bitsize, adjoint))])

Expand Down Expand Up @@ -233,9 +235,11 @@ def qroam_clean_adjoint_bloq(self) -> 'QROAMCleanAdjoint':
if self.qroam_clean.has_data():
return QROAMCleanAdjoint.build_from_data(
*self.qroam_clean.batched_data_permuted,
target_bitsizes=self.qroam_clean.target_bitsizes,
target_shapes=(self.qroam_clean.block_sizes,)
* len(self.qroam_clean.batched_data_permuted),
log_block_sizes=self.log_block_sizes,
num_controls=self.qroam_clean.num_controls,
)
else:
return QROAMCleanAdjoint.build_from_bitsize(
Expand All @@ -244,6 +248,7 @@ def qroam_clean_adjoint_bloq(self) -> 'QROAMCleanAdjoint':
target_shapes=(self.qroam_clean.block_sizes,)
* len(self.qroam_clean.target_bitsizes),
log_block_sizes=self.log_block_sizes,
num_controls=self.qroam_clean.num_controls,
)

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
Expand Down
20 changes: 7 additions & 13 deletions qualtran/bloqs/data_loading/qroam_clean_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,18 @@ def test_qroam_clean_classical_sim():
# 1D data, 1 dataset
N, max_N, log_block_sizes = 25, 2**10, 3
data = rng.integers(max_N, size=N)
bloq = QROAMClean.build_from_data(data, log_block_sizes=log_block_sizes)
bloq = QROAMClean.build_from_data(data, log_block_sizes=log_block_sizes, num_controls=1)
cbloq = bloq.decompose_bloq()
bloq_inv = bloq.adjoint()
assert isinstance(bloq_inv, QROAMCleanAdjointWrapper)
for x in range(N):
vals = bloq.call_classically(selection=x)
cvals = cbloq.call_classically(selection=x)
assert vals[0:2] == cvals[0:2] == (x, data[x])
assert np.array_equal(vals[2], cvals[2])
target_with_junk = np.array([vals[1], *vals[2]]) # type: ignore[misc]
vals = bloq.call_classically(selection=x, control=1)
cvals = cbloq.call_classically(selection=x, control=1)
assert vals[0:3] == cvals[0:3] == (1, x, data[x])
assert np.array_equal(vals[3], cvals[3])
assert bloq_inv.call_classically(
selection=vals[0], target0_=vals[1], junk_target0_=vals[2]
) == (x,)
control=vals[0], selection=vals[1], target0_=vals[2], junk_target0_=vals[3]
) == (1, x)

# 2D data, 1 datasets
N, M, max_N, log_block_sizes = 7, 11, 2**5, (2, 3)
Expand All @@ -107,7 +106,6 @@ def test_qroam_clean_classical_sim():
cvals = cbloq.call_classically(selection0=x, selection1=y)
assert vals[0:3] == cvals[0:3] == (x, y, data[x][y])
assert np.array_equal(vals[3], cvals[3])
# target_with_junk = np.array([vals[2], *vals[3]]).reshape(2 ** np.array(log_block_sizes)) # type: ignore[misc]
assert bloq_inv.call_classically(
selection0=x, selection1=y, target0_=vals[2], junk_target0_=vals[3]
) == (x, y)
Expand All @@ -128,8 +126,6 @@ def test_qroam_clean_classical_sim_multi_dataset():
cvals = cbloq.call_classically(selection=x)
assert vals[0:3] == cvals[0:3] == (x, data[0][x], data[1][x])
assert np.array_equal(vals[3], cvals[3]) and np.array_equal(vals[4], cvals[4])
targets_with_junk0 = np.array([vals[1], *vals[3]]) # type: ignore[misc]
targets_with_junk1 = np.array([vals[2], *vals[4]]) # type: ignore[misc]
assert bloq_inv.call_classically(
selection=vals[0],
target0_=vals[1],
Expand All @@ -154,8 +150,6 @@ def test_qroam_clean_classical_sim_multi_dataset():
cvals = cbloq.call_classically(selection0=x, selection1=y)
assert vals[0:4] == cvals[0:4] == (x, y, data[0][x][y], data[1][x][y])
assert np.array_equal(vals[4], cvals[4]) and np.array_equal(vals[5], cvals[5])
targets_with_junk0 = np.array([vals[2], *vals[4]]).reshape(2**log_block_sizes) # type: ignore[misc]
targets_with_junk1 = np.array([vals[3], *vals[5]]).reshape(2**log_block_sizes) # type: ignore[misc]
assert bloq_inv.call_classically(
selection0=x,
selection1=y,
Expand Down
9 changes: 8 additions & 1 deletion qualtran/bloqs/data_loading/qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sympy
from numpy.typing import ArrayLike, NDArray

from qualtran import bloq_example, BloqDocSpec, QUInt, Register
from qualtran import bloq_example, BloqDocSpec, DecomposeTypeError, QUInt, Register
from qualtran._infra.gate_with_registers import merge_qubits
from qualtran.bloqs.arithmetic import XorK
from qualtran.bloqs.basic_gates import CNOT
Expand Down Expand Up @@ -153,6 +153,13 @@ def decompose_zero_selection(
yield cirq.inverse(multi_controlled_and)
context.qubit_manager.qfree(list(junk.flatten()) + [and_target])

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
if self.has_data():
return super().decompose_from_registers(context=context, **quregs)
raise DecomposeTypeError(f"Cannot decompose symbolic {self} with no data.")

def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int):
if not self.has_data():
return False
Expand Down
32 changes: 23 additions & 9 deletions qualtran/bloqs/data_loading/select_swap_qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@


def find_optimal_log_block_size(
iteration_length: SymbolicInt, target_bitsize: SymbolicInt
iteration_length: SymbolicInt, target_bitsize: SymbolicInt, use_dirty_ancilla: bool = False
) -> SymbolicInt:
"""Find optimal block size, which is a power of 2, for SelectSwapQROM.
This functions returns the optimal `k` s.t.
* k is in an integer and k >= 0.
* iteration_length/2^k + target_bitsize*(2^k - 1) is minimized.
* iteration_length/2^k + target_bitsize*(2^k - 1) is minimized if use_dirty_ancilla is False
* iteration_length/2^k + 2*target_bitsize*(2^k - 1) is minimized if use_dirty_ancilla is True
The corresponding block size for SelectSwapQROM would be 2^k.
"""
if not use_dirty_ancilla:
target_bitsize = 2 * target_bitsize
k: SymbolicFloat = 0.5 * log2(iteration_length / target_bitsize)
if is_symbolic(k):
return ceil(k)
Expand All @@ -62,6 +66,14 @@ def value(kk: List[int]):
return int(k_int[np.argmin(value(k_int))]) # obtain optimal k


def _find_optimal_log_block_size_helper(qrom: 'SelectSwapQROM') -> Tuple[SymbolicInt, ...]:
target_bitsize = sum(qrom.target_bitsizes) * sum(prod(shape) for shape in qrom.target_shapes)
return tuple(
find_optimal_log_block_size(ilen, target_bitsize, qrom.use_dirty_ancilla)
for ilen in qrom.data_shape
)


def _alloc_anc_for_reg(
bb: 'BloqBuilder', dtype: 'QDType', shape: Tuple[int, ...], dirty: bool
) -> 'SoquetT':
Expand Down Expand Up @@ -120,6 +132,9 @@ class SelectSwapQROM(QROMBase, GateWithRegisters): # type: ignore[misc]

log_block_sizes: Tuple[SymbolicInt, ...] = attrs.field(
converter=lambda x: tuple(x.tolist() if isinstance(x, np.ndarray) else x)
if x is not None
else x,
default=None,
)
use_dirty_ancilla: bool = True

Expand All @@ -129,6 +144,8 @@ def __attrs_post_init__(self):
raise ValueError(
f"{type(self)} currently only supports target registers of shape (). Found {self.target_shapes}"
)
if self.log_block_sizes is None:
object.__setattr__(self, "log_block_sizes", _find_optimal_log_block_size_helper(self))

@cached_property
def signature(self) -> Signature:
Expand All @@ -137,13 +154,6 @@ def signature(self) -> Signature:
)

# Builder methods and helpers.
@log_block_sizes.default
def _default_log_block_sizes(self) -> Tuple[SymbolicInt, ...]:
target_bitsize = sum(self.target_bitsizes) * sum(
prod(shape) for shape in self.target_shapes
)
return tuple(find_optimal_log_block_size(ilen, target_bitsize) for ilen in self.data_shape)

def is_symbolic(self) -> bool:
return super().is_symbolic() or is_symbolic(*self.log_block_sizes)

Expand All @@ -160,6 +170,8 @@ def build_from_data(
*data, target_bitsizes=target_bitsizes, num_controls=num_controls
)
qroam = attrs.evolve(qroam, use_dirty_ancilla=use_dirty_ancilla)
if log_block_sizes is None:
log_block_sizes = _find_optimal_log_block_size_helper(qroam)
return qroam.with_log_block_sizes(log_block_sizes=log_block_sizes)

@classmethod
Expand All @@ -180,6 +192,8 @@ def build_from_bitsize(
num_controls=num_controls,
)
qroam = attrs.evolve(qroam, use_dirty_ancilla=use_dirty_ancilla)
if log_block_sizes is None:
log_block_sizes = _find_optimal_log_block_size_helper(qroam)
return qroam.with_log_block_sizes(log_block_sizes=log_block_sizes)

def with_log_block_sizes(
Expand Down
11 changes: 11 additions & 0 deletions qualtran/bloqs/data_loading/select_swap_qrom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,14 @@ def test_tensor_contraction(use_dirty_ancilla: bool):
)
qrom = QROM.build_from_data(data)
np.testing.assert_allclose(qrom.tensor_contract(), qroam.tensor_contract(), atol=1e-8)


def test_select_swap_block_sizes():
data = [*range(1600)]
qroam_opt = SelectSwapQROM.build_from_data(data)
qroam_subopt = SelectSwapQROM.build_from_data(data, log_block_sizes=(8,))
assert qroam_opt.block_sizes == (16,)
assert qroam_opt.t_complexity().t < qroam_subopt.t_complexity().t

qroam = SelectSwapQROM.build_from_data(data, use_dirty_ancilla=False)
assert qroam.block_sizes == (8,)

0 comments on commit c98612f

Please sign in to comment.