Skip to content

Commit

Permalink
Add my_static_costs override for accurate QubitCounts for qrom bl…
Browse files Browse the repository at this point in the history
…oqs (#1414)

* Add my_static_costs for accurate QubitCounts for qrom bloqs

* Fix mypy, pylint and autogenerate notebooks

* Fix failing tests
  • Loading branch information
tanujkhattar authored Sep 24, 2024
1 parent cbfd457 commit 21ef526
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 23 deletions.
7 changes: 6 additions & 1 deletion qualtran/bloqs/chemistry/sparse/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,19 @@ def test_sparse_costs_against_openfermion(num_spin_orb, num_bits_rot_aa):
# Qualtran (constants are not ignored). The difference arises from
# uncontrolled unary iteration used by QROM, which QROAMClean delegates to.
delta_qrom = -2
# The -4 comes from QROAMCleanAdjoint, which delegates to a QROM and SwapWithZero
# and each of them contributes a -2 factor.
delta_qrom_adjoint = -4
# inequality test difference
# https://github.com/quantumlib/Qualtran/issues/235
lte = LessThanEqual(prep_sparse.num_bits_state_prep, prep_sparse.num_bits_state_prep)
lte_cost = get_toffoli_count(lte) + get_toffoli_count(lte.adjoint())
lte_cost_paper = prep_sparse.num_bits_state_prep # inverted at zero cost
delta_ineq = lte_cost - lte_cost_paper
swap_cost = 8 * (num_spin_orb // 2 - 1).bit_length() + 1 # inverted at zero cost
adjusted_cost_qualtran = cost - delta_qrom - delta_uni_prep - delta_ineq - swap_cost
adjusted_cost_qualtran = (
cost - delta_qrom - delta_uni_prep - delta_ineq - swap_cost - delta_qrom_adjoint
)
cost_of = cost_sparse(
num_spin_orb, unused_lambda, num_non_zero, unused_de, num_bits_state_prep, unused_stps
)[0]
Expand Down
18 changes: 15 additions & 3 deletions qualtran/bloqs/data_loading/qroam_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
if TYPE_CHECKING:
from qualtran import Bloq, BloqBuilder, SoquetT, QDType
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator, CostKey

from qualtran.bloqs.data_loading.select_swap_qrom import _alloc_anc_for_reg, SelectSwapQROM

Expand Down Expand Up @@ -179,7 +179,7 @@ def with_log_block_sizes(
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
block_sizes = prod([2**k for k in self.log_block_sizes])
data_size = prod(self.data_shape)
n_toffoli = ceil(data_size / block_sizes) + block_sizes
n_toffoli = ceil(data_size / block_sizes) + block_sizes - 4 + self.num_controls
return {Toffoli(): n_toffoli}

@cached_property
Expand Down Expand Up @@ -268,7 +268,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
block_sizes = prod([2**k for k in self.log_block_sizes])
data_size = prod(self.qroam_clean.data_shape)
n_toffoli = ceil(data_size / block_sizes) + block_sizes
n_toffoli = ceil(data_size / block_sizes) + block_sizes - 4 + self.qroam_clean.num_controls
return {Toffoli(): n_toffoli}

def adjoint(self) -> 'QROAMClean':
Expand Down Expand Up @@ -299,6 +299,9 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
return Circle()
raise ValueError(f'Unknown register name {name}')

def __str__(self):
return 'QROAMCleanAdjoint'


@attrs.frozen
class QROAMClean(SelectSwapQROM):
Expand Down Expand Up @@ -450,6 +453,15 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
ret[swz] += 1
return ret

def my_static_costs(self, cost_key: "CostKey"):
from qualtran.resource_counting import get_cost_value, QubitCount

if isinstance(cost_key, QubitCount):
qrom_costs = get_cost_value(self.qrom_bloq, QubitCount())
return qrom_costs + sum(self.log_block_sizes)

return NotImplemented

def _build_composite_bloq_with_swz_clean(
self,
bb: 'BloqBuilder',
Expand Down
21 changes: 18 additions & 3 deletions qualtran/bloqs/data_loading/qroam_clean_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,29 @@
QROAMClean,
QROAMCleanAdjointWrapper,
)
from qualtran.symbolics import ceil
from qualtran.resource_counting import get_cost_value, QubitCount
from qualtran.symbolics import ceil, log2


def test_bloq_examples(bloq_autotester):
bloq_autotester(_qroam_clean_multi_data)
bloq_autotester(_qroam_clean_multi_dim)


def test_qroam_clean_qubit_counts():
bloq = _qroam_clean_multi_data.make()
assert get_cost_value(bloq, QubitCount()) == get_cost_value(bloq.decompose_bloq(), QubitCount())
bloq = _qroam_clean_multi_dim.make()
assert get_cost_value(bloq, QubitCount()) == get_cost_value(bloq.decompose_bloq(), QubitCount())
# Symbolic
N, b, k = sympy.symbols('N b k', positive=True, integer=True)
bloq = QROAMClean.build_from_bitsize((N,), (b,), log_block_sizes=(k,))
K = 2**k
# log(N) - k ancilla are required for the nested unary iteration.
expected_qubits = K * b + 2 * ceil(log2(N)) - k - 1
assert sympy.simplify(get_cost_value(bloq, QubitCount()) - expected_qubits) == 0


def test_t_complexity_1d_data_symbolic():
# 1D data, 1 dataset
N, b, k = sympy.symbols('N b k')
Expand All @@ -42,7 +57,7 @@ def test_t_complexity_1d_data_symbolic():
inv_k = sympy.symbols('kinv')
inv_K = 2**inv_k
bloq_inv = bloq_inv.with_log_block_sizes(log_block_sizes=(inv_k,))
expected_toffoli_inv = ceil(N / inv_K) + inv_K
expected_toffoli_inv = ceil(N / inv_K) + inv_K - 4
assert bloq_inv.t_complexity().t == 4 * expected_toffoli_inv


Expand All @@ -58,7 +73,7 @@ def test_t_complexity_2d_data_symbolic():
inv_k1, inv_k2 = sympy.symbols('kinv1, kinv2')
inv_K1, inv_K2 = 2**inv_k1, 2**inv_k2
bloq_inv = bloq_inv.with_log_block_sizes(log_block_sizes=(inv_k1, inv_k2))
expected_toffoli_inv = ceil(N1 * N2 / (inv_K1 * inv_K2)) + inv_K1 * inv_K2
expected_toffoli_inv = ceil(N1 * N2 / (inv_K1 * inv_K2)) + inv_K1 * inv_K2 - 4
assert bloq_inv.t_complexity().t == 4 * expected_toffoli_inv


Expand Down
13 changes: 12 additions & 1 deletion qualtran/bloqs/data_loading/qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from qualtran.symbolics import prod, SymbolicInt

if TYPE_CHECKING:
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, BloqCountT, CostKey, SympySymbolAllocator


def _to_tuple(x: Iterable[NDArray]) -> Sequence[NDArray]:
Expand Down Expand Up @@ -184,6 +184,17 @@ def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo:

return _wire_symbol_to_cirq_diagram_info(self, args)

def my_static_costs(self, cost_key: "CostKey"):
from qualtran.resource_counting import QubitCount

if isinstance(cost_key, QubitCount):
return self.signature.n_qubits() + sum(self.selection_bitsizes) - 1 + self.num_controls

return NotImplemented

def __str__(self):
return f'QROM({self.data_shape}, {self.target_shapes}, {self.target_bitsizes})'

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
if reg is None:
return Text('QROM')
Expand Down
18 changes: 17 additions & 1 deletion qualtran/bloqs/data_loading/qrom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from qualtran.bloqs.data_loading.qrom import _qrom_multi_data, _qrom_multi_dim, _qrom_small, QROM
from qualtran.cirq_interop.t_complexity_protocol import t_complexity
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim, GateHelper
from qualtran.resource_counting import get_cost_value, QECGatesCost
from qualtran.resource_counting import get_cost_value, QECGatesCost, QubitCount
from qualtran.symbolics import ceil, log2


def test_qrom_small(bloq_autotester):
Expand All @@ -40,6 +41,21 @@ def test_qrom_multi_dim(bloq_autotester):
bloq_autotester(_qrom_multi_dim)


def test_qrom_qubit_counts():
bloq = _qrom_small.make()
assert get_cost_value(bloq, QubitCount()) == get_cost_value(bloq.decompose_bloq(), QubitCount())
bloq = _qrom_multi_data.make()
assert get_cost_value(bloq, QubitCount()) == get_cost_value(bloq.decompose_bloq(), QubitCount())
bloq = _qrom_multi_dim.make()
assert get_cost_value(bloq, QubitCount()) == get_cost_value(bloq.decompose_bloq(), QubitCount())
# Symbolic
N, b, c = sympy.symbols('N b c', positive=True, integer=True)
bloq = QROM.build_from_bitsize((N,), (b,), num_controls=c)
# log(N) ancilla are required for the ancilla used in unary iteration.
expected_qubits = 2 * ceil(log2(N)) + b + 2 * c - 1
assert sympy.simplify(get_cost_value(bloq, QubitCount()) - expected_qubits) == 0


@pytest.mark.slow
@pytest.mark.parametrize(
"data,num_controls",
Expand Down
8 changes: 4 additions & 4 deletions qualtran/bloqs/data_loading/select_swap_qrom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@
},
"outputs": [],
"source": [
"N, b, k, c = sympy.symbols('N b k c')\n",
"N, b, k, c = sympy.symbols('N b k c', positive=True, integers=True)\n",
"qroam_symb_dirty_1d = SelectSwapQROM.build_from_bitsize(\n",
" (N,), (b,), log_block_sizes=(k,), num_controls=c\n",
")"
Expand All @@ -256,7 +256,7 @@
},
"outputs": [],
"source": [
"N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c')\n",
"N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c', positive=True, integers=True)\n",
"log_block_sizes = (k1, k2)\n",
"qroam_symb_dirty_2d = SelectSwapQROM.build_from_bitsize(\n",
" (N, M), (b1, b2), log_block_sizes=log_block_sizes, num_controls=c\n",
Expand All @@ -272,7 +272,7 @@
},
"outputs": [],
"source": [
"N, b, k, c = sympy.symbols('N b k c')\n",
"N, b, k, c = sympy.symbols('N b k c', positive=True, integers=True)\n",
"qroam_symb_clean_1d = SelectSwapQROM.build_from_bitsize(\n",
" (N,), (b,), log_block_sizes=(k,), num_controls=c, use_dirty_ancilla=False\n",
")"
Expand All @@ -287,7 +287,7 @@
},
"outputs": [],
"source": [
"N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c')\n",
"N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c', positive=True, integers=True)\n",
"log_block_sizes = (k1, k2)\n",
"qroam_symb_clean_2d = SelectSwapQROM.build_from_bitsize(\n",
" (N, M), (b1, b2), log_block_sizes=log_block_sizes, num_controls=c, use_dirty_ancilla=False\n",
Expand Down
34 changes: 27 additions & 7 deletions qualtran/bloqs/data_loading/select_swap_qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@
import sympy
from numpy.typing import ArrayLike

from qualtran import bloq_example, BloqDocSpec, BQUInt, GateWithRegisters, Register, Signature
from qualtran import (
bloq_example,
BloqDocSpec,
BQUInt,
DecomposeTypeError,
GateWithRegisters,
Register,
Signature,
)
from qualtran.bloqs.arithmetic.bitwise import Xor
from qualtran.bloqs.bookkeeping import Partition
from qualtran.bloqs.data_loading.qrom import QROM
Expand All @@ -33,7 +41,7 @@

if TYPE_CHECKING:
from qualtran import Bloq, BloqBuilder, QDType, SoquetT
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, CostKey, SympySymbolAllocator

SelSwapQROM_T = TypeVar('SelSwapQROM_T', bound='SelectSwapQROM')

Expand Down Expand Up @@ -408,7 +416,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
target = [soqs.pop(reg.name) for reg in self.target_registers]
# Allocate intermediate clean/dirty ancilla for the underlying QROM call.
if is_symbolic(*self.block_sizes):
raise ValueError(
raise DecomposeTypeError(
f"Cannot decompose SelectSwapQROM bloq with symbolic block sizes. Found {self.block_sizes=}"
)
block_sizes = cast(Tuple[int, ...], self.block_sizes)
Expand Down Expand Up @@ -448,6 +456,18 @@ def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo:

return _wire_symbol_to_cirq_diagram_info(self, args)

def my_static_costs(self, cost_key: "CostKey"):
from qualtran.resource_counting import get_cost_value, QubitCount

if isinstance(cost_key, QubitCount):
return (
get_cost_value(self.qrom_bloq, QubitCount())
+ sum(self.log_block_sizes)
+ sum(self.target_bitsizes)
)

return NotImplemented

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
if reg is None:
return Text('QROAM')
Expand Down Expand Up @@ -482,7 +502,7 @@ def _qroam_multi_dim() -> SelectSwapQROM:

@bloq_example
def _qroam_symb_dirty_1d() -> SelectSwapQROM:
N, b, k, c = sympy.symbols('N b k c')
N, b, k, c = sympy.symbols('N b k c', positive=True, integers=True)
qroam_symb_dirty_1d = SelectSwapQROM.build_from_bitsize(
(N,), (b,), log_block_sizes=(k,), num_controls=c
)
Expand All @@ -491,7 +511,7 @@ def _qroam_symb_dirty_1d() -> SelectSwapQROM:

@bloq_example
def _qroam_symb_dirty_2d() -> SelectSwapQROM:
N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c')
N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c', positive=True, integers=True)
log_block_sizes = (k1, k2)
qroam_symb_dirty_2d = SelectSwapQROM.build_from_bitsize(
(N, M), (b1, b2), log_block_sizes=log_block_sizes, num_controls=c
Expand All @@ -501,7 +521,7 @@ def _qroam_symb_dirty_2d() -> SelectSwapQROM:

@bloq_example
def _qroam_symb_clean_1d() -> SelectSwapQROM:
N, b, k, c = sympy.symbols('N b k c')
N, b, k, c = sympy.symbols('N b k c', positive=True, integers=True)
qroam_symb_clean_1d = SelectSwapQROM.build_from_bitsize(
(N,), (b,), log_block_sizes=(k,), num_controls=c, use_dirty_ancilla=False
)
Expand All @@ -510,7 +530,7 @@ def _qroam_symb_clean_1d() -> SelectSwapQROM:

@bloq_example
def _qroam_symb_clean_2d() -> SelectSwapQROM:
N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c')
N, M, b1, b2, k1, k2, c = sympy.symbols('N M b1 b2 k1 k2 c', positive=True, integers=True)
log_block_sizes = (k1, k2)
qroam_symb_clean_2d = SelectSwapQROM.build_from_bitsize(
(N, M), (b1, b2), log_block_sizes=log_block_sizes, num_controls=c, use_dirty_ancilla=False
Expand Down
18 changes: 17 additions & 1 deletion qualtran/bloqs/data_loading/select_swap_qrom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cirq
import numpy as np
import pytest
import sympy

from qualtran._infra.data_types import QUInt
from qualtran._infra.gate_with_registers import get_named_qubits, split_qubits
Expand All @@ -27,7 +28,8 @@
)
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost, QubitCount
from qualtran.symbolics import ceil, log2
from qualtran.testing import assert_valid_bloq_decomposition


Expand Down Expand Up @@ -192,6 +194,20 @@ def test_qroam_t_complexity():
assert qroam.t_complexity() == TComplexity(t=192, clifford=1082)


def test_selswap_qubit_counts():
bloq = _qroam_multi_data.make()
assert get_cost_value(bloq, QubitCount()) == get_cost_value(bloq.decompose_bloq(), QubitCount())
bloq = _qroam_multi_dim.make()
assert get_cost_value(bloq, QubitCount()) == get_cost_value(bloq.decompose_bloq(), QubitCount())
# Symbolic
N, b, k = sympy.symbols('N b k', positive=True, integer=True)
bloq = SelectSwapQROM.build_from_bitsize((N,), (b,), log_block_sizes=(k,))
K = 2**k
# log(N) - k ancilla are required for the nested unary iteration.
expected_qubits = K * b + b + 2 * ceil(log2(N)) - k - 1
assert sympy.simplify(get_cost_value(bloq, QubitCount()) - expected_qubits) == 0


def test_qroam_many_registers():
# Test > 10 registers which resulted in https://github.com/quantumlib/Qualtran/issues/556
target_bitsizes = (3,) * 10 + (1,) * 2 + (3,)
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/mod_arithmetic/mod_multiplication_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ def test_dirtyoutofplacemontgomerymodmul_symbolic_cost(uncompute):

# Litinski 2023 https://arxiv.org/abs/2306.08585
# Figure/Table 8. Lists modular multiplication as 2.25n^2+9n toffoli.
# The following formula is 2.25n^2+8.25n-1 written with rationals because sympy comparison fails with floats.
# The following formula is 2.25n^2+7.25n-1 written with rationals because sympy comparison fails with floats.
assert isinstance(cost['n_ccz'], sympy.Expr)
assert (
cost['n_ccz'].subs(m, n / 4).expand()
== sympy.Rational(9, 4) * n**2 + sympy.Rational(33, 4) * n - 1
== sympy.Rational(9, 4) * n**2 + sympy.Rational(29, 4) * n - 1
)


Expand Down

0 comments on commit 21ef526

Please sign in to comment.