Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement Windowed Modular Exponentiation #1468

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d70debe
Add rsa files - needs a lot of work just stashing it for now
fpapa250 Sep 25, 2024
74bab90
Merge branch 'main' into rsa-improvements
fpapa250 Sep 28, 2024
e5daf75
Made some structure changes RSA
fpapa250 Sep 28, 2024
dcf1284
Rework rsa mod exp bloqs to work in a rsa phase estimation circuit
fpapa250 Sep 29, 2024
dd9d251
Fix mypy issues
fpapa250 Sep 29, 2024
cb82d39
Add some classical simulation
fpapa250 Oct 1, 2024
14927c0
Implement primitives for ModExp
fpapa250 Oct 2, 2024
1ec788f
Merge branch 'main' into mod_exp_subroutines
fpapa250 Oct 2, 2024
fbdd7e1
Fix serialization test error
fpapa250 Oct 2, 2024
92c76e7
Merge branch 'mod_exp_subroutines' of github.com:fpapa250/Qualtran in…
fpapa250 Oct 2, 2024
5fdc417
Change Union -> SymbolicInt
fpapa250 Oct 2, 2024
6435600
Fix nits
fpapa250 Oct 6, 2024
f06e2f8
Better symbolic messages
fpapa250 Oct 7, 2024
dd45a8f
Better symbolic decomposition error messages
fpapa250 Oct 7, 2024
9fb62fd
Merge branch 'main' into mod_exp_subroutines
fpapa250 Oct 10, 2024
a9b9f56
Fix merge conflicts
fpapa250 Oct 10, 2024
3bf7e63
Fixed docstring to be more readable (hopefully)
fpapa250 Oct 11, 2024
284c552
Refactor RSA to have a phase estimation circuit and a classical simul…
fpapa250 Oct 12, 2024
885b33e
Merge branch 'main' into rsa-improvements
fpapa250 Oct 12, 2024
bd41420
Fix notebook specs merge conflict
fpapa250 Oct 12, 2024
f665ef0
Merge branch 'mod_exp_subroutines' into rsa-window
fpapa250 Oct 13, 2024
7d2bcba
Super WIP windowed mod exp
fpapa250 Oct 13, 2024
ba18d5f
Bloq decomposition complete - needs testing
fpapa250 Oct 14, 2024
45b58b4
More work on decomposition
fpapa250 Oct 15, 2024
d7460a4
stash current changes
fpapa250 Oct 16, 2024
db4828c
Merge branch 'main' into rsa-window
fpapa250 Oct 22, 2024
36e42b3
Partial bugfix windowed arithmetic
fpapa250 Oct 23, 2024
3138cae
Fix merge conflicts
fpapa250 Oct 23, 2024
edbcc51
stash changes for now
fpapa250 Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
More work on decomposition
  • Loading branch information
fpapa250 committed Oct 15, 2024
commit 45b58b4bd8d248694eb2112a25ffb6ab8d2ba5f1
46 changes: 31 additions & 15 deletions qualtran/bloqs/factoring/rsa/rsa_mod_exp.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,8 @@
from qualtran.bloqs.basic_gates.z_basis import IntState
from qualtran.bloqs.data_loading.qroam_clean import QROAMClean
from qualtran.bloqs.mod_arithmetic import CModMulK
from qualtran.bloqs.mod_arithmetic.mod_addition import ModAdd
from qualtran.bloqs.mod_arithmetic.mod_subtraction import ModSub
from qualtran.drawing import Text, WireSymbol
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.resource_counting.generalizers import ignore_split_join
@@ -83,8 +85,8 @@ class ModExp(Bloq):
mod: 'SymbolicInt'
exp_bitsize: 'SymbolicInt'
x_bitsize: 'SymbolicInt'
exp_window_size: 'SymbolicInt' = 1
mult_window_size: 'SymbolicInt' = 1
exp_window_size: Optional['SymbolicInt'] = None
mult_window_size: Optional['SymbolicInt'] = None

def __attrs_post_init__(self):
if not is_symbolic(self.base, self.mod):
@@ -100,7 +102,7 @@ def signature(self) -> 'Signature':
)

@classmethod
def make_for_shor(cls, big_n: 'SymbolicInt', g: Optional['SymbolicInt'] = None, exp_window_size: Optional['SymbolicInt'] = 1, mult_window_size: Optional['SymbolicInt'] = 1):
def make_for_shor(cls, big_n: 'SymbolicInt', g: Optional['SymbolicInt'] = None, exp_window_size: Optional['SymbolicInt'] = None, mult_window_size: Optional['SymbolicInt'] = None):
"""Factory method that sets up the modular exponentiation for a factoring run.

Args:
@@ -154,7 +156,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st
x = bb.add(IntState(val=1, bitsize=self.x_bitsize))
exponent = bb.split(exponent)

if self.exp_window_size > 1 or self.mult_window_size > 1:
if self.exp_window_size is not None and self.mult_window_size is not None:
k = self.base

a = bb.split(x)
@@ -170,9 +172,10 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st
data = list([(ke * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke in kes)
ei_i = bb.join(ei[i], QUInt((self.exp_window_size)))
mi_i = bb.join(mi[j], QUInt((self.mult_window_size)))
ei_i, mi_i, t = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i)
t, b = bb.add(Add(QUInt(self.x_bitsize), QUInt(self.x_bitsize)), a=t, b=b)
ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t)
ei_i, mi_i, t, *junk = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i)
t, b = bb.add(ModAdd(self.x_bitsize, self.mod), x=t, y=b)
junk_mapping = {f'junk_target{i}_': junk[i] for i in range(len(junk))}
ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t, **junk_mapping)
ei[i] = bb.split(ei_i)
mi[j] = bb.split(mi_i)

@@ -185,9 +188,10 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st
data = list([(ke_inv * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke_inv in kes_inv)
ei_i = bb.join(ei[i], QUInt((self.exp_window_size)))
mi_i = bb.join(mi[j], QUInt((self.mult_window_size)))
ei_i, mi_i, t = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i)
t, a = bb.add(SubtractFrom(QUInt(self.x_bitsize)), a=t, b=a)
ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t)
ei_i, mi_i, t, *junk = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i)
t, a = bb.add(ModSub(QUInt(self.x_bitsize), self.mod), x=t, y=a)
junk_mapping = {f'junk_target{i}_': junk[i] for i in range(len(junk))}
ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t, **junk_mapping)
ei[i] = bb.split(ei_i)
mi[j] = bb.split(mi_i)

@@ -201,7 +205,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st

x = bb.join(a, QUInt(self.x_bitsize))
exponent = np.concatenate(ei, axis=None)
bb.free(b)
bb.free(b, dirty=True)
else:
# https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
base = self.base % self.mod
@@ -212,10 +216,10 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st
return {'exponent': bb.join(exponent, dtype=QUInt(self.exp_bitsize)), 'x': x}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if self.exp_window_size > 1 or self.mult_window_size > 1:
if self.exp_window_size is not None and self.mult_window_size is not None:
bloq_counts = (self.exp_bitsize // self.exp_window_size) * (self.x_bitsize // self.mult_window_size)
return {self.qrom: 2 * bloq_counts, self.qrom.adjoint(): 2 * bloq_counts, Add(): bloq_counts, SubtractFrom(): bloq_counts,
Swap(self.x_bitsize): self.exp_bitsize // self.exp_window_size}
return {}#return {self.qrom: 2 * bloq_counts, self.qrom.adjoint(): 2 * bloq_counts, Add(): bloq_counts, SubtractFrom(): bloq_counts,
#Swap(self.x_bitsize): self.exp_bitsize // self.exp_window_size}
else:
k = ssa.new_symbol('k')
return {self._CtrlModMul(k=k): self.exp_bitsize, IntState(val=1, bitsize=self.x_bitsize): 1}
@@ -253,11 +257,23 @@ def _modexp() -> ModExp:
return modexp


@bloq_example(generalizer=(ignore_split_join, _generalize_k))
def _modexp_window() -> ModExp:
modexp_window = ModExp.make_for_shor(big_n=13 * 17, g=9, exp_window_size=2, mult_window_size=2)
return modexp_window


@bloq_example
def _modexp_symb() -> ModExp:
g, N, n_e, n_x = sympy.symbols('g N n_e, n_x')
modexp_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x)
return modexp_symb

@bloq_example
def _modexp_window_symb() -> ModExp:
g, N, n_e, n_x, w_e, w_m = sympy.symbols('g N n_e, n_x w_e w_m')
modexp_window_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x, exp_window_size=w_e, mult_window_size=w_m)
return modexp_window_symb


_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb))
_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb, _modexp_window, _modexp_window_symb))
4 changes: 2 additions & 2 deletions qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@

from qualtran import Bloq
from qualtran.bloqs.bookkeeping import Join, Split
from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp
from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, _modexp_window, _modexp_window_symb, ModExp
from qualtran.bloqs.mod_arithmetic import CModMulK
from qualtran.drawing import Text
from qualtran.resource_counting import SympySymbolAllocator
@@ -95,7 +95,7 @@ def test_mod_exp_t_complexity():
assert tcomp.t > 0


@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small])
@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small, _modexp_window, _modexp_window_symb])
def test_modexp(bloq_autotester, bloq):
bloq_autotester(bloq)