Skip to content

Commit

Permalink
Add tests for bloq counts.
Browse files Browse the repository at this point in the history
  • Loading branch information
fdmalone committed Oct 8, 2023
1 parent 0a0ef2a commit a100f2a
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion qualtran/bloqs/swap_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
# limitations under the License.

import random
from typing import Set, Tuple

import cirq
import cirq_ft
import cirq_ft.infra.testing as cq_testing
import numpy as np
import pytest

from qualtran import BloqBuilder
from qualtran import Bloq, BloqBuilder
from qualtran.bloqs.basic_gates import TGate
from qualtran.bloqs.basic_gates.z_basis import IntState
from qualtran.bloqs.swap_network import CSwapApprox, SwapWithZero
from qualtran.bloqs.util_bloqs import ArbitraryClifford
from qualtran.simulation.quimb_sim import flatten_for_tensor_contraction
from qualtran.testing import assert_valid_bloq_decomposition, execute_notebook

Expand Down Expand Up @@ -101,6 +104,13 @@ def test_swap_with_zero_classically():
print(sel, out_data)


def get_t_count_and_clifford(bc: Set[Tuple[int, Bloq]]) -> Tuple[int, int]:
"""Get the t count and clifford cost from bloq count."""
cliff_cost = sum([x[0] for x in bc if isinstance(x[1], ArbitraryClifford)])
t_cost = sum([x[0] for x in bc if isinstance(x[1], TGate)])
return t_cost, cliff_cost


@pytest.mark.parametrize("n", [*range(1, 6)])
def test_t_complexity(n):
g = cirq_ft.MultiTargetCSwap(n)
Expand All @@ -109,6 +119,31 @@ def test_t_complexity(n):
g = cirq_ft.MultiTargetCSwapApprox(n)
cq_testing.assert_decompose_is_consistent_with_t_complexity(g)

@pytest.mark.parametrize("n", [*range(2, 6)])
def test_cswap_approx_bloq_counts(n):
csa = CSwapApprox(n)
bc = csa.bloq_counts()
t_cost, cliff_cost = get_t_count_and_clifford(bc)
assert csa.t_complexity().clifford == cliff_cost
assert csa.t_complexity().t == t_cost

@pytest.mark.parametrize(
"selection_bitsize, target_bitsize, n_target_registers, want",
[
[3, 5, 1, cirq_ft.TComplexity(t=0, clifford=0)],
[2, 2, 3, cirq_ft.TComplexity(t=16, clifford=86)],
[2, 3, 4, cirq_ft.TComplexity(t=36, clifford=195)],
[3, 2, 5, cirq_ft.TComplexity(t=32, clifford=172)],
[4, 1, 10, cirq_ft.TComplexity(t=36, clifford=189)],
],
)
def test_swap_with_zero_bloq_counts(selection_bitsize, target_bitsize, n_target_registers, want):
gate = SwapWithZero(selection_bitsize, target_bitsize, n_target_registers)
bc = list(gate.bloq_counts())[0]
t_cost, cliff_cost = get_t_count_and_clifford(bc[1].bloq_counts())
assert bc[0]*t_cost == want.t
assert bc[0]*cliff_cost == want.clifford


@pytest.mark.parametrize(
"selection_bitsize, target_bitsize, n_target_registers, want",
Expand Down

0 comments on commit a100f2a

Please sign in to comment.