From 8a283071b9bdb27e42f20a2698634dbb0e600f8c Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 9 Sep 2024 14:09:47 -0700 Subject: [PATCH] Fix parameter values for `ModExp` examples and add post init assertions for better error messages (#1399) * Fix parameter values for ModExp examples and add post init assertions for better error messages * Fix mypy --- qualtran/bloqs/factoring/mod_exp.ipynb | 4 ++-- qualtran/bloqs/factoring/mod_exp.py | 23 ++++++++++--------- qualtran/bloqs/factoring/mod_exp_test.py | 6 +++-- .../mod_arithmetic/mod_multiplication.py | 9 ++++---- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/qualtran/bloqs/factoring/mod_exp.ipynb b/qualtran/bloqs/factoring/mod_exp.ipynb index 88f76f1ed..77c87aa15 100644 --- a/qualtran/bloqs/factoring/mod_exp.ipynb +++ b/qualtran/bloqs/factoring/mod_exp.ipynb @@ -89,7 +89,7 @@ }, "outputs": [], "source": [ - "modexp_small = ModExp(base=3, mod=15, exp_bitsize=3, x_bitsize=2048)" + "modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048)" ] }, { @@ -101,7 +101,7 @@ }, "outputs": [], "source": [ - "modexp = ModExp.make_for_shor(big_n=15 * 17, g=9)" + "modexp = ModExp.make_for_shor(big_n=13 * 17, g=9)" ] }, { diff --git a/qualtran/bloqs/factoring/mod_exp.py b/qualtran/bloqs/factoring/mod_exp.py index cabc6e0bd..129711797 100644 --- a/qualtran/bloqs/factoring/mod_exp.py +++ b/qualtran/bloqs/factoring/mod_exp.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import random from functools import cached_property -from typing import Dict, Optional, Tuple, Union +from typing import cast, Dict, Optional, Tuple, Union import attrs -import numpy as np import sympy from attrs import frozen @@ -38,6 +38,7 @@ from qualtran.drawing import Text, WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_split_join +from qualtran.symbolics import is_symbolic @frozen @@ -70,9 +71,9 @@ class ModExp(Bloq): exp_bitsize: Union[int, sympy.Expr] x_bitsize: Union[int, sympy.Expr] - def __post_init__(self): - if isinstance(self.base, int) and isinstance(self.mod, int): - assert math.gcd(self.base, self.mod) == 1 + def __attrs_post_init__(self): + if not is_symbolic(self.base, self.mod): + assert math.gcd(cast(int, self.base), cast(int, self.mod)) == 1 @cached_property def signature(self) -> 'Signature': @@ -95,9 +96,9 @@ def make_for_shor(cls, big_n: int, g: Optional[int] = None): if isinstance(big_n, sympy.Expr): little_n = sympy.ceiling(sympy.log(big_n, 2)) else: - little_n = int(np.ceil(np.log2(big_n))) + little_n = int(math.ceil(math.log2(big_n))) if g is None: - g = np.random.randint(big_n) + g = random.randint(2, big_n) return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n) def _CtrlModMul(self, k: Union[int, sympy.Expr]): @@ -111,10 +112,10 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st exponent = bb.split(exponent) # https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method - base = self.base + base = self.base % self.mod for j in range(self.exp_bitsize - 1, 0 - 1, -1): exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x) - base = base * base % self.mod + base = (base * base) % self.mod return {'exponent': bb.join(exponent, dtype=QUInt(self.exp_bitsize)), 'x': x} @@ -145,13 +146,13 @@ def _generalize_k(b: Bloq) -> Optional[Bloq]: @bloq_example(generalizer=(ignore_split_join, _generalize_k)) def _modexp_small() -> ModExp: - modexp_small = ModExp(base=3, mod=15, exp_bitsize=3, x_bitsize=2048) + modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048) return modexp_small @bloq_example(generalizer=(ignore_split_join, _generalize_k)) def _modexp() -> ModExp: - modexp = ModExp.make_for_shor(big_n=15 * 17, g=9) + modexp = ModExp.make_for_shor(big_n=13 * 17, g=9) return modexp diff --git a/qualtran/bloqs/factoring/mod_exp_test.py b/qualtran/bloqs/factoring/mod_exp_test.py index 0c978b2a7..d331d39f0 100644 --- a/qualtran/bloqs/factoring/mod_exp_test.py +++ b/qualtran/bloqs/factoring/mod_exp_test.py @@ -44,6 +44,8 @@ def test_mod_exp_consistent_classical(): # Choose a base smaller than mod. base = rs.randint(1, mod) + while np.gcd(base, mod) != 1: + base = rs.randint(1, mod) bloq = ModExp(base=base, exp_bitsize=ne, x_bitsize=n, mod=mod) ret1 = bloq.call_classically(exponent=exponent) @@ -65,7 +67,7 @@ def test_modexp_symb_manual(): def test_mod_exp_consistent_counts(): - bloq = ModExp(base=8, exp_bitsize=3, x_bitsize=10, mod=50) + bloq = ModExp(base=11, exp_bitsize=3, x_bitsize=10, mod=50) counts1 = bloq.bloq_counts() @@ -86,7 +88,7 @@ def generalize(b: Bloq) -> Optional[Bloq]: def test_mod_exp_t_complexity(): - bloq = ModExp(base=8, exp_bitsize=3, x_bitsize=10, mod=50) + bloq = ModExp(base=11, exp_bitsize=3, x_bitsize=10, mod=50) tcomp = bloq.t_complexity() assert tcomp.t > 0 diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index cdf389a52..536cfae6b 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import numbers from functools import cached_property -from typing import Dict, Optional, Tuple, Union +from typing import cast, Dict, Optional, Tuple, Union import attrs import numpy as np @@ -180,12 +181,10 @@ class CModMulK(Bloq): mod: Union[int, sympy.Expr] def __attrs_post_init__(self): - if isinstance(self.k, sympy.Expr): + if is_symbolic(self.k, self.mod): return - if isinstance(self.mod, sympy.Expr): - return - assert 0 < self.k < self.mod + assert math.gcd(cast(int, self.k), cast(int, self.mod)) == 1 @cached_property def signature(self) -> 'Signature':