Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup redundant code for computing t complexity #1317

Merged
merged 13 commits into from
Aug 23, 2024
19 changes: 6 additions & 13 deletions qualtran/bloqs/basic_gates/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Optional, Protocol, runtime_checkable, Tuple, Union
from typing import Optional, Tuple, Union

import attrs
import cirq
Expand All @@ -27,13 +27,6 @@
from qualtran.symbolics import SymbolicFloat


@runtime_checkable
class _HasEps(Protocol):
"""Protocol for typing `RotationBloq` base class mixin that has accuracy specified as eps."""

eps: float


@frozen
class ZPowGate(CirqGateAsBloqBase):
r"""A gate that rotates around the Z axis of the Bloch sphere.
Expand Down Expand Up @@ -115,7 +108,7 @@ def _z_pow() -> ZPowGate:
class CZPowGate(CirqGateAsBloqBase):
exponent: float = 1.0
global_shift: float = 0.0
eps: float = 1e-11
eps: SymbolicFloat = 1e-11

def decompose_bloq(self) -> 'CompositeBloq':
raise DecomposeTypeError(f"{self} is atomic")
Expand Down Expand Up @@ -183,7 +176,7 @@ class XPowGate(CirqGateAsBloqBase):
"""
exponent: Union[sympy.Expr, float] = 1.0
global_shift: float = 0.0
eps: float = 1e-11
eps: SymbolicFloat = 1e-11

def decompose_bloq(self) -> 'CompositeBloq':
raise DecomposeTypeError(f"{self} is atomic")
Expand Down Expand Up @@ -253,7 +246,7 @@ class YPowGate(CirqGateAsBloqBase):
"""
exponent: Union[sympy.Expr, float] = 1.0
global_shift: float = 0.0
eps: float = 1e-11
eps: SymbolicFloat = 1e-11

def decompose_bloq(self) -> 'CompositeBloq':
raise DecomposeTypeError(f"{self} is atomic")
Expand Down Expand Up @@ -321,7 +314,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
@frozen
class Rx(CirqGateAsBloqBase):
angle: Union[sympy.Expr, float]
eps: float = 1e-11
eps: SymbolicFloat = 1e-11

def decompose_bloq(self) -> 'CompositeBloq':
raise DecomposeTypeError(f"{self} is atomic")
Expand All @@ -342,7 +335,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
@frozen
class Ry(CirqGateAsBloqBase):
angle: Union[sympy.Expr, float]
eps: float = 1e-11
eps: SymbolicFloat = 1e-11

def decompose_bloq(self) -> 'CompositeBloq':
raise DecomposeTypeError(f"{self} is atomic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
"import numpy as np\n",
"import sympy\n",
"\n",
"from qualtran.resource_counting.t_counts_from_sigma import _get_all_rotation_types\n",
"from qualtran.resource_counting.classify_bloqs import bloq_is_rotation\n",
"from qualtran.resource_counting.generalizers import PHI\n",
"from qualtran.cirq_interop.t_complexity_protocol import TComplexity\n",
"from qualtran import Bloq\n",
Expand All @@ -130,11 +130,10 @@
"\n",
"\n",
"def t_and_rot_counts_from_sigma(sigma: Dict['Bloq', Union[int, 'sympy.Expr']]) -> Tuple[int, int]:\n",
" rotation_types = _get_all_rotation_types()\n",
" ret = sigma.get(TGate(), 0)\n",
" n_rot = 0\n",
" for bloq, counts in sigma.items():\n",
" if isinstance(bloq, rotation_types):\n",
" if bloq_is_rotation(bloq):\n",
" n_rot += counts\n",
" return ret, n_rot\n",
"\n",
Expand Down
7 changes: 4 additions & 3 deletions qualtran/bloqs/data_loading/select_swap_qrom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost
from qualtran.testing import assert_valid_bloq_decomposition


Expand Down Expand Up @@ -187,8 +187,9 @@ def test_qroam_t_complexity():
qroam = SelectSwapQROM.build_from_data(
[1, 2, 3, 4, 5, 6, 7, 8], target_bitsizes=(4,), log_block_sizes=(2,)
)
_, sigma = qroam.call_graph()
assert t_counts_from_sigma(sigma) == qroam.t_complexity().t == 192
gate_counts = get_cost_value(qroam, QECGatesCost())
assert gate_counts == GateCounts(t=192, clifford=1082)
assert qroam.t_complexity() == TComplexity(t=192, clifford=1082)


def test_qroam_many_registers():
Expand Down
37 changes: 9 additions & 28 deletions qualtran/resource_counting/t_counts_from_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import sys
from typing import cast, Mapping, Optional, Tuple, Type, TYPE_CHECKING
from typing import Mapping

import cirq

from qualtran import Bloq, Controlled
from qualtran.symbolics import ceil, SymbolicInt

if TYPE_CHECKING:
from qualtran import Bloq
from qualtran.bloqs.basic_gates.rotation import _HasEps


def _get_all_rotation_types() -> Tuple[Type['_HasEps'], ...]:
"""Returns all classes defined in bloqs.basic_gates which have an attribute `eps`."""
from qualtran.bloqs.basic_gates import GlobalPhase
from qualtran.bloqs.basic_gates.rotation import _HasEps

bloqs_to_exclude = [GlobalPhase]

return tuple(
cast(Type['_HasEps'], v) # Can't use `issubclass` with protocols with attributes.
for (_, v) in inspect.getmembers(sys.modules['qualtran.bloqs.basic_gates'], inspect.isclass)
if isinstance(v, _HasEps) and v not in bloqs_to_exclude
)


def t_counts_from_sigma(
sigma: Mapping['Bloq', SymbolicInt],
rotation_types: Optional[Tuple[Type['_HasEps'], ...]] = None,
) -> SymbolicInt:
def t_counts_from_sigma(sigma: Mapping['Bloq', SymbolicInt]) -> SymbolicInt:
"""Aggregates T-counts from a sigma dictionary by summing T-costs for all rotation bloqs."""
from qualtran.bloqs.basic_gates import TGate
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.resource_counting.classify_bloqs import bloq_is_rotation

if rotation_types is None:
rotation_types = _get_all_rotation_types()
ret = sigma.get(TGate(), 0) + sigma.get(TGate().adjoint(), 0)
for bloq, counts in sigma.items():
if isinstance(bloq, rotation_types) and not cirq.has_stabilizer_effect(bloq):
if bloq_is_rotation(bloq) and not cirq.has_stabilizer_effect(bloq):
if isinstance(bloq, Controlled):
# TODO native controlled rotation bloqs missing (CRz, CRy etc.)
anurudhp marked this conversation as resolved.
Show resolved Hide resolved
bloq = bloq.subbloq
assert hasattr(bloq, 'eps')
ret += ceil(TComplexity.rotation_cost(bloq.eps)) * counts
return ret
76 changes: 32 additions & 44 deletions qualtran/resource_counting/t_counts_from_sigma_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,66 +11,54 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import sympy

from qualtran import Bloq
from qualtran.bloqs.basic_gates import (
CZPowGate,
Rx,
Ry,
Rz,
SU2RotationGate,
TGate,
Toffoli,
XPowGate,
YPowGate,
ZPowGate,
)
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.resource_counting.t_counts_from_sigma import (
_get_all_rotation_types,
t_counts_from_sigma,
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma
from qualtran.symbolics import SymbolicFloat

EPS: SymbolicFloat = sympy.Symbol("eps")


@pytest.mark.parametrize(
("bloq", "t_count"), [(TGate(), 1), pytest.param(Toffoli(), 4, marks=pytest.mark.xfail)]
)
def test_t_counts_from_sigma_known(bloq: Bloq, t_count: int):
assert t_counts_from_sigma({bloq: 1}) == t_count


def test_all_rotation_types():
assert set(_get_all_rotation_types()) == {
CZPowGate,
Rx,
Ry,
Rz,
XPowGate,
YPowGate,
ZPowGate,
SU2RotationGate,
}
@pytest.mark.parametrize(
"bloq",
[
ZPowGate(0.01, eps=EPS),
Rz(0.01, eps=EPS),
Rx(0.01, eps=EPS),
XPowGate(0.01, eps=EPS),
Ry(0.01, eps=EPS),
YPowGate(0.01, eps=EPS),
CZPowGate(0.01, eps=EPS),
],
)
def test_t_counts_from_sigma_for_rotation_with_eps(bloq: Bloq):
expected_t_count = TComplexity.rotation_cost(EPS)
assert t_counts_from_sigma({bloq: 1}) == expected_t_count


def test_t_counts_from_sigma():
z_eps1, z_eps2, x_eps, y_eps, cz_eps = sympy.symbols('z_eps1, z_eps2, x_eps, y_eps, cz_eps')
sigma = {
ZPowGate(eps=z_eps1): 1,
ZPowGate(eps=z_eps2): 2,
ZPowGate(0.01, eps=z_eps1): 1,
ZPowGate(0.01, eps=z_eps2): 2,
Rz(0.01, eps=z_eps2): 3,
Rx(0.01, eps=x_eps): 4,
XPowGate(eps=x_eps): 5,
XPowGate(0.01, eps=x_eps): 5,
Ry(0.01, eps=y_eps): 6,
YPowGate(eps=y_eps): 7,
YPowGate(0.01, eps=y_eps): 7,
CZPowGate(eps=cz_eps): 20,
CZPowGate(0.01, eps=cz_eps): 20,
TGate(): 100,
Toffoli(): 200,
}
expected_t_count = (
+100
+ 1 * TComplexity.rotation_cost(z_eps1)
+ 5 * TComplexity.rotation_cost(z_eps2)
+ 9 * TComplexity.rotation_cost(x_eps)
+ 13 * TComplexity.rotation_cost(y_eps)
+ 20 * TComplexity.rotation_cost(cz_eps)
)
assert t_counts_from_sigma(sigma) == expected_t_count
anurudhp marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
"bloq", [ZPowGate(eps=EPS), XPowGate(eps=EPS), YPowGate(eps=EPS), CZPowGate(eps=EPS)]
)
def test_t_counts_from_sigma_zero(bloq: Bloq):
assert t_counts_from_sigma({bloq: 1}) == 0
Loading