From 6e4b7ad72d624e6917600c55d2442f0ca0ad8365 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 16 Oct 2024 22:32:43 +0200 Subject: [PATCH 1/8] `PlusEqualProduct` version of `GFMultiplication` for GF($2^m$) (#1457) * PlusEqualProduct version of multiplication for GF(2^m) * Regenerate notebooks and update docstring * Fix mypy --- .../gf_arithmetic/gf2_multiplication.ipynb | 9 +-- .../bloqs/gf_arithmetic/gf2_multiplication.py | 57 +++++++++++++------ .../gf_arithmetic/gf2_multiplication_test.py | 49 +++++++++++++++- 3 files changed, 92 insertions(+), 23 deletions(-) diff --git a/qualtran/bloqs/gf_arithmetic/gf2_multiplication.ipynb b/qualtran/bloqs/gf_arithmetic/gf2_multiplication.ipynb index 03bb4ddca..b001a27dd 100644 --- a/qualtran/bloqs/gf_arithmetic/gf2_multiplication.ipynb +++ b/qualtran/bloqs/gf_arithmetic/gf2_multiplication.ipynb @@ -56,12 +56,13 @@ "gates.\n", "\n", "#### Parameters\n", - " - `bitsize`: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of qubits in each of the two input registers $a$ and $b$ that should be multiplied. \n", + " - `bitsize`: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of qubits in each of the two input registers $a$ and $b$ that should be multiplied.\n", + " - `plus_equal_prod`: If True, implements the `PlusEqualProduct` version that applies the map $|x\\rangle |y\\rangle |z\\rangle \\rightarrow |x\\rangle |y\\rangle |x + z\\rangle$. \n", "\n", "#### Registers\n", " - `x`: Input THRU register of size $m$ that stores elements from $GF(2^m)$.\n", " - `y`: Input THRU register of size $m$ that stores elements from $GF(2^m)$.\n", - " - `result`: Output RIGHT register of size $m$ that stores the product $x * y$ in $GF(2^m)$. \n", + " - `result`: Register of size $m$ that stores the product $x * y$ in $GF(2^m)$. If plus_equal_prod is True - result is a THRU register and stores $result + x * y$. If plus_equal_prod is False - result is a RIGHT register and stores $x * y$. \n", "\n", "#### References\n", " - [On the Design and Optimization of a Quantum Polynomial-Time Attack on Elliptic Curve Cryptography](https://arxiv.org/abs/0710.1093). \n", @@ -99,7 +100,7 @@ }, "outputs": [], "source": [ - "gf16_multiplication = GF2Multiplication(4)" + "gf16_multiplication = GF2Multiplication(4, plus_equal_prod=True)" ] }, { @@ -114,7 +115,7 @@ "import sympy\n", "\n", "m = sympy.Symbol('m')\n", - "gf2_multiplication_symbolic = GF2Multiplication(m)" + "gf2_multiplication_symbolic = GF2Multiplication(m, plus_equal_prod=False)" ] }, { diff --git a/qualtran/bloqs/gf_arithmetic/gf2_multiplication.py b/qualtran/bloqs/gf_arithmetic/gf2_multiplication.py index 4f91ecaa0..48e3a7f1e 100644 --- a/qualtran/bloqs/gf_arithmetic/gf2_multiplication.py +++ b/qualtran/bloqs/gf_arithmetic/gf2_multiplication.py @@ -51,18 +51,19 @@ class SynthesizeLRCircuit(Bloq): """Synthesize linear reversible circuit using CNOT gates. Args: - matrix: An n x m matrix describing the linear transformation. + matrix: An n x n matrix describing the linear transformation. References: [Efficient Synthesis of Linear Reversible Circuits](https://arxiv.org/abs/quant-ph/0302002) """ matrix: Union[Shaped, np.ndarray] = attrs.field(eq=_data_or_shape_to_tuple) + is_adjoint: bool = False def __attrs_post_init__(self): assert len(self.matrix.shape) == 2 n, m = self.matrix.shape - assert is_symbolic(n, m) or n >= m + assert is_symbolic(n, m) or n == m @cached_property def signature(self) -> 'Signature': @@ -72,10 +73,13 @@ def signature(self) -> 'Signature': def on_classical_vals(self, *, q: 'ClassicalValT') -> Dict[str, 'ClassicalValT']: matrix = self.matrix assert isinstance(matrix, np.ndarray) + if self.is_adjoint: + matrix = np.linalg.inv(matrix) + assert np.allclose(matrix, matrix.astype(int)) + matrix = matrix.astype(int) _, m = matrix.shape assert isinstance(q, np.ndarray) - q_in = q[:m] - return {'q': (matrix @ q_in) % 2} + return {'q': (matrix @ q) % 2} def build_call_graph( self, ssa: 'SympySymbolAllocator' @@ -83,6 +87,9 @@ def build_call_graph( n = self.matrix.shape[0] return {CNOT(): ceil(n**2 / log2(n))} + def adjoint(self) -> 'SynthesizeLRCircuit': + return attrs.evolve(self, is_adjoint=not self.is_adjoint) + @attrs.frozen class GF2Multiplication(Bloq): @@ -108,11 +115,15 @@ class GF2Multiplication(Bloq): Args: bitsize: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of qubits in each of the two input registers $a$ and $b$ that should be multiplied. + plus_equal_prod: If True, implements the `PlusEqualProduct` version that applies the + map $|x\rangle |y\rangle |z\rangle \rightarrow |x\rangle |y\rangle |x + z\rangle$. Registers: x: Input THRU register of size $m$ that stores elements from $GF(2^m)$. y: Input THRU register of size $m$ that stores elements from $GF(2^m)$. - result: Output RIGHT register of size $m$ that stores the product $x * y$ in $GF(2^m)$. + result: Register of size $m$ that stores the product $x * y$ in $GF(2^m)$. + If plus_equal_prod is True - result is a THRU register and stores $result + x * y$. + If plus_equal_prod is False - result is a RIGHT register and stores $x * y$. References: @@ -124,14 +135,16 @@ class GF2Multiplication(Bloq): """ bitsize: SymbolicInt + plus_equal_prod: bool = False @cached_property def signature(self) -> 'Signature': + result_side = Side.THRU if self.plus_equal_prod else Side.RIGHT return Signature( [ Register('x', dtype=self.qgf), Register('y', dtype=self.qgf), - Register('result', dtype=self.qgf, side=Side.RIGHT), + Register('result', dtype=self.qgf, side=result_side), ] ) @@ -143,7 +156,7 @@ def qgf(self) -> QGF: def reduction_matrix_q(self) -> np.ndarray: m = int(self.bitsize) f = self.qgf.gf_type.irreducible_poly - M = np.zeros((m - 1, m)) + M = np.zeros((m, m)) alpha = [1] + [0] * m for i in range(m - 1): # x ** (m + i) % f @@ -151,6 +164,7 @@ def reduction_matrix_q(self) -> np.ndarray: coeffs = coeffs + [0] * (m - len(coeffs)) M[i] = coeffs alpha += [0] + M[m - 1][m - 1] = 1 return np.transpose(M) @cached_property @@ -162,14 +176,18 @@ def synthesize_reduction_matrix_q(self) -> SynthesizeLRCircuit: else SynthesizeLRCircuit(self.reduction_matrix_q) ) - def build_composite_bloq( - self, bb: 'BloqBuilder', *, x: 'Soquet', y: 'Soquet' - ) -> Dict[str, 'Soquet']: + def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'Soquet') -> Dict[str, 'Soquet']: if is_symbolic(self.bitsize): raise DecomposeTypeError(f"Cannot decompose symbolic {self}") - result = bb.allocate(dtype=self.qgf) + x, y = soqs['x'], soqs['y'] + result = soqs['result'] if self.plus_equal_prod else bb.allocate(dtype=self.qgf) x, y, result = bb.split(x)[::-1], bb.split(y)[::-1], bb.split(result)[::-1] m = int(self.bitsize) + + # Step-0: PlusEqualProduct special case. + if self.plus_equal_prod: + result = bb.add(self.synthesize_reduction_matrix_q.adjoint(), q=result) + # Step-1: Multiply Monomials. for i in range(m): for j in range(i + 1, m): @@ -199,16 +217,21 @@ def build_call_graph( self, ssa: 'SympySymbolAllocator' ) -> Union['BloqCountDictT', Set['BloqCountT']]: m = self.bitsize - return {Toffoli(): m**2, self.synthesize_reduction_matrix_q: 1} + plus_equal_prod = ( + {self.synthesize_reduction_matrix_q.adjoint(): 1} if self.plus_equal_prod else {} + ) + return {Toffoli(): m**2, self.synthesize_reduction_matrix_q: 1} | plus_equal_prod - def on_classical_vals(self, *, x, y) -> Dict[str, 'ClassicalValT']: - assert isinstance(x, self.qgf.gf_type) and isinstance(y, self.qgf.gf_type) - return {'x': x, 'y': y, 'result': x * y} + def on_classical_vals(self, **vals) -> Dict[str, 'ClassicalValT']: + assert all(isinstance(val, self.qgf.gf_type) for val in vals.values()) + x, y = vals['x'], vals['y'] + result = vals['result'] if self.plus_equal_prod else self.qgf.gf_type(0) + return {'x': x, 'y': y, 'result': result + x * y} @bloq_example def _gf16_multiplication() -> GF2Multiplication: - gf16_multiplication = GF2Multiplication(4) + gf16_multiplication = GF2Multiplication(4, plus_equal_prod=True) return gf16_multiplication @@ -217,7 +240,7 @@ def _gf2_multiplication_symbolic() -> GF2Multiplication: import sympy m = sympy.Symbol('m') - gf2_multiplication_symbolic = GF2Multiplication(m) + gf2_multiplication_symbolic = GF2Multiplication(m, plus_equal_prod=False) return gf2_multiplication_symbolic diff --git a/qualtran/bloqs/gf_arithmetic/gf2_multiplication_test.py b/qualtran/bloqs/gf_arithmetic/gf2_multiplication_test.py index 94eaba106..a37b0cddf 100644 --- a/qualtran/bloqs/gf_arithmetic/gf2_multiplication_test.py +++ b/qualtran/bloqs/gf_arithmetic/gf2_multiplication_test.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest from galois import GF +from qualtran import QGF from qualtran.bloqs.gf_arithmetic.gf2_multiplication import ( _gf2_multiplication_symbolic, _gf16_multiplication, GF2Multiplication, + SynthesizeLRCircuit, ) from qualtran.testing import assert_consistent_classical_action @@ -31,9 +34,51 @@ def test_gf2_multiplication_symbolic(bloq_autotester): bloq_autotester(_gf2_multiplication_symbolic) +def test_synthesize_lr_circuit(): + m = 2 + matrix = GF2Multiplication(m).reduction_matrix_q + bloq = SynthesizeLRCircuit(matrix) + bloq_adj = bloq.adjoint() + QGFM, GFM = QGF(2, m), GF(2**m) + for i in GFM.elements: + bloq_out = bloq.call_classically(q=np.array(QGFM.to_bits(i)))[0] + bloq_adj_out = bloq_adj.call_classically(q=bloq_out)[0] + assert isinstance(bloq_adj_out, np.ndarray) + assert i == QGFM.from_bits([*bloq_adj_out]) + + +@pytest.mark.slow +@pytest.mark.parametrize('m', [3, 4, 5]) +def test_synthesize_lr_circuit_slow(m): + matrix = GF2Multiplication(m).reduction_matrix_q + bloq = SynthesizeLRCircuit(matrix) + bloq_adj = bloq.adjoint() + QGFM, GFM = QGF(2, m), GF(2**m) + for i in GFM.elements: + bloq_out = bloq.call_classically(q=np.array(QGFM.to_bits(i)))[0] + bloq_adj_out = bloq_adj.call_classically(q=bloq_out)[0] + assert isinstance(bloq_adj_out, np.ndarray) + assert i == QGFM.from_bits([*bloq_adj_out]) + + +def test_gf2_plus_equal_prod_classical_sim_quick(): + m = 2 + bloq = GF2Multiplication(m, plus_equal_prod=True) + GFM = GF(2**m) + assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements, result=GFM.elements) + + +@pytest.mark.slow +def test_gf2_plus_equal_prod_classical_sim(): + m = 3 + bloq = GF2Multiplication(m, plus_equal_prod=True) + GFM = GF(2**m) + assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements, result=GFM.elements) + + def test_gf2_multiplication_classical_sim_quick(): m = 2 - bloq = GF2Multiplication(m) + bloq = GF2Multiplication(m, plus_equal_prod=False) GFM = GF(2**m) assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements) @@ -41,6 +86,6 @@ def test_gf2_multiplication_classical_sim_quick(): @pytest.mark.slow @pytest.mark.parametrize('m', [3, 4, 5]) def test_gf2_multiplication_classical_sim(m): - bloq = GF2Multiplication(m) + bloq = GF2Multiplication(m, plus_equal_prod=False) GFM = GF(2**m) assert_consistent_classical_action(bloq, x=GFM.elements, y=GFM.elements) From 0ef9f1deac8155d10f12134760bb9e519bc771b4 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 16 Oct 2024 22:48:54 +0200 Subject: [PATCH 2/8] Optimized implementation of `GF2Inverse` (#1459) * Optimized implementation of GF2Inverse * Address nits * Regenerate notebook --- .../bloqs/gf_arithmetic/gf2_inverse.ipynb | 11 +- qualtran/bloqs/gf_arithmetic/gf2_inverse.py | 127 +++++++++++++----- .../bloqs/gf_arithmetic/gf2_inverse_test.py | 16 ++- 3 files changed, 116 insertions(+), 38 deletions(-) diff --git a/qualtran/bloqs/gf_arithmetic/gf2_inverse.ipynb b/qualtran/bloqs/gf_arithmetic/gf2_inverse.ipynb index 6a5a4b73b..dec031115 100644 --- a/qualtran/bloqs/gf_arithmetic/gf2_inverse.ipynb +++ b/qualtran/bloqs/gf_arithmetic/gf2_inverse.ipynb @@ -58,7 +58,10 @@ " a^{-1} = a^{2^m - 2}\n", "$$\n", "\n", - "Thus, the inverse can be obtained via $m - 1$ squaring and multiplication operations.\n", + "The exponential $a^{2^m - 2}$ is computed using $\\mathcal{O}(m)$ squaring and\n", + "$\\mathcal{O}(\\log_2(m))$ multiplications via Itoh-Tsujii inversion. The algorithm is described on\n", + "page 4 and 5 of Ref[1] and resembles binary exponentiation. The inverse is computed as $B_{n-1}^2$,\n", + "where $B_1 = x$ and $B_{i+j} = B_i B_j^{2^i}$.\n", "\n", "#### Parameters\n", " - `bitsize`: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of qubits in the input register whose inverse should be calculated. \n", @@ -66,7 +69,11 @@ "#### Registers\n", " - `x`: Input THRU register of size $m$ that stores elements from $GF(2^m)$.\n", " - `result`: Output RIGHT register of size $m$ that stores $x^{-1}$ from $GF(2^m)$.\n", - " - `junk`: Output RIGHT register of size $m$ and shape ($m - 2$) that stores results from intermediate multiplications.\n" + " - `junk`: Output RIGHT register of size $m$ and shape ($m - 2$) that stores results from intermediate multiplications. \n", + "\n", + "#### References\n", + " - [Efficient quantum circuits for binary elliptic curve arithmetic: reducing T -gate complexity](https://arxiv.org/abs/1209.6348). Section 2.3\n", + " - [Structure of parallel multipliers for a class of fields GF(2^m)](https://doi.org/10.1016/0890-5401(89)90045-X)\n" ] }, { diff --git a/qualtran/bloqs/gf_arithmetic/gf2_inverse.py b/qualtran/bloqs/gf_arithmetic/gf2_inverse.py index 20d0cf54d..eb0e80f4f 100644 --- a/qualtran/bloqs/gf_arithmetic/gf2_inverse.py +++ b/qualtran/bloqs/gf_arithmetic/gf2_inverse.py @@ -30,10 +30,11 @@ from qualtran.bloqs.gf_arithmetic.gf2_addition import GF2Addition from qualtran.bloqs.gf_arithmetic.gf2_multiplication import GF2Multiplication from qualtran.bloqs.gf_arithmetic.gf2_square import GF2Square -from qualtran.symbolics import is_symbolic, SymbolicInt +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join +from qualtran.symbolics import bit_length, ceil, is_symbolic, log2, SymbolicInt if TYPE_CHECKING: - from qualtran import BloqBuilder, Soquet + from qualtran import BloqBuilder, Soquet, SoquetT from qualtran.resource_counting import BloqCountDictT, BloqCountT, CostKey, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT @@ -62,7 +63,10 @@ class GF2Inverse(Bloq): a^{-1} = a^{2^m - 2} $$ - Thus, the inverse can be obtained via $m - 1$ squaring and multiplication operations. + The exponential $a^{2^m - 2}$ is computed using $\mathcal{O}(m)$ squaring and + $\mathcal{O}(\log_2(m))$ multiplications via Itoh-Tsujii inversion. The algorithm is described on + page 4 and 5 of Ref[1] and resembles binary exponentiation. The inverse is computed as $B_{n-1}^2$, + where $B_1 = x$ and $B_{i+j} = B_i B_j^{2^i}$. Args: bitsize: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of @@ -73,6 +77,12 @@ class GF2Inverse(Bloq): result: Output RIGHT register of size $m$ that stores $x^{-1}$ from $GF(2^m)$. junk: Output RIGHT register of size $m$ and shape ($m - 2$) that stores results from intermediate multiplications. + + References: + [Efficient quantum circuits for binary elliptic curve arithmetic: reducing T -gate complexity](https://arxiv.org/abs/1209.6348). + Section 2.3 + + [Structure of parallel multipliers for a class of fields GF(2^m)](https://doi.org/10.1016/0890-5401(89)90045-X) """ bitsize: SymbolicInt @@ -80,8 +90,8 @@ class GF2Inverse(Bloq): @cached_property def signature(self) -> 'Signature': junk_reg = ( - [Register('junk', dtype=self.qgf, shape=(self.bitsize - 2,), side=Side.RIGHT)] - if is_symbolic(self.bitsize) or self.bitsize > 2 + [Register('junk', dtype=self.qgf, shape=(self.n_junk_regs,), side=Side.RIGHT)] + if is_symbolic(self.bitsize) or self.bitsize > 1 else [] ) return Signature( @@ -96,60 +106,111 @@ def signature(self) -> 'Signature': def qgf(self) -> QGF: return QGF(characteristic=2, degree=self.bitsize) + @cached_property + def n_junk_regs(self) -> SymbolicInt: + return 2 * bit_length(self.bitsize - 1) + self.bitsize_hamming_weight + + @cached_property + def bitsize_hamming_weight(self) -> SymbolicInt: + """Hamming weight of self.bitsize - 1""" + return ( + bit_length(self.bitsize - 1) + if is_symbolic(self.bitsize) + else int(self.bitsize - 1).bit_count() + ) + def my_static_costs(self, cost_key: 'CostKey'): + from qualtran._infra.gate_with_registers import total_bits from qualtran.resource_counting import QubitCount if isinstance(cost_key, QubitCount): - return self.signature.n_qubits() + return total_bits(self.signature.rights()) return NotImplemented - def build_composite_bloq(self, bb: 'BloqBuilder', *, x: 'Soquet') -> Dict[str, 'Soquet']: + def build_composite_bloq(self, bb: 'BloqBuilder', *, x: 'Soquet') -> Dict[str, 'SoquetT']: if is_symbolic(self.bitsize): raise DecomposeTypeError(f"Cannot decompose symbolic {self}") + result = bb.allocate(dtype=self.qgf) if self.bitsize == 1: x, result = bb.add(GF2Addition(self.bitsize), x=x, y=result) return {'x': x, 'result': result} - x = bb.add(GF2Square(self.bitsize), x=x) - x, result = bb.add(GF2Addition(self.bitsize), x=x, y=result) - junk = [] - for i in range(2, self.bitsize): - x = bb.add(GF2Square(self.bitsize), x=x) - x, result, new_result = bb.add(GF2Multiplication(self.bitsize), x=x, y=result) - junk.append(result) - result = new_result - x = bb.add(GF2Square(self.bitsize), x=x) - return {'x': x, 'result': result} | ({'junk': np.array(junk)} if junk else {}) + beta = bb.allocate(dtype=self.qgf) + x, beta = bb.add(GF2Addition(self.bitsize), x=x, y=beta) + is_first = True + bitsize_minus_one = int(self.bitsize - 1) + for i in range(bitsize_minus_one.bit_length()): + if (1 << i) & bitsize_minus_one: + if is_first: + beta, result = bb.add(GF2Addition(self.bitsize), x=beta, y=result) + is_first = False + else: + for j in range(2**i): + result = bb.add(GF2Square(self.bitsize), x=result) + beta, result, new_result = bb.add( + GF2Multiplication(self.bitsize), x=beta, y=result + ) + junk.append(result) + result = new_result + beta_squared = bb.allocate(dtype=self.qgf) + beta, beta_squared = bb.add(GF2Addition(self.bitsize), x=beta, y=beta_squared) + for j in range(2**i): + beta_squared = bb.add(GF2Square(self.bitsize), x=beta_squared) + beta, beta_squared, beta_new = bb.add( + GF2Multiplication(self.bitsize), x=beta, y=beta_squared + ) + junk.extend([beta, beta_squared]) + beta = beta_new + junk.append(beta) + result = bb.add(GF2Square(self.bitsize), x=result) + assert len(junk) == self.n_junk_regs, f'{len(junk)=}, {self.n_junk_regs=}' + return {'x': x, 'result': result, 'junk': np.array(junk)} def build_call_graph( self, ssa: 'SympySymbolAllocator' ) -> Union['BloqCountDictT', Set['BloqCountT']]: - if is_symbolic(self.bitsize) or self.bitsize > 2: - return { - GF2Addition(self.bitsize): 1, - GF2Square(self.bitsize): self.bitsize - 1, - GF2Multiplication(self.bitsize): self.bitsize - 2, - } - return {GF2Addition(self.bitsize): 1} | ( - {GF2Square(self.bitsize): 1} if self.bitsize == 2 else {} - ) + if not is_symbolic(self.bitsize) and self.bitsize == 1: + return {GF2Addition(self.bitsize): 1} + square_count = self.bitsize + 2 ** ceil(log2(self.bitsize)) - 1 + if not is_symbolic(self.bitsize): + n = self.bitsize - 1 + square_count -= n & (-n) + return { + GF2Addition(self.bitsize): 2 + ceil(log2(self.bitsize)), + GF2Square(self.bitsize): square_count, + GF2Multiplication(self.bitsize): ceil(log2(self.bitsize)) + + self.bitsize_hamming_weight + - 1, + } def on_classical_vals(self, *, x) -> Dict[str, 'ClassicalValT']: assert isinstance(x, self.qgf.gf_type) - x_temp = x**2 - result = x_temp junk = [] - for i in range(2, int(self.bitsize)): - junk.append(result) - x_temp = x_temp * x_temp - result = result * x_temp + bitsize_minus_one = int(self.bitsize - 1) + beta = x + result = self.qgf.gf_type(0) + is_first = True + for i in range(bitsize_minus_one.bit_length()): + if (1 << i) & bitsize_minus_one: + if is_first: + is_first = False + result = beta + else: + for j in range(2**i): + result = result**2 + junk.append(result) + result = result * beta + beta_squared = beta ** (2 ** (2**i)) + junk.extend([beta, beta_squared]) + beta = beta * beta_squared + junk.append(beta) return {'x': x, 'result': x ** (-1), 'junk': np.array(junk)} -@bloq_example +@bloq_example(generalizer=[ignore_split_join, ignore_alloc_free]) def _gf16_inverse() -> GF2Inverse: gf16_inverse = GF2Inverse(4) return gf16_inverse diff --git a/qualtran/bloqs/gf_arithmetic/gf2_inverse_test.py b/qualtran/bloqs/gf_arithmetic/gf2_inverse_test.py index 3b3371283..91b793c67 100644 --- a/qualtran/bloqs/gf_arithmetic/gf2_inverse_test.py +++ b/qualtran/bloqs/gf_arithmetic/gf2_inverse_test.py @@ -22,7 +22,9 @@ GF2Inverse, ) from qualtran.resource_counting import get_cost_value, QECGatesCost, QubitCount -from qualtran.testing import assert_consistent_classical_action +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join +from qualtran.symbolics import ceil, log2 +from qualtran.testing import assert_consistent_classical_action, assert_equivalent_bloq_counts def test_gf16_inverse(bloq_autotester): @@ -36,8 +38,10 @@ def test_gf2_inverse_symbolic(bloq_autotester): def test_gf2_inverse_symbolic_toffoli_complexity(): bloq = _gf2_inverse_symbolic.make() m = bloq.bitsize - assert get_cost_value(bloq, QECGatesCost()).total_toffoli_only() - m**2 * (m - 2) == 0 - assert sympy.simplify(get_cost_value(bloq, QubitCount()) - m**2) == 0 + expected_expr = m**2 * (2 * ceil(log2(m)) - 1) + assert get_cost_value(bloq, QECGatesCost()).total_toffoli_only() - expected_expr == 0 + expected_expr = m * (3 * ceil(log2(m)) + 2) + assert sympy.simplify(get_cost_value(bloq, QubitCount()) - expected_expr) == 0 def test_gf2_inverse_classical_sim_quick(): @@ -53,3 +57,9 @@ def test_gf2_inverse_classical_sim(m): bloq = GF2Inverse(m) GFM = GF(2**m) assert_consistent_classical_action(bloq, x=GFM.elements[1:]) + + +@pytest.mark.parametrize('m', [*range(1, 12)]) +def test_gf2_equivalent_bloq_counts(m): + bloq = GF2Inverse(m) + assert_equivalent_bloq_counts(bloq, generalizer=[ignore_split_join, ignore_alloc_free]) From 3210c3063fe08764d7a4be8637f25f4504aef649 Mon Sep 17 00:00:00 2001 From: Steve Habegger <95657217+shab5@users.noreply.github.com> Date: Thu, 17 Oct 2024 11:16:43 -0700 Subject: [PATCH 3/8] [qwa] Export docs for bloqs without examples (#1473) * Export docs for bloqs without examples * Clear notebook output --- dev_tools/ui-export.ipynb | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/dev_tools/ui-export.ipynb b/dev_tools/ui-export.ipynb index c1d9d416f..353b6e250 100644 --- a/dev_tools/ui-export.ipynb +++ b/dev_tools/ui-export.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -45,12 +45,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import attrs\n", "import hashlib\n", + "import json\n", "\n", "from qualtran import CompositeBloq\n", "from qualtran.bloqs.rotations.programmable_rotation_gate_array import ProgrammableRotationGateArray\n", @@ -108,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -132,13 +133,6 @@ " for child_bloq, _ in call_graph.succ[bloq].items():\n", " write_example(child_bloq)\n", "\n", - " Path(f'ui_export/{bloq.__class__.__name__}').mkdir(parents=True, exist_ok=True)\n", - "\n", - " doc_name = f'ui_export/{bloq.__class__.__name__}/docs.txt'\n", - " if not os.path.isfile(doc_name):\n", - " with open(doc_name, 'w') as doc_file:\n", - " doc_file.write('\\n'.join(get_markdown_docstring(bloq.__class__)))\n", - "\n", " file_name = f'ui_export/{bloq.__class__.__name__}/{bloq_filename(bloq)}'\n", " if not os.path.isfile(file_name):\n", " bloq_dict = {\n", @@ -160,6 +154,13 @@ "for section in NB_BY_SECTION:\n", " for notebook_spec in section[1]:\n", " for bloq_spec in notebook_spec.bloq_specs:\n", + " Path(f'ui_export/{bloq_spec.bloq_cls.__name__}').mkdir(parents=True, exist_ok=True)\n", + "\n", + " doc_name = f'ui_export/{bloq_spec.bloq_cls.__name__}/docs.txt'\n", + " if not os.path.isfile(doc_name):\n", + " with open(doc_name, 'w') as doc_file:\n", + " doc_file.write('\\n'.join(get_markdown_docstring(bloq_spec.bloq_cls)))\n", + "\n", " for example in bloq_spec.examples:\n", " write_example(example.make())" ] From 0220df23d2c58339ef17004e8ae3c63397510d86 Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Thu, 17 Oct 2024 16:18:39 -0700 Subject: [PATCH 4/8] Implement CtrlScaleModAdd and CModAddK bloqs for Modular Exponentiation (#1432) * Add some classical simulation * Implement primitives for ModExp * Fix serialization test error * Change Union -> SymbolicInt * Fix nits * Better symbolic decomposition error messages * Fix merge conflicts * Fixed docstring to be more readable (hopefully) * Address nits --------- Co-authored-by: Matthew Harrigan --- .../qualtran_dev_tools/notebook_specs.py | 2 + qualtran/bloqs/basic_gates/z_basis.py | 5 +- .../bloqs/mod_arithmetic/mod_addition.ipynb | 240 ++++++++++++++++++ qualtran/bloqs/mod_arithmetic/mod_addition.py | 125 ++++++++- .../bloqs/mod_arithmetic/mod_addition_test.py | 81 ++++-- qualtran/serialization/resolver_dict.py | 4 +- 6 files changed, 421 insertions(+), 36 deletions(-) diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index f96b511e4..1b5f219b4 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -493,6 +493,8 @@ qualtran.bloqs.mod_arithmetic.mod_addition._MOD_ADD_DOC, qualtran.bloqs.mod_arithmetic.mod_addition._MOD_ADD_K_DOC, qualtran.bloqs.mod_arithmetic.mod_addition._C_MOD_ADD_DOC, + qualtran.bloqs.mod_arithmetic.mod_addition._C_MOD_ADD_K_DOC, + qualtran.bloqs.mod_arithmetic.mod_addition._CTRL_SCALE_MOD_ADD_DOC, ], ), NotebookSpecV2( diff --git a/qualtran/bloqs/basic_gates/z_basis.py b/qualtran/bloqs/basic_gates/z_basis.py index b0f39b132..40bd02255 100644 --- a/qualtran/bloqs/basic_gates/z_basis.py +++ b/qualtran/bloqs/basic_gates/z_basis.py @@ -42,6 +42,7 @@ ) from qualtran.bloqs.bookkeeping import ArbitraryClifford from qualtran.drawing import Circle, directional_text_box, Text, TextBox, WireSymbol +from qualtran.symbolics import SymbolicInt if TYPE_CHECKING: import cirq @@ -453,7 +454,7 @@ class IntState(_IntVector): val: The register of size `bitsize` which initializes the value `val`. """ - def __init__(self, val: Union[int, sympy.Expr], bitsize: Union[int, sympy.Expr]): + def __init__(self, val: SymbolicInt, bitsize: SymbolicInt): self.__attrs_init__(val=val, bitsize=bitsize, state=True) @@ -478,7 +479,7 @@ class IntEffect(_IntVector): val: The register of size `bitsize` which de-allocates the value `val`. """ - def __init__(self, val: int, bitsize: int): + def __init__(self, val: SymbolicInt, bitsize: SymbolicInt): self.__attrs_init__(val=val, bitsize=bitsize, state=False) diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb b/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb index dab6f57c5..b081513d2 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb +++ b/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb @@ -394,6 +394,246 @@ "show_call_graph(cmodadd_example_g)\n", "show_counts_sigma(cmodadd_example_sigma)" ] + }, + { + "cell_type": "markdown", + "id": "0523961e", + "metadata": { + "cq.autogen": "CModAddK.bloq_doc.md" + }, + "source": [ + "## `CModAddK`\n", + "Perform x += k mod m for constant k, m and quantum x.\n", + "\n", + "#### Parameters\n", + " - `k`: The integer to add to `x`.\n", + " - `mod`: The modulus for the addition.\n", + " - `bitsize`: The bitsize of the `x` register. \n", + "\n", + "#### Registers\n", + " - `ctrl`: The control bit\n", + " - `x`: The register to perform the in-place modular addition. \n", + "\n", + "#### References\n", + " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Gidney and Ekerå 2019. The reference implementation in section 2.2 uses CModAddK, but the circuit that it points to is just ModAdd (not ModAddK). This ModAdd is less efficient than the circuit later introduced in the Litinski paper so we choose to use that since it is more efficient and already implemented in Qualtran.\n", + " - [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). Litinski et al. 2023. This CModAdd circuit uses 2 fewer additions than the implementation referenced in the paper above. Because of this we choose to use this CModAdd bloq instead.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "341211fa", + "metadata": { + "cq.autogen": "CModAddK.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.mod_arithmetic import CModAddK" + ] + }, + { + "cell_type": "markdown", + "id": "6b00aef7", + "metadata": { + "cq.autogen": "CModAddK.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8877f17", + "metadata": { + "cq.autogen": "CModAddK.cmod_add_k" + }, + "outputs": [], + "source": [ + "n, m, k = sympy.symbols('n m k')\n", + "cmod_add_k = CModAddK(bitsize=n, mod=m, k=k)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87fe8b8f", + "metadata": { + "cq.autogen": "CModAddK.cmod_add_k_small" + }, + "outputs": [], + "source": [ + "cmod_add_k_small = CModAddK(bitsize=4, mod=7, k=1)" + ] + }, + { + "cell_type": "markdown", + "id": "f44f0268", + "metadata": { + "cq.autogen": "CModAddK.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7a02d6e", + "metadata": { + "cq.autogen": "CModAddK.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([cmod_add_k, cmod_add_k_small],\n", + " ['`cmod_add_k`', '`cmod_add_k_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "48b87ff2", + "metadata": { + "cq.autogen": "CModAddK.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6e6770b", + "metadata": { + "cq.autogen": "CModAddK.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "cmod_add_k_g, cmod_add_k_sigma = cmod_add_k.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(cmod_add_k_g)\n", + "show_counts_sigma(cmod_add_k_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "21f93349", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.bloq_doc.md" + }, + "source": [ + "## `CtrlScaleModAdd`\n", + "Perform y += x*k mod m for constant k, m and quantum x, y.\n", + "\n", + "#### Parameters\n", + " - `k`: The constant integer to scale `x` before adding into `y`.\n", + " - `mod`: The modulus of the addition\n", + " - `bitsize`: The size of the two registers. \n", + "\n", + "#### Registers\n", + " - `ctrl`: The control bit\n", + " - `x`: The 'source' quantum register containing the integer to be scaled and added to `y`.\n", + " - `y`: The 'destination' quantum register to which the addition will apply. \n", + "\n", + "#### References\n", + " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Construction based on description in section 2.2 paragraph 4. We add n And/And† bloqs because the bloq is controlled, but the construction also involves modular addition controlled on the qubits comprising register x.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ac170e5", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.mod_arithmetic import CtrlScaleModAdd" + ] + }, + { + "cell_type": "markdown", + "id": "58ee7de2", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73c2c6f7", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.ctrl_scale_mod_add" + }, + "outputs": [], + "source": [ + "n, m, k = sympy.symbols('n m k')\n", + "ctrl_scale_mod_add = CtrlScaleModAdd(bitsize=n, mod=m, k=k)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e822d5eb", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.ctrl_scale_mod_add_small" + }, + "outputs": [], + "source": [ + "ctrl_scale_mod_add_small = CtrlScaleModAdd(bitsize=4, mod=7, k=1)" + ] + }, + { + "cell_type": "markdown", + "id": "fe4c8957", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4dc8923", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([ctrl_scale_mod_add, ctrl_scale_mod_add_small],\n", + " ['`ctrl_scale_mod_add`', '`ctrl_scale_mod_add_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "97d6888d", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd734f6b", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "ctrl_scale_mod_add_g, ctrl_scale_mod_add_sigma = ctrl_scale_mod_add.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(ctrl_scale_mod_add_g)\n", + "show_counts_sigma(ctrl_scale_mod_add_sigma)" + ] } ], "metadata": { diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition.py b/qualtran/bloqs/mod_arithmetic/mod_addition.py index efe224946..a8186cdfe 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition.py @@ -23,6 +23,7 @@ Bloq, bloq_example, BloqDocSpec, + DecomposeTypeError, GateWithRegisters, QBit, QMontgomeryUInt, @@ -35,8 +36,9 @@ from qualtran.bloqs.arithmetic.addition import Add, AddK from qualtran.bloqs.arithmetic.comparison import CLinearDepthGreaterThan, LinearDepthGreaterThan from qualtran.bloqs.arithmetic.controlled_addition import CAdd -from qualtran.bloqs.basic_gates import XGate +from qualtran.bloqs.basic_gates import IntEffect, IntState, XGate from qualtran.bloqs.bookkeeping import Cast +from qualtran.bloqs.mcmt.and_bloq import And from qualtran.drawing import Circle, Text, TextBox, WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_split_join @@ -89,7 +91,7 @@ def on_classical_vals( def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']: if is_symbolic(self.bitsize): - raise NotImplementedError(f'symbolic decomposition is not supported for {self}') + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") # Allocate ancilla bits for use in addition. junk_bit = bb.allocate(n=1) sign = bb.allocate(n=1) @@ -269,6 +271,19 @@ class CModAddK(Bloq): Registers: ctrl: The control bit x: The register to perform the in-place modular addition. + + References: + [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). + Gidney and Ekerå 2019. + The reference implementation in section 2.2 uses CModAddK, but the circuit that it points + to is just ModAdd (not ModAddK). This ModAdd is less efficient than the circuit later + introduced in the Litinski paper so we choose to use that since it is more efficient and + already implemented in Qualtran. + + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). + Litinski et al. 2023. + This CModAdd circuit uses 2 fewer additions than the implementation referenced in the paper + above. Because of this we choose to use this CModAdd bloq instead. """ k: Union[int, sympy.Expr] @@ -279,12 +294,53 @@ class CModAddK(Bloq): def signature(self) -> 'Signature': return Signature([Register('ctrl', QBit()), Register('x', QUInt(self.bitsize))]) + def build_composite_bloq( + self, bb: 'BloqBuilder', ctrl: 'Soquet', x: 'Soquet' + ) -> Dict[str, 'SoquetT']: + k = bb.add(IntState(bitsize=self.bitsize, val=self.k)) + ctrl, k, x = bb.add(CModAdd(QUInt(self.bitsize), mod=self.mod), ctrl=ctrl, x=k, y=x) + bb.add(IntEffect(bitsize=self.bitsize, val=self.k), val=k) + return {'ctrl': ctrl, 'x': x} + + def on_classical_vals( + self, ctrl: 'ClassicalValT', x: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 0: + return {'ctrl': 0, 'x': x} + + assert ctrl == 1, 'Bad ctrl value.' + x = (x + self.k) % self.mod + return {'ctrl': ctrl, 'x': x} + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - k = ssa.new_symbol('k') - return {AddK(k=k, bitsize=self.bitsize).controlled(): 5} + return {CModAdd(QUInt(self.bitsize), mod=self.mod): 1} - def short_name(self) -> str: - return f'x += {self.k} % {self.mod}' + def wire_symbol( + self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() + ) -> 'WireSymbol': + if reg is None: + return Text(f"mod {self.mod}") + if reg.name == 'ctrl': + return Circle() + if reg.name == 'x': + return TextBox(f'x += {self.k}') + raise ValueError(f"Unknown register {reg}") + + +@bloq_example(generalizer=ignore_split_join) +def _cmod_add_k() -> CModAddK: + n, m, k = sympy.symbols('n m k') + cmod_add_k = CModAddK(bitsize=n, mod=m, k=k) + return cmod_add_k + + +@bloq_example +def _cmod_add_k_small() -> CModAddK: + cmod_add_k_small = CModAddK(bitsize=4, mod=7, k=1) + return cmod_add_k_small + + +_C_MOD_ADD_K_DOC = BloqDocSpec(bloq_cls=CModAddK, examples=[_cmod_add_k, _cmod_add_k_small]) @frozen @@ -300,6 +356,12 @@ class CtrlScaleModAdd(Bloq): ctrl: The control bit x: The 'source' quantum register containing the integer to be scaled and added to `y`. y: The 'destination' quantum register to which the addition will apply. + + References: + [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). + Construction based on description in section 2.2 paragraph 4. We add n And/And† bloqs + because the bloq is controlled, but the construction also involves modular addition + controlled on the qubits comprising register x. """ k: Union[int, sympy.Expr] @@ -316,9 +378,38 @@ def signature(self) -> 'Signature': ] ) + def build_composite_bloq( + self, bb: 'BloqBuilder', ctrl: 'Soquet', x: 'Soquet', y: 'Soquet' + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + x_split = bb.split(x) + for i in range(int(self.bitsize)): + and_ctrl = [ctrl, x_split[i]] + and_ctrl, ancilla = bb.add(And(), ctrl=and_ctrl) + ancilla, y = bb.add( + CModAddK( + k=((self.k * 2 ** (self.bitsize - 1 - i)) % self.mod), + bitsize=self.bitsize, + mod=self.mod, + ), + ctrl=ancilla, + x=y, + ) + and_ctrl = bb.add(And().adjoint(), ctrl=and_ctrl, target=ancilla) + ctrl = and_ctrl[0] + x_split[i] = and_ctrl[1] + x = bb.join(x_split, dtype=QUInt(self.bitsize)) + + return {'ctrl': ctrl, 'x': x, 'y': y} + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': k = ssa.new_symbol('k') - return {CModAddK(k=k, bitsize=self.bitsize, mod=self.mod): self.bitsize} + return { + CModAddK(k=k, bitsize=self.bitsize, mod=self.mod): self.bitsize, + And(): self.bitsize, + And().adjoint(): self.bitsize, + } def on_classical_vals( self, ctrl: 'ClassicalValT', x: 'ClassicalValT', y: 'ClassicalValT' @@ -344,6 +435,24 @@ def wire_symbol( raise ValueError(f"Unknown register {reg}") +@bloq_example(generalizer=ignore_split_join) +def _ctrl_scale_mod_add() -> CtrlScaleModAdd: + n, m, k = sympy.symbols('n m k') + ctrl_scale_mod_add = CtrlScaleModAdd(bitsize=n, mod=m, k=k) + return ctrl_scale_mod_add + + +@bloq_example +def _ctrl_scale_mod_add_small() -> CtrlScaleModAdd: + ctrl_scale_mod_add_small = CtrlScaleModAdd(bitsize=4, mod=7, k=1) + return ctrl_scale_mod_add_small + + +_CTRL_SCALE_MOD_ADD_DOC = BloqDocSpec( + bloq_cls=CtrlScaleModAdd, examples=[_ctrl_scale_mod_add, _ctrl_scale_mod_add_small] +) + + @frozen class CModAdd(Bloq): r"""Controlled Modular Addition. @@ -390,6 +499,8 @@ def on_classical_vals( def build_composite_bloq( self, bb: 'BloqBuilder', ctrl, x: Soquet, y: Soquet ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.dtype.bitsize): + raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}') y_arr = bb.split(y) ancilla = bb.allocate(1) x = bb.add(Cast(self.dtype, QUInt(self.dtype.bitsize)), reg=x) diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py index 5b37b355e..455670bb5 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py @@ -18,10 +18,21 @@ import pytest import sympy +import qualtran.testing as qlt_testing from qualtran import QMontgomeryUInt, QUInt from qualtran.bloqs.arithmetic import Add from qualtran.bloqs.mod_arithmetic import CModAdd, CModAddK, CtrlScaleModAdd, ModAdd, ModAddK -from qualtran.bloqs.mod_arithmetic.mod_addition import _cmodadd_example +from qualtran.bloqs.mod_arithmetic.mod_addition import ( + _cmod_add_k, + _cmod_add_k_small, + _cmodadd_example, + _ctrl_scale_mod_add, + _ctrl_scale_mod_add_small, + _mod_add, + _mod_add_k, + _mod_add_k_large, + _mod_add_k_small, +) from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join @@ -33,6 +44,29 @@ ) +@pytest.mark.parametrize( + "bloq", + [ + _mod_add, + _mod_add_k, + _mod_add_k_small, + _mod_add_k_large, + _cmod_add_k, + _cmod_add_k_small, + _ctrl_scale_mod_add, + _ctrl_scale_mod_add_small, + _cmodadd_example, + ], +) +def test_examples(bloq_autotester, bloq): + bloq_autotester(bloq) + + +@pytest.mark.notebook +def test_notebook(): + execute_notebook('mod_addition') + + def identity_map(n: int): """Returns a dict of size `2**n` mapping each integer in range [0, 2**n) to itself.""" return {i: i for i in range(2**n)} @@ -61,22 +95,6 @@ def test_add_mod_n_gate_counts(bitsize): assert bloq.t_complexity() == add_constant_mod_n_ref_t_complexity_(bloq) -def test_ctrl_scale_mod_add(): - bloq = CtrlScaleModAdd(k=123, mod=13 * 17, bitsize=8) - - counts = bloq.bloq_counts() - ((bloq, n),) = counts.items() - assert n == 8 - - -def test_ctrl_mod_add_k(): - bloq = CModAddK(k=123, mod=13 * 17, bitsize=8) - - counts = bloq.bloq_counts() - ((bloq, n),) = counts.items() - assert n == 5 - - @pytest.mark.parametrize('bitsize,p', [(1, 1), (2, 3), (5, 8)]) def test_mod_add_valid_decomp(bitsize, p): bloq = ModAdd(bitsize=bitsize, mod=p) @@ -131,6 +149,26 @@ def test_classical_action_cmodadd_fast(control, bitsize): assert b.call_classically(ctrl=c, x=x, y=y) == cb.call_classically(ctrl=c, x=x, y=y) +@pytest.mark.slow +@pytest.mark.parametrize( + ['prime', 'bitsize', 'k'], + [(p, n, k) for p in (13, 17, 23) for n in range(p.bit_length(), 8) for k in range(1, p)], +) +def test_cscalemodadd_classical_action(bitsize, prime, k): + b = CtrlScaleModAdd(bitsize=bitsize, mod=prime, k=k) + qlt_testing.assert_consistent_classical_action(b, ctrl=(0, 1), x=range(prime), y=range(prime)) + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['prime', 'bitsize', 'k'], + [(p, n, k) for p in (13, 17, 23) for n in range(p.bit_length(), 8) for k in range(1, p)], +) +def test_cmodaddk_classical_action(bitsize, prime, k): + b = CModAddK(bitsize=bitsize, mod=prime, k=k) + qlt_testing.assert_consistent_classical_action(b, ctrl=(0, 1), x=range(prime)) + + @pytest.mark.parametrize('control', range(2)) @pytest.mark.parametrize( ['prime', 'bitsize'], @@ -154,15 +192,6 @@ def test_cmodadd_cost(control, dtype): assert cost.total_t_count() == 4 * n_toffolis -def test_cmodadd_example(bloq_autotester): - bloq_autotester(_cmodadd_example) - - -@pytest.mark.notebook -def test_notebook(): - execute_notebook('mod_addition') - - def test_cmod_add_complexity_vs_ref(): n, k = sympy.symbols('n k', integer=True, positive=True) bloq = CModAdd(QUInt(n), mod=k) diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index dad2907ea..f53ea457b 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -328,8 +328,10 @@ "qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint": qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint, "qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM": qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM, "qualtran.bloqs.mod_arithmetic.CModAddK": qualtran.bloqs.mod_arithmetic.CModAddK, - "qualtran.bloqs.mod_arithmetic.mod_addition.CModAdd": qualtran.bloqs.mod_arithmetic.CModAdd, + "qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd": qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd, + "qualtran.bloqs.mod_arithmetic.mod_addition.CModAdd": qualtran.bloqs.mod_arithmetic.mod_addition.CModAdd, "qualtran.bloqs.mod_arithmetic.mod_addition.ModAddK": qualtran.bloqs.mod_arithmetic.mod_addition.ModAddK, + "qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK": qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK, "qualtran.bloqs.mod_arithmetic.mod_addition.CtrlScaleModAdd": qualtran.bloqs.mod_arithmetic.CtrlScaleModAdd, "qualtran.bloqs.mod_arithmetic.ModAdd": qualtran.bloqs.mod_arithmetic.ModAdd, "qualtran.bloqs.mod_arithmetic.ModSub": qualtran.bloqs.mod_arithmetic.ModSub, From 3117a84efb57a1222f0a865fd29f3043c6aa704b Mon Sep 17 00:00:00 2001 From: Anurudh Peduri <7265746+anurudhp@users.noreply.github.com> Date: Thu, 17 Oct 2024 16:45:20 -0700 Subject: [PATCH 5/8] Default `Bloq.get_ctrl_system`: Use `And` ladder to reduce multiple controls to single control (#1456) * default ctrl system: use And ladder to reduce controls to a single control bit * fix tests * test: `XGate().controlled(...)` --------- Co-authored-by: Matthew Harrigan --- qualtran/_infra/bloq.py | 7 +++- qualtran/_infra/gate_with_registers_test.py | 4 ++- qualtran/bloqs/basic_gates/identity_test.py | 40 +-------------------- qualtran/bloqs/basic_gates/x_basis_test.py | 14 ++++++++ qualtran/bloqs/qsp/generalized_qsp_test.py | 2 +- 5 files changed, 25 insertions(+), 42 deletions(-) diff --git a/qualtran/_infra/bloq.py b/qualtran/_infra/bloq.py index 7ccd20037..ec5e2d735 100644 --- a/qualtran/_infra/bloq.py +++ b/qualtran/_infra/bloq.py @@ -394,7 +394,12 @@ def _my_add_controlled( add_controlled: A function with the signature documented above that the system can use to automatically wire up the new control registers. """ - from qualtran import Controlled + from qualtran import Controlled, CtrlSpec + from qualtran.bloqs.mcmt.controlled_via_and import ControlledViaAnd + + if ctrl_spec != CtrlSpec(): + # reduce controls to a single qubit + return ControlledViaAnd.make_ctrl_system(self, ctrl_spec=ctrl_spec) return Controlled.make_ctrl_system(self, ctrl_spec=ctrl_spec) diff --git a/qualtran/_infra/gate_with_registers_test.py b/qualtran/_infra/gate_with_registers_test.py index a3735a889..31fd6f863 100644 --- a/qualtran/_infra/gate_with_registers_test.py +++ b/qualtran/_infra/gate_with_registers_test.py @@ -151,8 +151,10 @@ def test_gate_with_registers_decompose_from_context_auto_generated(): def test_non_unitary_controlled(): + from qualtran.bloqs.mcmt.controlled_via_and import ControlledViaAnd + bloq = BloqWithDecompose() - assert bloq.controlled(control_values=[0]) == Controlled(bloq, CtrlSpec(cvs=0)) + assert bloq.controlled(control_values=[0]) == ControlledViaAnd(bloq, CtrlSpec(cvs=0)) @pytest.mark.notebook diff --git a/qualtran/bloqs/basic_gates/identity_test.py b/qualtran/bloqs/basic_gates/identity_test.py index ea3a4255e..3a45b8980 100644 --- a/qualtran/bloqs/basic_gates/identity_test.py +++ b/qualtran/bloqs/basic_gates/identity_test.py @@ -15,16 +15,14 @@ import numpy as np import pytest import sympy -from attrs import frozen -from qualtran import Bloq, BloqBuilder, CtrlSpec, QInt, QUInt, Signature, Soquet, SoquetT +from qualtran import BloqBuilder from qualtran.bloqs.basic_gates import OneState from qualtran.bloqs.basic_gates.identity import _identity, _identity_n, _identity_symb, Identity from qualtran.simulation.classical_sim import ( format_classical_truth_table, get_classical_truth_table, ) -from qualtran.symbolics import SymbolicInt from qualtran.testing import execute_notebook @@ -95,42 +93,6 @@ def test_identity_controlled(): assert Identity(n).controlled() == Identity(n + 1) -@frozen -class TestIdentityDecomposition(Bloq): - """helper to test Identity.get_ctrl_system""" - - bitsize: SymbolicInt - - @property - def signature(self) -> 'Signature': - return Signature.build(q=self.bitsize) - - def build_composite_bloq(self, bb: 'BloqBuilder', q: Soquet) -> dict[str, 'SoquetT']: - q = bb.add(Identity(self.bitsize), q=q) - q = bb.add(Identity(self.bitsize), q=q) - return {'q': q} - - -@pytest.mark.parametrize("n", [4, sympy.Symbol("n")]) -@pytest.mark.parametrize( - "ctrl_spec", - [ - CtrlSpec(cvs=(np.array([1, 0, 1]),)), - CtrlSpec(qdtypes=(QUInt(3), QInt(3)), cvs=(np.array(0b010), np.array(0b001))), - ], -) -def test_identity_get_ctrl_system(n: SymbolicInt, ctrl_spec: CtrlSpec): - m = ctrl_spec.num_qubits - - bloq = TestIdentityDecomposition(n) - ctrl_bloq = bloq.controlled(ctrl_spec) - - _ = ctrl_bloq.decompose_bloq() - - _, sigma = ctrl_bloq.call_graph() - assert sigma == {Identity(n + m): 2} - - @pytest.mark.notebook def test_notebook(): execute_notebook('identity') diff --git a/qualtran/bloqs/basic_gates/x_basis_test.py b/qualtran/bloqs/basic_gates/x_basis_test.py index 58851757e..d6b3b670d 100644 --- a/qualtran/bloqs/basic_gates/x_basis_test.py +++ b/qualtran/bloqs/basic_gates/x_basis_test.py @@ -92,3 +92,17 @@ def test_x_truth_table(): 0 -> 1 1 -> 0""" ) + + +def test_controlled_x(): + from qualtran import CtrlSpec, QUInt + from qualtran.bloqs.basic_gates import CNOT + from qualtran.bloqs.mcmt import And + + def _keep_and(b): + return isinstance(b, And) + + n = 8 + bloq = XGate().controlled(CtrlSpec(qdtypes=QUInt(n), cvs=1)) + _, sigma = bloq.call_graph(keep=_keep_and) + assert sigma == {And(): n - 1, CNOT(): 1, And().adjoint(): n - 1, XGate(): 4 * (n - 1)} diff --git a/qualtran/bloqs/qsp/generalized_qsp_test.py b/qualtran/bloqs/qsp/generalized_qsp_test.py index c64899b19..c925da2a8 100644 --- a/qualtran/bloqs/qsp/generalized_qsp_test.py +++ b/qualtran/bloqs/qsp/generalized_qsp_test.py @@ -201,7 +201,7 @@ def catch_rotations(bloq: Bloq) -> Bloq: expected_sigma: dict[Bloq, int] = {arbitrary_rotation: degree + 1} if degree > negative_power: - expected_sigma[Controlled(U, CtrlSpec(cvs=0))] = degree - negative_power + expected_sigma[U.controlled(ctrl_spec=CtrlSpec(cvs=0))] = degree - negative_power if negative_power > 0: expected_sigma[Controlled(U.adjoint(), CtrlSpec())] = min(degree, negative_power) if negative_power > degree: From c227032a41292ad9f8f95619d063fbbcaff9ad47 Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Fri, 18 Oct 2024 13:05:44 -0700 Subject: [PATCH 6/8] Add RSA Phase Estimate Bloq and Move ModExp to rsa/ subdirectory (#1428) * Add rsa files - needs a lot of work just stashing it for now * Made some structure changes RSA * Rework rsa mod exp bloqs to work in a rsa phase estimation circuit * Fix mypy issues * Better symbolic messages * Refactor RSA to have a phase estimation circuit and a classical simulable modular exponentiation circuit * Fix notebook specs merge conflict * remove unecessary x values for classical simulation test * fix nits * Better documentation init * Fix broken link * Fix random issue and cirq interop import of modexp * Fix broken import msft interop * Fix another dependency of ModExp --------- Co-authored-by: Matthew Harrigan --- .../qualtran_dev_tools/notebook_specs.py | 12 +- docs/bloqs/index.rst | 2 +- qualtran/_infra/Bloqs-Tutorial.ipynb | 2 +- qualtran/bloqs/factoring/__init__.py | 4 +- .../_ecc_shims.py => _factoring_shims.py} | 3 +- .../factoring/ecc/ec_phase_estimate_r.py | 2 +- qualtran/bloqs/factoring/mod_add_test.py | 0 qualtran/bloqs/factoring/mod_exp.html | 23 -- qualtran/bloqs/factoring/mod_exp.ipynb | 209 ----------- qualtran/bloqs/factoring/rsa/__init__.py | 37 ++ .../{ => rsa}/factoring-via-modexp.ipynb | 5 +- qualtran/bloqs/factoring/rsa/rsa.ipynb | 331 ++++++++++++++++++ .../{mod_exp.py => rsa/rsa_mod_exp.py} | 50 ++- .../rsa_mod_exp_test.py} | 28 +- .../bloqs/factoring/rsa/rsa_phase_estimate.py | 150 ++++++++ .../factoring/rsa/rsa_phase_estimate_test.py | 27 ++ qualtran/cirq_interop/_bloq_to_cirq_test.py | 2 +- qualtran/cirq_interop/cirq_interop.ipynb | 2 +- qualtran/serialization/bloq_test.py | 2 +- qualtran/serialization/resolver_dict.py | 7 +- .../msft_resource_estimator_interop.ipynb | 2 +- 21 files changed, 611 insertions(+), 289 deletions(-) rename qualtran/bloqs/factoring/{ecc/_ecc_shims.py => _factoring_shims.py} (95%) delete mode 100644 qualtran/bloqs/factoring/mod_add_test.py delete mode 100644 qualtran/bloqs/factoring/mod_exp.html delete mode 100644 qualtran/bloqs/factoring/mod_exp.ipynb create mode 100644 qualtran/bloqs/factoring/rsa/__init__.py rename qualtran/bloqs/factoring/{ => rsa}/factoring-via-modexp.ipynb (98%) create mode 100644 qualtran/bloqs/factoring/rsa/rsa.ipynb rename qualtran/bloqs/factoring/{mod_exp.py => rsa/rsa_mod_exp.py} (78%) rename qualtran/bloqs/factoring/{mod_exp_test.py => rsa/rsa_mod_exp_test.py} (83%) create mode 100644 qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py create mode 100644 qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index 1b5f219b4..7c66311fe 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -83,7 +83,7 @@ import qualtran.bloqs.data_loading.qrom_base import qualtran.bloqs.data_loading.select_swap_qrom import qualtran.bloqs.factoring.ecc -import qualtran.bloqs.factoring.mod_exp +import qualtran.bloqs.factoring.rsa import qualtran.bloqs.gf_arithmetic.gf2_add_k import qualtran.bloqs.gf_arithmetic.gf2_addition import qualtran.bloqs.gf_arithmetic.gf2_inverse @@ -517,10 +517,12 @@ ], ), NotebookSpecV2( - title='Modular Exponentiation', - module=qualtran.bloqs.factoring.mod_exp, - bloq_specs=[qualtran.bloqs.factoring.mod_exp._MODEXP_DOC], - directory=f'{SOURCE_DIR}/bloqs/factoring', + title='Factoring RSA', + module=qualtran.bloqs.factoring.rsa, + bloq_specs=[ + qualtran.bloqs.factoring.rsa.rsa_phase_estimate._RSA_PE_BLOQ_DOC, + qualtran.bloqs.factoring.rsa.rsa_mod_exp._RSA_MODEXP_DOC, + ], ), NotebookSpecV2( title='Elliptic Curve Addition', diff --git a/docs/bloqs/index.rst b/docs/bloqs/index.rst index 0906dd71b..16c591baa 100644 --- a/docs/bloqs/index.rst +++ b/docs/bloqs/index.rst @@ -83,7 +83,7 @@ Bloqs Library mod_arithmetic/mod_addition.ipynb mod_arithmetic/mod_subtraction.ipynb mod_arithmetic/mod_multiplication.ipynb - factoring/mod_exp.ipynb + factoring/rsa/rsa.ipynb factoring/ecc/ec_add.ipynb factoring/ecc/ecc.ipynb diff --git a/qualtran/_infra/Bloqs-Tutorial.ipynb b/qualtran/_infra/Bloqs-Tutorial.ipynb index 219dc48f4..ca64c4755 100644 --- a/qualtran/_infra/Bloqs-Tutorial.ipynb +++ b/qualtran/_infra/Bloqs-Tutorial.ipynb @@ -914,7 +914,7 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.bloqs.factoring import ModExp\n", + "from qualtran.bloqs.factoring.rsa import ModExp\n", "\n", "mod_exp = ModExp(base=8, mod=13*17, exp_bitsize=3, x_bitsize=1024)\n", "show_bloq(mod_exp)" diff --git a/qualtran/bloqs/factoring/__init__.py b/qualtran/bloqs/factoring/__init__.py index 59a92dad8..15780de77 100644 --- a/qualtran/bloqs/factoring/__init__.py +++ b/qualtran/bloqs/factoring/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,5 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .mod_exp import ModExp diff --git a/qualtran/bloqs/factoring/ecc/_ecc_shims.py b/qualtran/bloqs/factoring/_factoring_shims.py similarity index 95% rename from qualtran/bloqs/factoring/ecc/_ecc_shims.py rename to qualtran/bloqs/factoring/_factoring_shims.py index 4b602e73f..896e103b8 100644 --- a/qualtran/bloqs/factoring/ecc/_ecc_shims.py +++ b/qualtran/bloqs/factoring/_factoring_shims.py @@ -19,11 +19,12 @@ from qualtran import Bloq, CompositeBloq, DecomposeTypeError, QBit, Register, Side, Signature from qualtran.drawing import RarrowTextBox, Text, WireSymbol +from qualtran.symbolics import SymbolicInt @frozen class MeasureQFT(Bloq): - n: int + n: 'SymbolicInt' @cached_property def signature(self) -> 'Signature': diff --git a/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py b/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py index f4ddb9d19..ffe03f2f9 100644 --- a/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py +++ b/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py @@ -33,7 +33,7 @@ from qualtran.bloqs.basic_gates import PlusState from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator -from ._ecc_shims import MeasureQFT +from .._factoring_shims import MeasureQFT from .ec_add_r import ECAddR from .ec_point import ECPoint diff --git a/qualtran/bloqs/factoring/mod_add_test.py b/qualtran/bloqs/factoring/mod_add_test.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/qualtran/bloqs/factoring/mod_exp.html b/qualtran/bloqs/factoring/mod_exp.html deleted file mode 100644 index 9d015ba4c..000000000 --- a/qualtran/bloqs/factoring/mod_exp.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - shor - - - - - - -
- -
- - - - \ No newline at end of file diff --git a/qualtran/bloqs/factoring/mod_exp.ipynb b/qualtran/bloqs/factoring/mod_exp.ipynb deleted file mode 100644 index 77c87aa15..000000000 --- a/qualtran/bloqs/factoring/mod_exp.ipynb +++ /dev/null @@ -1,209 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "5947b041", - "metadata": { - "cq.autogen": "title_cell" - }, - "source": [ - "# Modular Exponentiation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ce6b0d51", - "metadata": { - "cq.autogen": "top_imports" - }, - "outputs": [], - "source": [ - "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", - "from qualtran import QBit, QInt, QUInt, QAny\n", - "from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n", - "from typing import *\n", - "import numpy as np\n", - "import sympy\n", - "import cirq" - ] - }, - { - "cell_type": "markdown", - "id": "2f68374c", - "metadata": { - "cq.autogen": "ModExp.bloq_doc.md" - }, - "source": [ - "## `ModExp`\n", - "Perform $b^e \\mod{m}$ for constant `base` $b$, `mod` $m$, and quantum `exponent` $e$.\n", - "\n", - "Modular exponentiation is the main computational primitive for quantum factoring algorithms.\n", - "We follow [GE2019]'s \"reference implementation\" for factoring. See `ModExp.make_for_shor`\n", - "to set the class attributes for a factoring run.\n", - "\n", - "This bloq decomposes into controlled modular exponentiation for each exponent bit.\n", - "\n", - "#### Parameters\n", - " - `base`: The integer base of the exponentiation\n", - " - `mod`: The integer modulus\n", - " - `exp_bitsize`: The size of the `exponent` thru-register\n", - " - `x_bitsize`: The size of the `x` right-register \n", - "\n", - "#### Registers\n", - " - `exponent`: The exponent\n", - " - `x [right]`: The output register containing the result of the exponentiation \n", - "\n", - "#### References\n", - " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Gidney and Ekerå. 2019.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4a1b8a2a", - "metadata": { - "cq.autogen": "ModExp.bloq_doc.py" - }, - "outputs": [], - "source": [ - "from qualtran.bloqs.factoring import ModExp" - ] - }, - { - "cell_type": "markdown", - "id": "902ec939", - "metadata": { - "cq.autogen": "ModExp.example_instances.md" - }, - "source": [ - "### Example Instances" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c5aac47", - "metadata": { - "cq.autogen": "ModExp.modexp_small" - }, - "outputs": [], - "source": [ - "modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41d94c81", - "metadata": { - "cq.autogen": "ModExp.modexp" - }, - "outputs": [], - "source": [ - "modexp = ModExp.make_for_shor(big_n=13 * 17, g=9)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c0c6c06", - "metadata": { - "cq.autogen": "ModExp.modexp_symb" - }, - "outputs": [], - "source": [ - "g, N, n_e, n_x = sympy.symbols('g N n_e, n_x')\n", - "modexp_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x)" - ] - }, - { - "cell_type": "markdown", - "id": "a55a51df", - "metadata": { - "cq.autogen": "ModExp.graphical_signature.md" - }, - "source": [ - "#### Graphical Signature" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ce953a6d", - "metadata": { - "cq.autogen": "ModExp.graphical_signature.py" - }, - "outputs": [], - "source": [ - "from qualtran.drawing import show_bloqs\n", - "show_bloqs([modexp_symb, modexp_small, modexp],\n", - " ['`modexp_symb`', '`modexp_small`', '`modexp`'])" - ] - }, - { - "cell_type": "markdown", - "id": "83ba4b54", - "metadata": {}, - "source": [ - "### Decomposition" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a0e53334", - "metadata": {}, - "outputs": [], - "source": [ - "show_bloq(modexp_small.decompose_bloq())" - ] - }, - { - "cell_type": "markdown", - "id": "8662fa01", - "metadata": { - "cq.autogen": "ModExp.call_graph.md" - }, - "source": [ - "### Call Graph" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a7d9f5a", - "metadata": { - "cq.autogen": "ModExp.call_graph.py" - }, - "outputs": [], - "source": [ - "from qualtran.resource_counting.generalizers import ignore_split_join\n", - "modexp_symb_g, modexp_symb_sigma = modexp_symb.call_graph(max_depth=1, generalizer=ignore_split_join)\n", - "show_call_graph(modexp_symb_g)\n", - "show_counts_sigma(modexp_symb_sigma)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/qualtran/bloqs/factoring/rsa/__init__.py b/qualtran/bloqs/factoring/rsa/__init__.py new file mode 100644 index 000000000..d58509847 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# isort:skip_file + +r"""Bloqs for breaking RSA cryptography systems via integer factorization. + +RSA cryptography is a form of public key cryptography based on the difficulty of +factoring the product of two large prime numbers. + +Using RSA, the cryptographic scheme chooses two large prime numbers p, q, their product n, +λ(n) = lcm(p - 1, q - 1) where λ is Carmichael's totient function, an integer e such that +1 < e < λ(n), and finally d as d ≡ e^-1 (mod λ(n)). The public key consists of the modulus n and +the public (or encryption) exponent e. The private key consists of the private (or decryption) +exponent d, which must be kept secret. p, q, and λ(n) must also be kept secret because they can be +used to calculate d. + +Using Shor's algorithm for factoring, we can find p and q (the factors of n) in polynomial time +with a quantum algorithm. + +References: + [RSA (cryptosystem)](https://en.wikipedia.org/wiki/RSA_(cryptosystem)). +""" + +from .rsa_phase_estimate import RSAPhaseEstimate +from .rsa_mod_exp import ModExp diff --git a/qualtran/bloqs/factoring/factoring-via-modexp.ipynb b/qualtran/bloqs/factoring/rsa/factoring-via-modexp.ipynb similarity index 98% rename from qualtran/bloqs/factoring/factoring-via-modexp.ipynb rename to qualtran/bloqs/factoring/rsa/factoring-via-modexp.ipynb index dfea45b9e..08ff940b8 100644 --- a/qualtran/bloqs/factoring/factoring-via-modexp.ipynb +++ b/qualtran/bloqs/factoring/rsa/factoring-via-modexp.ipynb @@ -180,7 +180,7 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.bloqs.factoring.mod_exp import ModExp\n", + "from qualtran.bloqs.factoring.rsa.rsa_mod_exp import ModExp\n", "from qualtran.drawing import show_bloq\n", "\n", "mod_exp = ModExp(base=g, mod=N, exp_bitsize=32, x_bitsize=32)\n", @@ -205,6 +205,7 @@ "metadata": {}, "outputs": [], "source": [ + "from qualtran import QUInt\n", "for e in range(20):\n", " ref = (g ** e) % N\n", " _, bloq_eval = mod_exp.call_classically(exponent=e)\n", @@ -231,7 +232,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/qualtran/bloqs/factoring/rsa/rsa.ipynb b/qualtran/bloqs/factoring/rsa/rsa.ipynb new file mode 100644 index 000000000..5415a4d53 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/rsa.ipynb @@ -0,0 +1,331 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "48ce60bb", + "metadata": { + "cq.autogen": "title_cell" + }, + "source": [ + "# Factoring RSA\n", + "\n", + "Bloqs for breaking RSA cryptography systems via integer factorization.\n", + "\n", + "RSA cryptography is a form of public key cryptography based on the difficulty of\n", + "factoring the product of two large prime numbers.\n", + "\n", + "Using RSA, the cryptographic scheme chooses two large prime numbers p, q, their product n,\n", + "λ(n) = lcm(p - 1, q - 1) where λ is Carmichael's totient function, an integer e such that\n", + "1 < e < λ(n), and finally d as d ≡ e^-1 (mod λ(n)). The public key consists of the modulus n and\n", + "the public (or encryption) exponent e. The private key consists of the private (or decryption)\n", + "exponent d, which must be kept secret. p, q, and λ(n) must also be kept secret because they can be\n", + "used to calculate d.\n", + "\n", + "Using Shor's algorithm for factoring, we can find p and q (the factors of n) in polynomial time\n", + "with a quantum algorithm.\n", + "\n", + "References:\n", + " [RSA (cryptosystem)](https://en.wikipedia.org/wiki/RSA_(cryptosystem))." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d12766dd", + "metadata": { + "cq.autogen": "top_imports" + }, + "outputs": [], + "source": [ + "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", + "from qualtran import QBit, QInt, QUInt, QAny\n", + "from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n", + "from typing import *\n", + "import numpy as np\n", + "import sympy\n", + "import cirq" + ] + }, + { + "cell_type": "markdown", + "id": "881834c3", + "metadata": { + "cq.autogen": "ModExp.bloq_doc.md" + }, + "source": [ + "## `ModExp`\n", + "Perform $b^e \\mod{m}$ for constant `base` $b$, `mod` $m$, and quantum `exponent` $e$.\n", + "\n", + "Modular exponentiation is the main computational primitive for quantum factoring algorithms.\n", + "We follow [GE2019]'s \"reference implementation\" for factoring. See `ModExp.make_for_shor`\n", + "to set the class attributes for a factoring run.\n", + "\n", + "This bloq decomposes into controlled modular exponentiation for each exponent bit.\n", + "\n", + "#### Parameters\n", + " - `base`: The integer base of the exponentiation\n", + " - `mod`: The integer modulus\n", + " - `exp_bitsize`: The size of the `exponent` thru-register\n", + " - `x_bitsize`: The size of the `x` right-register \n", + "\n", + "#### Registers\n", + " - `exponent`: The exponent\n", + " - `x [right]`: The output register containing the result of the exponentiation \n", + "\n", + "#### References\n", + " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Gidney and Ekerå. 2019.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61062a3d", + "metadata": { + "cq.autogen": "ModExp.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.factoring.rsa import ModExp" + ] + }, + { + "cell_type": "markdown", + "id": "6963ad94", + "metadata": { + "cq.autogen": "ModExp.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56d35af9", + "metadata": { + "cq.autogen": "ModExp.modexp_symb" + }, + "outputs": [], + "source": [ + "g, N, n_e, n_x = sympy.symbols('g N n_e, n_x')\n", + "modexp_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d3d90b4", + "metadata": { + "cq.autogen": "ModExp.modexp_small" + }, + "outputs": [], + "source": [ + "modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3a004d3", + "metadata": { + "cq.autogen": "ModExp.modexp" + }, + "outputs": [], + "source": [ + "modexp = ModExp.make_for_shor(big_n=13 * 17, g=9)" + ] + }, + { + "cell_type": "markdown", + "id": "39422b45", + "metadata": { + "cq.autogen": "ModExp.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83a891e2", + "metadata": { + "cq.autogen": "ModExp.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([modexp_small, modexp, modexp_symb],\n", + " ['`modexp_small`', '`modexp`', '`modexp_symb`'])" + ] + }, + { + "cell_type": "markdown", + "id": "9c271392", + "metadata": { + "cq.autogen": "ModExp.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6d55278", + "metadata": { + "cq.autogen": "ModExp.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "modexp_small_g, modexp_small_sigma = modexp_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(modexp_small_g)\n", + "show_counts_sigma(modexp_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "2603abbd", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.bloq_doc.md" + }, + "source": [ + "## `RSAPhaseEstimate`\n", + "Perform a single phase estimation of the decomposition of Modular Exponentiation for the\n", + "given base.\n", + "\n", + "The constructor requires a pre-set base, see the make_for_shor factory method for picking a\n", + "random, valid base\n", + "\n", + "#### Parameters\n", + " - `n`: The bitsize of the modulus N.\n", + " - `mod`: The modulus N; a part of the public key for RSA.\n", + " - `base`: A base for modular exponentiation. \n", + "\n", + "#### References\n", + " - [Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095). Beauregard. 2003. Fig 1.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b838c20", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.factoring.rsa import RSAPhaseEstimate" + ] + }, + { + "cell_type": "markdown", + "id": "20426b03", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f696c5fd", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.rsa_pe" + }, + "outputs": [], + "source": [ + "n, p, g = sympy.symbols('n p g')\n", + "rsa_pe = RSAPhaseEstimate(n=n, mod=p, base=g)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b16e84a5", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.rsa_pe_small" + }, + "outputs": [], + "source": [ + "rsa_pe_small = RSAPhaseEstimate.make_for_shor(big_n=13 * 17, g=9)" + ] + }, + { + "cell_type": "markdown", + "id": "0c0078d5", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b493a30", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([rsa_pe_small, rsa_pe],\n", + " ['`rsa_pe_small`', '`rsa_pe`'])" + ] + }, + { + "cell_type": "markdown", + "id": "441028fa", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f30cb55", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "rsa_pe_small_g, rsa_pe_small_sigma = rsa_pe_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(rsa_pe_small_g)\n", + "show_counts_sigma(rsa_pe_small_sigma)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d72bc625", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.rsa_pe_shor" + }, + "outputs": [], + "source": [ + "rsa_pe_shor = RSAPhaseEstimate.make_for_shor(big_n=13 * 17, g=9)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qualtran/bloqs/factoring/mod_exp.py b/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py similarity index 78% rename from qualtran/bloqs/factoring/mod_exp.py rename to qualtran/bloqs/factoring/rsa/rsa_mod_exp.py index 129711797..fda33f5fc 100644 --- a/qualtran/bloqs/factoring/mod_exp.py +++ b/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import random from functools import cached_property from typing import cast, Dict, Optional, Tuple, Union import attrs +import numpy as np import sympy from attrs import frozen @@ -28,17 +28,19 @@ DecomposeTypeError, QUInt, Register, - Side, Signature, Soquet, SoquetT, ) -from qualtran.bloqs.basic_gates import IntState +from qualtran._infra.registers import Side +from qualtran.bloqs.basic_gates.z_basis import IntState from qualtran.bloqs.mod_arithmetic import CModMulK from qualtran.drawing import Text, WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_split_join +from qualtran.simulation.classical_sim import ClassicalValT from qualtran.symbolics import is_symbolic +from qualtran.symbolics.types import SymbolicInt @frozen @@ -66,10 +68,10 @@ class ModExp(Bloq): Gidney and Ekerå. 2019. """ - base: Union[int, sympy.Expr] - mod: Union[int, sympy.Expr] - exp_bitsize: Union[int, sympy.Expr] - x_bitsize: Union[int, sympy.Expr] + base: 'SymbolicInt' + mod: 'SymbolicInt' + exp_bitsize: 'SymbolicInt' + x_bitsize: 'SymbolicInt' def __attrs_post_init__(self): if not is_symbolic(self.base, self.mod): @@ -85,33 +87,47 @@ def signature(self) -> 'Signature': ) @classmethod - def make_for_shor(cls, big_n: int, g: Optional[int] = None): + def make_for_shor( + cls, + big_n: 'SymbolicInt', + g: Optional['SymbolicInt'] = None, + rs: Optional[np.random.RandomState] = None, + ): """Factory method that sets up the modular exponentiation for a factoring run. Args: big_n: The large composite number N. Used to set `mod`. Its bitsize is used to set `x_bitsize` and `exp_bitsize`. g: Optional base of the exponentiation. If `None`, we pick a random base. + rs: Optional random state which can be seeded to make base generation deterministic. """ - if isinstance(big_n, sympy.Expr): + if is_symbolic(big_n): little_n = sympy.ceiling(sympy.log(big_n, 2)) else: little_n = int(math.ceil(math.log2(big_n))) if g is None: - g = random.randint(2, big_n) + if is_symbolic(big_n): + g = sympy.symbols('g') + else: + if rs is None: + rs = np.random.RandomState() + while True: + g = rs.randint(2, int(big_n)) + if math.gcd(g, int(big_n)) == 1: + break return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n) - def _CtrlModMul(self, k: Union[int, sympy.Expr]): + def _CtrlModMul(self, k: 'SymbolicInt'): """Helper method to return a `CModMulK` with attributes forwarded.""" return CModMulK(QUInt(self.x_bitsize), k=k, mod=self.mod) def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[str, 'SoquetT']: - if isinstance(self.exp_bitsize, sympy.Expr): - raise DecomposeTypeError("`exp_bitsize` must be a concrete value.") + if is_symbolic(self.exp_bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `exp_bitsize`.") + # https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method x = bb.add(IntState(val=1, bitsize=self.x_bitsize)) exponent = bb.split(exponent) - # https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method base = self.base % self.mod for j in range(self.exp_bitsize - 1, 0 - 1, -1): exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x) @@ -121,9 +137,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': k = ssa.new_symbol('k') - return {IntState(val=1, bitsize=self.x_bitsize): 1, self._CtrlModMul(k=k): self.exp_bitsize} + return {self._CtrlModMul(k=k): self.exp_bitsize, IntState(val=1, bitsize=self.x_bitsize): 1} - def on_classical_vals(self, exponent: int): + def on_classical_vals(self, exponent) -> Dict[str, Union['ClassicalValT', sympy.Expr]]: return {'exponent': exponent, 'x': (self.base**exponent) % self.mod} def wire_symbol( @@ -163,4 +179,4 @@ def _modexp_symb() -> ModExp: return modexp_symb -_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_symb, _modexp_small, _modexp)) +_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb)) diff --git a/qualtran/bloqs/factoring/mod_exp_test.py b/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py similarity index 83% rename from qualtran/bloqs/factoring/mod_exp_test.py rename to qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py index d0b48c4af..e4cabd8c4 100644 --- a/qualtran/bloqs/factoring/mod_exp_test.py +++ b/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py @@ -21,7 +21,7 @@ from qualtran import Bloq from qualtran.bloqs.bookkeeping import Join, Split -from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_symb, ModExp +from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp from qualtran.bloqs.mod_arithmetic import CModMulK from qualtran.drawing import Text from qualtran.resource_counting import SympySymbolAllocator @@ -40,17 +40,13 @@ def test_mod_exp_consistent_classical(): # Choose an exponent in a range. Set exp_bitsize=ne bit enough to fit. exponent = rs.randint(1, 2**n) - ne = 2 * n - # Choose a base smaller than mod. - base = rs.randint(1, mod) - while np.gcd(base, mod) != 1: - base = rs.randint(1, mod) - - bloq = ModExp(base=base, exp_bitsize=ne, x_bitsize=n, mod=mod) + bloq = ModExp.make_for_shor(big_n=mod, rs=rs) ret1 = bloq.call_classically(exponent=exponent) ret2 = bloq.decompose_bloq().call_classically(exponent=exponent) - assert ret1 == ret2 + assert len(ret1) == len(ret2) + for i in range(len(ret1)): + np.testing.assert_array_equal(ret1[i], ret2[i]) def test_modexp_symb_manual(): @@ -93,19 +89,11 @@ def test_mod_exp_t_complexity(): assert tcomp.t > 0 -def test_modexp(bloq_autotester): - bloq_autotester(_modexp) - - -def test_modexp_symb(bloq_autotester): - bloq_autotester(_modexp_symb) +@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small]) +def test_modexp(bloq_autotester, bloq): + bloq_autotester(bloq) @pytest.mark.notebook def test_intro_notebook(): execute_notebook('factoring-via-modexp') - - -@pytest.mark.notebook -def test_notebook(): - execute_notebook('mod_exp') diff --git a/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py new file mode 100644 index 000000000..611af9de4 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py @@ -0,0 +1,150 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Dict, Optional + +import attrs +import numpy as np +import sympy +from attrs import frozen + +from qualtran import ( + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + DecomposeTypeError, + QUInt, + Signature, + SoquetT, +) +from qualtran.bloqs.basic_gates import IntState, PlusState +from qualtran.bloqs.bookkeeping import Free +from qualtran.bloqs.factoring._factoring_shims import MeasureQFT +from qualtran.bloqs.mod_arithmetic.mod_multiplication import CModMulK +from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.symbolics import is_symbolic, SymbolicInt + + +@frozen +class RSAPhaseEstimate(Bloq): + """Perform a single phase estimation of the decomposition of Modular Exponentiation for the + given base. + + The constructor requires a pre-set base, see the make_for_shor factory method for picking a + random, valid base + + Args: + n: The bitsize of the modulus N. + mod: The modulus N; a part of the public key for RSA. + base: A base for modular exponentiation. + + References: + [Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095). + Beauregard. 2003. Fig 1. + """ + + n: 'SymbolicInt' + mod: 'SymbolicInt' + base: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature([]) + + @classmethod + def make_for_shor( + cls, + big_n: 'SymbolicInt', + g: Optional['SymbolicInt'] = None, + rs: Optional[np.random.RandomState] = None, + ): + """Factory method that sets up the modular exponentiation for a factoring run. + + Args: + big_n: The large composite number N. Used to set `mod`. Its bitsize is used + to set `x_bitsize` and `exp_bitsize`. + g: Optional base of the exponentiation. If `None`, we pick a random base. + rs: Optional random state which can be seeded to make base generation deterministic. + """ + if is_symbolic(big_n): + little_n = sympy.ceiling(sympy.log(big_n, 2)) + else: + little_n = int(math.ceil(math.log2(big_n))) + if g is None: + if is_symbolic(big_n): + g = sympy.symbols('g') + else: + if rs is None: + rs = np.random.RandomState() + while True: + g = rs.randint(2, int(big_n)) + if math.gcd(g, int(big_n)) == 1: + break + return cls(base=g, mod=big_n, n=little_n) + + def __attrs_post_init__(self): + if not is_symbolic(self.n, self.mod): + assert self.n == int(math.ceil(math.log2(self.mod))) + + def _CtrlModMul(self, k: 'SymbolicInt'): + """Helper method to return a `CModMulK` with attributes forwarded.""" + return CModMulK(QUInt(self.n), k=k, mod=self.mod) + + def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + exponent = [bb.add(PlusState()) for _ in range(2 * self.n)] + x = bb.add(IntState(val=1, bitsize=self.n)) + + base = self.base % self.mod + for j in range((2 * self.n) - 1, 0 - 1, -1): + exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x) + base = (base * base) % self.mod + + bb.add(MeasureQFT(n=2 * self.n), x=exponent) + bb.add(Free(QUInt(self.n), dirty=True), reg=x) + return {} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + k = ssa.new_symbol('k') + return {MeasureQFT(n=self.n): 1, self._CtrlModMul(k=k): 2 * self.n} + + +_K = sympy.Symbol('k_exp') + + +def _generalize_k(b: Bloq) -> Optional[Bloq]: + if isinstance(b, CModMulK): + return attrs.evolve(b, k=_K) + + return b + + +@bloq_example +def _rsa_pe() -> RSAPhaseEstimate: + n, p, g = sympy.symbols('n p g') + rsa_pe = RSAPhaseEstimate(n=n, mod=p, base=g) + return rsa_pe + + +@bloq_example +def _rsa_pe_small() -> RSAPhaseEstimate: + rsa_pe_small = RSAPhaseEstimate.make_for_shor(big_n=13 * 17, g=9) + return rsa_pe_small + + +_RSA_PE_BLOQ_DOC = BloqDocSpec(bloq_cls=RSAPhaseEstimate, examples=[_rsa_pe_small, _rsa_pe]) diff --git a/qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py new file mode 100644 index 000000000..4ddedf964 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import qualtran.testing as qlt_testing +from qualtran.bloqs.factoring.rsa.rsa_phase_estimate import _rsa_pe, _rsa_pe_small + + +@pytest.mark.parametrize('bloq', [_rsa_pe_small, _rsa_pe]) +def test_rsa_pe(bloq_autotester, bloq): + bloq_autotester(bloq) + + +def test_notebook(): + qlt_testing.execute_notebook('rsa') diff --git a/qualtran/cirq_interop/_bloq_to_cirq_test.py b/qualtran/cirq_interop/_bloq_to_cirq_test.py index 95dae1cef..87a76c014 100644 --- a/qualtran/cirq_interop/_bloq_to_cirq_test.py +++ b/qualtran/cirq_interop/_bloq_to_cirq_test.py @@ -21,7 +21,7 @@ from qualtran import Bloq, BloqBuilder, ConnectionT, Signature, Soquet, SoquetT from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.basic_gates import Toffoli, XGate, YGate -from qualtran.bloqs.factoring import ModExp +from qualtran.bloqs.factoring.rsa import ModExp from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd from qualtran.bloqs.state_preparation import PrepareUniformSuperposition from qualtran.cirq_interop._bloq_to_cirq import BloqAsCirqGate, CirqQuregT diff --git a/qualtran/cirq_interop/cirq_interop.ipynb b/qualtran/cirq_interop/cirq_interop.ipynb index 8a2854df8..991f95bd8 100644 --- a/qualtran/cirq_interop/cirq_interop.ipynb +++ b/qualtran/cirq_interop/cirq_interop.ipynb @@ -471,7 +471,7 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.bloqs.factoring.mod_exp import ModExp\n", + "from qualtran.bloqs.factoring.rsa import ModExp\n", "from qualtran.drawing import show_bloq\n", "from qualtran.drawing import get_musical_score_data, draw_musical_score\n", "N = 13*17\n", diff --git a/qualtran/serialization/bloq_test.py b/qualtran/serialization/bloq_test.py index 218206d7f..d411f09e3 100644 --- a/qualtran/serialization/bloq_test.py +++ b/qualtran/serialization/bloq_test.py @@ -23,7 +23,7 @@ from qualtran import Bloq, Signature from qualtran._infra.composite_bloq_test import TestTwoCNOT -from qualtran.bloqs.factoring.mod_exp import ModExp +from qualtran.bloqs.factoring.rsa.rsa_mod_exp import ModExp from qualtran.cirq_interop import CirqGateAsBloq from qualtran.cirq_interop._cirq_to_bloq_test import TestCNOT as TestCNOTCirq from qualtran.protos import registers_pb2 diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index f53ea457b..4aa73158d 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -97,7 +97,8 @@ import qualtran.bloqs.data_loading.qroam_clean import qualtran.bloqs.data_loading.qrom import qualtran.bloqs.data_loading.select_swap_qrom -import qualtran.bloqs.factoring.mod_exp +import qualtran.bloqs.factoring._factoring_shims +import qualtran.bloqs.factoring.rsa import qualtran.bloqs.for_testing.atom import qualtran.bloqs.for_testing.casting import qualtran.bloqs.for_testing.interior_alloc @@ -342,7 +343,9 @@ "qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK": qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK, "qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul, "qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul, - "qualtran.bloqs.factoring.mod_exp.ModExp": qualtran.bloqs.factoring.mod_exp.ModExp, + "qualtran.bloqs.factoring._factoring_shims.MeasureQFT": qualtran.bloqs.factoring._factoring_shims.MeasureQFT, + "qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate": qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate, + "qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp": qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp, "qualtran.bloqs.for_testing.atom.TestAtom": qualtran.bloqs.for_testing.atom.TestAtom, "qualtran.bloqs.for_testing.atom.TestGWRAtom": qualtran.bloqs.for_testing.atom.TestGWRAtom, "qualtran.bloqs.for_testing.atom.TestTwoBitOp": qualtran.bloqs.for_testing.atom.TestTwoBitOp, diff --git a/qualtran/surface_code/msft_resource_estimator_interop.ipynb b/qualtran/surface_code/msft_resource_estimator_interop.ipynb index af6337271..2af07a495 100644 --- a/qualtran/surface_code/msft_resource_estimator_interop.ipynb +++ b/qualtran/surface_code/msft_resource_estimator_interop.ipynb @@ -27,7 +27,7 @@ }, "outputs": [], "source": [ - "from qualtran.bloqs.factoring.mod_exp import ModExp\n", + "from qualtran.bloqs.factoring.rsa import ModExp\n", "from qualtran.drawing import show_bloq\n", "\n", "N = 13*17 # integer to factor\n", From 7344830491f7cbac8e43c9ec106be37a39cb94e8 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 18 Oct 2024 23:51:10 +0100 Subject: [PATCH 7/8] Create linear half comparison bloqs (#1408) This PR creates all versions of half a comparison operation in $n$ toffoli complexity. The uncomputation of these bloqs has zero cost. The half comparason operation is needed in the modular inversion bloq. I also modify the OutOfPlaceAdder to allow it to not compute the extra And when it's not needed --- This PR also shows how we can implement all versions of comparison by implementing just one. for example creating a half greater than bloq that use logarithmic depth is enough to implement all the others. --- .../qualtran_dev_tools/notebook_specs.py | 4 + qualtran/bloqs/arithmetic/__init__.py | 4 + qualtran/bloqs/arithmetic/addition.ipynb | 6 +- qualtran/bloqs/arithmetic/addition.py | 28 +- qualtran/bloqs/arithmetic/comparison.ipynb | 452 +++++++++++++++++ qualtran/bloqs/arithmetic/comparison.py | 470 ++++++++++++++++++ qualtran/bloqs/arithmetic/comparison_test.py | 103 ++++ qualtran/serialization/resolver_dict.py | 4 + 8 files changed, 1063 insertions(+), 8 deletions(-) diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index 7c66311fe..6dc00babe 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -435,6 +435,10 @@ qualtran.bloqs.arithmetic.comparison._SQ_CMP_DOC, qualtran.bloqs.arithmetic.comparison._LEQ_DOC, qualtran.bloqs.arithmetic.comparison._CLinearDepthGreaterThan_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_GREATERTHAN_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_GREATERTHANEQUAL_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_LESSTHAN_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_LESSTHANEQUAL_DOC, ], ), NotebookSpecV2( diff --git a/qualtran/bloqs/arithmetic/__init__.py b/qualtran/bloqs/arithmetic/__init__.py index 533d0ee0c..59ca8a5af 100644 --- a/qualtran/bloqs/arithmetic/__init__.py +++ b/qualtran/bloqs/arithmetic/__init__.py @@ -23,6 +23,10 @@ GreaterThanConstant, LessThanConstant, LessThanEqual, + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, SingleQubitCompare, ) from qualtran.bloqs.arithmetic.controlled_addition import CAdd diff --git a/qualtran/bloqs/arithmetic/addition.ipynb b/qualtran/bloqs/arithmetic/addition.ipynb index c9d271e14..4cd90a386 100644 --- a/qualtran/bloqs/arithmetic/addition.ipynb +++ b/qualtran/bloqs/arithmetic/addition.ipynb @@ -186,12 +186,14 @@ "using $4n - 4 T$ gates. Uncomputation requires 0 T-gates.\n", "\n", "#### Parameters\n", - " - `bitsize`: Number of bits used to represent each input integer. The allocated output register is of size `bitsize+1` so it has enough space to hold the sum of `a+b`. \n", + " - `bitsize`: Number of bits used to represent each input integer. The allocated output register is of size `bitsize+1` so it has enough space to hold the sum of `a+b`.\n", + " - `is_adjoint`: Whether this is compute or uncompute version.\n", + " - `include_most_significant_bit`: Whether to add an extra most significant (i.e. carry) bit. \n", "\n", "#### Registers\n", " - `a`: A bitsize-sized input register (register a above).\n", " - `b`: A bitsize-sized input register (register b above).\n", - " - `c`: A bitize+1-sized LEFT/RIGHT register depending on whether the gate adjoint or not. \n", + " - `c`: The LEFT/RIGHT register depending on whether the gate adjoint or not. This register size is either bitsize or bitsize+1 depending on the value of `include_most_significant_bit`. \n", "\n", "#### References\n", " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" diff --git a/qualtran/bloqs/arithmetic/addition.py b/qualtran/bloqs/arithmetic/addition.py index f0b8de6b9..7b208bd36 100644 --- a/qualtran/bloqs/arithmetic/addition.py +++ b/qualtran/bloqs/arithmetic/addition.py @@ -260,11 +260,15 @@ class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[m Args: bitsize: Number of bits used to represent each input integer. The allocated output register is of size `bitsize+1` so it has enough space to hold the sum of `a+b`. + is_adjoint: Whether this is compute or uncompute version. + include_most_significant_bit: Whether to add an extra most significant (i.e. carry) bit. Registers: a: A bitsize-sized input register (register a above). b: A bitsize-sized input register (register b above). - c: A bitize+1-sized LEFT/RIGHT register depending on whether the gate adjoint or not. + c: The LEFT/RIGHT register depending on whether the gate adjoint or not. + This register size is either bitsize or bitsize+1 depending on + the value of `include_most_significant_bit`. References: [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648) @@ -272,6 +276,11 @@ class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[m bitsize: 'SymbolicInt' is_adjoint: bool = False + include_most_significant_bit: bool = True + + @property + def out_bitsize(self): + return self.bitsize + (1 if self.include_most_significant_bit else 0) @property def signature(self): @@ -280,14 +289,14 @@ def signature(self): [ Register('a', QUInt(self.bitsize)), Register('b', QUInt(self.bitsize)), - Register('c', QUInt(self.bitsize + 1), side=side), + Register('c', QUInt(self.out_bitsize), side=side), ] ) def registers(self) -> Sequence[Union[int, Sequence[int]]]: if not isinstance(self.bitsize, int): raise ValueError(f'Symbolic bitsize {self.bitsize} not supported') - return [2] * self.bitsize, [2] * self.bitsize, [2] * (self.bitsize + 1) + return [2] * self.bitsize, [2] * self.bitsize, [2] * self.out_bitsize def apply(self, a: int, b: int, c: int) -> Tuple[int, int, int]: return a, b, c + a + b @@ -307,7 +316,7 @@ def on_classical_vals( return { 'a': a, 'b': b, - 'c': add_ints(int(a), int(b), num_bits=self.bitsize + 1, is_signed=False), + 'c': add_ints(int(a), int(b), num_bits=self.out_bitsize, is_signed=False), } def with_registers(self, *new_registers: Union[int, Sequence[int]]): @@ -328,12 +337,19 @@ def decompose_from_registers( cirq.CX(a[i], c[i + 1]), cirq.CX(b[i], c[i]), ] - for i in range(self.bitsize) + for i in range(self.out_bitsize - 1) ] + if not self.include_most_significant_bit: + # Update c[-1] as c[-1] ^= a[-1]^b[-1] + i = self.bitsize - 1 + optree.append([cirq.CX(a[i], c[i]), cirq.CX(b[i], c[i])]) return cirq.inverse(optree) if self.is_adjoint else optree def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - return {And(uncompute=self.is_adjoint): self.bitsize, CNOT(): 5 * self.bitsize} + return { + And(uncompute=self.is_adjoint): self.out_bitsize - 1, + CNOT(): 5 * (self.bitsize - 1) + 2 + (3 if self.include_most_significant_bit else 0), + } def __pow__(self, power: int): if power == 1: diff --git a/qualtran/bloqs/arithmetic/comparison.ipynb b/qualtran/bloqs/arithmetic/comparison.ipynb index 4a0c8ad6d..5183ede00 100644 --- a/qualtran/bloqs/arithmetic/comparison.ipynb +++ b/qualtran/bloqs/arithmetic/comparison.ipynb @@ -1020,6 +1020,458 @@ "show_call_graph(equals_g)\n", "show_counts_sigma(equals_sigma)" ] + }, + { + "cell_type": "markdown", + "id": "e2db7c5d", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfGreaterThan`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{b-a}\\ket{a>b}$ using $n$ And gates.\n", + "\n", + "This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a > b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "758a6e35", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfGreaterThan" + ] + }, + { + "cell_type": "markdown", + "id": "3ea5a7bc", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26e4245f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.lineardepthhalfgreaterthan_small" + }, + "outputs": [], + "source": [ + "lineardepthhalfgreaterthan_small = LinearDepthHalfGreaterThan(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "29abac9f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d065b007", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalfgreaterthan_small],\n", + " ['`lineardepthhalfgreaterthan_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "d67e1888", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32952025", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalfgreaterthan_small_g, lineardepthhalfgreaterthan_small_sigma = lineardepthhalfgreaterthan_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalfgreaterthan_small_g)\n", + "show_counts_sigma(lineardepthhalfgreaterthan_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "9c39992e", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfGreaterThanEqual`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{a-b}\\ket{a \\geq b}$ using $n$ And gates.\n", + "\n", + "This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a >= b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58e6973f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfGreaterThanEqual" + ] + }, + { + "cell_type": "markdown", + "id": "32dae58a", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eded8868", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.lineardepthhalfgreaterthanequal_small" + }, + "outputs": [], + "source": [ + "lineardepthhalfgreaterthanequal_small = LinearDepthHalfGreaterThanEqual(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "0edbf55b", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5326975b", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalfgreaterthanequal_small],\n", + " ['`lineardepthhalfgreaterthanequal_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "eecfbf65", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abfbe19f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalfgreaterthanequal_small_g, lineardepthhalfgreaterthanequal_small_sigma = lineardepthhalfgreaterthanequal_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalfgreaterthanequal_small_g)\n", + "show_counts_sigma(lineardepthhalfgreaterthanequal_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "19744b75", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfLessThan`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{a-b}\\ket{a b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a < b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1eec63a1", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfLessThan" + ] + }, + { + "cell_type": "markdown", + "id": "75142163", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4759ae6d", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.lineardepthhalflessthan_small" + }, + "outputs": [], + "source": [ + "lineardepthhalflessthan_small = LinearDepthHalfLessThan(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "903efca4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbbefc84", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalflessthan_small],\n", + " ['`lineardepthhalflessthan_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "8862065f", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9be045c4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalflessthan_small_g, lineardepthhalflessthan_small_sigma = lineardepthhalflessthan_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalflessthan_small_g)\n", + "show_counts_sigma(lineardepthhalflessthan_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "8a868719", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfLessThanEqual`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{b-a}\\ket{a \\leq b}$ using $n$ And gates.\n", + "\n", + "This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a <= b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b4c9b03", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfLessThanEqual" + ] + }, + { + "cell_type": "markdown", + "id": "bae993de", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6610fd4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.lineardepthhalflessthanequal_small" + }, + "outputs": [], + "source": [ + "lineardepthhalflessthanequal_small = LinearDepthHalfLessThanEqual(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "72de4f8e", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdd8d4c4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalflessthanequal_small],\n", + " ['`lineardepthhalflessthanequal_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "973be9d4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9efd6db6", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalflessthanequal_small_g, lineardepthhalflessthanequal_small_sigma = lineardepthhalflessthanequal_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalflessthanequal_small_g)\n", + "show_counts_sigma(lineardepthhalflessthanequal_small_sigma)" + ] } ], "metadata": { diff --git a/qualtran/bloqs/arithmetic/comparison.py b/qualtran/bloqs/arithmetic/comparison.py index c546aee77..b28269ae4 100644 --- a/qualtran/bloqs/arithmetic/comparison.py +++ b/qualtran/bloqs/arithmetic/comparison.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc from collections import defaultdict from functools import cached_property from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union @@ -51,6 +52,7 @@ from qualtran.drawing import WireSymbol from qualtran.drawing.musical_score import Circle, Text, TextBox from qualtran.resource_counting.generalizers import ignore_split_join +from qualtran.simulation.classical_sim import add_ints from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt if TYPE_CHECKING: @@ -1228,3 +1230,471 @@ def _clineardepthgreaterthan_example() -> CLinearDepthGreaterThan: _CLinearDepthGreaterThan_DOC = BloqDocSpec( bloq_cls=CLinearDepthGreaterThan, examples=[_clineardepthgreaterthan_example] ) + + +@frozen +class _HalfLinearDepthGreaterThan(Bloq): + """A concrete implementation of half-circuit for greater than. + + This bloq can be returned by the _HalfComparisonBase._half_greater_than_bloq abstract property. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a > b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + dtype: Union[QInt, QUInt, QMontgomeryUInt] + uncompute: bool = False + + @cached_property + def signature(self) -> Signature: + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('a', self.dtype), + Register('b', self.dtype), + Register('c', QUInt(bitsize=self.dtype.bitsize + 1), side=side), + Register('target', QBit(), side=side), + ] + ) + + def adjoint(self) -> '_HalfLinearDepthGreaterThan': + return attrs.evolve(self, uncompute=self.uncompute ^ True) + + def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']: + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=b) + else: + a = bb.join(np.concatenate([[bb.allocate(1)], bb.split(a)])) + b = bb.join(np.concatenate([[bb.allocate(1)], bb.split(b)])) + + dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1) + b = bb.add(BitwiseNot(dtype), x=b) # b := -b-1 + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=b) + a, b, c = bb.add( + OutOfPlaceAdder(self.dtype.bitsize + 1, include_most_significant_bit=False), a=a, b=b + ) # c := a - b - 1 + c = bb.add(BitwiseNot(QUInt(dtype.bitsize)), x=c) # c := b - a + + # Update `target` + c_arr = bb.split(c) + target = bb.allocate(1) + c_arr[0], target = bb.add(CNOT(), ctrl=c_arr[0], target=target) + c = bb.join(c_arr) + + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=b) + b = bb.add(BitwiseNot(dtype), x=b) + + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=b) + else: + a_arr = bb.split(a) + a = bb.join(a_arr[1:]) + b_arr = bb.split(b) + b = bb.join(b_arr[1:]) + bb.free(a_arr[0]) + bb.free(b_arr[0]) + return {'a': a, 'b': b, 'c': c, 'target': target} + + def _uncompute( + self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet', c: 'Soquet', target: 'Soquet' + ) -> Dict[str, 'SoquetT']: + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=b) + else: + a = bb.join(np.concatenate([[bb.allocate(1)], bb.split(a)])) + b = bb.join(np.concatenate([[bb.allocate(1)], bb.split(b)])) + + dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1) + b = bb.add(BitwiseNot(dtype), x=b) # b := -b-1 + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=b) + + c_arr = bb.split(c) + c_arr[0], target = bb.add(CNOT(), ctrl=c_arr[0], target=target) + c = bb.join(c_arr) + + c = bb.add(BitwiseNot(QUInt(dtype.bitsize)), x=c) + a, b = bb.add( + OutOfPlaceAdder(self.dtype.bitsize + 1, include_most_significant_bit=False).adjoint(), + a=a, + b=b, + c=c, + ) + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=b) + b = bb.add(BitwiseNot(dtype), x=b) + + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=b) + else: + a_arr = bb.split(a) + a = bb.join(a_arr[1:]) + b_arr = bb.split(b) + b = bb.join(b_arr[1:]) + bb.free(a_arr[0]) + bb.free(b_arr[0]) + bb.free(target) + return {'a': a, 'b': b} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + a: 'Soquet', + b: 'Soquet', + c: Optional['Soquet'] = None, + target: Optional['Soquet'] = None, + ) -> Dict[str, 'SoquetT']: + if self.uncompute: + # Uncompute + assert c is not None + assert target is not None + return self._uncompute(bb, a, b, c, target) + else: + assert c is None + assert target is None + return self._compute(bb, a, b) + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1) + counts: 'BloqCountDictT' + if isinstance(self.dtype, QUInt): + counts = {BitwiseNot(dtype): 3} + else: + counts = {BitwiseNot(dtype): 2, BitwiseNot(QUInt(dtype.bitsize)): 1} + + counts[CNOT()] = 1 + + adder = OutOfPlaceAdder(self.dtype.bitsize + 1, include_most_significant_bit=False) + if self.uncompute: + adder = adder.adjoint() + counts[adder] = 1 + + return counts + + +@frozen +class _HalfComparisonBase(Bloq): + """Parent class for the 4 comparison operations (>, >=, <, <=). + + The four comparison operations can be implemented by implementing only one of them + and computing the others either by reversing the input order, flipping the result or both. + + The choice made is to build the four opertions around greater than. Namely the greater than + bloq returned by `._half_greater_than_bloq`; By changing this property we can change + change the properties of the constructed circuit (e.g. complexity, depth, ..etc). + + For example _LinearDepthHalfComparisonBase sets the property to a linear depth construction, + other implementations can set the property to a log depth construction. + + Args: + dtype: dtype of the two integers a and b. + _op_symbol: The symbol of the comparison operation. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a > b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + dtype: Union[QInt, QUInt, QMontgomeryUInt] + _op_symbol: str = attrs.field( + default='>', validator=lambda _, __, s: s in ('>', '<', '>=', '<='), repr=False + ) + uncompute: bool = False + + @cached_property + def signature(self) -> Signature: + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('a', self.dtype), + Register('b', self.dtype), + Register('c', QUInt(bitsize=self.dtype.bitsize + 1), side=side), + Register('target', QBit(), side=side), + ] + ) + + def adjoint(self) -> '_HalfComparisonBase': + return attrs.evolve(self, uncompute=self.uncompute ^ True) + + @cached_property + @abc.abstractmethod + def _half_greater_than_bloq(self) -> Bloq: + raise NotImplementedError() + + def _classical_comparison( + self, a: 'ClassicalValT', b: 'ClassicalValT' + ) -> Union[bool, np.bool_, NDArray[np.bool_]]: + if self._op_symbol == '>': + return a > b + elif self._op_symbol == '<': + return a < b + elif self._op_symbol == '>=': + return a >= b + else: + return a <= b + + def on_classical_vals( + self, + a: 'ClassicalValT', + b: 'ClassicalValT', + c: Optional['ClassicalValT'] = None, + target: Optional['ClassicalValT'] = None, + ) -> Dict[str, 'ClassicalValT']: + if self.uncompute: + assert c == add_ints( + int(a), + int(b), + num_bits=int(self.dtype.bitsize), + is_signed=isinstance(self.dtype, QInt), + ) + assert target == self._classical_comparison(a, b) + return {'a': a, 'b': b} + if self._op_symbol in ('>', '<='): + c = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) + else: + c = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) + return {'a': a, 'b': b, 'c': c, 'target': int(self._classical_comparison(a, b))} + + def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']: + if self._op_symbol in ('>', '<='): + a, b, c, target = bb.add_from(self._half_greater_than_bloq, a=a, b=b) # type: ignore + else: + b, a, c, target = bb.add_from(self._half_greater_than_bloq, a=b, b=a) # type: ignore + + if self._op_symbol in ('<=', '>='): + target = bb.add(XGate(), q=target) + + return {'a': a, 'b': b, 'c': c, 'target': target} + + def _uncompute( + self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet', c: 'Soquet', target: 'Soquet' + ) -> Dict[str, 'SoquetT']: + if self._op_symbol in ('<=', '>='): + target = bb.add(XGate(), q=target) + + if self._op_symbol in ('>', '<='): + a, b = bb.add_from(self._half_greater_than_bloq.adjoint(), a=a, b=b, c=c, target=target) # type: ignore + else: + a, b = bb.add_from(self._half_greater_than_bloq.adjoint(), a=b, b=a, c=c, target=target) # type: ignore + + return {'a': a, 'b': b} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + a: 'Soquet', + b: 'Soquet', + c: Optional['Soquet'] = None, + target: Optional['Soquet'] = None, + ) -> Dict[str, 'SoquetT']: + if self.uncompute: + assert c is not None + assert target is not None + return self._uncompute(bb, a, b, c, target) + else: + assert c is None + assert target is None + return self._compute(bb, a, b) + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + extra_ops = {} + if isinstance(self.dtype, QInt): + extra_ops = { + SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)): 2, + SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(): 2, + } + if self._op_symbol in ('>=', '<='): + extra_ops[XGate()] = 1 + adder = self._half_greater_than_bloq + if self.uncompute: + adder = adder.adjoint() + adder_call_graph = adder.build_call_graph(ssa) + assert isinstance(adder_call_graph, dict) + counts: defaultdict['Bloq', Union[int, sympy.Expr]] = defaultdict(lambda: 0) + counts.update(adder_call_graph) + for k, v in extra_ops.items(): + counts[k] += v + return counts + + +@frozen +class _LinearDepthHalfComparisonBase(_HalfComparisonBase): + """A wrapper around _HalfComparisonBase that sets ._half_greater_than_bloq property to a construction with linear depth.""" + + @cached_property + def _half_greater_than_bloq(self) -> Bloq: + return _HalfLinearDepthGreaterThan(self.dtype, uncompute=False) + + +@frozen +class LinearDepthHalfGreaterThan(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{b-a}\ket{a>b}$ using $n$ And gates. + + This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a > b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='>', repr=False, init=False) + + +@frozen +class LinearDepthHalfGreaterThanEqual(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{a-b}\ket{a \geq b}$ using $n$ And gates. + + This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a >= b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='>=', repr=False, init=False) + + +@frozen +class LinearDepthHalfLessThan(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{a-b}\ket{a b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a < b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='<', repr=False, init=False) + + +@frozen +class LinearDepthHalfLessThanEqual(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{b-a}\ket{a \leq b}$ using $n$ And gates. + + This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a <= b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='<=', repr=False, init=False) + + +@bloq_example +def _lineardepthhalfgreaterthan_small() -> LinearDepthHalfGreaterThan: + lineardepthhalfgreaterthan_small = LinearDepthHalfGreaterThan(QUInt(3)) + return lineardepthhalfgreaterthan_small + + +@bloq_example +def _lineardepthhalflessthan_small() -> LinearDepthHalfLessThan: + lineardepthhalflessthan_small = LinearDepthHalfLessThan(QUInt(3)) + return lineardepthhalflessthan_small + + +@bloq_example +def _lineardepthhalfgreaterthanequal_small() -> LinearDepthHalfGreaterThanEqual: + lineardepthhalfgreaterthanequal_small = LinearDepthHalfGreaterThanEqual(QUInt(3)) + return lineardepthhalfgreaterthanequal_small + + +@bloq_example +def _lineardepthhalflessthanequal_small() -> LinearDepthHalfLessThanEqual: + lineardepthhalflessthanequal_small = LinearDepthHalfLessThanEqual(QUInt(3)) + return lineardepthhalflessthanequal_small + + +_LINEAR_DEPTH_HALF_GREATERTHAN_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfGreaterThan, examples=[_lineardepthhalfgreaterthan_small] +) + + +_LINEAR_DEPTH_HALF_LESSTHAN_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfLessThan, examples=[_lineardepthhalflessthan_small] +) + + +_LINEAR_DEPTH_HALF_GREATERTHANEQUAL_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfGreaterThanEqual, examples=[_lineardepthhalfgreaterthanequal_small] +) + + +_LINEAR_DEPTH_HALF_LESSTHANEQUAL_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfLessThanEqual, examples=[_lineardepthhalflessthanequal_small] +) diff --git a/qualtran/bloqs/arithmetic/comparison_test.py b/qualtran/bloqs/arithmetic/comparison_test.py index f755d1eb7..9876537e8 100644 --- a/qualtran/bloqs/arithmetic/comparison_test.py +++ b/qualtran/bloqs/arithmetic/comparison_test.py @@ -28,6 +28,10 @@ _greater_than, _gt_k, _leq_symb, + _lineardepthhalfgreaterthan_small, + _lineardepthhalfgreaterthanequal_small, + _lineardepthhalflessthan_small, + _lineardepthhalflessthanequal_small, _lt_k_symb, BiQubitsMixer, CLinearDepthGreaterThan, @@ -38,10 +42,15 @@ LessThanConstant, LessThanEqual, LinearDepthGreaterThan, + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, SingleQubitCompare, ) from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim +from qualtran.resource_counting import get_cost_value, QECGatesCost from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join @@ -398,3 +407,97 @@ def test_clineardepthgreaterthan_tcomplexity(ctrl, dtype): c = CLinearDepthGreaterThan(dtype(n), ctrl).t_complexity() assert c.t == 4 * (n + 2) assert c.rotations == 0 + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +@pytest.mark.parametrize('bitsize', range(2, 5)) +@pytest.mark.parametrize('uncompute', [True, False]) +def test_linear_half_comparison_decomposition(comp_cls, dtype, bitsize, uncompute): + b = comp_cls(dtype(bitsize), uncompute) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +@pytest.mark.parametrize('bitsize', range(2, 5)) +@pytest.mark.parametrize('uncompute', [True, False]) +def test_linear_half_comparison_bloq_counts(comp_cls, dtype, bitsize, uncompute): + b = comp_cls(dtype(bitsize), uncompute) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +@pytest.mark.parametrize('bitsize', range(2, 5)) +def test_linear_half_comparison_classical_action(comp_cls, dtype, bitsize): + b = comp_cls(dtype(bitsize)) + qlt_testing.assert_consistent_classical_action( + b, a=dtype(bitsize).get_classical_domain(), b=dtype(bitsize).get_classical_domain() + ) + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +def test_linear_half_comparison_symbolic_complexity(comp_cls, dtype): + n = sympy.Symbol('n') + b = comp_cls(dtype(n)) + + cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() + + assert cost['n_t'] == 0 + assert cost['n_ccz'] == n + + # uncomputation has zero cost. + cost = get_cost_value(b.adjoint(), QECGatesCost()).total_t_and_ccz_count() + + assert cost['n_t'] == 0 + assert cost['n_ccz'] == 0 + + +def test_lineardepthhalfgreaterthan_small(bloq_autotester): + bloq_autotester(_lineardepthhalfgreaterthan_small) + + +def test_lineardepthhalflessthan_small(bloq_autotester): + bloq_autotester(_lineardepthhalflessthan_small) + + +def test_lineardepthhalfgreaterthanequal_small(bloq_autotester): + bloq_autotester(_lineardepthhalfgreaterthanequal_small) + + +def test_lineardepthhalflessthanequal_small(bloq_autotester): + bloq_autotester(_lineardepthhalflessthanequal_small) diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index 4aa73158d..bc3be1a65 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -177,6 +177,10 @@ "qualtran.bloqs.arithmetic.comparison.LinearDepthGreaterThan": qualtran.bloqs.arithmetic.comparison.LinearDepthGreaterThan, "qualtran.bloqs.arithmetic.comparison.SingleQubitCompare": qualtran.bloqs.arithmetic.comparison.SingleQubitCompare, "qualtran.bloqs.arithmetic.comparison.CLinearDepthGreaterThan": qualtran.bloqs.arithmetic.comparison.CLinearDepthGreaterThan, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThan": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThan, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThan": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThan, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThanEqual": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThanEqual, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThanEqual": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThanEqual, "qualtran.bloqs.arithmetic.controlled_add_or_subtract.ControlledAddOrSubtract": qualtran.bloqs.arithmetic.controlled_add_or_subtract.ControlledAddOrSubtract, "qualtran.bloqs.arithmetic.conversions.contiguous_index.ToContiguousIndex": qualtran.bloqs.arithmetic.conversions.contiguous_index.ToContiguousIndex, "qualtran.bloqs.arithmetic.conversions.ones_complement_to_twos_complement.SignedIntegerToTwosComplement": qualtran.bloqs.arithmetic.conversions.ones_complement_to_twos_complement.SignedIntegerToTwosComplement, From e3aeee0d5e8f72304b2c8a1124196954e764ce0a Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Wed, 23 Oct 2024 14:27:30 -0700 Subject: [PATCH 8/8] Add ECAdd() Bloq (#1425) * Initial commit of ec add waiting on equals to be merged. * Working on tests for ECAdd * ECAdd implementation and tests * remove modmul typo * Fix mypy errors * Better bugfix for ModAdd * Change mod inv classical impl to use monttgomery inv * Fix pytest error * Fix some comments * ECAdd lots of testing * Add comments about bugs to be fixed * Reduce complexity by keeping intermediate values mod p * Stash qmontgomery tests * Address comments * Fix montgomery prod/inv calculations + pylint/mypy --------- Co-authored-by: Noureldin Co-authored-by: Matthew Harrigan --- qualtran/_infra/data_types.py | 44 +- qualtran/_infra/data_types_test.py | 26 + qualtran/bloqs/arithmetic/_shims.py | 22 +- qualtran/bloqs/factoring/ecc/ec_add.ipynb | 36 +- qualtran/bloqs/factoring/ecc/ec_add.py | 1075 ++++++++++++++++- qualtran/bloqs/factoring/ecc/ec_add_test.py | 412 ++++++- qualtran/bloqs/factoring/ecc/ec_point.py | 3 +- qualtran/bloqs/factoring/ecc/ec_point_test.py | 1 + qualtran/bloqs/mod_arithmetic/_shims.py | 55 +- .../mod_arithmetic/mod_multiplication.py | 3 +- qualtran/serialization/resolver_dict.py | 8 + 11 files changed, 1639 insertions(+), 46 deletions(-) diff --git a/qualtran/_infra/data_types.py b/qualtran/_infra/data_types.py index e21eb209f..ee918d506 100644 --- a/qualtran/_infra/data_types.py +++ b/qualtran/_infra/data_types.py @@ -772,9 +772,14 @@ class QMontgomeryUInt(QDType): bitsize: The number of qubits used to represent the integer. References: - [Montgomery modular multiplication](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication) + [Montgomery modular multiplication](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication). + + [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126133 Cat Qubits](https://arxiv.org/abs/2302.06639). + Gouzien et al. 2023. + We follow Montgomery form as described in the above paper; namely, r = 2^bitsize. """ + # TODO(https://github.com/quantumlib/Qualtran/issues/1471): Add modulus p as a class member. bitsize: SymbolicInt @property @@ -810,6 +815,43 @@ def assert_valid_classical_val_array( if np.any(val_array >= 2**self.bitsize): raise ValueError(f"Too-large classical values encountered in {debug_str}") + def montgomery_inverse(self, xm: int, p: int) -> int: + """Returns the modular inverse of an integer in montgomery form. + + Args: + xm: An integer in montgomery form. + p: The modulus of the finite field. + """ + return ((pow(xm, -1, p)) * pow(2, 2 * self.bitsize, p)) % p + + def montgomery_product(self, xm: int, ym: int, p: int) -> int: + """Returns the modular product of two integers in montgomery form. + + Args: + xm: The first montgomery form integer for the product. + ym: The second montgomery form integer for the product. + p: The modulus of the finite field. + """ + return (xm * ym * pow(2, -self.bitsize, p)) % p + + def montgomery_to_uint(self, xm: int, p: int) -> int: + """Converts an integer in montgomery form to a normal form integer. + + Args: + xm: An integer in montgomery form. + p: The modulus of the finite field. + """ + return (xm * pow(2, -self.bitsize, p)) % p + + def uint_to_montgomery(self, x: int, p: int) -> int: + """Converts an integer into montgomery form. + + Args: + x: An integer. + p: The modulus of the finite field. + """ + return (x * pow(2, int(self.bitsize), p)) % p + @attrs.frozen class QGF(QDType): diff --git a/qualtran/_infra/data_types_test.py b/qualtran/_infra/data_types_test.py index 10347b702..65252c5cb 100644 --- a/qualtran/_infra/data_types_test.py +++ b/qualtran/_infra/data_types_test.py @@ -135,6 +135,32 @@ def test_qmontgomeryuint(): assert is_symbolic(QMontgomeryUInt(sympy.Symbol('x'))) +@pytest.mark.parametrize('p', [13, 17, 29]) +@pytest.mark.parametrize('val', [1, 5, 7, 9]) +def test_qmontgomeryuint_operations(val, p): + qmontgomeryuint_8 = QMontgomeryUInt(8) + # Convert value to montgomery form and get the modular inverse. + val_m = qmontgomeryuint_8.uint_to_montgomery(val, p) + mod_inv = qmontgomeryuint_8.montgomery_inverse(val_m, p) + + # Calculate the product in montgomery form and convert back to normal form for assertion. + assert ( + qmontgomeryuint_8.montgomery_to_uint( + qmontgomeryuint_8.montgomery_product(val_m, mod_inv, p), p + ) + == 1 + ) + + +@pytest.mark.parametrize('p', [13, 17, 29]) +@pytest.mark.parametrize('val', [1, 5, 7, 9]) +def test_qmontgomeryuint_conversions(val, p): + qmontgomeryuint_8 = QMontgomeryUInt(8) + assert val == qmontgomeryuint_8.montgomery_to_uint( + qmontgomeryuint_8.uint_to_montgomery(val, p), p + ) + + def test_qgf(): qgf_256 = QGF(characteristic=2, degree=8) assert str(qgf_256) == 'QGF(2**8)' diff --git a/qualtran/bloqs/arithmetic/_shims.py b/qualtran/bloqs/arithmetic/_shims.py index 0d40daba7..d75e3d72f 100644 --- a/qualtran/bloqs/arithmetic/_shims.py +++ b/qualtran/bloqs/arithmetic/_shims.py @@ -22,8 +22,11 @@ from attrs import frozen -from qualtran import Bloq, QBit, QUInt, Register, Signature +from qualtran import Bloq, QBit, QMontgomeryUInt, QUInt, Register, Signature +from qualtran.bloqs.arithmetic.bitwise import BitwiseNot +from qualtran.bloqs.arithmetic.controlled_addition import CAdd from qualtran.bloqs.basic_gates import Toffoli +from qualtran.bloqs.basic_gates.swap import TwoBitCSwap from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -39,6 +42,20 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: return {Toffoli(): self.n - 2} +@frozen +class CSub(Bloq): + n: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [Register('ctrl', QBit()), Register('x', QUInt(self.n)), Register('y', QUInt(self.n))] + ) + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return {CAdd(QMontgomeryUInt(self.n)): 1, BitwiseNot(QMontgomeryUInt(self.n)): 3} + + @frozen class Lt(Bloq): n: int @@ -62,3 +79,6 @@ class CHalf(Bloq): @cached_property def signature(self) -> 'Signature': return Signature([Register('ctrl', QBit()), Register('x', QUInt(self.n))]) + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return {TwoBitCSwap(): self.n} diff --git a/qualtran/bloqs/factoring/ecc/ec_add.ipynb b/qualtran/bloqs/factoring/ecc/ec_add.ipynb index 543458c8c..cbc279bbe 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.ipynb +++ b/qualtran/bloqs/factoring/ecc/ec_add.ipynb @@ -41,16 +41,22 @@ "This takes elliptic curve points given by (a, b) and (x, y)\n", "and outputs the sum (x_r, y_r) in the second pair of registers.\n", "\n", + "Because the decomposition of this Bloq is complex, we split it into six separate parts\n", + "corresponding to the parts described in figure 10 of the Litinski paper cited below. We follow\n", + "the signature from figure 5 and break down the further decompositions based on the steps in\n", + "figure 10.\n", + "\n", "#### Parameters\n", " - `n`: The bitsize of the two registers storing the elliptic curve point\n", - " - `mod`: The modulus of the field in which we do the addition. \n", + " - `mod`: The modulus of the field in which we do the addition.\n", + " - `window_size`: The number of bits in the ModMult window. \n", "\n", "#### Registers\n", - " - `a`: The x component of the first input elliptic curve point of bitsize `n`.\n", - " - `b`: The y component of the first input elliptic curve point of bitsize `n`.\n", - " - `x`: The x component of the second input elliptic curve point of bitsize `n`, which will contain the x component of the resultant curve point.\n", - " - `y`: The y component of the second input elliptic curve point of bitsize `n`, which will contain the y component of the resultant curve point.\n", - " - `lam`: The precomputed lambda slope used in the addition operation. \n", + " - `a`: The x component of the first input elliptic curve point of bitsize `n` in montgomery form.\n", + " - `b`: The y component of the first input elliptic curve point of bitsize `n` in montgomery form.\n", + " - `x`: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the x component of the resultant curve point.\n", + " - `y`: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the y component of the resultant curve point.\n", + " - `lam_r`: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. \n", "\n", "#### References\n", " - [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). Litinski. 2023. Fig 5.\n" @@ -91,6 +97,18 @@ "ec_add = ECAdd(n, mod=p)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "170da165", + "metadata": { + "cq.autogen": "ECAdd.ec_add_small" + }, + "outputs": [], + "source": [ + "ec_add_small = ECAdd(5, mod=7)" + ] + }, { "cell_type": "markdown", "id": "39210af4", @@ -111,8 +129,8 @@ "outputs": [], "source": [ "from qualtran.drawing import show_bloqs\n", - "show_bloqs([ec_add],\n", - " ['`ec_add`'])" + "show_bloqs([ec_add, ec_add_small],\n", + " ['`ec_add`', '`ec_add_small`'])" ] }, { @@ -157,7 +175,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/qualtran/bloqs/factoring/ecc/ec_add.py b/qualtran/bloqs/factoring/ecc/ec_add.py index 74a57706a..5f8216777 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.py +++ b/qualtran/bloqs/factoring/ecc/ec_add.py @@ -12,12 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cached_property +from typing import Dict, Union +import numpy as np import sympy from attrs import frozen -from qualtran import Bloq, bloq_example, BloqDocSpec, QUInt, Register, Signature -from qualtran.bloqs.arithmetic._shims import MultiCToffoli +from qualtran import ( + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + DecomposeTypeError, + QBit, + QMontgomeryUInt, + Register, + Side, + Signature, + Soquet, + SoquetT, +) +from qualtran.bloqs.arithmetic.comparison import Equals +from qualtran.bloqs.basic_gates import CNOT, IntState, Toffoli, ZeroState +from qualtran.bloqs.bookkeeping import Free +from qualtran.bloqs.mcmt import MultiAnd, MultiControlX, MultiTargetCNOT from qualtran.bloqs.mod_arithmetic import ( CModAdd, CModNeg, @@ -30,6 +48,925 @@ ) from qualtran.bloqs.mod_arithmetic._shims import ModInv from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.simulation.classical_sim import ClassicalValT +from qualtran.symbolics.types import HasLength, is_symbolic + +from .ec_point import ECPoint + + +@frozen +class _ECAddStepOne(Bloq): + r"""Performs step one of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + + Registers: + f1: Flag to set if a = x. + f2: Flag to set if b = -y. + f3: Flag to set if (a, b) = (0, 0). + f4: Flag to set if (x, y) = (0, 0). + ctrl: Flag to set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit(), side=Side.RIGHT), + Register('f2', QBit(), side=Side.RIGHT), + Register('f3', QBit(), side=Side.RIGHT), + Register('f4', QBit(), side=Side.RIGHT), + Register('ctrl', QBit(), side=Side.RIGHT), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, a: 'ClassicalValT', b: 'ClassicalValT', x: 'ClassicalValT', y: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + f1 = int(a == x) + f2 = int(b == (-y % self.mod)) + f3 = int(a == b == 0) + f4 = int(x == y == 0) + ctrl = int(f2 == f3 == f4 == 0) + return { + 'f1': f1, + 'f2': f2, + 'f3': f3, + 'f4': f4, + 'ctrl': ctrl, + 'a': a, + 'b': b, + 'x': x, + 'y': y, + } + + def build_composite_bloq( + self, bb: 'BloqBuilder', a: Soquet, b: Soquet, x: Soquet, y: Soquet + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initialize control flags to 0. + f1 = bb.add(ZeroState()) + f2 = bb.add(ZeroState()) + f3 = bb.add(ZeroState()) + f4 = bb.add(ZeroState()) + ctrl = bb.add(ZeroState()) + + # Set flag 1 if a = x. + a, x, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=a, y=x, target=f1) + + # Set flag 2 if b = -y. + y = bb.add(ModNeg(QMontgomeryUInt(self.n), mod=self.mod), x=y) + b, y, f2 = bb.add(Equals(QMontgomeryUInt(self.n)), x=b, y=y, target=f2) + y = bb.add(ModNeg(QMontgomeryUInt(self.n), mod=self.mod), x=y) + + # Set flag 3 if (a, b) == (0, 0). + ab_arr = np.concatenate([bb.split(a), bb.split(b)]) + ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) + ab_arr = np.split(ab_arr, 2) + a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Set flag 4 if (x, y) == (0, 0). + xy_arr = np.concatenate([bb.split(x), bb.split(y)]) + xy_arr, f4 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=xy_arr, target=f4) + xy_arr = np.split(xy_arr, 2) + x = bb.join(xy_arr[0], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Set ctrl flag if f2, f3, f4 are set. + f_ctrls = [f2, f3, f4] + f_ctrls, ctrl = bb.add(MultiControlX(cvs=[0] * 3), controls=f_ctrls, target=ctrl) + f2 = f_ctrls[0] + f3 = f_ctrls[1] + f4 = f_ctrls[2] + + # Return the output registers. + return { + 'f1': f1, + 'f2': f2, + 'f3': f3, + 'f4': f4, + 'ctrl': ctrl, + 'a': a, + 'b': b, + 'x': x, + 'y': y, + } + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + cvs: Union[list[int], HasLength] + if isinstance(self.n, int): + cvs = [0] * 2 * self.n + else: + cvs = HasLength(2 * self.n) + return { + Equals(QMontgomeryUInt(self.n)): 2, + ModNeg(QMontgomeryUInt(self.n), mod=self.mod): 2, + MultiControlX(cvs=cvs): 2, + MultiControlX(cvs=[0] * 3): 1, + } + + +@frozen +class _ECAddStepTwo(Bloq): + r"""Performs step two of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + f1: Flag set if a = x. + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit()), + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n), side=Side.RIGHT), + Register('lam_r', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + f1: 'ClassicalValT', + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam_r: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + x = (x - a) % self.mod + if ctrl == 1: + y = (y - b) % self.mod + if f1 == 1: + lam = lam_r + f1 = 0 + else: + lam = QMontgomeryUInt(self.n).montgomery_product( + int(y), QMontgomeryUInt(self.n).montgomery_inverse(int(x), self.mod), self.mod + ) + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit + # which flips f1 when lam and lam_r are equal. + if lam == lam_r: + f1 = (f1 + 1) % 2 + else: + lam = 0 + return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + f1: Soquet, + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam_r: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initalize lambda to 0. + lam = bb.add(IntState(bitsize=self.n, val=0)) + + # Perform modular subtraction so that x = (x - a) % p. + a, x = bb.add(ModSub(QMontgomeryUInt(self.n), mod=self.mod), x=a, y=x) + + # Perform controlled modular subtraction so that y = (y - b) % p iff ctrl = 1. + ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) + + # Perform modular inversion s.t. x = (x - a)^-1 % p. + x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + + # Perform modular multiplication z4 = (y / x) % p. + x, y, z4, z3, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=y, + ) + + # If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p. + z4_split = bb.split(z4) + lam_split = bb.split(lam) + for i in range(self.n): + ctrls = [f1, ctrl, z4_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[0, 1, 1]), controls=ctrls, target=lam_split[i] + ) + f1 = ctrls[0] + ctrl = ctrls[1] + z4_split[i] = ctrls[2] + z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) + + # If ctrl = 1 and x = a: lam = lam_r. + lam_r_split = bb.split(lam_r) + for i in range(self.n): + ctrls = [f1, ctrl, lam_r_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i] + ) + f1 = ctrls[0] + ctrl = ctrls[1] + lam_r_split[i] = ctrls[2] + lam_r = bb.join(lam_r_split, dtype=QMontgomeryUInt(self.n)) + lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) + + # If lam = lam_r: return f1 = 0. (If not we will flip f1 to 0 at the end iff x_r = y_r = 0). + lam, lam_r, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=lam, y=lam_r, target=f1) + + # Uncompute the modular multiplication then the modular inversion. + x, y = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=y, + target=z4, + qrom_indices=z3, + reduced=reduced, + ) + x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + + # Return the output registers. + return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + Equals(QMontgomeryUInt(self.n)): 1, + ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModInv(n=self.n, mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + MultiControlX(cvs=[0, 1, 1]): self.n, + MultiControlX(cvs=[1, 1, 1]): self.n, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + ModInv(n=self.n, mod=self.mod).adjoint(): 1, + } + + +@frozen +class _ECAddStepThree(Bloq): + r"""Performs step three of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 1: + x = (x + 3 * a) % self.mod + y = 0 + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Store (x - a) * lam % p in z1 (= (y - b) % p). + x, lam, z1, z2, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=lam, + ) + + # If ctrl: subtract z1 from y (= 0). + ctrl, z1, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=z1, y=y) + + # Uncompute original multiplication. + x, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=lam, + target=z1, + qrom_indices=z2, + reduced=reduced, + ) + + # z1 = a. + z1 = bb.add(IntState(bitsize=self.n, val=0)) + a_split = bb.split(a) + z1_split = bb.split(z1) + for i in range(self.n): + a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) + a = bb.join(a_split, QMontgomeryUInt(self.n)) + z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) + + # z1 = (3 * a) % p. + z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod), x=z1) + a, z1 = bb.add(ModAdd(self.n, mod=self.mod), x=a, y=z1) + + # If ctrl: x = (x + 2 * a) % p. + ctrl, z1, x = bb.add(CModAdd(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=z1, y=x) + + # Uncompute z1. + a, z1 = bb.add(ModAdd(self.n, mod=self.mod).adjoint(), x=a, y=z1) + z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(), x=z1) + a_split = bb.split(a) + z1_split = bb.split(z1) + for i in range(self.n): + a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) + a = bb.join(a_split, QMontgomeryUInt(self.n)) + z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) + bb.add(Free(QMontgomeryUInt(self.n)), reg=z1) + + # Return the output registers. + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + CNOT(): 2 * self.n, + ModDbl(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModAdd(self.n, mod=self.mod): 1, + CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModAdd(self.n, mod=self.mod).adjoint(): 1, + ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(): 1, + } + + +@frozen +class _ECAddStepFour(Bloq): + r"""Performs step four of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + x = ( + x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), self.mod) + ) % self.mod + if lam > 0: + y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), self.mod) + return {'x': x, 'y': y, 'lam': lam} + + def build_composite_bloq( + self, bb: 'BloqBuilder', x: Soquet, y: Soquet, lam: Soquet + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initialize z4 = lam. + z4 = bb.add(IntState(bitsize=self.n, val=0)) + lam_split = bb.split(lam) + z4_split = bb.split(z4) + for i in range(self.n): + lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) + lam = bb.join(lam_split, QMontgomeryUInt(self.n)) + z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) + + # z3 = lam * lam % p. + z4, lam, z3, z2, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=z4, + y=lam, + ) + + # x = a - x_r % p. + z3, x = bb.add(ModSub(QMontgomeryUInt(self.n), mod=self.mod), x=z3, y=x) + + # Uncompute the multiplication and initialization of z4. + z4, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=z4, + y=lam, + target=z3, + qrom_indices=z2, + reduced=reduced, + ) + lam_split = bb.split(lam) + z4_split = bb.split(z4) + for i in range(self.n): + lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) + lam = bb.join(lam_split, QMontgomeryUInt(self.n)) + z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) + bb.add(Free(QMontgomeryUInt(self.n)), reg=z4) + + # z3 = lam * x % p. + x, lam, z3, z4, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=lam, + ) + + # y = y_r + b % p. + z3_split = bb.split(z3) + y_split = bb.split(y) + for i in range(self.n): + z3_split[i], y_split[i] = bb.add(CNOT(), ctrl=z3_split[i], target=y_split[i]) + z3 = bb.join(z3_split, QMontgomeryUInt(self.n)) + y = bb.join(y_split, QMontgomeryUInt(self.n)) + + # Uncompute multiplication. + x, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=lam, + target=z3, + qrom_indices=z4, + reduced=reduced, + ) + + # Return the output registers. + return {'x': x, 'y': y, 'lam': lam} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 2, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 2, + CNOT(): 3 * self.n, + } + + +@frozen +class _ECAddStepFive(Bloq): + r"""Performs step five of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n), side=Side.LEFT), + ] + ) + + def on_classical_vals( + self, + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 1: + x = (a - x) % self.mod + y = (y - b) % self.mod + else: + x = (x + a) % self.mod + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # x = x ^ -1 % p. + x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + + # z4 = x * y % p. + x, y, z4, z3, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=y, + ) + + # If ctrl: lam = 0. + z4_split = bb.split(z4) + lam_split = bb.split(lam) + for i in range(self.n): + ctrls = [ctrl, z4_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[1, 1]), controls=ctrls, target=lam_split[i] + ) + ctrl = ctrls[0] + z4_split[i] = ctrls[1] + z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) + lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit where lambda + # is not set to 0 before being freed. + bb.add(Free(QMontgomeryUInt(self.n), dirty=True), reg=lam) + + # Uncompute multiplication and inverse. + x, y = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=y, + target=z4, + qrom_indices=z3, + reduced=reduced, + ) + x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + + # If ctrl: x = x_r - a % p. + ctrl, x = bb.add(CModNeg(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=x) + + # Add a to x (x = x_r). + a, x = bb.add(ModAdd(self.n, mod=self.mod), x=a, y=x) + + # If ctrl: subtract b from y (y = y_r). + ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) + + # Return the output registers. + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModInv(n=self.n, mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + ModInv(n=self.n, mod=self.mod).adjoint(): 1, + ModAdd(self.n, mod=self.mod): 1, + MultiControlX(cvs=[1, 1]): self.n, + CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1, + } + + +@frozen +class _ECAddStepSix(Bloq): + r"""Performs step six of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + + Registers: + f1: Flag to set if a = x. + f2: Flag to set if b = -y. + f3: Flag to set if (a, b) = (0, 0). + f4: Flag to set if (x, y) = (0, 0). + ctrl: Flag to set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit(), side=Side.LEFT), + Register('f2', QBit(), side=Side.LEFT), + Register('f3', QBit(), side=Side.LEFT), + Register('f4', QBit(), side=Side.LEFT), + Register('ctrl', QBit(), side=Side.LEFT), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + f1: 'ClassicalValT', + f2: 'ClassicalValT', + f3: 'ClassicalValT', + f4: 'ClassicalValT', + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if f4 == 1: + x = a + y = b + if f1 and f2: + x = 0 + y = 0 + return {'a': a, 'b': b, 'x': x, 'y': y} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + f1: Soquet, + f2: Soquet, + f3: Soquet, + f4: Soquet, + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Unset control if f2, f3, and f4 flags are set. + f_ctrls = [f2, f3, f4] + f_ctrls, ctrl = bb.add(MultiControlX(cvs=[0] * 3), controls=f_ctrls, target=ctrl) + f2 = f_ctrls[0] + f3 = f_ctrls[1] + f4 = f_ctrls[2] + + # Set (x, y) to (a, b) if f4 is set. + a_split = bb.split(a) + x_split = bb.split(x) + for i in range(self.n): + toff_ctrl = [f4, a_split[i]] + toff_ctrl, x_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=x_split[i]) + f4 = toff_ctrl[0] + a_split[i] = toff_ctrl[1] + a = bb.join(a_split, QMontgomeryUInt(self.n)) + x = bb.join(x_split, QMontgomeryUInt(self.n)) + b_split = bb.split(b) + y_split = bb.split(y) + for i in range(self.n): + toff_ctrl = [f4, b_split[i]] + toff_ctrl, y_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=y_split[i]) + f4 = toff_ctrl[0] + b_split[i] = toff_ctrl[1] + b = bb.join(b_split, QMontgomeryUInt(self.n)) + y = bb.join(y_split, QMontgomeryUInt(self.n)) + + # Unset f4 if (x, y) = (a, b). + ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n)) + xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) + ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) + ab_split = bb.split(ab) + a = bb.join(ab_split[: self.n], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_split[self.n :], dtype=QMontgomeryUInt(self.n)) + xy_split = bb.split(xy) + x = bb.join(xy_split[: self.n], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_split[self.n :], dtype=QMontgomeryUInt(self.n)) + + # Unset f3 if (a, b) = (0, 0). + ab_arr = np.concatenate([bb.split(a), bb.split(b)]) + ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) + ab_arr = np.split(ab_arr, 2) + a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) + + # If f1 and f2 are set, subtract a from x and add b to y. + ancilla = bb.add(ZeroState()) + toff_ctrl = [f1, f2] + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + ancilla, a, x = bb.add( + CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ancilla, x=a, y=x + ) + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + f1 = toff_ctrl[0] + f2 = toff_ctrl[1] + bb.add(Free(QBit()), reg=ancilla) + ancilla = bb.add(ZeroState()) + toff_ctrl = [f1, f2] + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + ancilla, b, y = bb.add( + CModAdd(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ancilla, x=b, y=y + ) + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + f1 = toff_ctrl[0] + f2 = toff_ctrl[1] + bb.add(Free(QBit()), reg=ancilla) + + # Unset f1 and f2 if (x, y) = (0, 0). + xy_arr = np.concatenate([bb.split(x), bb.split(y)]) + xy_arr, junk, out = bb.add(MultiAnd(cvs=[0] * 2 * self.n), ctrl=xy_arr) + targets = bb.join(np.array([f1, f2])) + out, targets = bb.add(MultiTargetCNOT(2), control=out, targets=targets) + targets = bb.split(targets) + f1 = targets[0] + f2 = targets[1] + xy_arr = bb.add( + MultiAnd(cvs=[0] * 2 * self.n).adjoint(), ctrl=xy_arr, junk=junk, target=out + ) + xy_arr = np.split(xy_arr, 2) + x = bb.join(xy_arr[0], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Free all ancilla qubits in the zero state. + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bugs in circuit where f1, + # f2, and f4 are freed before being set to 0. + bb.add(Free(QBit(), dirty=True), reg=f1) + bb.add(Free(QBit(), dirty=True), reg=f2) + bb.add(Free(QBit()), reg=f3) + bb.add(Free(QBit(), dirty=True), reg=f4) + bb.add(Free(QBit()), reg=ctrl) + + # Return the output registers. + return {'a': a, 'b': b, 'x': x, 'y': y} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + cvs: Union[list[int], HasLength] + if isinstance(self.n, int): + cvs = [0] * 2 * self.n + else: + cvs = HasLength(2 * self.n) + return { + MultiControlX(cvs=cvs): 1, + MultiControlX(cvs=[0] * 3): 1, + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, + Toffoli(): 2 * self.n + 4, + Equals(QMontgomeryUInt(2 * self.n)): 1, + MultiAnd(cvs=cvs): 1, + MultiTargetCNOT(2): 1, + MultiAnd(cvs=cvs).adjoint(): 1, + } @frozen @@ -39,18 +976,24 @@ class ECAdd(Bloq): This takes elliptic curve points given by (a, b) and (x, y) and outputs the sum (x_r, y_r) in the second pair of registers. + Because the decomposition of this Bloq is complex, we split it into six separate parts + corresponding to the parts described in figure 10 of the Litinski paper cited below. We follow + the signature from figure 5 and break down the further decompositions based on the steps in + figure 10. + Args: n: The bitsize of the two registers storing the elliptic curve point mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. Registers: - a: The x component of the first input elliptic curve point of bitsize `n`. - b: The y component of the first input elliptic curve point of bitsize `n`. - x: The x component of the second input elliptic curve point of bitsize `n`, which + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the x component of the resultant curve point. - y: The y component of the second input elliptic curve point of bitsize `n`, which + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the y component of the resultant curve point. - lam: The precomputed lambda slope used in the addition operation. + lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. References: [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). @@ -59,32 +1002,108 @@ class ECAdd(Bloq): n: int mod: int + window_size: int = 1 @cached_property def signature(self) -> 'Signature': return Signature( [ - Register('a', QUInt(self.n)), - Register('b', QUInt(self.n)), - Register('x', QUInt(self.n)), - Register('y', QUInt(self.n)), - Register('lam', QUInt(self.n)), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam_r', QMontgomeryUInt(self.n)), ] ) - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: - # litinksi + def build_composite_bloq( + self, bb: 'BloqBuilder', a: Soquet, b: Soquet, x: Soquet, y: Soquet, lam_r: Soquet + ) -> Dict[str, 'SoquetT']: + f1, f2, f3, f4, ctrl, a, b, x, y = bb.add( + _ECAddStepOne(n=self.n, mod=self.mod), a=a, b=b, x=x, y=y + ) + f1, ctrl, a, b, x, y, lam, lam_r = bb.add( + _ECAddStepTwo(n=self.n, mod=self.mod, window_size=self.window_size), + f1=f1, + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam_r=lam_r, + ) + ctrl, a, b, x, y, lam = bb.add( + _ECAddStepThree(n=self.n, mod=self.mod, window_size=self.window_size), + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam=lam, + ) + x, y, lam = bb.add( + _ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size), x=x, y=y, lam=lam + ) + ctrl, a, b, x, y = bb.add( + _ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size), + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam=lam, + ) + a, b, x, y = bb.add( + _ECAddStepSix(n=self.n, mod=self.mod), + f1=f1, + f2=f2, + f3=f3, + f4=f4, + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + ) + + return {'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r} + + def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]: + curve_a = ( + QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, self.mod) + * 2 + * QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod) + - (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod) ** 2) + ) % self.mod + p1 = ECPoint( + QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod), + mod=self.mod, + curve_a=curve_a, + ) + p2 = ECPoint( + QMontgomeryUInt(self.n).montgomery_to_uint(x, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(y, self.mod), + mod=self.mod, + curve_a=curve_a, + ) + result = p1 + p2 + return { + 'a': a, + 'b': b, + 'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, self.mod), + 'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, self.mod), + 'lam_r': lam_r, + } + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { - MultiCToffoli(n=self.n): 18, - ModAdd(bitsize=self.n, mod=self.mod): 3, - CModAdd(QUInt(self.n), mod=self.mod): 2, - ModSub(QUInt(self.n), mod=self.mod): 2, - CModSub(QUInt(self.n), mod=self.mod): 4, - ModNeg(QUInt(self.n), mod=self.mod): 2, - CModNeg(QUInt(self.n), mod=self.mod): 1, - ModDbl(QUInt(self.n), mod=self.mod): 2, - DirtyOutOfPlaceMontgomeryModMul(bitsize=self.n, window_size=4, mod=self.mod): 10, - ModInv(n=self.n, mod=self.mod): 4, + _ECAddStepOne(n=self.n, mod=self.mod): 1, + _ECAddStepTwo(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepThree(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepSix(n=self.n, mod=self.mod): 1, } @@ -95,4 +1114,10 @@ def _ec_add() -> ECAdd: return ec_add -_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add]) +@bloq_example +def _ec_add_small() -> ECAdd: + ec_add_small = ECAdd(5, mod=7) + return ec_add_small + + +_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add, _ec_add_small]) diff --git a/qualtran/bloqs/factoring/ecc/ec_add_test.py b/qualtran/bloqs/factoring/ecc/ec_add_test.py index 44a8b77e1..5295316f1 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_add_test.py @@ -12,13 +12,423 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import sympy + import qualtran.testing as qlt_testing -from qualtran.bloqs.factoring.ecc.ec_add import _ec_add +from qualtran._infra.data_types import QMontgomeryUInt +from qualtran.bloqs.factoring.ecc.ec_add import ( + _ec_add, + _ec_add_small, + _ECAddStepFive, + _ECAddStepFour, + _ECAddStepOne, + _ECAddStepSix, + _ECAddStepThree, + _ECAddStepTwo, + ECAdd, +) +from qualtran.resource_counting._bloq_counts import QECGatesCost +from qualtran.resource_counting._costing import get_cost_value +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join + + +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 8) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize('a,b', [(15, 13), (2, 10)]) +@pytest.mark.parametrize('x,y', [(15, 13), (0, 0)]) +def test_ec_add_steps_classical_fast(n, m, a, b, x, y): + p = 17 + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + + a = QMontgomeryUInt(n).uint_to_montgomery(a, p) + b = QMontgomeryUInt(n).uint_to_montgomery(b, p) + x = QMontgomeryUInt(n).uint_to_montgomery(x, p) + y = QMontgomeryUInt(n).uint_to_montgomery(y, p) + lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p + + bloq = _ECAddStepOne(n=n, mod=p) + ret1 = bloq.call_classically(a=a, b=b, x=x, y=y) + ret2 = bloq.decompose_bloq().call_classically(a=a, b=b, x=x, y=y) + assert ret1 == ret2 + + step_1 = _ECAddStepOne(n=n, mod=p).on_classical_vals(a=a, b=b, x=x, y=y) + bloq = _ECAddStepTwo(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + assert ret1 == ret2 + + step_2 = _ECAddStepTwo(n=n, mod=p, window_size=m).on_classical_vals( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + bloq = _ECAddStepThree(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + assert ret1 == ret2 + + step_3 = _ECAddStepThree(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + bloq = _ECAddStepFour(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + ret2 = bloq.decompose_bloq().call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + assert ret1 == ret2 + + step_4 = _ECAddStepFour(n=n, mod=p, window_size=m).on_classical_vals( + x=step_3['x'], y=step_3['y'], lam=step_3['lam'] + ) + bloq = _ECAddStepFive(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + assert ret1 == ret2 + + step_5 = _ECAddStepFive(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + bloq = _ECAddStepSix(n=n, mod=p) + ret1 = bloq.call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + assert ret1 == ret2 + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize( + 'a,b', + [ + (15, 13), + (2, 10), + (8, 3), + (12, 1), + (6, 6), + (5, 8), + (10, 15), + (1, 12), + (3, 0), + (1, 5), + (10, 2), + (0, 0), + ], +) +@pytest.mark.parametrize('x,y', [(15, 13), (5, 8), (10, 15), (1, 12), (3, 0), (1, 5), (10, 2)]) +def test_ec_add_steps_classical(n, m, a, b, x, y): + p = 17 + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + + a = QMontgomeryUInt(n).uint_to_montgomery(a, p) + b = QMontgomeryUInt(n).uint_to_montgomery(b, p) + x = QMontgomeryUInt(n).uint_to_montgomery(x, p) + y = QMontgomeryUInt(n).uint_to_montgomery(y, p) + lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p + + bloq = _ECAddStepOne(n=n, mod=p) + ret1 = bloq.call_classically(a=a, b=b, x=x, y=y) + ret2 = bloq.decompose_bloq().call_classically(a=a, b=b, x=x, y=y) + assert ret1 == ret2 + + step_1 = _ECAddStepOne(n=n, mod=p).on_classical_vals(a=a, b=b, x=x, y=y) + bloq = _ECAddStepTwo(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + assert ret1 == ret2 + + step_2 = _ECAddStepTwo(n=n, mod=p, window_size=m).on_classical_vals( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + bloq = _ECAddStepThree(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + assert ret1 == ret2 + + step_3 = _ECAddStepThree(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + bloq = _ECAddStepFour(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + ret2 = bloq.decompose_bloq().call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + assert ret1 == ret2 + + step_4 = _ECAddStepFour(n=n, mod=p, window_size=m).on_classical_vals( + x=step_3['x'], y=step_3['y'], lam=step_3['lam'] + ) + bloq = _ECAddStepFive(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + assert ret1 == ret2 + + step_5 = _ECAddStepFive(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + bloq = _ECAddStepSix(n=n, mod=p) + ret1 = bloq.call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + assert ret1 == ret2 + + +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 8) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize('a,b', [(15, 13), (2, 10)]) +@pytest.mark.parametrize('x,y', [(15, 13), (0, 0)]) +def test_ec_add_classical_fast(n, m, a, b, x, y): + p = 17 + bloq = ECAdd(n=n, mod=p, window_size=m) + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + ret1 = bloq.call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + ret2 = bloq.decompose_bloq().call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + assert ret1 == ret2 + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize( + 'a,b', + [ + (15, 13), + (2, 10), + (8, 3), + (12, 1), + (6, 6), + (5, 8), + (10, 15), + (1, 12), + (3, 0), + (1, 5), + (10, 2), + (0, 0), + ], +) +@pytest.mark.parametrize('x,y', [(15, 13), (5, 8), (10, 15), (1, 12), (3, 0), (1, 5), (10, 2)]) +def test_ec_add_classical(n, m, a, b, x, y): + p = 17 + bloq = ECAdd(n=n, mod=p, window_size=m) + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + ret1 = bloq.call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + ret2 = bloq.decompose_bloq().call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + assert ret1 == ret2 + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize( + ['n', 'window_size'], + [ + (n, window_size) + for n in range(5, 8) + for window_size in range(1, n + 1) + if n % window_size == 0 + ], +) +def test_ec_add_decomposition(n, window_size, p): + b = ECAdd(n=n, window_size=window_size, mod=p) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize( + ['n', 'window_size'], + [ + (n, window_size) + for n in range(5, 8) + for window_size in range(1, n + 1) + if n % window_size == 0 + ], +) +def test_ec_add_bloq_counts(n, window_size, p): + b = ECAdd(n=n, window_size=window_size, mod=p) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +def test_ec_add_symbolic_cost(): + n, m, p = sympy.symbols('n m p', integer=True) + + # In Litinski 2023 https://arxiv.org/abs/2306.08585 a window size of 4 is used. + # The cost function generally has floor/ceil division that disappear for bitsize=0 mod 4. + # This is why instead of using bitsize=n directly, we use bitsize=4*m=n. + b = ECAdd(n=4 * m, window_size=4, mod=p) + cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() + assert cost['n_t'] == 0 + + # Litinski 2023 https://arxiv.org/abs/2306.08585 + # Based on the counts from Figures 3, 5, and 8 the toffoli count for ECAdd is 126.5n^2 + 189n. + # The following formula is 126.5n^2 + 175.5n - 35. We account for the discrepancy in the + # coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, n extra toffolis + # in ModNeg, and 2n extra toffolis to do n 3-controlled toffolis in step 2. The expression is + # written with rationals because sympy comparison fails with floats. + assert isinstance(cost['n_ccz'], sympy.Expr) + assert ( + cost['n_ccz'].subs(m, n / 4).expand() + == sympy.Rational(253, 2) * n**2 + sympy.Rational(351, 2) * n - 35 + ) def test_ec_add(bloq_autotester): bloq_autotester(_ec_add) +def test_ec_add_small(bloq_autotester): + bloq_autotester(_ec_add_small) + + def test_notebook(): qlt_testing.execute_notebook('ec_add') diff --git a/qualtran/bloqs/factoring/ecc/ec_point.py b/qualtran/bloqs/factoring/ecc/ec_point.py index 968ebe0b5..c17ea5957 100644 --- a/qualtran/bloqs/factoring/ecc/ec_point.py +++ b/qualtran/bloqs/factoring/ecc/ec_point.py @@ -69,7 +69,8 @@ def __add__(self, other): return ECPoint(xr, yr, mod=self.mod, curve_a=self.curve_a) def __mul__(self, other): - assert other > 0, other + if other == 0: + return ECPoint.inf(mod=self.mod, curve_a=self.curve_a) x = self for _ in range(other - 1): x = x + self diff --git a/qualtran/bloqs/factoring/ecc/ec_point_test.py b/qualtran/bloqs/factoring/ecc/ec_point_test.py index f65981c80..1ac1b59bd 100644 --- a/qualtran/bloqs/factoring/ecc/ec_point_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_point_test.py @@ -21,6 +21,7 @@ def test_ec_point_overrides(): assert 1 * p == p assert 2 * p == (p + p) assert 3 * p == (p + p + p) + assert 0 * p == ECPoint.inf(mod=17, curve_a=0) def test_ec_point_addition(): diff --git a/qualtran/bloqs/mod_arithmetic/_shims.py b/qualtran/bloqs/mod_arithmetic/_shims.py index 65b61b4ca..7242f6afa 100644 --- a/qualtran/bloqs/mod_arithmetic/_shims.py +++ b/qualtran/bloqs/mod_arithmetic/_shims.py @@ -24,14 +24,17 @@ from functools import cached_property from typing import Dict, Optional, Tuple, TYPE_CHECKING +import attrs from attrs import frozen -from qualtran import Bloq, QUInt, Register, Signature -from qualtran.bloqs.arithmetic import Add, AddK, Negate, Subtract -from qualtran.bloqs.arithmetic._shims import CHalf, Lt, MultiCToffoli +from qualtran import Bloq, QUInt, Register, Side, Signature +from qualtran.bloqs.arithmetic import AddK, Negate +from qualtran.bloqs.arithmetic._shims import CHalf, CSub, Lt, MultiCToffoli +from qualtran.bloqs.arithmetic.controlled_addition import CAdd from qualtran.bloqs.basic_gates import CNOT, CSwap, Swap, Toffoli from qualtran.bloqs.mod_arithmetic.mod_multiplication import ModDbl from qualtran.drawing import Text, TextBox, WireSymbol +from qualtran.simulation.classical_sim import ClassicalValT if TYPE_CHECKING: from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -57,9 +60,9 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': (CNOT(), 2), (Lt(self.n), 1), (CSwap(self.n), 2), - (Subtract(QUInt(self.n)), 1), - (Add(QUInt(self.n)), 1), - (CNOT(), 1), + (CSub(self.n), 1), + (CAdd(QUInt(self.n)), 1), + (CNOT(), 2), (ModDbl(QUInt(self.n), self.mod), 1), (CHalf(self.n), 1), (CSwap(self.n), 2), @@ -88,10 +91,21 @@ def wire_symbol( class ModInv(Bloq): n: int mod: int + uncompute: bool = False @cached_property def signature(self) -> 'Signature': - return Signature([Register('x', QUInt(self.n)), Register('out', QUInt(self.n))]) + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('x', QUInt(self.n)), + Register('garbage1', QUInt(self.n), side=side), + Register('garbage2', QUInt(self.n), side=side), + ] + ) + + def adjoint(self) -> 'ModInv': + return attrs.evolve(self, uncompute=self.uncompute ^ True) def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': # Roetteler @@ -103,6 +117,29 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': Swap(self.n): 1, } + def on_classical_vals( + self, + x: 'ClassicalValT', + garbage1: Optional['ClassicalValT'] = None, + garbage2: Optional['ClassicalValT'] = None, + ) -> Dict[str, ClassicalValT]: + # TODO(https://github.com/quantumlib/Qualtran/issues/1443): Hacky classical simulation just + # to confirm correctness of ECAdd circuit. + if self.uncompute: + assert garbage1 is not None + assert garbage2 is not None + return {'x': garbage1} + assert garbage1 is None + assert garbage2 is None + + # Store the original x in the garbage registers for the uncompute simulation. + garbage1 = x + garbage2 = x + + x = pow(int(x), self.mod - 2, mod=self.mod) * pow(2, 2 * self.n, self.mod) % self.mod + + return {'x': x, 'garbage1': garbage1, 'garbage2': garbage2} + def wire_symbol( self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() ) -> 'WireSymbol': @@ -112,4 +149,8 @@ def wire_symbol( return TextBox('x') elif reg.name == 'out': return TextBox('$x^{-1}$') + elif reg.name == 'garbage1': + return TextBox('garbage1') + elif reg.name == 'garbage2': + return TextBox('garbage2') raise ValueError(f'Unrecognized register name {reg.name}') diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index 95f74edc7..9f289add1 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -650,7 +650,8 @@ def on_classical_vals( raise ValueError(f'classical action is not supported for {self}') if self.uncompute: assert ( - target is not None and target == (x * y * pow(2, self.bitsize, self.mod)) % self.mod + target is not None + and target == (x * y * pow(2, self.bitsize * (self.mod - 2), self.mod)) % self.mod ) assert qrom_indices is not None assert reduced is not None diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index bc3be1a65..347c78b61 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -98,6 +98,7 @@ import qualtran.bloqs.data_loading.qrom import qualtran.bloqs.data_loading.select_swap_qrom import qualtran.bloqs.factoring._factoring_shims +import qualtran.bloqs.factoring.ecc.ec_add import qualtran.bloqs.factoring.rsa import qualtran.bloqs.for_testing.atom import qualtran.bloqs.for_testing.casting @@ -348,6 +349,13 @@ "qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul, "qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul, "qualtran.bloqs.factoring._factoring_shims.MeasureQFT": qualtran.bloqs.factoring._factoring_shims.MeasureQFT, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepThree": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepThree, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFour": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFour, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFive": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFive, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepSix": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepSix, + "qualtran.bloqs.factoring.ecc.ec_add.ECAdd": qualtran.bloqs.factoring.ecc.ec_add.ECAdd, "qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate": qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate, "qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp": qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp, "qualtran.bloqs.for_testing.atom.TestAtom": qualtran.bloqs.for_testing.atom.TestAtom,