Skip to content

Commit

Permalink
Add a test util for classical action and refactor factoring/mod_mul (
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri authored Aug 27, 2024
1 parent f2eabb2 commit d240ed2
Show file tree
Hide file tree
Showing 21 changed files with 697 additions and 509 deletions.
8 changes: 5 additions & 3 deletions dev_tools/autogenerate-bloqs-notebooks-v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,11 @@
),
NotebookSpecV2(
title='Modular Multiplication',
module=qualtran.bloqs.factoring.mod_mul,
bloq_specs=[qualtran.bloqs.factoring.mod_mul._MODMUL_DOC],
directory=f'{SOURCE_DIR}/bloqs/factoring',
module=qualtran.bloqs.mod_arithmetic.mod_multiplication,
bloq_specs=[
qualtran.bloqs.mod_arithmetic.mod_multiplication._MOD_DBL_DOC,
qualtran.bloqs.mod_arithmetic.mod_multiplication._C_MOD_MUL_K_DOC,
],
),
NotebookSpecV2(
title='Modular Exponentiation',
Expand Down
2 changes: 1 addition & 1 deletion docs/bloqs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Bloqs Library

mod_arithmetic/mod_addition.ipynb
mod_arithmetic/mod_subtraction.ipynb
factoring/mod_mul.ipynb
mod_arithmetic/mod_multiplication.ipynb
factoring/mod_exp.ipynb
factoring/ecc/ec_add.ipynb
factoring/ecc/ecc.ipynb
Expand Down
9 changes: 8 additions & 1 deletion qualtran/bloqs/arithmetic/bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Optional, Sequence, TYPE_CHECKING
from typing import Dict, Optional, Sequence, TYPE_CHECKING

import numpy as np
import sympy
Expand All @@ -26,6 +26,7 @@
DecomposeTypeError,
QAny,
QDType,
QMontgomeryUInt,
QUInt,
Register,
Signature,
Expand Down Expand Up @@ -221,6 +222,12 @@ def wire_symbol(

return TextBox("~x")

def on_classical_vals(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
x = -x - 1
if isinstance(self.dtype, (QUInt, QMontgomeryUInt)):
x %= 2**self.dtype.bitsize
return {'x': x}


@bloq_example
def _bitwise_not() -> BitwiseNot:
Expand Down
14 changes: 13 additions & 1 deletion qualtran/bloqs/arithmetic/bitwise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import numpy as np
import pytest

from qualtran import BloqBuilder, QAny, QUInt
import qualtran.testing as qlt_testing
from qualtran import BloqBuilder, QAny, QInt, QMontgomeryUInt, QUInt
from qualtran.bloqs.arithmetic.bitwise import (
_bitwise_not,
_bitwise_not_symb,
Expand Down Expand Up @@ -172,3 +173,14 @@ def test_bitwise_not_diagram():
x3: ───~x───
''',
)


@pytest.mark.parametrize('dtype', [QUInt, QMontgomeryUInt, QInt])
@pytest.mark.parametrize('bitsize', range(2, 6))
def test_bitwisenot_classical_action(dtype, bitsize):
b = BitwiseNot(dtype(bitsize))
if dtype is QInt:
valid_range = range(-(2 ** (bitsize - 1)), 2 ** (bitsize - 1))
else:
valid_range = range(2**bitsize)
qlt_testing.assert_consistent_classical_action(b, x=valid_range)
1 change: 0 additions & 1 deletion qualtran/bloqs/factoring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
# limitations under the License.

from .mod_exp import ModExp
from .mod_mul import CtrlModMul
2 changes: 1 addition & 1 deletion qualtran/bloqs/factoring/ecc/ec_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
(CModSub(QUInt(self.n), mod=self.mod), 4),
(ModNeg(QUInt(self.n), mod=self.mod), 2),
(CModNeg(QUInt(self.n), mod=self.mod), 1),
(ModDbl(n=self.n, mod=self.mod), 2),
(ModDbl(QUInt(self.n), mod=self.mod), 2),
(ModMul(n=self.n, mod=self.mod), 10),
(ModInv(n=self.n, mod=self.mod), 4),
}
Expand Down
13 changes: 9 additions & 4 deletions qualtran/bloqs/factoring/mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import cached_property
from typing import Dict, Optional, Set, Tuple, Union

Expand All @@ -33,7 +34,7 @@
SoquetT,
)
from qualtran.bloqs.basic_gates import IntState
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.bloqs.mod_arithmetic import CModMulK
from qualtran.drawing import Text, WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting.generalizers import ignore_split_join
Expand Down Expand Up @@ -69,6 +70,10 @@ class ModExp(Bloq):
exp_bitsize: Union[int, sympy.Expr]
x_bitsize: Union[int, sympy.Expr]

def __post_init__(self):
if isinstance(self.base, int) and isinstance(self.mod, int):
assert math.gcd(self.base, self.mod) == 1

@cached_property
def signature(self) -> 'Signature':
return Signature(
Expand Down Expand Up @@ -96,8 +101,8 @@ def make_for_shor(cls, big_n: int, g: Optional[int] = None):
return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n)

def _CtrlModMul(self, k: Union[int, sympy.Expr]):
"""Helper method to return a `CtrlModMul` with attributes forwarded."""
return CtrlModMul(k=k, bitsize=self.x_bitsize, mod=self.mod)
"""Helper method to return a `CModMulK` with attributes forwarded."""
return CModMulK(QUInt(self.x_bitsize), k=k, mod=self.mod)

def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[str, 'SoquetT']:
if isinstance(self.exp_bitsize, sympy.Expr):
Expand Down Expand Up @@ -135,7 +140,7 @@ def wire_symbol(


def _generalize_k(b: Bloq) -> Optional[Bloq]:
if isinstance(b, CtrlModMul):
if isinstance(b, CModMulK):
return attrs.evolve(b, k=_K)

return b
Expand Down
19 changes: 9 additions & 10 deletions qualtran/bloqs/factoring/mod_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,25 @@
from qualtran import Bloq
from qualtran.bloqs.bookkeeping import Join, Split
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_symb, ModExp
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
from qualtran.bloqs.mod_arithmetic import CModMulK
from qualtran.drawing import Text
from qualtran.resource_counting import SympySymbolAllocator
from qualtran.testing import execute_notebook


# TODO: Fix ModExp and improve this test
def test_mod_exp_consistent_classical():
rs = np.random.RandomState(52)

# 100 random attribute choices.
for _ in range(100):
# Sample moduli in a range. Set x_bitsize=n big enough to fit.
mod = rs.randint(4, 123)
mod = 7 * 13
n = int(np.ceil(np.log2(mod)))
n = rs.randint(n, n + 10)

# Choose an exponent in a range. Set exp_bitsize=ne bit enough to fit.
exponent = rs.randint(1, 20)
ne = int(np.ceil(np.log2(exponent)))
ne = rs.randint(ne, ne + 10)
exponent = rs.randint(1, 2**n)
ne = 2 * n

# Choose a base smaller than mod.
base = rs.randint(1, mod)
Expand All @@ -59,30 +58,30 @@ def test_modexp_symb_manual():
counts = modexp.bloq_counts()
counts_by_bloq = {bloq.pretty_name(): n for bloq, n in counts.items()}
assert counts_by_bloq['|1>'] == 1
assert counts_by_bloq['CtrlModMul'] == n_e
assert counts_by_bloq['CModMulK'] == n_e

b, x = modexp.call_classically(exponent=sympy.Symbol('b'))
assert str(x) == 'Mod(g**b, N)'


def test_mod_exp_consistent_counts():
bloq = ModExp(base=8, exp_bitsize=3, x_bitsize=10, mod=50)

counts1 = bloq.bloq_counts()

ssa = SympySymbolAllocator()
my_k = ssa.new_symbol('k')

def generalize(b: Bloq) -> Optional[Bloq]:
if isinstance(b, CtrlModMul):
# Symbolic k in `CtrlModMul`.
if isinstance(b, CModMulK):
# Symbolic k in `CModMulK`.
return attrs.evolve(b, k=my_k)
if isinstance(b, (Split, Join)):
# Ignore these
return None
return b

counts2 = bloq.decompose_bloq().bloq_counts(generalizer=generalize)

assert counts1 == counts2


Expand Down
189 changes: 0 additions & 189 deletions qualtran/bloqs/factoring/mod_mul.ipynb

This file was deleted.

Loading

0 comments on commit d240ed2

Please sign in to comment.