Skip to content

Commit

Permalink
Swap network bloq counts (#382)
Browse files Browse the repository at this point in the history
* Add swap network T complexity and bloq counts.

* Add tests for bloq counts.

* Fix formatting.

* Determine number of swaps more neatly.
  • Loading branch information
fdmalone authored Oct 13, 2023
1 parent a997abe commit a27f271
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
37 changes: 36 additions & 1 deletion qualtran/bloqs/swap_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,25 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, Tuple, TYPE_CHECKING, Union
from typing import Dict, Optional, Set, Tuple, TYPE_CHECKING, Union

import cirq
import cirq_ft
import numpy as np
import sympy
from attrs import frozen
from cirq_ft import MultiTargetCSwapApprox
from numpy.typing import NDArray

from qualtran import Bloq, BloqBuilder, Register, Signature, Soquet, SoquetT
from qualtran.bloqs.basic_gates import TGate
from qualtran.bloqs.util_bloqs import ArbitraryClifford
from qualtran.cirq_interop import decompose_from_cirq_op

if TYPE_CHECKING:
from qualtran import CompositeBloq
from qualtran.cirq_interop import CirqQuregT
from qualtran.resource_counting import SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -84,6 +90,27 @@ def on_classical_vals(
def short_name(self) -> str:
return '~swap'

def t_complexity(self) -> cirq_ft.TComplexity:
"""TComplexity as explained in Appendix B.2.c of https://arxiv.org/abs/1812.00954"""
n = self.bitsize
# 4 * n: G gates, each wth 1 T and 4 single qubit cliffords
# 4 * n: CNOTs
# 2 * n - 1: CNOTs from 1 MultiTargetCNOT
return cirq_ft.TComplexity(t=4 * n, clifford=22 * n - 1)

def bloq_counts(
self, ssa: Optional['SympySymbolAllocator'] = None
) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]:
n = self.bitsize
# 4 * n: G gates, each wth 1 T and 4 single qubit cliffords
# 4 * n: CNOTs
# 2 * n - 1: CNOTs from 1 MultiTargetCNOT
return {
(4 * n, TGate()),
(16 * n, ArbitraryClifford(n=1)),
(6 * n - 1, ArbitraryClifford(n=2)),
}


@frozen
class SwapWithZero(Bloq):
Expand Down Expand Up @@ -125,3 +152,11 @@ def build_composite_bloq(
)

return {'selection': bb.join(selection), 'targets': targets}

def bloq_counts(
self, ssa: Optional['SympySymbolAllocator'] = None
) -> Set[Tuple[Union[int, sympy.Expr], Bloq]]:
num_swaps = np.floor(
sum([self.n_target_registers / (2 ** (j + 1)) for j in range(self.selection_bitsize)])
)
return {(num_swaps, CSwapApprox(self.target_bitsize))}
39 changes: 38 additions & 1 deletion qualtran/bloqs/swap_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
# limitations under the License.

import random
from typing import Set, Tuple

import cirq
import cirq_ft
import cirq_ft.infra.testing as cq_testing
import numpy as np
import pytest

from qualtran import BloqBuilder
from qualtran import Bloq, BloqBuilder
from qualtran.bloqs.basic_gates import TGate
from qualtran.bloqs.basic_gates.z_basis import IntState
from qualtran.bloqs.swap_network import CSwapApprox, SwapWithZero
from qualtran.bloqs.util_bloqs import ArbitraryClifford
from qualtran.simulation.quimb_sim import flatten_for_tensor_contraction
from qualtran.testing import assert_valid_bloq_decomposition, execute_notebook

Expand Down Expand Up @@ -101,6 +104,13 @@ def test_swap_with_zero_classically():
print(sel, out_data)


def get_t_count_and_clifford(bc: Set[Tuple[int, Bloq]]) -> Tuple[int, int]:
"""Get the t count and clifford cost from bloq count."""
cliff_cost = sum([x[0] for x in bc if isinstance(x[1], ArbitraryClifford)])
t_cost = sum([x[0] for x in bc if isinstance(x[1], TGate)])
return t_cost, cliff_cost


@pytest.mark.parametrize("n", [*range(1, 6)])
def test_t_complexity(n):
g = cirq_ft.MultiTargetCSwap(n)
Expand All @@ -110,6 +120,33 @@ def test_t_complexity(n):
cq_testing.assert_decompose_is_consistent_with_t_complexity(g)


@pytest.mark.parametrize("n", [*range(2, 6)])
def test_cswap_approx_bloq_counts(n):
csa = CSwapApprox(n)
bc = csa.bloq_counts()
t_cost, cliff_cost = get_t_count_and_clifford(bc)
assert csa.t_complexity().clifford == cliff_cost
assert csa.t_complexity().t == t_cost


@pytest.mark.parametrize(
"selection_bitsize, target_bitsize, n_target_registers, want",
[
[3, 5, 1, cirq_ft.TComplexity(t=0, clifford=0)],
[2, 2, 3, cirq_ft.TComplexity(t=16, clifford=86)],
[2, 3, 4, cirq_ft.TComplexity(t=36, clifford=195)],
[3, 2, 5, cirq_ft.TComplexity(t=32, clifford=172)],
[4, 1, 10, cirq_ft.TComplexity(t=36, clifford=189)],
],
)
def test_swap_with_zero_bloq_counts(selection_bitsize, target_bitsize, n_target_registers, want):
gate = SwapWithZero(selection_bitsize, target_bitsize, n_target_registers)
bc = list(gate.bloq_counts())[0]
t_cost, cliff_cost = get_t_count_and_clifford(bc[1].bloq_counts())
assert bc[0] * t_cost == want.t
assert bc[0] * cliff_cost == want.clifford


@pytest.mark.parametrize(
"selection_bitsize, target_bitsize, n_target_registers, want",
[
Expand Down

0 comments on commit a27f271

Please sign in to comment.