diff --git a/dev_tools/autogenerate-bloqs-notebooks-v2.py b/dev_tools/autogenerate-bloqs-notebooks-v2.py index c98c6cd3e..64eed620c 100644 --- a/dev_tools/autogenerate-bloqs-notebooks-v2.py +++ b/dev_tools/autogenerate-bloqs-notebooks-v2.py @@ -524,6 +524,7 @@ bloq_specs=[ qualtran.bloqs.mod_arithmetic.mod_multiplication._MOD_DBL_DOC, qualtran.bloqs.mod_arithmetic.mod_multiplication._C_MOD_MUL_K_DOC, + qualtran.bloqs.mod_arithmetic.mod_multiplication._DIRTY_OUT_OF_PLACE_MONTGOMERY_MOD_MUL_DOC, ], ), NotebookSpecV2( diff --git a/qualtran/bloqs/factoring/ecc/ec_add.py b/qualtran/bloqs/factoring/ecc/ec_add.py index e0ff64e6a..74a57706a 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.py +++ b/qualtran/bloqs/factoring/ecc/ec_add.py @@ -18,8 +18,17 @@ from qualtran import Bloq, bloq_example, BloqDocSpec, QUInt, Register, Signature from qualtran.bloqs.arithmetic._shims import MultiCToffoli -from qualtran.bloqs.mod_arithmetic import CModAdd, CModNeg, CModSub, ModAdd, ModNeg, ModSub -from qualtran.bloqs.mod_arithmetic._shims import ModDbl, ModInv, ModMul +from qualtran.bloqs.mod_arithmetic import ( + CModAdd, + CModNeg, + CModSub, + DirtyOutOfPlaceMontgomeryModMul, + ModAdd, + ModDbl, + ModNeg, + ModSub, +) +from qualtran.bloqs.mod_arithmetic._shims import ModInv from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -63,7 +72,7 @@ def signature(self) -> 'Signature': ] ) - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: # litinksi return { MultiCToffoli(n=self.n): 18, @@ -74,7 +83,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': ModNeg(QUInt(self.n), mod=self.mod): 2, CModNeg(QUInt(self.n), mod=self.mod): 1, ModDbl(QUInt(self.n), mod=self.mod): 2, - ModMul(n=self.n, mod=self.mod): 10, + DirtyOutOfPlaceMontgomeryModMul(bitsize=self.n, window_size=4, mod=self.mod): 10, ModInv(n=self.n, mod=self.mod): 4, } diff --git a/qualtran/bloqs/mod_arithmetic/__init__.py b/qualtran/bloqs/mod_arithmetic/__init__.py index deba42bb0..ff0d3a3da 100644 --- a/qualtran/bloqs/mod_arithmetic/__init__.py +++ b/qualtran/bloqs/mod_arithmetic/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._shims import ModInv, ModMul +from ._shims import ModInv from .mod_addition import CModAdd, CModAddK, CtrlScaleModAdd, ModAdd, ModAddK -from .mod_multiplication import CModMulK, ModDbl +from .mod_multiplication import CModMulK, DirtyOutOfPlaceMontgomeryModMul, ModDbl from .mod_subtraction import CModNeg, CModSub, ModNeg, ModSub diff --git a/qualtran/bloqs/mod_arithmetic/_shims.py b/qualtran/bloqs/mod_arithmetic/_shims.py index c6e630235..65b61b4ca 100644 --- a/qualtran/bloqs/mod_arithmetic/_shims.py +++ b/qualtran/bloqs/mod_arithmetic/_shims.py @@ -32,7 +32,6 @@ from qualtran.bloqs.basic_gates import CNOT, CSwap, Swap, Toffoli from qualtran.bloqs.mod_arithmetic.mod_multiplication import ModDbl from qualtran.drawing import Text, TextBox, WireSymbol -from qualtran.symbolics import ceil, log2 if TYPE_CHECKING: from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -114,37 +113,3 @@ def wire_symbol( elif reg.name == 'out': return TextBox('$x^{-1}$') raise ValueError(f'Unrecognized register name {reg.name}') - - -@frozen -class ModMul(Bloq): - n: int - mod: int - - @cached_property - def signature(self) -> 'Signature': - return Signature( - [ - Register('x', QUInt(self.n)), - Register('y', QUInt(self.n)), - Register('out', QUInt(self.n)), - ] - ) - - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - # Roetteler montgomery - return {Toffoli(): ceil(16 * self.n**2 * log2(self.n) - 26.3 * self.n**2)} - - def wire_symbol( - self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() - ) -> 'WireSymbol': - if reg is None: - return Text("") - if reg.name in ['x', 'y']: - return TextBox(reg.name) - elif reg.name == 'out': - return TextBox('x*y') - raise ValueError(f'Unrecognized register name {reg.name}') - - def __str__(self): - return self.__class__.__name__ diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.ipynb b/qualtran/bloqs/mod_arithmetic/mod_multiplication.ipynb index dc4c62bdf..a7f5a4aac 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.ipynb +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "6b63cfe5", + "id": "b23ed079", "metadata": { "cq.autogen": "title_cell" }, @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d72f6711", + "id": "5cde4184", "metadata": { "cq.autogen": "top_imports" }, @@ -30,7 +30,7 @@ }, { "cell_type": "markdown", - "id": "d3899162", + "id": "33013eac", "metadata": { "cq.autogen": "ModDbl.bloq_doc.md" }, @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53515719", + "id": "934adfa2", "metadata": { "cq.autogen": "ModDbl.bloq_doc.py" }, @@ -65,7 +65,7 @@ }, { "cell_type": "markdown", - "id": "b5e0c374", + "id": "854a2f34", "metadata": { "cq.autogen": "ModDbl.example_instances.md" }, @@ -76,7 +76,7 @@ { "cell_type": "code", "execution_count": null, - "id": "550f264b", + "id": "a376c520", "metadata": { "cq.autogen": "ModDbl.moddbl_small" }, @@ -88,7 +88,7 @@ { "cell_type": "code", "execution_count": null, - "id": "89df68f0", + "id": "559f3f97", "metadata": { "cq.autogen": "ModDbl.moddbl_large" }, @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "acd85b81", + "id": "e90cf054", "metadata": { "cq.autogen": "ModDbl.graphical_signature.md" }, @@ -111,7 +111,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c588ee92", + "id": "75c4294c", "metadata": { "cq.autogen": "ModDbl.graphical_signature.py" }, @@ -124,7 +124,7 @@ }, { "cell_type": "markdown", - "id": "3cfc35a0", + "id": "ef3bfee0", "metadata": { "cq.autogen": "ModDbl.call_graph.md" }, @@ -135,7 +135,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c5211901", + "id": "bd1dfb09", "metadata": { "cq.autogen": "ModDbl.call_graph.py" }, @@ -149,7 +149,7 @@ }, { "cell_type": "markdown", - "id": "e21338a3", + "id": "03dac121", "metadata": { "cq.autogen": "CModMulK.bloq_doc.md" }, @@ -172,7 +172,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08cc01f5", + "id": "b735fef0", "metadata": { "cq.autogen": "CModMulK.bloq_doc.py" }, @@ -183,7 +183,7 @@ }, { "cell_type": "markdown", - "id": "4a8585a7", + "id": "0d8c1a4b", "metadata": { "cq.autogen": "CModMulK.example_instances.md" }, @@ -194,7 +194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7d72393a", + "id": "1986bbf9", "metadata": { "cq.autogen": "CModMulK.modmul_symb" }, @@ -209,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "521a0b51", + "id": "ecdbe3f4", "metadata": { "cq.autogen": "CModMulK.modmul" }, @@ -220,7 +220,7 @@ }, { "cell_type": "markdown", - "id": "b51e0ac8", + "id": "52e944f4", "metadata": { "cq.autogen": "CModMulK.graphical_signature.md" }, @@ -231,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "686db91d", + "id": "1e28aa1f", "metadata": { "cq.autogen": "CModMulK.graphical_signature.py" }, @@ -244,7 +244,7 @@ }, { "cell_type": "markdown", - "id": "0749b88f", + "id": "8e34d67f", "metadata": { "cq.autogen": "CModMulK.call_graph.md" }, @@ -255,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29decc82", + "id": "f8645e0c", "metadata": { "cq.autogen": "CModMulK.call_graph.py" }, @@ -266,6 +266,143 @@ "show_call_graph(modmul_symb_g)\n", "show_counts_sigma(modmul_symb_sigma)" ] + }, + { + "cell_type": "markdown", + "id": "849371cb", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.bloq_doc.md" + }, + "source": [ + "## `DirtyOutOfPlaceMontgomeryModMul`\n", + "Perform windowed montgomery modular multiplication.\n", + "\n", + "Applies the trasformation\n", + "$$\n", + " \\ket{x}\\ket{y}\\ket{0}\\ket{0}\\ket{0} \\rightarrow \\ket{x}\\ket{y}\\ket{xy2^{-n}}\\ket{h}\\ket{c}\n", + "$$\n", + "\n", + "Where:\n", + "\n", + "- $n$ is the bitsize.\n", + "- $x, y$ are in montgomery form\n", + "- $h$ is an ancilla register that represents intermidate values.\n", + "- $c$ is whether a final modular reduction was applied or not.\n", + "\n", + "#### Parameters\n", + " - `bitsize`: size of the numbers.\n", + " - `window_size`: size of the window.\n", + " - `mod`: The integer modulus.\n", + " - `uncompute`: whether to compute or uncompute. \n", + "\n", + "#### Registers\n", + " - `x`: The first integer\n", + " - `y`: The second integer\n", + " - `target`: product in montgomery form $xy 2^{-n}$\n", + " - `qrom_indices`: concatination of the indicies used to query QROM.\n", + " - `reduced`: whether a final modular reduction was applied. \n", + "\n", + "#### References\n", + " - [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126 133 Cat Qubits](https://arxiv.org/abs/2302.06639). Appendix C4.\n", + " - [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). page 8.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d08d7f3a", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.mod_arithmetic import DirtyOutOfPlaceMontgomeryModMul" + ] + }, + { + "cell_type": "markdown", + "id": "56c6466e", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b573e60f", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.dirtyoutofplacemontgomerymodmul_small" + }, + "outputs": [], + "source": [ + "dirtyoutofplacemontgomerymodmul_small = DirtyOutOfPlaceMontgomeryModMul(6, 2, 7)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f339fb21", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.dirtyoutofplacemontgomerymodmul_medium" + }, + "outputs": [], + "source": [ + "dirtyoutofplacemontgomerymodmul_medium = DirtyOutOfPlaceMontgomeryModMul(\n", + " bitsize=16, window_size=4, mod=2**15 - 1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c832c4a9", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99ee86e1", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([dirtyoutofplacemontgomerymodmul_small, dirtyoutofplacemontgomerymodmul_medium],\n", + " ['`dirtyoutofplacemontgomerymodmul_small`', '`dirtyoutofplacemontgomerymodmul_medium`'])" + ] + }, + { + "cell_type": "markdown", + "id": "1de095e7", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bde436f4", + "metadata": { + "cq.autogen": "DirtyOutOfPlaceMontgomeryModMul.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "dirtyoutofplacemontgomerymodmul_small_g, dirtyoutofplacemontgomerymodmul_small_sigma = dirtyoutofplacemontgomerymodmul_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(dirtyoutofplacemontgomerymodmul_small_g)\n", + "show_counts_sigma(dirtyoutofplacemontgomerymodmul_small_sigma)" + ] } ], "metadata": { diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index 536cfae6b..c38de6634 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -15,34 +15,42 @@ import math import numbers from functools import cached_property -from typing import cast, Dict, Optional, Tuple, Union +from typing import cast, Dict, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union import attrs import numpy as np import sympy from attrs import frozen +from numpy.typing import NDArray from qualtran import ( Bloq, bloq_example, BloqBuilder, BloqDocSpec, + DecomposeNotImplementedError, QBit, QMontgomeryUInt, QUInt, Register, + Side, Signature, Soquet, SoquetT, ) -from qualtran.bloqs.arithmetic.addition import AddK +from qualtran.bloqs.arithmetic import Add, AddK, CAdd, Xor +from qualtran.bloqs.arithmetic.comparison import LessThanConstant from qualtran.bloqs.basic_gates import CNOT, CSwap, XGate +from qualtran.bloqs.data_loading.qroam_clean import QROAMClean from qualtran.bloqs.mod_arithmetic.mod_addition import CtrlScaleModAdd from qualtran.drawing import Circle, directional_text_box, Text, WireSymbol -from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join from qualtran.simulation.classical_sim import ClassicalValT -from qualtran.symbolics import is_symbolic +from qualtran.symbolics import is_symbolic, Shaped + +if TYPE_CHECKING: + from qualtran.symbolics import SymbolicInt @frozen @@ -265,3 +273,475 @@ def _modmul_symb() -> CModMulK: _C_MOD_MUL_K_DOC = BloqDocSpec(bloq_cls=CModMulK, examples=(_modmul_symb, _modmul)) + + +@frozen +class SingleWindowModMul(Bloq): + r"""Performs modular multiplication on a single windowed. + + This bloq is used as a subroutine in DirtyOutOfPlaceMontgomeryModMul. + + Applies + $$ + \ket{x}\key{y}\key{t}\ket{0} \rightarrow \ket{x}\key{y}\ket{t+xy \mod p} \ket{xy \mod 2^w} + $$ + + Where: + + - $w$ is the window size. + - $p$ is the modulus. + + Args: + window_size: size of the window (=size of the first register). + bitsize: size of the second register. + mod: The integer modulus. + + Registers: + x: The first integer as an array of bits (`window_size` bits). + y: The second integer (`bitsize` bits) + target: product accumulation array of bits. + qrom_index: contains the value $xy \mod 2^w$ (starts at 0). + """ + + window_size: 'SymbolicInt' + bitsize: 'SymbolicInt' + mod: 'SymbolicInt' + + def __attrs_post_init__(self): + if not is_symbolic(self.bitsize, self.window_size): + assert self.bitsize % self.window_size == 0 + + @property + def signature(self) -> 'Signature': + return Signature( + [ + Register('x', QBit(), shape=(self.window_size,)), + Register('y', QMontgomeryUInt(self.bitsize)), + Register('target', QBit(), shape=(self.window_size + self.bitsize,)), + Register('qrom_index', QMontgomeryUInt(self.window_size)), + ] + ) + + @cached_property + def qrom(self) -> QROAMClean: + if is_symbolic(self.bitsize) or is_symbolic(self.window_size) or is_symbolic(self.mod): + log_block_sizes = None + if is_symbolic(self.bitsize) and not is_symbolic(self.window_size): + # We assume that bitsize is much larger than window_size + log_block_sizes = (0,) + return QROAMClean( + [Shaped((2**self.window_size,))], + selection_bitsizes=(self.window_size,), + target_bitsizes=(self.window_size + self.bitsize,), + log_block_sizes=log_block_sizes, + ) + inv_mod = pow(self.mod, 2 ** (self.window_size - 1) - 1, 2**self.window_size) + N = 2**self.window_size + data = (-np.arange(N) * inv_mod) % N + data *= self.mod + return QROAMClean( + [data], + selection_bitsizes=(self.window_size,), + target_bitsizes=(self.window_size + self.bitsize,), + ) + + def on_classical_vals(self, x: Sequence[int], y: int, target: Sequence[int], qrom_index: int): + if is_symbolic(self.bitsize) or is_symbolic(self.window_size): + raise ValueError(f'classical action is not supported for {self}') + dtype = QMontgomeryUInt(self.window_size + self.bitsize) + target_val = QMontgomeryUInt.from_bits(dtype, target) + for i in range(self.window_size): + if x[i]: + target_val += y << i + qrom_index = target_val & (2**self.window_size - 1) + Tm = self.qrom.data[0][qrom_index] + target_val = (target_val + Tm) >> self.window_size + target = QMontgomeryUInt.to_bits(dtype, target_val) + return {'target': target, 'qrom_index': qrom_index, 'x': x, 'y': y} + + def build_composite_bloq( + self, bb: 'BloqBuilder', x: NDArray[Soquet], y: Soquet, target: NDArray[Soquet], qrom_index: Soquet # type: ignore[type-var] + ): + if is_symbolic(self.window_size): + raise DecomposeNotImplementedError(f'symbolic decomposition not supported for {self}') + for i in range(self.window_size): + z = bb.join(target[-self.bitsize - 1 - i : len(target) - i]) + x[i], y, z = bb.add( + CAdd(QMontgomeryUInt(self.bitsize), QMontgomeryUInt(self.bitsize + 1)), + ctrl=x[i], + a=y, + b=z, + ) + z_arr = bb.split(z) + target[-self.bitsize - 1 - i : len(target) - i] = z_arr + + m = bb.join(target[-self.window_size :], QMontgomeryUInt(self.window_size)) + m, qrom_index = bb.add(Xor(QMontgomeryUInt(self.window_size)), x=m, y=qrom_index) + target[-self.window_size :] = bb.split(m) + + qrom_index, qrom_target, *junk = bb.add(self.qrom, selection=qrom_index) + z = bb.join(target) + qrom_target, z = bb.add( + Add(QMontgomeryUInt(self.bitsize + self.window_size)), a=qrom_target, b=z + ) + if junk: + assert len(junk) == 1 + qrom_index = bb.add( + self.qrom.adjoint(), + selection=qrom_index, + target0_=qrom_target, + junk_target0_=junk[0], + ) + else: + qrom_index = bb.add(self.qrom.adjoint(), selection=qrom_index, target0_=qrom_target) + target_arr = bb.split(z) + target_arr = np.roll(target_arr, self.window_size) + + return {'x': x, 'y': y, 'target': target_arr, 'qrom_index': qrom_index} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: + return { + CAdd( + QMontgomeryUInt(self.bitsize), QMontgomeryUInt(self.bitsize + 1) + ): self.window_size, + Xor(QMontgomeryUInt(self.window_size)): 1, + Add(QMontgomeryUInt(self.bitsize + self.window_size)): 1, + self.qrom: 1, + self.qrom.adjoint(): 1, + } + + +@frozen +class _DirtyOutOfPlaceMontgomeryModMulImpl(Bloq): + r"""Perform windowed montgomery modular multiplication. + + Applies the trasformation + $$ + \ket{x}\ket{y}\ket{0}\ket{0}\ket{0} \rightarrow \ket{x}\ket{y}\ket{xy2^{-n}}\ket{h}\ket{c} + $$ + + Where: + + - $n$ is the bitsize. + - $x, y$ are in montgomery form + - $h$ is an ancilla register that represents intermediate values. + - $c$ is whether a final modular reduction was applied or not. + + Note: this is an internal implementation class that assumes the target registers (see above) are clean. + + Args: + bitsize: size of the numbers. + window_size: size of the window. + mod: The integer modulus. + uncompute: whether to compute or uncompute. + + Registers: + x: The first integer + y: The second integer + target: product in montgomery form $xy 2^{-n}$ + qrom_indices: concatination of the indicies used to query QROM. + reduced: whether a final modular reduction was applied. + + References: + [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126 133 Cat Qubits](https://arxiv.org/abs/2302.06639) + Appendix C4. + + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + page 8. + """ + bitsize: 'SymbolicInt' + window_size: 'SymbolicInt' + mod: 'SymbolicInt' + + def __attrs_post_init__(self): + if isinstance(self.mod, int): + assert self.mod > 1 and self.mod % 2 == 1 # Must be an odd integer greater than 1. + + if isinstance(self.mod, int) and isinstance(self.bitsize, int): + assert 2 * self.mod - 1 < 2**self.bitsize, f'bitsize={self.bitsize} is too small' + + if isinstance(self.window_size, int) and isinstance(self.bitsize, int): + assert self.window_size <= self.bitsize + + @cached_property + def signature(self) -> 'Signature': + num_windows = ( + self.bitsize + self.window_size - 1 + ) // self.window_size # = ceil(self.bitsize/self.window_size) + return Signature( + [ + Register('x', QMontgomeryUInt(self.bitsize)), + Register('y', QMontgomeryUInt(self.bitsize)), + Register('target', QMontgomeryUInt(self.bitsize)), + Register('qrom_indices', QMontgomeryUInt(num_windows * self.window_size)), + Register('reduced', QBit()), + ] + ) + + @cached_property + def _window(self): + return SingleWindowModMul(self.window_size, self.bitsize, self.mod) + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + x: Soquet, + y: Soquet, + target: Soquet, + qrom_indices: Soquet, + reduced: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.window_size) or is_symbolic(self.bitsize) or is_symbolic(self.mod): + raise DecomposeNotImplementedError(f'symbolic decomposition not supported for {self}') + x_arr = bb.split(x) + x_arr = np.flip(x_arr) + + target_arr = np.concatenate([bb.split(bb.allocate(self.window_size)), bb.split(target)]) + qrom_indices_arr = bb.split(qrom_indices) + + for i in range(0, self.bitsize, self.window_size): + (x_arr[i : i + self.window_size], y, target_arr, qrom_index) = bb.add( + self._window, + x=x_arr[i : i + self.window_size], + y=y, + target=target_arr, + qrom_index=bb.join(qrom_indices_arr[i : i + self.window_size]), + ) + qrom_indices_arr[i : i + self.window_size] = bb.split(qrom_index) + + # Free ancillas and join + bb.free(bb.join(target_arr[: -self.bitsize])) + x_arr = np.flip(x_arr[: self.bitsize]) + x = bb.join(x_arr) + qrom_indices = bb.join(qrom_indices_arr) + + # Modular reduction + target = bb.join(target_arr[-self.bitsize :]) + reduced = bb.add(XGate(), q=reduced) + target, reduced = bb.add(LessThanConstant(self.bitsize, self.mod), x=target, target=reduced) + (reduced,), target = bb.add( + AddK(self.bitsize, self.mod, cvs=(1,), signed=False), ctrls=(reduced,), x=target + ) + + return {'x': x, 'y': y, 'target': target, 'qrom_indices': qrom_indices, 'reduced': reduced} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: + num_windows = (self.bitsize + self.window_size - 1) // self.window_size + return { + AddK(self.bitsize, self.mod, cvs=(1,), signed=False): 1, + LessThanConstant(bitsize=self.bitsize, less_than_val=self.mod): 1, + XGate(): 1, + self._window: num_windows, + } + + +@frozen +class DirtyOutOfPlaceMontgomeryModMul(Bloq): + r"""Perform windowed montgomery modular multiplication. + + Applies the trasformation + $$ + \ket{x}\ket{y}\ket{0}\ket{0}\ket{0} \rightarrow \ket{x}\ket{y}\ket{xy2^{-n}}\ket{h}\ket{c} + $$ + + Where: + + - $n$ is the bitsize. + - $x, y$ are in montgomery form + - $h$ is an ancilla register that represents intermidate values. + - $c$ is whether a final modular reduction was applied or not. + + Args: + bitsize: size of the numbers. + window_size: size of the window. + mod: The integer modulus. + uncompute: whether to compute or uncompute. + + Registers: + x: The first integer + y: The second integer + target: product in montgomery form $xy 2^{-n}$ + qrom_indices: concatination of the indicies used to query QROM. + reduced: whether a final modular reduction was applied. + + References: + [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126 133 Cat Qubits](https://arxiv.org/abs/2302.06639) + Appendix C4. + + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + page 8. + """ + bitsize: 'SymbolicInt' + window_size: 'SymbolicInt' + mod: 'SymbolicInt' + uncompute: bool = False + + def __attrs_post_init__(self): + if isinstance(self.mod, int): + assert self.mod > 1 and self.mod % 2 == 1 # Must be an odd integer greater than 1. + + if isinstance(self.mod, int) and isinstance(self.bitsize, int): + assert 2 * self.mod - 1 < 2**self.bitsize, f'bitsize={self.bitsize} is too small' + + if isinstance(self.window_size, int) and isinstance(self.bitsize, int): + assert self.bitsize % self.window_size == 0 + + @cached_property + def signature(self) -> 'Signature': + num_windows = ( + self.bitsize + self.window_size - 1 + ) // self.window_size # = ceil(self.bitsize/self.window_size) + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('x', QMontgomeryUInt(self.bitsize)), + Register('y', QMontgomeryUInt(self.bitsize)), + Register('target', QMontgomeryUInt(self.bitsize), side=side), + Register( + 'qrom_indices', QMontgomeryUInt(num_windows * self.window_size), side=side + ), + Register('reduced', QBit(), side=side), + ] + ) + + def adjoint(self) -> 'DirtyOutOfPlaceMontgomeryModMul': + return attrs.evolve(self, uncompute=self.uncompute ^ True) + + @cached_property + def _inversion_data(self) -> np.typing.NDArray: + inv_mod = pow(self.mod, 2 ** (self.window_size - 1) - 1, 2**self.window_size) + N = 2**self.window_size + data = (-np.arange(N) * inv_mod) % N + data *= self.mod + return data + + def _classical_action_window( + self, + x: 'ClassicalValT', + y: 'ClassicalValT', + target: 'ClassicalValT', + qrom_indices: 'ClassicalValT', + ): + # This method implements same logic as SingleWindowModMul.on_classical_vals except that it works on integers rather than bit arrays. + # Calls to this function are equivalent to calls to self._window.call_classically given the appropiate conversion int <-> bitarray. + if is_symbolic(self.bitsize) or is_symbolic(self.window_size) or is_symbolic(self.mod): + raise ValueError(f'classical action is not supported for {self}') + for i in range(self.window_size): + if (x >> i) & 1: + target += y << i + m = target & (2**self.window_size - 1) + Tm = self._inversion_data[m] + target += Tm + target >>= self.window_size + qrom_indices = (qrom_indices << self.window_size) | m + return target, qrom_indices + + def on_classical_vals( + self, + x: 'ClassicalValT', + y: 'ClassicalValT', + target: Optional['ClassicalValT'] = None, + qrom_indices: Optional['ClassicalValT'] = None, + reduced: Optional['ClassicalValT'] = None, + ) -> Dict[str, ClassicalValT]: + if is_symbolic(self.bitsize) or is_symbolic(self.window_size) or is_symbolic(self.mod): + raise ValueError(f'classical action is not supported for {self}') + if self.uncompute: + assert ( + target is not None and target == (x * y * pow(2, self.bitsize, self.mod)) % self.mod + ) + assert qrom_indices is not None + assert reduced is not None + return {'x': x, 'y': y} + assert target is None + assert qrom_indices is None + assert reduced is None + + if not (0 < x < self.mod and 0 < y < self.mod): + return {'x': x, 'y': y, 'target': 0, 'qrom_indices': 0, 'reduced': 0} + + target = 0 + qrom_indices = 0 + reduced = 0 + for i in range(0, self.bitsize, self.window_size): + target, qrom_indices = self._classical_action_window(x >> i, y, target, qrom_indices) + + if target >= self.mod: + target -= self.mod + reduced = 1 + + montgomery_prod = (x * y * pow(2, self.bitsize * (self.mod - 2), self.mod)) % self.mod + assert target == montgomery_prod + return {'x': x, 'y': y, 'target': target, 'qrom_indices': qrom_indices, 'reduced': reduced} + + @cached_property + def _mod_mul_impl(self) -> Bloq: + b: Bloq = _DirtyOutOfPlaceMontgomeryModMulImpl(self.bitsize, self.window_size, self.mod) + if self.uncompute: + b = b.adjoint() + return b + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + x: Soquet, + y: Soquet, + target: Optional[Soquet] = None, + qrom_indices: Optional[Soquet] = None, + reduced: Optional[Soquet] = None, + ) -> Dict[str, 'SoquetT']: + if self.uncompute: + assert target is not None + assert qrom_indices is not None + assert reduced is not None + + x, y, target, qrom_indices, reduced = bb.add_from( # type: ignore + self._mod_mul_impl, + x=x, + y=y, + target=target, + qrom_indices=qrom_indices, + reduced=reduced, + ) + + bb.free(reduced) + bb.free(qrom_indices) + bb.free(target) + return {'x': x, 'y': y} + + target = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) + num_windows = (self.bitsize + self.window_size - 1) // self.window_size + qrom_indices = bb.allocate( + num_windows * self.window_size, QMontgomeryUInt(num_windows * self.window_size) + ) + reduced = bb.allocate(1) + + x, y, target, qrom_indices, reduced = bb.add_from( + self._mod_mul_impl, x=x, y=y, target=target, qrom_indices=qrom_indices, reduced=reduced + ) + return {'x': x, 'y': y, 'target': target, 'qrom_indices': qrom_indices, 'reduced': reduced} + + def build_call_graph( + self, ssa: 'SympySymbolAllocator' + ) -> Union[Set['BloqCountT'], BloqCountDictT]: + return self._mod_mul_impl.build_call_graph(ssa) + + +@bloq_example(generalizer=[ignore_alloc_free, ignore_split_join]) +def _dirtyoutofplacemontgomerymodmul_small() -> DirtyOutOfPlaceMontgomeryModMul: + dirtyoutofplacemontgomerymodmul_small = DirtyOutOfPlaceMontgomeryModMul(6, 2, 7) + return dirtyoutofplacemontgomerymodmul_small + + +@bloq_example(generalizer=[ignore_alloc_free, ignore_split_join]) +def _dirtyoutofplacemontgomerymodmul_medium() -> DirtyOutOfPlaceMontgomeryModMul: + dirtyoutofplacemontgomerymodmul_medium = DirtyOutOfPlaceMontgomeryModMul( + bitsize=16, window_size=4, mod=2**15 - 1 + ) + return dirtyoutofplacemontgomerymodmul_medium + + +_DIRTY_OUT_OF_PLACE_MONTGOMERY_MOD_MUL_DOC = BloqDocSpec( + bloq_cls=DirtyOutOfPlaceMontgomeryModMul, + examples=(_dirtyoutofplacemontgomerymodmul_small, _dirtyoutofplacemontgomerymodmul_medium), +) diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication_test.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication_test.py index c5e16682b..3be8b3793 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication_test.py @@ -13,6 +13,7 @@ # limitations under the License. import attrs +import numpy as np import pytest import sympy @@ -20,12 +21,16 @@ from qualtran import QMontgomeryUInt, QUInt from qualtran.bloqs.mod_arithmetic.mod_addition import CtrlScaleModAdd from qualtran.bloqs.mod_arithmetic.mod_multiplication import ( + _dirtyoutofplacemontgomerymodmul_medium, + _dirtyoutofplacemontgomerymodmul_small, _moddbl_large, _moddbl_small, _modmul, _modmul_symb, CModMulK, + DirtyOutOfPlaceMontgomeryModMul, ModDbl, + SingleWindowModMul, ) from qualtran.resource_counting import get_cost_value, QECGatesCost, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join @@ -132,6 +137,138 @@ def test_examples_modmul(bloq_autotester): bloq_autotester(_modmul) +def test_examples_dirtyoutofplacemontgomerymodmul_small(bloq_autotester): + bloq_autotester(_dirtyoutofplacemontgomerymodmul_small) + + +def test_examples_dirtyoutofplacemontgomerymodmul_medium(bloq_autotester): + bloq_autotester(_dirtyoutofplacemontgomerymodmul_medium) + + @pytest.mark.notebook def test_notebook(): qlt_testing.execute_notebook('mod_multiplication') + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize('uncompute', [True, False]) +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(6, 10) for m in range(1, n + 1) if n % m == 0] +) +def test_dirtyoutofplacemontgomerymodmul_decomposition(n, m, p, uncompute): + b = DirtyOutOfPlaceMontgomeryModMul(n, m, p, uncompute) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize('uncompute', [True, False]) +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(6, 10) for m in range(1, n + 1) if n % m == 0] +) +def test_dirtyoutofplacemontgomerymodmul_bloq_counts(n, m, p, uncompute): + b = DirtyOutOfPlaceMontgomeryModMul(n, m, p, uncompute) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +@pytest.mark.parametrize('uncompute', [True, False]) +def test_dirtyoutofplacemontgomerymodmul_symbolic_cost(uncompute): + n, m, p = sympy.symbols('n m p', integer=True) + + # In Litinski 2023 https://arxiv.org/abs/2306.08585 a window size of 4 is used. + # The cost function generally has floor/ceil division that disappear for bitsize=0 mod 4. + # This is why instead of using bitsize=n directly, we use bitsize=4*m=n. + b = DirtyOutOfPlaceMontgomeryModMul(4 * m, 4, p, uncompute) + cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() + assert cost['n_t'] == 0 + + # Litinski 2023 https://arxiv.org/abs/2306.08585 + # Figure/Table 8. Lists modular multiplication as 2.25n^2+9n toffoli. + # The following formula is 2.25n^2+8.25n-1 written with rationals because sympy comparison fails with floats. + assert isinstance(cost['n_ccz'], sympy.Expr) + assert ( + cost['n_ccz'].subs(m, n / 4).expand() + == sympy.Rational(9, 4) * n**2 + sympy.Rational(33, 4) * n - 1 + ) + + +@pytest.mark.parametrize('p', (3, 5, 7)) +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(5, 8) for m in range(1, n + 1) if n % m == 0] +) +def test_dirtyoutofplacemontgomerymodmul_classical_action(n, m, p): + b = DirtyOutOfPlaceMontgomeryModMul(n, m, p, False) + qlt_testing.assert_consistent_classical_action(b, x=range(1, p), y=range(1, p)) + + +@pytest.mark.parametrize('p', (3, 5, 7)) +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(5, 8) for m in range(1, n + 1) if n % m == 0] +) +def test_singlewindowmodmul_decomposition(n, m, p): + b = SingleWindowModMul(window_size=m, bitsize=n, mod=p) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize('p', (3, 5, 7)) +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(5, 8) for m in range(1, n + 1) if n % m == 0] +) +def test_singlewindowmodmul_bloq_counts(n, m, p): + b = SingleWindowModMul(window_size=m, bitsize=n, mod=p) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +@pytest.mark.slow +@pytest.mark.parametrize('p', (3, 5, 7)) +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(5, 8) for m in range(1, n + 1) if n % m == 0] +) +def test_singlewindowmodmul_classical_action(n, m, p): + b = SingleWindowModMul(window_size=m, bitsize=n, mod=p) + cb = b.decompose_bloq() + for x in range(1, min(p, 2**m)): + for y in range(1, p): + for target in range(2**n): + bloq_res = b.call_classically( + x=np.array(QUInt(m).to_bits(x)), + y=y, + target=np.array(QUInt(n + m).to_bits(target)), + qrom_index=0, + ) + decomposed_res = cb.call_classically( + x=np.array(QUInt(m).to_bits(x)), + y=y, + target=np.array(QUInt(n + m).to_bits(target)), + qrom_index=0, + ) + np.testing.assert_equal(bloq_res[0], decomposed_res[0]) # x + assert bloq_res[1] == decomposed_res[1] # y + np.testing.assert_equal(bloq_res[2], decomposed_res[2]) # target + assert bloq_res[3] == decomposed_res[3] # qrom_index + + +def test_singlewindowmodmul_classical_action_fast(): + n = 4 + m = 2 + p = 5 + b = SingleWindowModMul(window_size=m, bitsize=n, mod=p) + cb = b.decompose_bloq() + for x in range(1, min(p, 2**m)): + for y in range(1, p): + for target in range(2**n): + bloq_res = b.call_classically( + x=np.array(QUInt(m).to_bits(x)), + y=y, + target=np.array(QUInt(n + m).to_bits(target)), + qrom_index=0, + ) + decomposed_res = cb.call_classically( + x=np.array(QUInt(m).to_bits(x)), + y=y, + target=np.array(QUInt(n + m).to_bits(target)), + qrom_index=0, + ) + np.testing.assert_equal(bloq_res[0], decomposed_res[0]) # x + assert bloq_res[1] == decomposed_res[1] # y + np.testing.assert_equal(bloq_res[2], decomposed_res[2]) # target + assert bloq_res[3] == decomposed_res[3] # qrom_index diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index ee8025a4e..45518ad08 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -337,6 +337,8 @@ "qualtran.bloqs.mod_arithmetic.mod_subtraction.CModNeg": qualtran.bloqs.mod_arithmetic.mod_subtraction.CModNeg, "qualtran.bloqs.mod_arithmetic.mod_multiplication.ModDbl": qualtran.bloqs.mod_arithmetic.mod_multiplication.ModDbl, "qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK": qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK, + "qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul, + "qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul, "qualtran.bloqs.factoring.mod_exp.ModExp": qualtran.bloqs.factoring.mod_exp.ModExp, "qualtran.bloqs.for_testing.atom.TestAtom": qualtran.bloqs.for_testing.atom.TestAtom, "qualtran.bloqs.for_testing.atom.TestGWRAtom": qualtran.bloqs.for_testing.atom.TestGWRAtom,