From a053918954534f407ab13902b774beb0077bd27a Mon Sep 17 00:00:00 2001 From: Nour Yosri Date: Mon, 4 Nov 2024 15:21:20 -0800 Subject: [PATCH] Modify KaliskiModInverse to support zero --- qualtran/bloqs/mod_arithmetic/mod_division.py | 161 ++++++++++++++---- .../bloqs/mod_arithmetic/mod_division_test.py | 15 +- 2 files changed, 142 insertions(+), 34 deletions(-) diff --git a/qualtran/bloqs/mod_arithmetic/mod_division.py b/qualtran/bloqs/mod_arithmetic/mod_division.py index 575a1aada..005323b96 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_division.py +++ b/qualtran/bloqs/mod_arithmetic/mod_division.py @@ -17,7 +17,7 @@ import numpy as np import sympy -from attrs import evolve, frozen +from attrs import evolve, field, frozen from qualtran import ( Bloq, @@ -65,16 +65,23 @@ def signature(self) -> 'Signature': Register('v', QMontgomeryUInt(self.bitsize)), Register('m', QBit()), Register('f', QBit()), + Register('is_terminal', QBit()), ] ) - def on_classical_vals(self, v: int, m: int, f: int) -> Dict[str, 'ClassicalValT']: + def on_classical_vals( + self, v: int, m: int, f: int, is_terminal: int + ) -> Dict[str, 'ClassicalValT']: + print('here') + assert False m ^= f & (v == 0) + assert is_terminal == 0 + is_terminal ^= m f ^= m - return {'v': v, 'm': m, 'f': f} + return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal} def build_composite_bloq( - self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet + self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet, is_terminal: Soquet ) -> Dict[str, 'SoquetT']: if is_symbolic(self.bitsize): raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}') @@ -89,7 +96,8 @@ def build_composite_bloq( f = ctrls[-1] v = bb.join(v_arr) m, f = bb.add(CNOT(), ctrl=m, target=f) - return {'v': v, 'm': m, 'f': f} + m, is_terminal = bb.add(CNOT(), ctrl=m, target=is_terminal) + return {'v': v, 'm': m, 'f': f, 'is_terminal': is_terminal} def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': if is_symbolic(self.bitsize): @@ -408,16 +416,27 @@ def signature(self) -> 'Signature': Register('s', QMontgomeryUInt(self.bitsize)), Register('m', QBit()), Register('f', QBit()), + Register('is_terminal', QBit()), ] ) def build_composite_bloq( - self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet + self, + bb: 'BloqBuilder', + u: Soquet, + v: Soquet, + r: Soquet, + s: Soquet, + m: Soquet, + f: Soquet, + is_terminal: Soquet, ) -> Dict[str, 'SoquetT']: a = bb.allocate(1) b = bb.allocate(1) - v, m, f = bb.add(_KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f) + v, m, f, is_terminal = bb.add( + _KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f, is_terminal=is_terminal + ) u, v, b, a, m, f = bb.add( _KaliskiIterationStep2(self.bitsize), u=u, v=v, b=b, a=a, m=m, f=f ) @@ -434,7 +453,7 @@ def build_composite_bloq( bb.free(a) bb.free(b) - return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal} def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { @@ -447,7 +466,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': } def on_classical_vals( - self, u: int, v: int, r: int, s: int, m: int, f: int + self, u: int, v: int, r: int, s: int, m: int, f: int, is_terminal: int ) -> Dict[str, 'ClassicalValT']: """This is the Kaliski algorithm as described in Fig7 of https://arxiv.org/pdf/2001.09580. @@ -456,6 +475,7 @@ def on_classical_vals( of `f` and `m`. """ assert m == 0 + is_terminal = f == 1 and v == 0 if f == 0: # When `f = 0` this means that the algorithm is nearly over and that we just need to # double the value of `r`. @@ -484,7 +504,7 @@ def on_classical_vals( if swap: u, v = v, u r, s = s, r - return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f, 'is_terminal': is_terminal} @frozen @@ -504,6 +524,7 @@ def signature(self) -> 'Signature': Register('s', QMontgomeryUInt(self.bitsize)), Register('m', QAny(2 * self.bitsize)), Register('f', QBit()), + Register('terminal_condition', QAny(2 * self.bitsize)), ] ) @@ -512,17 +533,33 @@ def _kaliski_iteration(self): return _KaliskiIteration(self.bitsize, self.mod) def build_composite_bloq( - self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet + self, + bb: 'BloqBuilder', + u: Soquet, + v: Soquet, + r: Soquet, + s: Soquet, + m: Soquet, + f: Soquet, + terminal_condition: Soquet, ) -> Dict[str, 'SoquetT']: f = bb.add(XGate(), q=f) u = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=u) s = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=s) m_arr = bb.split(m) + terminal_condition_arr = bb.split(terminal_condition) for i in range(2 * self.bitsize): - u, v, r, s, m_arr[i], f = bb.add( - self._kaliski_iteration, u=u, v=v, r=r, s=s, m=m_arr[i], f=f + u, v, r, s, m_arr[i], f, terminal_condition_arr[i] = bb.add( + self._kaliski_iteration, + u=u, + v=v, + r=r, + s=s, + m=m_arr[i], + f=f, + is_terminal=terminal_condition_arr[i], ) r = bb.add(BitwiseNot(QMontgomeryUInt(self.bitsize)), x=r) @@ -531,8 +568,43 @@ def build_composite_bloq( u = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=u) s = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=s) + # This is an extra step not present in the original Kaliski algorithm in order to + # handle the case of x=0. The invariant of the Kaliski algorithm is that that end of the + # algorithm u=1, s=0, r=mod inverse. This happens for all cases where the modular inverse + # exists (i.e. gcd(x, mod) = 1). + # The case where the input is zero is important. Although mathematically the inverse + # doesn't exist. For the bloq to be unitary it needs to map zero to itself. + # When the input is zero, the terminal values of the registers are r=mod, u=v=mod^1=mod-1 + # (assuming odd modulus). + # So we clean those registers conditioned on the first terminal qubit which is set + # if and only if the input is zero. + terminal_condition_arr[0], r = bb.add( + XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(), + ctrl=terminal_condition_arr[0], + x=r, + ) + terminal_condition_arr[0], u = bb.add( + XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(), + ctrl=terminal_condition_arr[0], + x=u, + ) + terminal_condition_arr[0], s = bb.add( + XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(), + ctrl=terminal_condition_arr[0], + x=s, + ) + m = bb.join(m_arr) - return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + terminal_condition = bb.join(terminal_condition_arr) + return { + 'u': u, + 'v': v, + 'r': r, + 's': s, + 'm': m, + 'f': f, + 'terminal_condition': terminal_condition, + } def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { @@ -542,6 +614,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': XGate(): 1, XorK(QMontgomeryUInt(self.bitsize), self.mod): 2, XorK(QMontgomeryUInt(self.bitsize), 1): 2, + XorK(QMontgomeryUInt(self.bitsize), self.mod).controlled(): 1, + XorK(QMontgomeryUInt(self.bitsize), self.mod - 1).controlled(): 2, } @@ -575,7 +649,7 @@ class KaliskiModInverse(Bloq): """ bitsize: 'SymbolicInt' - mod: 'SymbolicInt' + mod: 'SymbolicInt' = field(validator=lambda _, __, v: is_symbolic(v) or v % 2 == 1) uncompute: bool = False @cached_property @@ -584,12 +658,12 @@ def signature(self) -> 'Signature': return Signature( [ Register('x', QMontgomeryUInt(self.bitsize)), - Register('m', QAny(2 * self.bitsize), side=side), + Register('junk', QAny(4 * self.bitsize), side=side), ] ) def build_composite_bloq( - self, bb: 'BloqBuilder', x: Soquet, m: Optional[Soquet] = None, f: Optional[Soquet] = None + self, bb: 'BloqBuilder', x: Soquet, junk: Optional[Soquet] = None ) -> Dict[str, 'SoquetT']: u = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) r = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) @@ -597,9 +671,12 @@ def build_composite_bloq( f = bb.allocate(1) if self.uncompute: - assert m is not None - u, x, r, s, m, f = cast( - Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet], + assert junk is not None + junk_arr = bb.split(junk) + m = bb.join(junk_arr[: 2 * self.bitsize]) + terminal_condition = bb.join(junk_arr[2 * self.bitsize :]) + u, x, r, s, m, f, terminal_condition = cast( + Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet], bb.add_from( _KaliskiModInverseImpl(self.bitsize, self.mod).adjoint(), u=u, @@ -608,6 +685,7 @@ def build_composite_bloq( s=s, m=m, f=f, + terminal_condition=terminal_condition, ), ) bb.free(u) @@ -615,22 +693,31 @@ def build_composite_bloq( bb.free(s) bb.free(m) bb.free(f) + bb.free(terminal_condition) return {'x': x} m = bb.allocate(2 * self.bitsize) - u, v, x, s, m, f = bb.add_from( - _KaliskiModInverseImpl(self.bitsize, self.mod), u=u, v=x, r=r, s=s, m=m, f=f + terminal_condition = bb.allocate(2 * self.bitsize) + u, v, x, s, m, f, terminal_condition = cast( + Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet, Soquet], + bb.add_from( + _KaliskiModInverseImpl(self.bitsize, self.mod), + u=u, + v=x, + r=r, + s=s, + m=m, + f=f, + terminal_condition=terminal_condition, + ), ) - assert isinstance(u, Soquet) - assert isinstance(v, Soquet) - assert isinstance(s, Soquet) - assert isinstance(f, Soquet) bb.free(u) bb.free(v) bb.free(s) bb.free(f) - return {'x': x, 'm': m} + junk = bb.join(np.concatenate([bb.split(m), bb.split(terminal_condition)])) + return {'x': x, 'junk': junk} def adjoint(self) -> 'KaliskiModInverse': return evolve(self, uncompute=not self.uncompute) @@ -638,17 +725,25 @@ def adjoint(self) -> 'KaliskiModInverse': def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return _KaliskiModInverseImpl(self.bitsize, self.mod).build_call_graph(ssa) - def on_classical_vals(self, x: int, m: int = 0) -> Dict[str, 'ClassicalValT']: - u, v, r, s, f = int(self.mod), x, 0, 1, 1 + def on_classical_vals(self, x: int, junk: int = 0) -> Dict[str, 'ClassicalValT']: + mod = int(self.mod) + u, v, r, s, f = mod, x, 0, 1, 1 + terminal_condition = m = 0 iteration = _KaliskiModInverseImpl(self.bitsize, self.mod)._kaliski_iteration for _ in range(2 * int(self.bitsize)): - u, v, r, s, m_i, f = iteration.call_classically(u=u, v=v, r=r, s=s, m=0, f=f) + u, v, r, s, m_i, f, is_terminal = iteration.call_classically( + u=u, v=v, r=r, s=s, m=0, f=f, is_terminal=0 + ) m = (m << 1) | m_i - assert u == 1 - assert s == self.mod + terminal_condition = (terminal_condition << 1) | is_terminal + assert u == 1 or (x == 0 and u == mod) + assert s == self.mod or (x == 0 and s == 1) assert f == 0 assert v == 0 - return {'x': self.mod - r, 'm': m} + return { + 'x': (self.mod - r) if r else 0, + 'junk': m * 2 ** (2 * self.bitsize) + terminal_condition, + } @bloq_example diff --git a/qualtran/bloqs/mod_arithmetic/mod_division_test.py b/qualtran/bloqs/mod_arithmetic/mod_division_test.py index 02646ef82..093f0908f 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_division_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_division_test.py @@ -36,11 +36,24 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod): continue x_montgomery = dtype.uint_to_montgomery(x, mod) res = blq.call_classically(x=x_montgomery) + print(x, x_montgomery) assert res == cblq.call_classically(x=x_montgomery) assert len(res) == 2 assert res[0] == dtype.montgomery_inverse(x_montgomery, mod) assert dtype.montgomery_product(int(res[0]), x_montgomery, mod) == R - assert blq.adjoint().call_classically(x=res[0], m=res[1]) == (x_montgomery,) + assert blq.adjoint().call_classically(x=res[0], junk=res[1]) == (x_montgomery,) + + +@pytest.mark.parametrize('bitsize', [5, 6]) +@pytest.mark.parametrize('mod', [3, 5, 7, 11, 13, 15]) +def test_kaliski_mod_inverse_classical_action_zero(bitsize, mod): + blq = KaliskiModInverse(bitsize, mod) + cblq = blq.decompose_bloq() + # When x = 0 the terminal condition is achieved at the first iteration, this corresponds to + # m_0 = is_terminal_0 = 1 and all other bits = 0. + junk = 2 ** (4 * bitsize - 1) + 2 ** (2 * bitsize - 1) + assert blq.call_classically(x=0) == cblq.call_classically(x=0) == (0, junk) + assert blq.adjoint().call_classically(x=0, junk=junk) == (0,) @pytest.mark.parametrize('bitsize', [5, 6])