Skip to content

Commit

Permalink
Bugfix in QROM.build_call_graph for symbolic case (#1321)
Browse files Browse the repository at this point in the history
Bugfix in QROM.build_call_graph for symbolic case
  • Loading branch information
tanujkhattar authored Aug 21, 2024
1 parent 0559dfd commit 96cf463
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
4 changes: 3 additions & 1 deletion qualtran/bloqs/data_loading/qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
if self.has_data():
return super().build_call_graph(ssa=ssa)
n_and = prod(self.data_shape) - 2 + self.num_controls
n_cnot = prod(self.target_bitsizes) * prod(self.data_shape)
n_cnot = prod(
bitsize * prod(sh) for bitsize, sh in zip(self.target_bitsizes, self.target_shapes)
) * prod(self.data_shape)
return {(And(), n_and), (And().adjoint(), n_and), (CNOT(), n_cnot)}

def adjoint(self) -> 'QROM':
Expand Down
16 changes: 13 additions & 3 deletions qualtran/bloqs/data_loading/qrom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,21 @@ def test_t_complexity(data):
def test_t_complexity_symbolic():
N, M = sympy.symbols('N M')
b1, b2 = sympy.symbols('b1 b2')
t1, t2 = sympy.symbols('t1 t2')
c = sympy.Symbol('c')
qrom_symb = QROM.build_from_bitsize((N, M), (b1, b2), num_controls=c)
qrom_symb = QROM.build_from_bitsize(
(N, M), (b1, b2), target_shapes=((t1,), (t2,)), num_controls=c
)
t_counts = qrom_symb.t_complexity()
assert t_counts.t == 4 * (N * M - 2 + c)
assert t_counts
n_and = N * M - 2 + c
assert t_counts.t == 4 * n_and
from qualtran.bloqs.mcmt import And

assert (
t_counts.clifford
== N * M * b1 * b2 * t1 * t2
+ (And().t_complexity().clifford + And().adjoint().t_complexity().clifford) * n_and
)


def _assert_qrom_has_diagram(qrom: QROM, expected_diagram: str):
Expand Down
6 changes: 6 additions & 0 deletions qualtran/bloqs/swap_network/swap_with_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ def on_classical_vals(
def adjoint(self) -> 'SwapWithZero':
return attrs.evolve(self, uncompute=not self.uncompute)

def pretty_name(self) -> str:
return 'SwapWithZero†' if self.uncompute else 'SwapWithZero'

def __str__(self) -> str:
return 'SwapWithZero†' if self.uncompute else 'SwapWithZero'


@bloq_example(generalizer=ignore_split_join)
def _swz() -> SwapWithZero:
Expand Down

0 comments on commit 96cf463

Please sign in to comment.