Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify KaliskiModInverse to support zero #1486

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 128 additions & 33 deletions qualtran/bloqs/mod_arithmetic/mod_division.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np
import sympy
from attrs import evolve, frozen
from attrs import evolve, field, frozen

from qualtran import (
Bloq,
Expand Down Expand Up @@ -65,16 +65,23 @@ def signature(self) -> 'Signature':
Register('v', QMontgomeryUInt(self.bitsize)),
Register('m', QBit()),
Register('f', QBit()),
Register('is_terminal', QBit()),
]
)

def on_classical_vals(self, v: int, m: int, f: int) -> Dict[str, 'ClassicalValT']:
def on_classical_vals(
self, v: int, m: int, f: int, is_terminal: int
) -> Dict[str, 'ClassicalValT']:
print('here')
assert False
m ^= f & (v == 0)
assert is_terminal == 0
is_terminal ^= m
f ^= m
return {'v': v, 'm': m, 'f': f}
return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal}

def build_composite_bloq(
self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet
self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet, is_terminal: Soquet
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.bitsize):
raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}')
Expand All @@ -89,7 +96,8 @@ def build_composite_bloq(
f = ctrls[-1]
v = bb.join(v_arr)
m, f = bb.add(CNOT(), ctrl=m, target=f)
return {'v': v, 'm': m, 'f': f}
m, is_terminal = bb.add(CNOT(), ctrl=m, target=is_terminal)
return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if is_symbolic(self.bitsize):
Expand Down Expand Up @@ -408,16 +416,27 @@ def signature(self) -> 'Signature':
Register('s', QMontgomeryUInt(self.bitsize)),
Register('m', QBit()),
Register('f', QBit()),
Register('is_terminal', QBit()),
]
)

def build_composite_bloq(
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet
self,
bb: 'BloqBuilder',
u: Soquet,
v: Soquet,
r: Soquet,
s: Soquet,
m: Soquet,
f: Soquet,
is_terminal: Soquet,
) -> Dict[str, 'SoquetT']:
a = bb.allocate(1)
b = bb.allocate(1)

v, m, f = bb.add(_KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f)
v, m, f, is_terminal = bb.add(
_KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f, is_terminal=is_terminal
)
u, v, b, a, m, f = bb.add(
_KaliskiIterationStep2(self.bitsize), u=u, v=v, b=b, a=a, m=m, f=f
)
Expand All @@ -434,7 +453,7 @@ def build_composite_bloq(

bb.free(a)
bb.free(b)
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f}
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
Expand All @@ -447,7 +466,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
}

def on_classical_vals(
self, u: int, v: int, r: int, s: int, m: int, f: int
self, u: int, v: int, r: int, s: int, m: int, f: int, is_terminal: int
) -> Dict[str, 'ClassicalValT']:
"""This is the Kaliski algorithm as described in Fig7 of https://arxiv.org/pdf/2001.09580.

Expand All @@ -456,6 +475,7 @@ def on_classical_vals(
of `f` and `m`.
"""
assert m == 0
is_terminal = f == 1 and v == 0
if f == 0:
# When `f = 0` this means that the algorithm is nearly over and that we just need to
# double the value of `r`.
Expand Down Expand Up @@ -484,7 +504,7 @@ def on_classical_vals(
if swap:
u, v = v, u
r, s = s, r
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f}
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal}


@frozen
Expand All @@ -504,6 +524,7 @@ def signature(self) -> 'Signature':
Register('s', QMontgomeryUInt(self.bitsize)),
Register('m', QAny(2 * self.bitsize)),
Register('f', QBit()),
Register('terminal_condition', QAny(2 * self.bitsize)),
]
)

Expand All @@ -512,17 +533,33 @@ def _kaliski_iteration(self):
return _KaliskiIteration(self.bitsize, self.mod)

def build_composite_bloq(
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet
self,
bb: 'BloqBuilder',
u: Soquet,
v: Soquet,
r: Soquet,
s: Soquet,
m: Soquet,
f: Soquet,
terminal_condition: Soquet,
) -> Dict[str, 'SoquetT']:
f = bb.add(XGate(), q=f)
u = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=u)
s = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=s)

m_arr = bb.split(m)
terminal_condition_arr = bb.split(terminal_condition)

for i in range(2 * self.bitsize):
u, v, r, s, m_arr[i], f = bb.add(
self._kaliski_iteration, u=u, v=v, r=r, s=s, m=m_arr[i], f=f
u, v, r, s, m_arr[i], f, terminal_condition_arr[i] = bb.add(
self._kaliski_iteration,
u=u,
v=v,
r=r,
s=s,
m=m_arr[i],
f=f,
is_terminal=terminal_condition_arr[i],
)

r = bb.add(BitwiseNot(QMontgomeryUInt(self.bitsize)), x=r)
Expand All @@ -531,8 +568,43 @@ def build_composite_bloq(
u = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=u)
s = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=s)

# This is an extra step not present in the original Kaliski algorithm in order to
# handle the case of x=0. The invariant of the Kaliski algorithm is that that end of the
# algorithm u=1, s=0, r=mod inverse. This happens for all cases where the modular inverse
# exists (i.e. gcd(x, mod) = 1).
# The case where the input is zero is important. Although mathematically the inverse
# doesn't exist. For the bloq to be unitary it needs to map zero to itself.
# When the input is zero, the terminal values of the registers are r=mod, u=v=mod^1=mod-1
# (assuming odd modulus).
# So we clean those registers conditioned on the first terminal qubit which is set
# if and only if the input is zero.
terminal_condition_arr[0], r = bb.add(
XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(),
ctrl=terminal_condition_arr[0],
x=r,
)
terminal_condition_arr[0], u = bb.add(
XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(),
ctrl=terminal_condition_arr[0],
x=u,
)
terminal_condition_arr[0], s = bb.add(
XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(),
ctrl=terminal_condition_arr[0],
x=s,
)

m = bb.join(m_arr)
return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f}
terminal_condition = bb.join(terminal_condition_arr)
return {
'u': u,
'v': v,
'r': r,
's': s,
'm': m,
'f': f,
'terminal_condition': terminal_condition,
}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
Expand All @@ -542,6 +614,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
XGate(): 1,
XorK(QMontgomeryUInt(self.bitsize), self.mod): 2,
XorK(QMontgomeryUInt(self.bitsize), 1): 2,
XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(): 1,
XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(): 2,
}


Expand Down Expand Up @@ -575,7 +649,7 @@ class KaliskiModInverse(Bloq):
"""

bitsize: 'SymbolicInt'
mod: 'SymbolicInt'
mod: 'SymbolicInt' = field(validator=lambda _, __, v: is_symbolic(v) or v % 2 == 1)
uncompute: bool = False

@cached_property
Expand All @@ -584,22 +658,25 @@ def signature(self) -> 'Signature':
return Signature(
[
Register('x', QMontgomeryUInt(self.bitsize)),
Register('m', QAny(2 * self.bitsize), side=side),
Register('junk', QAny(4 * self.bitsize), side=side),
]
)

def build_composite_bloq(
self, bb: 'BloqBuilder', x: Soquet, m: Optional[Soquet] = None, f: Optional[Soquet] = None
self, bb: 'BloqBuilder', x: Soquet, junk: Optional[Soquet] = None
) -> Dict[str, 'SoquetT']:
u = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize))
r = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize))
s = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize))
f = bb.allocate(1)

if self.uncompute:
assert m is not None
u, x, r, s, m, f = cast(
Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet],
assert junk is not None
junk_arr = bb.split(junk)
m = bb.join(junk_arr[: 2 * self.bitsize])
terminal_condition = bb.join(junk_arr[2 * self.bitsize :])
u, x, r, s, m, f, terminal_condition = cast(
Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet],
bb.add_from(
_KaliskiModInverseImpl(self.bitsize, self.mod).adjoint(),
u=u,
Expand All @@ -608,47 +685,65 @@ def build_composite_bloq(
s=s,
m=m,
f=f,
terminal_condition=terminal_condition,
),
)
bb.free(u)
bb.free(r)
bb.free(s)
bb.free(m)
bb.free(f)
bb.free(terminal_condition)
return {'x': x}

m = bb.allocate(2 * self.bitsize)
u, v, x, s, m, f = bb.add_from(
_KaliskiModInverseImpl(self.bitsize, self.mod), u=u, v=x, r=r, s=s, m=m, f=f
terminal_condition = bb.allocate(2 * self.bitsize)
u, v, x, s, m, f, terminal_condition = cast(
Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet],
bb.add_from(
_KaliskiModInverseImpl(self.bitsize, self.mod),
u=u,
v=x,
r=r,
s=s,
m=m,
f=f,
terminal_condition=terminal_condition,
),
)

assert isinstance(u, Soquet)
assert isinstance(v, Soquet)
assert isinstance(s, Soquet)
assert isinstance(f, Soquet)
bb.free(u)
bb.free(v)
bb.free(s)
bb.free(f)
return {'x': x, 'm': m}
junk = bb.join(np.concatenate([bb.split(m), bb.split(terminal_condition)]))
return {'x': x, 'junk': junk}

def adjoint(self) -> 'KaliskiModInverse':
return evolve(self, uncompute=not self.uncompute)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return _KaliskiModInverseImpl(self.bitsize, self.mod).build_call_graph(ssa)

def on_classical_vals(self, x: int, m: int = 0) -> Dict[str, 'ClassicalValT']:
u, v, r, s, f = int(self.mod), x, 0, 1, 1
def on_classical_vals(self, x: int, junk: int = 0) -> Dict[str, 'ClassicalValT']:
mod = int(self.mod)
u, v, r, s, f = mod, x, 0, 1, 1
terminal_condition = m = 0
iteration = _KaliskiModInverseImpl(self.bitsize, self.mod)._kaliski_iteration
for _ in range(2 * int(self.bitsize)):
u, v, r, s, m_i, f = iteration.call_classically(u=u, v=v, r=r, s=s, m=0, f=f)
u, v, r, s, m_i, f, is_terminal = iteration.call_classically(
u=u, v=v, r=r, s=s, m=0, f=f, is_terminal=0
)
m = (m << 1) | m_i
assert u == 1
assert s == self.mod
terminal_condition = (terminal_condition << 1) | is_terminal
assert u == 1 or (x == 0 and u == mod)
assert s == self.mod or (x == 0 and s == 1)
assert f == 0
assert v == 0
return {'x': self.mod - r, 'm': m}
return {
'x': (self.mod - r) if r else 0,
'junk': m * 2 ** (2 * self.bitsize) + terminal_condition,
}


@bloq_example
Expand Down
15 changes: 14 additions & 1 deletion qualtran/bloqs/mod_arithmetic/mod_division_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,24 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod):
continue
x_montgomery = dtype.uint_to_montgomery(x, mod)
res = blq.call_classically(x=x_montgomery)
print(x, x_montgomery)
assert res == cblq.call_classically(x=x_montgomery)
assert len(res) == 2
assert res[0] == dtype.montgomery_inverse(x_montgomery, mod)
assert dtype.montgomery_product(int(res[0]), x_montgomery, mod) == R
assert blq.adjoint().call_classically(x=res[0], m=res[1]) == (x_montgomery,)
assert blq.adjoint().call_classically(x=res[0], junk=res[1]) == (x_montgomery,)


@pytest.mark.parametrize('bitsize', [5, 6])
@pytest.mark.parametrize('mod', [3, 5, 7, 11, 13, 15])
def test_kaliski_mod_inverse_classical_action_zero(bitsize, mod):
blq = KaliskiModInverse(bitsize, mod)
cblq = blq.decompose_bloq()
# When x = 0 the terminal condition is achieved at the first iteration, this corresponds to
# m_0 = is_terminal_0 = 1 and all other bits = 0.
junk = 2 ** (4 * bitsize - 1) + 2 ** (2 * bitsize - 1)
assert blq.call_classically(x=0) == cblq.call_classically(x=0) == (0, junk)
assert blq.adjoint().call_classically(x=0, junk=junk) == (0,)


@pytest.mark.parametrize('bitsize', [5, 6])
Expand Down
Loading