Skip to content

Commit

Permalink
Fix parameter values for ModExp examples and add post init assertio…
Browse files Browse the repository at this point in the history
…ns for better error messages (#1399)

* Fix parameter values for ModExp examples and add post init assertions for better error messages

* Fix mypy
  • Loading branch information
tanujkhattar authored Sep 9, 2024
1 parent c98612f commit 8a28307
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
4 changes: 2 additions & 2 deletions qualtran/bloqs/factoring/mod_exp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
},
"outputs": [],
"source": [
"modexp_small = ModExp(base=3, mod=15, exp_bitsize=3, x_bitsize=2048)"
"modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048)"
]
},
{
Expand All @@ -101,7 +101,7 @@
},
"outputs": [],
"source": [
"modexp = ModExp.make_for_shor(big_n=15 * 17, g=9)"
"modexp = ModExp.make_for_shor(big_n=13 * 17, g=9)"
]
},
{
Expand Down
23 changes: 12 additions & 11 deletions qualtran/bloqs/factoring/mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
from functools import cached_property
from typing import Dict, Optional, Tuple, Union
from typing import cast, Dict, Optional, Tuple, Union

import attrs
import numpy as np
import sympy
from attrs import frozen

Expand All @@ -38,6 +38,7 @@
from qualtran.drawing import Text, WireSymbol
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.symbolics import is_symbolic


@frozen
Expand Down Expand Up @@ -70,9 +71,9 @@ class ModExp(Bloq):
exp_bitsize: Union[int, sympy.Expr]
x_bitsize: Union[int, sympy.Expr]

def __post_init__(self):
if isinstance(self.base, int) and isinstance(self.mod, int):
assert math.gcd(self.base, self.mod) == 1
def __attrs_post_init__(self):
if not is_symbolic(self.base, self.mod):
assert math.gcd(cast(int, self.base), cast(int, self.mod)) == 1

@cached_property
def signature(self) -> 'Signature':
Expand All @@ -95,9 +96,9 @@ def make_for_shor(cls, big_n: int, g: Optional[int] = None):
if isinstance(big_n, sympy.Expr):
little_n = sympy.ceiling(sympy.log(big_n, 2))
else:
little_n = int(np.ceil(np.log2(big_n)))
little_n = int(math.ceil(math.log2(big_n)))
if g is None:
g = np.random.randint(big_n)
g = random.randint(2, big_n)
return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n)

def _CtrlModMul(self, k: Union[int, sympy.Expr]):
Expand All @@ -111,10 +112,10 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st
exponent = bb.split(exponent)

# https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
base = self.base
base = self.base % self.mod
for j in range(self.exp_bitsize - 1, 0 - 1, -1):
exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x)
base = base * base % self.mod
base = (base * base) % self.mod

return {'exponent': bb.join(exponent, dtype=QUInt(self.exp_bitsize)), 'x': x}

Expand Down Expand Up @@ -145,13 +146,13 @@ def _generalize_k(b: Bloq) -> Optional[Bloq]:

@bloq_example(generalizer=(ignore_split_join, _generalize_k))
def _modexp_small() -> ModExp:
modexp_small = ModExp(base=3, mod=15, exp_bitsize=3, x_bitsize=2048)
modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048)
return modexp_small


@bloq_example(generalizer=(ignore_split_join, _generalize_k))
def _modexp() -> ModExp:
modexp = ModExp.make_for_shor(big_n=15 * 17, g=9)
modexp = ModExp.make_for_shor(big_n=13 * 17, g=9)
return modexp


Expand Down
6 changes: 4 additions & 2 deletions qualtran/bloqs/factoring/mod_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def test_mod_exp_consistent_classical():

# Choose a base smaller than mod.
base = rs.randint(1, mod)
while np.gcd(base, mod) != 1:
base = rs.randint(1, mod)

bloq = ModExp(base=base, exp_bitsize=ne, x_bitsize=n, mod=mod)
ret1 = bloq.call_classically(exponent=exponent)
Expand All @@ -65,7 +67,7 @@ def test_modexp_symb_manual():


def test_mod_exp_consistent_counts():
bloq = ModExp(base=8, exp_bitsize=3, x_bitsize=10, mod=50)
bloq = ModExp(base=11, exp_bitsize=3, x_bitsize=10, mod=50)

counts1 = bloq.bloq_counts()

Expand All @@ -86,7 +88,7 @@ def generalize(b: Bloq) -> Optional[Bloq]:


def test_mod_exp_t_complexity():
bloq = ModExp(base=8, exp_bitsize=3, x_bitsize=10, mod=50)
bloq = ModExp(base=11, exp_bitsize=3, x_bitsize=10, mod=50)
tcomp = bloq.t_complexity()
assert tcomp.t > 0

Expand Down
9 changes: 4 additions & 5 deletions qualtran/bloqs/mod_arithmetic/mod_multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numbers
from functools import cached_property
from typing import Dict, Optional, Tuple, Union
from typing import cast, Dict, Optional, Tuple, Union

import attrs
import numpy as np
Expand Down Expand Up @@ -180,12 +181,10 @@ class CModMulK(Bloq):
mod: Union[int, sympy.Expr]

def __attrs_post_init__(self):
if isinstance(self.k, sympy.Expr):
if is_symbolic(self.k, self.mod):
return
if isinstance(self.mod, sympy.Expr):
return

assert 0 < self.k < self.mod
assert math.gcd(cast(int, self.k), cast(int, self.mod)) == 1

@cached_property
def signature(self) -> 'Signature':
Expand Down

0 comments on commit 8a28307

Please sign in to comment.