From 0220df23d2c58339ef17004e8ae3c63397510d86 Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Thu, 17 Oct 2024 16:18:39 -0700 Subject: [PATCH 1/7] Implement CtrlScaleModAdd and CModAddK bloqs for Modular Exponentiation (#1432) * Add some classical simulation * Implement primitives for ModExp * Fix serialization test error * Change Union -> SymbolicInt * Fix nits * Better symbolic decomposition error messages * Fix merge conflicts * Fixed docstring to be more readable (hopefully) * Address nits --------- Co-authored-by: Matthew Harrigan --- .../qualtran_dev_tools/notebook_specs.py | 2 + qualtran/bloqs/basic_gates/z_basis.py | 5 +- .../bloqs/mod_arithmetic/mod_addition.ipynb | 240 ++++++++++++++++++ qualtran/bloqs/mod_arithmetic/mod_addition.py | 125 ++++++++- .../bloqs/mod_arithmetic/mod_addition_test.py | 81 ++++-- qualtran/serialization/resolver_dict.py | 4 +- 6 files changed, 421 insertions(+), 36 deletions(-) diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index f96b511e4..1b5f219b4 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -493,6 +493,8 @@ qualtran.bloqs.mod_arithmetic.mod_addition._MOD_ADD_DOC, qualtran.bloqs.mod_arithmetic.mod_addition._MOD_ADD_K_DOC, qualtran.bloqs.mod_arithmetic.mod_addition._C_MOD_ADD_DOC, + qualtran.bloqs.mod_arithmetic.mod_addition._C_MOD_ADD_K_DOC, + qualtran.bloqs.mod_arithmetic.mod_addition._CTRL_SCALE_MOD_ADD_DOC, ], ), NotebookSpecV2( diff --git a/qualtran/bloqs/basic_gates/z_basis.py b/qualtran/bloqs/basic_gates/z_basis.py index b0f39b132..40bd02255 100644 --- a/qualtran/bloqs/basic_gates/z_basis.py +++ b/qualtran/bloqs/basic_gates/z_basis.py @@ -42,6 +42,7 @@ ) from qualtran.bloqs.bookkeeping import ArbitraryClifford from qualtran.drawing import Circle, directional_text_box, Text, TextBox, WireSymbol +from qualtran.symbolics import SymbolicInt if TYPE_CHECKING: import cirq @@ -453,7 +454,7 @@ class IntState(_IntVector): val: The register of size `bitsize` which initializes the value `val`. """ - def __init__(self, val: Union[int, sympy.Expr], bitsize: Union[int, sympy.Expr]): + def __init__(self, val: SymbolicInt, bitsize: SymbolicInt): self.__attrs_init__(val=val, bitsize=bitsize, state=True) @@ -478,7 +479,7 @@ class IntEffect(_IntVector): val: The register of size `bitsize` which de-allocates the value `val`. """ - def __init__(self, val: int, bitsize: int): + def __init__(self, val: SymbolicInt, bitsize: SymbolicInt): self.__attrs_init__(val=val, bitsize=bitsize, state=False) diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb b/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb index dab6f57c5..b081513d2 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb +++ b/qualtran/bloqs/mod_arithmetic/mod_addition.ipynb @@ -394,6 +394,246 @@ "show_call_graph(cmodadd_example_g)\n", "show_counts_sigma(cmodadd_example_sigma)" ] + }, + { + "cell_type": "markdown", + "id": "0523961e", + "metadata": { + "cq.autogen": "CModAddK.bloq_doc.md" + }, + "source": [ + "## `CModAddK`\n", + "Perform x += k mod m for constant k, m and quantum x.\n", + "\n", + "#### Parameters\n", + " - `k`: The integer to add to `x`.\n", + " - `mod`: The modulus for the addition.\n", + " - `bitsize`: The bitsize of the `x` register. \n", + "\n", + "#### Registers\n", + " - `ctrl`: The control bit\n", + " - `x`: The register to perform the in-place modular addition. \n", + "\n", + "#### References\n", + " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Gidney and Ekerå 2019. The reference implementation in section 2.2 uses CModAddK, but the circuit that it points to is just ModAdd (not ModAddK). This ModAdd is less efficient than the circuit later introduced in the Litinski paper so we choose to use that since it is more efficient and already implemented in Qualtran.\n", + " - [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). Litinski et al. 2023. This CModAdd circuit uses 2 fewer additions than the implementation referenced in the paper above. Because of this we choose to use this CModAdd bloq instead.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "341211fa", + "metadata": { + "cq.autogen": "CModAddK.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.mod_arithmetic import CModAddK" + ] + }, + { + "cell_type": "markdown", + "id": "6b00aef7", + "metadata": { + "cq.autogen": "CModAddK.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8877f17", + "metadata": { + "cq.autogen": "CModAddK.cmod_add_k" + }, + "outputs": [], + "source": [ + "n, m, k = sympy.symbols('n m k')\n", + "cmod_add_k = CModAddK(bitsize=n, mod=m, k=k)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87fe8b8f", + "metadata": { + "cq.autogen": "CModAddK.cmod_add_k_small" + }, + "outputs": [], + "source": [ + "cmod_add_k_small = CModAddK(bitsize=4, mod=7, k=1)" + ] + }, + { + "cell_type": "markdown", + "id": "f44f0268", + "metadata": { + "cq.autogen": "CModAddK.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7a02d6e", + "metadata": { + "cq.autogen": "CModAddK.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([cmod_add_k, cmod_add_k_small],\n", + " ['`cmod_add_k`', '`cmod_add_k_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "48b87ff2", + "metadata": { + "cq.autogen": "CModAddK.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6e6770b", + "metadata": { + "cq.autogen": "CModAddK.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "cmod_add_k_g, cmod_add_k_sigma = cmod_add_k.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(cmod_add_k_g)\n", + "show_counts_sigma(cmod_add_k_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "21f93349", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.bloq_doc.md" + }, + "source": [ + "## `CtrlScaleModAdd`\n", + "Perform y += x*k mod m for constant k, m and quantum x, y.\n", + "\n", + "#### Parameters\n", + " - `k`: The constant integer to scale `x` before adding into `y`.\n", + " - `mod`: The modulus of the addition\n", + " - `bitsize`: The size of the two registers. \n", + "\n", + "#### Registers\n", + " - `ctrl`: The control bit\n", + " - `x`: The 'source' quantum register containing the integer to be scaled and added to `y`.\n", + " - `y`: The 'destination' quantum register to which the addition will apply. \n", + "\n", + "#### References\n", + " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Construction based on description in section 2.2 paragraph 4. We add n And/And† bloqs because the bloq is controlled, but the construction also involves modular addition controlled on the qubits comprising register x.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ac170e5", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.mod_arithmetic import CtrlScaleModAdd" + ] + }, + { + "cell_type": "markdown", + "id": "58ee7de2", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73c2c6f7", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.ctrl_scale_mod_add" + }, + "outputs": [], + "source": [ + "n, m, k = sympy.symbols('n m k')\n", + "ctrl_scale_mod_add = CtrlScaleModAdd(bitsize=n, mod=m, k=k)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e822d5eb", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.ctrl_scale_mod_add_small" + }, + "outputs": [], + "source": [ + "ctrl_scale_mod_add_small = CtrlScaleModAdd(bitsize=4, mod=7, k=1)" + ] + }, + { + "cell_type": "markdown", + "id": "fe4c8957", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4dc8923", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([ctrl_scale_mod_add, ctrl_scale_mod_add_small],\n", + " ['`ctrl_scale_mod_add`', '`ctrl_scale_mod_add_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "97d6888d", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd734f6b", + "metadata": { + "cq.autogen": "CtrlScaleModAdd.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "ctrl_scale_mod_add_g, ctrl_scale_mod_add_sigma = ctrl_scale_mod_add.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(ctrl_scale_mod_add_g)\n", + "show_counts_sigma(ctrl_scale_mod_add_sigma)" + ] } ], "metadata": { diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition.py b/qualtran/bloqs/mod_arithmetic/mod_addition.py index efe224946..a8186cdfe 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition.py @@ -23,6 +23,7 @@ Bloq, bloq_example, BloqDocSpec, + DecomposeTypeError, GateWithRegisters, QBit, QMontgomeryUInt, @@ -35,8 +36,9 @@ from qualtran.bloqs.arithmetic.addition import Add, AddK from qualtran.bloqs.arithmetic.comparison import CLinearDepthGreaterThan, LinearDepthGreaterThan from qualtran.bloqs.arithmetic.controlled_addition import CAdd -from qualtran.bloqs.basic_gates import XGate +from qualtran.bloqs.basic_gates import IntEffect, IntState, XGate from qualtran.bloqs.bookkeeping import Cast +from qualtran.bloqs.mcmt.and_bloq import And from qualtran.drawing import Circle, Text, TextBox, WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_split_join @@ -89,7 +91,7 @@ def on_classical_vals( def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']: if is_symbolic(self.bitsize): - raise NotImplementedError(f'symbolic decomposition is not supported for {self}') + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") # Allocate ancilla bits for use in addition. junk_bit = bb.allocate(n=1) sign = bb.allocate(n=1) @@ -269,6 +271,19 @@ class CModAddK(Bloq): Registers: ctrl: The control bit x: The register to perform the in-place modular addition. + + References: + [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). + Gidney and Ekerå 2019. + The reference implementation in section 2.2 uses CModAddK, but the circuit that it points + to is just ModAdd (not ModAddK). This ModAdd is less efficient than the circuit later + introduced in the Litinski paper so we choose to use that since it is more efficient and + already implemented in Qualtran. + + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). + Litinski et al. 2023. + This CModAdd circuit uses 2 fewer additions than the implementation referenced in the paper + above. Because of this we choose to use this CModAdd bloq instead. """ k: Union[int, sympy.Expr] @@ -279,12 +294,53 @@ class CModAddK(Bloq): def signature(self) -> 'Signature': return Signature([Register('ctrl', QBit()), Register('x', QUInt(self.bitsize))]) + def build_composite_bloq( + self, bb: 'BloqBuilder', ctrl: 'Soquet', x: 'Soquet' + ) -> Dict[str, 'SoquetT']: + k = bb.add(IntState(bitsize=self.bitsize, val=self.k)) + ctrl, k, x = bb.add(CModAdd(QUInt(self.bitsize), mod=self.mod), ctrl=ctrl, x=k, y=x) + bb.add(IntEffect(bitsize=self.bitsize, val=self.k), val=k) + return {'ctrl': ctrl, 'x': x} + + def on_classical_vals( + self, ctrl: 'ClassicalValT', x: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 0: + return {'ctrl': 0, 'x': x} + + assert ctrl == 1, 'Bad ctrl value.' + x = (x + self.k) % self.mod + return {'ctrl': ctrl, 'x': x} + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - k = ssa.new_symbol('k') - return {AddK(k=k, bitsize=self.bitsize).controlled(): 5} + return {CModAdd(QUInt(self.bitsize), mod=self.mod): 1} - def short_name(self) -> str: - return f'x += {self.k} % {self.mod}' + def wire_symbol( + self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() + ) -> 'WireSymbol': + if reg is None: + return Text(f"mod {self.mod}") + if reg.name == 'ctrl': + return Circle() + if reg.name == 'x': + return TextBox(f'x += {self.k}') + raise ValueError(f"Unknown register {reg}") + + +@bloq_example(generalizer=ignore_split_join) +def _cmod_add_k() -> CModAddK: + n, m, k = sympy.symbols('n m k') + cmod_add_k = CModAddK(bitsize=n, mod=m, k=k) + return cmod_add_k + + +@bloq_example +def _cmod_add_k_small() -> CModAddK: + cmod_add_k_small = CModAddK(bitsize=4, mod=7, k=1) + return cmod_add_k_small + + +_C_MOD_ADD_K_DOC = BloqDocSpec(bloq_cls=CModAddK, examples=[_cmod_add_k, _cmod_add_k_small]) @frozen @@ -300,6 +356,12 @@ class CtrlScaleModAdd(Bloq): ctrl: The control bit x: The 'source' quantum register containing the integer to be scaled and added to `y`. y: The 'destination' quantum register to which the addition will apply. + + References: + [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). + Construction based on description in section 2.2 paragraph 4. We add n And/And† bloqs + because the bloq is controlled, but the construction also involves modular addition + controlled on the qubits comprising register x. """ k: Union[int, sympy.Expr] @@ -316,9 +378,38 @@ def signature(self) -> 'Signature': ] ) + def build_composite_bloq( + self, bb: 'BloqBuilder', ctrl: 'Soquet', x: 'Soquet', y: 'Soquet' + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.") + x_split = bb.split(x) + for i in range(int(self.bitsize)): + and_ctrl = [ctrl, x_split[i]] + and_ctrl, ancilla = bb.add(And(), ctrl=and_ctrl) + ancilla, y = bb.add( + CModAddK( + k=((self.k * 2 ** (self.bitsize - 1 - i)) % self.mod), + bitsize=self.bitsize, + mod=self.mod, + ), + ctrl=ancilla, + x=y, + ) + and_ctrl = bb.add(And().adjoint(), ctrl=and_ctrl, target=ancilla) + ctrl = and_ctrl[0] + x_split[i] = and_ctrl[1] + x = bb.join(x_split, dtype=QUInt(self.bitsize)) + + return {'ctrl': ctrl, 'x': x, 'y': y} + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': k = ssa.new_symbol('k') - return {CModAddK(k=k, bitsize=self.bitsize, mod=self.mod): self.bitsize} + return { + CModAddK(k=k, bitsize=self.bitsize, mod=self.mod): self.bitsize, + And(): self.bitsize, + And().adjoint(): self.bitsize, + } def on_classical_vals( self, ctrl: 'ClassicalValT', x: 'ClassicalValT', y: 'ClassicalValT' @@ -344,6 +435,24 @@ def wire_symbol( raise ValueError(f"Unknown register {reg}") +@bloq_example(generalizer=ignore_split_join) +def _ctrl_scale_mod_add() -> CtrlScaleModAdd: + n, m, k = sympy.symbols('n m k') + ctrl_scale_mod_add = CtrlScaleModAdd(bitsize=n, mod=m, k=k) + return ctrl_scale_mod_add + + +@bloq_example +def _ctrl_scale_mod_add_small() -> CtrlScaleModAdd: + ctrl_scale_mod_add_small = CtrlScaleModAdd(bitsize=4, mod=7, k=1) + return ctrl_scale_mod_add_small + + +_CTRL_SCALE_MOD_ADD_DOC = BloqDocSpec( + bloq_cls=CtrlScaleModAdd, examples=[_ctrl_scale_mod_add, _ctrl_scale_mod_add_small] +) + + @frozen class CModAdd(Bloq): r"""Controlled Modular Addition. @@ -390,6 +499,8 @@ def on_classical_vals( def build_composite_bloq( self, bb: 'BloqBuilder', ctrl, x: Soquet, y: Soquet ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.dtype.bitsize): + raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}') y_arr = bb.split(y) ancilla = bb.allocate(1) x = bb.add(Cast(self.dtype, QUInt(self.dtype.bitsize)), reg=x) diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py index 5b37b355e..455670bb5 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py @@ -18,10 +18,21 @@ import pytest import sympy +import qualtran.testing as qlt_testing from qualtran import QMontgomeryUInt, QUInt from qualtran.bloqs.arithmetic import Add from qualtran.bloqs.mod_arithmetic import CModAdd, CModAddK, CtrlScaleModAdd, ModAdd, ModAddK -from qualtran.bloqs.mod_arithmetic.mod_addition import _cmodadd_example +from qualtran.bloqs.mod_arithmetic.mod_addition import ( + _cmod_add_k, + _cmod_add_k_small, + _cmodadd_example, + _ctrl_scale_mod_add, + _ctrl_scale_mod_add_small, + _mod_add, + _mod_add_k, + _mod_add_k_large, + _mod_add_k_small, +) from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join @@ -33,6 +44,29 @@ ) +@pytest.mark.parametrize( + "bloq", + [ + _mod_add, + _mod_add_k, + _mod_add_k_small, + _mod_add_k_large, + _cmod_add_k, + _cmod_add_k_small, + _ctrl_scale_mod_add, + _ctrl_scale_mod_add_small, + _cmodadd_example, + ], +) +def test_examples(bloq_autotester, bloq): + bloq_autotester(bloq) + + +@pytest.mark.notebook +def test_notebook(): + execute_notebook('mod_addition') + + def identity_map(n: int): """Returns a dict of size `2**n` mapping each integer in range [0, 2**n) to itself.""" return {i: i for i in range(2**n)} @@ -61,22 +95,6 @@ def test_add_mod_n_gate_counts(bitsize): assert bloq.t_complexity() == add_constant_mod_n_ref_t_complexity_(bloq) -def test_ctrl_scale_mod_add(): - bloq = CtrlScaleModAdd(k=123, mod=13 * 17, bitsize=8) - - counts = bloq.bloq_counts() - ((bloq, n),) = counts.items() - assert n == 8 - - -def test_ctrl_mod_add_k(): - bloq = CModAddK(k=123, mod=13 * 17, bitsize=8) - - counts = bloq.bloq_counts() - ((bloq, n),) = counts.items() - assert n == 5 - - @pytest.mark.parametrize('bitsize,p', [(1, 1), (2, 3), (5, 8)]) def test_mod_add_valid_decomp(bitsize, p): bloq = ModAdd(bitsize=bitsize, mod=p) @@ -131,6 +149,26 @@ def test_classical_action_cmodadd_fast(control, bitsize): assert b.call_classically(ctrl=c, x=x, y=y) == cb.call_classically(ctrl=c, x=x, y=y) +@pytest.mark.slow +@pytest.mark.parametrize( + ['prime', 'bitsize', 'k'], + [(p, n, k) for p in (13, 17, 23) for n in range(p.bit_length(), 8) for k in range(1, p)], +) +def test_cscalemodadd_classical_action(bitsize, prime, k): + b = CtrlScaleModAdd(bitsize=bitsize, mod=prime, k=k) + qlt_testing.assert_consistent_classical_action(b, ctrl=(0, 1), x=range(prime), y=range(prime)) + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['prime', 'bitsize', 'k'], + [(p, n, k) for p in (13, 17, 23) for n in range(p.bit_length(), 8) for k in range(1, p)], +) +def test_cmodaddk_classical_action(bitsize, prime, k): + b = CModAddK(bitsize=bitsize, mod=prime, k=k) + qlt_testing.assert_consistent_classical_action(b, ctrl=(0, 1), x=range(prime)) + + @pytest.mark.parametrize('control', range(2)) @pytest.mark.parametrize( ['prime', 'bitsize'], @@ -154,15 +192,6 @@ def test_cmodadd_cost(control, dtype): assert cost.total_t_count() == 4 * n_toffolis -def test_cmodadd_example(bloq_autotester): - bloq_autotester(_cmodadd_example) - - -@pytest.mark.notebook -def test_notebook(): - execute_notebook('mod_addition') - - def test_cmod_add_complexity_vs_ref(): n, k = sympy.symbols('n k', integer=True, positive=True) bloq = CModAdd(QUInt(n), mod=k) diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index dad2907ea..f53ea457b 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -328,8 +328,10 @@ "qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint": qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint, "qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM": qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM, "qualtran.bloqs.mod_arithmetic.CModAddK": qualtran.bloqs.mod_arithmetic.CModAddK, - "qualtran.bloqs.mod_arithmetic.mod_addition.CModAdd": qualtran.bloqs.mod_arithmetic.CModAdd, + "qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd": qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd, + "qualtran.bloqs.mod_arithmetic.mod_addition.CModAdd": qualtran.bloqs.mod_arithmetic.mod_addition.CModAdd, "qualtran.bloqs.mod_arithmetic.mod_addition.ModAddK": qualtran.bloqs.mod_arithmetic.mod_addition.ModAddK, + "qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK": qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK, "qualtran.bloqs.mod_arithmetic.mod_addition.CtrlScaleModAdd": qualtran.bloqs.mod_arithmetic.CtrlScaleModAdd, "qualtran.bloqs.mod_arithmetic.ModAdd": qualtran.bloqs.mod_arithmetic.ModAdd, "qualtran.bloqs.mod_arithmetic.ModSub": qualtran.bloqs.mod_arithmetic.ModSub, From 3117a84efb57a1222f0a865fd29f3043c6aa704b Mon Sep 17 00:00:00 2001 From: Anurudh Peduri <7265746+anurudhp@users.noreply.github.com> Date: Thu, 17 Oct 2024 16:45:20 -0700 Subject: [PATCH 2/7] Default `Bloq.get_ctrl_system`: Use `And` ladder to reduce multiple controls to single control (#1456) * default ctrl system: use And ladder to reduce controls to a single control bit * fix tests * test: `XGate().controlled(...)` --------- Co-authored-by: Matthew Harrigan --- qualtran/_infra/bloq.py | 7 +++- qualtran/_infra/gate_with_registers_test.py | 4 ++- qualtran/bloqs/basic_gates/identity_test.py | 40 +-------------------- qualtran/bloqs/basic_gates/x_basis_test.py | 14 ++++++++ qualtran/bloqs/qsp/generalized_qsp_test.py | 2 +- 5 files changed, 25 insertions(+), 42 deletions(-) diff --git a/qualtran/_infra/bloq.py b/qualtran/_infra/bloq.py index 7ccd20037..ec5e2d735 100644 --- a/qualtran/_infra/bloq.py +++ b/qualtran/_infra/bloq.py @@ -394,7 +394,12 @@ def _my_add_controlled( add_controlled: A function with the signature documented above that the system can use to automatically wire up the new control registers. """ - from qualtran import Controlled + from qualtran import Controlled, CtrlSpec + from qualtran.bloqs.mcmt.controlled_via_and import ControlledViaAnd + + if ctrl_spec != CtrlSpec(): + # reduce controls to a single qubit + return ControlledViaAnd.make_ctrl_system(self, ctrl_spec=ctrl_spec) return Controlled.make_ctrl_system(self, ctrl_spec=ctrl_spec) diff --git a/qualtran/_infra/gate_with_registers_test.py b/qualtran/_infra/gate_with_registers_test.py index a3735a889..31fd6f863 100644 --- a/qualtran/_infra/gate_with_registers_test.py +++ b/qualtran/_infra/gate_with_registers_test.py @@ -151,8 +151,10 @@ def test_gate_with_registers_decompose_from_context_auto_generated(): def test_non_unitary_controlled(): + from qualtran.bloqs.mcmt.controlled_via_and import ControlledViaAnd + bloq = BloqWithDecompose() - assert bloq.controlled(control_values=[0]) == Controlled(bloq, CtrlSpec(cvs=0)) + assert bloq.controlled(control_values=[0]) == ControlledViaAnd(bloq, CtrlSpec(cvs=0)) @pytest.mark.notebook diff --git a/qualtran/bloqs/basic_gates/identity_test.py b/qualtran/bloqs/basic_gates/identity_test.py index ea3a4255e..3a45b8980 100644 --- a/qualtran/bloqs/basic_gates/identity_test.py +++ b/qualtran/bloqs/basic_gates/identity_test.py @@ -15,16 +15,14 @@ import numpy as np import pytest import sympy -from attrs import frozen -from qualtran import Bloq, BloqBuilder, CtrlSpec, QInt, QUInt, Signature, Soquet, SoquetT +from qualtran import BloqBuilder from qualtran.bloqs.basic_gates import OneState from qualtran.bloqs.basic_gates.identity import _identity, _identity_n, _identity_symb, Identity from qualtran.simulation.classical_sim import ( format_classical_truth_table, get_classical_truth_table, ) -from qualtran.symbolics import SymbolicInt from qualtran.testing import execute_notebook @@ -95,42 +93,6 @@ def test_identity_controlled(): assert Identity(n).controlled() == Identity(n + 1) -@frozen -class TestIdentityDecomposition(Bloq): - """helper to test Identity.get_ctrl_system""" - - bitsize: SymbolicInt - - @property - def signature(self) -> 'Signature': - return Signature.build(q=self.bitsize) - - def build_composite_bloq(self, bb: 'BloqBuilder', q: Soquet) -> dict[str, 'SoquetT']: - q = bb.add(Identity(self.bitsize), q=q) - q = bb.add(Identity(self.bitsize), q=q) - return {'q': q} - - -@pytest.mark.parametrize("n", [4, sympy.Symbol("n")]) -@pytest.mark.parametrize( - "ctrl_spec", - [ - CtrlSpec(cvs=(np.array([1, 0, 1]),)), - CtrlSpec(qdtypes=(QUInt(3), QInt(3)), cvs=(np.array(0b010), np.array(0b001))), - ], -) -def test_identity_get_ctrl_system(n: SymbolicInt, ctrl_spec: CtrlSpec): - m = ctrl_spec.num_qubits - - bloq = TestIdentityDecomposition(n) - ctrl_bloq = bloq.controlled(ctrl_spec) - - _ = ctrl_bloq.decompose_bloq() - - _, sigma = ctrl_bloq.call_graph() - assert sigma == {Identity(n + m): 2} - - @pytest.mark.notebook def test_notebook(): execute_notebook('identity') diff --git a/qualtran/bloqs/basic_gates/x_basis_test.py b/qualtran/bloqs/basic_gates/x_basis_test.py index 58851757e..d6b3b670d 100644 --- a/qualtran/bloqs/basic_gates/x_basis_test.py +++ b/qualtran/bloqs/basic_gates/x_basis_test.py @@ -92,3 +92,17 @@ def test_x_truth_table(): 0 -> 1 1 -> 0""" ) + + +def test_controlled_x(): + from qualtran import CtrlSpec, QUInt + from qualtran.bloqs.basic_gates import CNOT + from qualtran.bloqs.mcmt import And + + def _keep_and(b): + return isinstance(b, And) + + n = 8 + bloq = XGate().controlled(CtrlSpec(qdtypes=QUInt(n), cvs=1)) + _, sigma = bloq.call_graph(keep=_keep_and) + assert sigma == {And(): n - 1, CNOT(): 1, And().adjoint(): n - 1, XGate(): 4 * (n - 1)} diff --git a/qualtran/bloqs/qsp/generalized_qsp_test.py b/qualtran/bloqs/qsp/generalized_qsp_test.py index c64899b19..c925da2a8 100644 --- a/qualtran/bloqs/qsp/generalized_qsp_test.py +++ b/qualtran/bloqs/qsp/generalized_qsp_test.py @@ -201,7 +201,7 @@ def catch_rotations(bloq: Bloq) -> Bloq: expected_sigma: dict[Bloq, int] = {arbitrary_rotation: degree + 1} if degree > negative_power: - expected_sigma[Controlled(U, CtrlSpec(cvs=0))] = degree - negative_power + expected_sigma[U.controlled(ctrl_spec=CtrlSpec(cvs=0))] = degree - negative_power if negative_power > 0: expected_sigma[Controlled(U.adjoint(), CtrlSpec())] = min(degree, negative_power) if negative_power > degree: From c227032a41292ad9f8f95619d063fbbcaff9ad47 Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Fri, 18 Oct 2024 13:05:44 -0700 Subject: [PATCH 3/7] Add RSA Phase Estimate Bloq and Move ModExp to rsa/ subdirectory (#1428) * Add rsa files - needs a lot of work just stashing it for now * Made some structure changes RSA * Rework rsa mod exp bloqs to work in a rsa phase estimation circuit * Fix mypy issues * Better symbolic messages * Refactor RSA to have a phase estimation circuit and a classical simulable modular exponentiation circuit * Fix notebook specs merge conflict * remove unecessary x values for classical simulation test * fix nits * Better documentation init * Fix broken link * Fix random issue and cirq interop import of modexp * Fix broken import msft interop * Fix another dependency of ModExp --------- Co-authored-by: Matthew Harrigan --- .../qualtran_dev_tools/notebook_specs.py | 12 +- docs/bloqs/index.rst | 2 +- qualtran/_infra/Bloqs-Tutorial.ipynb | 2 +- qualtran/bloqs/factoring/__init__.py | 4 +- .../_ecc_shims.py => _factoring_shims.py} | 3 +- .../factoring/ecc/ec_phase_estimate_r.py | 2 +- qualtran/bloqs/factoring/mod_add_test.py | 0 qualtran/bloqs/factoring/mod_exp.html | 23 -- qualtran/bloqs/factoring/mod_exp.ipynb | 209 ----------- qualtran/bloqs/factoring/rsa/__init__.py | 37 ++ .../{ => rsa}/factoring-via-modexp.ipynb | 5 +- qualtran/bloqs/factoring/rsa/rsa.ipynb | 331 ++++++++++++++++++ .../{mod_exp.py => rsa/rsa_mod_exp.py} | 50 ++- .../rsa_mod_exp_test.py} | 28 +- .../bloqs/factoring/rsa/rsa_phase_estimate.py | 150 ++++++++ .../factoring/rsa/rsa_phase_estimate_test.py | 27 ++ qualtran/cirq_interop/_bloq_to_cirq_test.py | 2 +- qualtran/cirq_interop/cirq_interop.ipynb | 2 +- qualtran/serialization/bloq_test.py | 2 +- qualtran/serialization/resolver_dict.py | 7 +- .../msft_resource_estimator_interop.ipynb | 2 +- 21 files changed, 611 insertions(+), 289 deletions(-) rename qualtran/bloqs/factoring/{ecc/_ecc_shims.py => _factoring_shims.py} (95%) delete mode 100644 qualtran/bloqs/factoring/mod_add_test.py delete mode 100644 qualtran/bloqs/factoring/mod_exp.html delete mode 100644 qualtran/bloqs/factoring/mod_exp.ipynb create mode 100644 qualtran/bloqs/factoring/rsa/__init__.py rename qualtran/bloqs/factoring/{ => rsa}/factoring-via-modexp.ipynb (98%) create mode 100644 qualtran/bloqs/factoring/rsa/rsa.ipynb rename qualtran/bloqs/factoring/{mod_exp.py => rsa/rsa_mod_exp.py} (78%) rename qualtran/bloqs/factoring/{mod_exp_test.py => rsa/rsa_mod_exp_test.py} (83%) create mode 100644 qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py create mode 100644 qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index 1b5f219b4..7c66311fe 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -83,7 +83,7 @@ import qualtran.bloqs.data_loading.qrom_base import qualtran.bloqs.data_loading.select_swap_qrom import qualtran.bloqs.factoring.ecc -import qualtran.bloqs.factoring.mod_exp +import qualtran.bloqs.factoring.rsa import qualtran.bloqs.gf_arithmetic.gf2_add_k import qualtran.bloqs.gf_arithmetic.gf2_addition import qualtran.bloqs.gf_arithmetic.gf2_inverse @@ -517,10 +517,12 @@ ], ), NotebookSpecV2( - title='Modular Exponentiation', - module=qualtran.bloqs.factoring.mod_exp, - bloq_specs=[qualtran.bloqs.factoring.mod_exp._MODEXP_DOC], - directory=f'{SOURCE_DIR}/bloqs/factoring', + title='Factoring RSA', + module=qualtran.bloqs.factoring.rsa, + bloq_specs=[ + qualtran.bloqs.factoring.rsa.rsa_phase_estimate._RSA_PE_BLOQ_DOC, + qualtran.bloqs.factoring.rsa.rsa_mod_exp._RSA_MODEXP_DOC, + ], ), NotebookSpecV2( title='Elliptic Curve Addition', diff --git a/docs/bloqs/index.rst b/docs/bloqs/index.rst index 0906dd71b..16c591baa 100644 --- a/docs/bloqs/index.rst +++ b/docs/bloqs/index.rst @@ -83,7 +83,7 @@ Bloqs Library mod_arithmetic/mod_addition.ipynb mod_arithmetic/mod_subtraction.ipynb mod_arithmetic/mod_multiplication.ipynb - factoring/mod_exp.ipynb + factoring/rsa/rsa.ipynb factoring/ecc/ec_add.ipynb factoring/ecc/ecc.ipynb diff --git a/qualtran/_infra/Bloqs-Tutorial.ipynb b/qualtran/_infra/Bloqs-Tutorial.ipynb index 219dc48f4..ca64c4755 100644 --- a/qualtran/_infra/Bloqs-Tutorial.ipynb +++ b/qualtran/_infra/Bloqs-Tutorial.ipynb @@ -914,7 +914,7 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.bloqs.factoring import ModExp\n", + "from qualtran.bloqs.factoring.rsa import ModExp\n", "\n", "mod_exp = ModExp(base=8, mod=13*17, exp_bitsize=3, x_bitsize=1024)\n", "show_bloq(mod_exp)" diff --git a/qualtran/bloqs/factoring/__init__.py b/qualtran/bloqs/factoring/__init__.py index 59a92dad8..15780de77 100644 --- a/qualtran/bloqs/factoring/__init__.py +++ b/qualtran/bloqs/factoring/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,5 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .mod_exp import ModExp diff --git a/qualtran/bloqs/factoring/ecc/_ecc_shims.py b/qualtran/bloqs/factoring/_factoring_shims.py similarity index 95% rename from qualtran/bloqs/factoring/ecc/_ecc_shims.py rename to qualtran/bloqs/factoring/_factoring_shims.py index 4b602e73f..896e103b8 100644 --- a/qualtran/bloqs/factoring/ecc/_ecc_shims.py +++ b/qualtran/bloqs/factoring/_factoring_shims.py @@ -19,11 +19,12 @@ from qualtran import Bloq, CompositeBloq, DecomposeTypeError, QBit, Register, Side, Signature from qualtran.drawing import RarrowTextBox, Text, WireSymbol +from qualtran.symbolics import SymbolicInt @frozen class MeasureQFT(Bloq): - n: int + n: 'SymbolicInt' @cached_property def signature(self) -> 'Signature': diff --git a/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py b/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py index f4ddb9d19..ffe03f2f9 100644 --- a/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py +++ b/qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py @@ -33,7 +33,7 @@ from qualtran.bloqs.basic_gates import PlusState from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator -from ._ecc_shims import MeasureQFT +from .._factoring_shims import MeasureQFT from .ec_add_r import ECAddR from .ec_point import ECPoint diff --git a/qualtran/bloqs/factoring/mod_add_test.py b/qualtran/bloqs/factoring/mod_add_test.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/qualtran/bloqs/factoring/mod_exp.html b/qualtran/bloqs/factoring/mod_exp.html deleted file mode 100644 index 9d015ba4c..000000000 --- a/qualtran/bloqs/factoring/mod_exp.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - shor - - - - - - -
- -
- - - - \ No newline at end of file diff --git a/qualtran/bloqs/factoring/mod_exp.ipynb b/qualtran/bloqs/factoring/mod_exp.ipynb deleted file mode 100644 index 77c87aa15..000000000 --- a/qualtran/bloqs/factoring/mod_exp.ipynb +++ /dev/null @@ -1,209 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "5947b041", - "metadata": { - "cq.autogen": "title_cell" - }, - "source": [ - "# Modular Exponentiation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ce6b0d51", - "metadata": { - "cq.autogen": "top_imports" - }, - "outputs": [], - "source": [ - "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", - "from qualtran import QBit, QInt, QUInt, QAny\n", - "from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n", - "from typing import *\n", - "import numpy as np\n", - "import sympy\n", - "import cirq" - ] - }, - { - "cell_type": "markdown", - "id": "2f68374c", - "metadata": { - "cq.autogen": "ModExp.bloq_doc.md" - }, - "source": [ - "## `ModExp`\n", - "Perform $b^e \\mod{m}$ for constant `base` $b$, `mod` $m$, and quantum `exponent` $e$.\n", - "\n", - "Modular exponentiation is the main computational primitive for quantum factoring algorithms.\n", - "We follow [GE2019]'s \"reference implementation\" for factoring. See `ModExp.make_for_shor`\n", - "to set the class attributes for a factoring run.\n", - "\n", - "This bloq decomposes into controlled modular exponentiation for each exponent bit.\n", - "\n", - "#### Parameters\n", - " - `base`: The integer base of the exponentiation\n", - " - `mod`: The integer modulus\n", - " - `exp_bitsize`: The size of the `exponent` thru-register\n", - " - `x_bitsize`: The size of the `x` right-register \n", - "\n", - "#### Registers\n", - " - `exponent`: The exponent\n", - " - `x [right]`: The output register containing the result of the exponentiation \n", - "\n", - "#### References\n", - " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Gidney and Ekerå. 2019.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4a1b8a2a", - "metadata": { - "cq.autogen": "ModExp.bloq_doc.py" - }, - "outputs": [], - "source": [ - "from qualtran.bloqs.factoring import ModExp" - ] - }, - { - "cell_type": "markdown", - "id": "902ec939", - "metadata": { - "cq.autogen": "ModExp.example_instances.md" - }, - "source": [ - "### Example Instances" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c5aac47", - "metadata": { - "cq.autogen": "ModExp.modexp_small" - }, - "outputs": [], - "source": [ - "modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41d94c81", - "metadata": { - "cq.autogen": "ModExp.modexp" - }, - "outputs": [], - "source": [ - "modexp = ModExp.make_for_shor(big_n=13 * 17, g=9)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c0c6c06", - "metadata": { - "cq.autogen": "ModExp.modexp_symb" - }, - "outputs": [], - "source": [ - "g, N, n_e, n_x = sympy.symbols('g N n_e, n_x')\n", - "modexp_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x)" - ] - }, - { - "cell_type": "markdown", - "id": "a55a51df", - "metadata": { - "cq.autogen": "ModExp.graphical_signature.md" - }, - "source": [ - "#### Graphical Signature" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ce953a6d", - "metadata": { - "cq.autogen": "ModExp.graphical_signature.py" - }, - "outputs": [], - "source": [ - "from qualtran.drawing import show_bloqs\n", - "show_bloqs([modexp_symb, modexp_small, modexp],\n", - " ['`modexp_symb`', '`modexp_small`', '`modexp`'])" - ] - }, - { - "cell_type": "markdown", - "id": "83ba4b54", - "metadata": {}, - "source": [ - "### Decomposition" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a0e53334", - "metadata": {}, - "outputs": [], - "source": [ - "show_bloq(modexp_small.decompose_bloq())" - ] - }, - { - "cell_type": "markdown", - "id": "8662fa01", - "metadata": { - "cq.autogen": "ModExp.call_graph.md" - }, - "source": [ - "### Call Graph" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a7d9f5a", - "metadata": { - "cq.autogen": "ModExp.call_graph.py" - }, - "outputs": [], - "source": [ - "from qualtran.resource_counting.generalizers import ignore_split_join\n", - "modexp_symb_g, modexp_symb_sigma = modexp_symb.call_graph(max_depth=1, generalizer=ignore_split_join)\n", - "show_call_graph(modexp_symb_g)\n", - "show_counts_sigma(modexp_symb_sigma)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/qualtran/bloqs/factoring/rsa/__init__.py b/qualtran/bloqs/factoring/rsa/__init__.py new file mode 100644 index 000000000..d58509847 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# isort:skip_file + +r"""Bloqs for breaking RSA cryptography systems via integer factorization. + +RSA cryptography is a form of public key cryptography based on the difficulty of +factoring the product of two large prime numbers. + +Using RSA, the cryptographic scheme chooses two large prime numbers p, q, their product n, +λ(n) = lcm(p - 1, q - 1) where λ is Carmichael's totient function, an integer e such that +1 < e < λ(n), and finally d as d ≡ e^-1 (mod λ(n)). The public key consists of the modulus n and +the public (or encryption) exponent e. The private key consists of the private (or decryption) +exponent d, which must be kept secret. p, q, and λ(n) must also be kept secret because they can be +used to calculate d. + +Using Shor's algorithm for factoring, we can find p and q (the factors of n) in polynomial time +with a quantum algorithm. + +References: + [RSA (cryptosystem)](https://en.wikipedia.org/wiki/RSA_(cryptosystem)). +""" + +from .rsa_phase_estimate import RSAPhaseEstimate +from .rsa_mod_exp import ModExp diff --git a/qualtran/bloqs/factoring/factoring-via-modexp.ipynb b/qualtran/bloqs/factoring/rsa/factoring-via-modexp.ipynb similarity index 98% rename from qualtran/bloqs/factoring/factoring-via-modexp.ipynb rename to qualtran/bloqs/factoring/rsa/factoring-via-modexp.ipynb index dfea45b9e..08ff940b8 100644 --- a/qualtran/bloqs/factoring/factoring-via-modexp.ipynb +++ b/qualtran/bloqs/factoring/rsa/factoring-via-modexp.ipynb @@ -180,7 +180,7 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.bloqs.factoring.mod_exp import ModExp\n", + "from qualtran.bloqs.factoring.rsa.rsa_mod_exp import ModExp\n", "from qualtran.drawing import show_bloq\n", "\n", "mod_exp = ModExp(base=g, mod=N, exp_bitsize=32, x_bitsize=32)\n", @@ -205,6 +205,7 @@ "metadata": {}, "outputs": [], "source": [ + "from qualtran import QUInt\n", "for e in range(20):\n", " ref = (g ** e) % N\n", " _, bloq_eval = mod_exp.call_classically(exponent=e)\n", @@ -231,7 +232,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/qualtran/bloqs/factoring/rsa/rsa.ipynb b/qualtran/bloqs/factoring/rsa/rsa.ipynb new file mode 100644 index 000000000..5415a4d53 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/rsa.ipynb @@ -0,0 +1,331 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "48ce60bb", + "metadata": { + "cq.autogen": "title_cell" + }, + "source": [ + "# Factoring RSA\n", + "\n", + "Bloqs for breaking RSA cryptography systems via integer factorization.\n", + "\n", + "RSA cryptography is a form of public key cryptography based on the difficulty of\n", + "factoring the product of two large prime numbers.\n", + "\n", + "Using RSA, the cryptographic scheme chooses two large prime numbers p, q, their product n,\n", + "λ(n) = lcm(p - 1, q - 1) where λ is Carmichael's totient function, an integer e such that\n", + "1 < e < λ(n), and finally d as d ≡ e^-1 (mod λ(n)). The public key consists of the modulus n and\n", + "the public (or encryption) exponent e. The private key consists of the private (or decryption)\n", + "exponent d, which must be kept secret. p, q, and λ(n) must also be kept secret because they can be\n", + "used to calculate d.\n", + "\n", + "Using Shor's algorithm for factoring, we can find p and q (the factors of n) in polynomial time\n", + "with a quantum algorithm.\n", + "\n", + "References:\n", + " [RSA (cryptosystem)](https://en.wikipedia.org/wiki/RSA_(cryptosystem))." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d12766dd", + "metadata": { + "cq.autogen": "top_imports" + }, + "outputs": [], + "source": [ + "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", + "from qualtran import QBit, QInt, QUInt, QAny\n", + "from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n", + "from typing import *\n", + "import numpy as np\n", + "import sympy\n", + "import cirq" + ] + }, + { + "cell_type": "markdown", + "id": "881834c3", + "metadata": { + "cq.autogen": "ModExp.bloq_doc.md" + }, + "source": [ + "## `ModExp`\n", + "Perform $b^e \\mod{m}$ for constant `base` $b$, `mod` $m$, and quantum `exponent` $e$.\n", + "\n", + "Modular exponentiation is the main computational primitive for quantum factoring algorithms.\n", + "We follow [GE2019]'s \"reference implementation\" for factoring. See `ModExp.make_for_shor`\n", + "to set the class attributes for a factoring run.\n", + "\n", + "This bloq decomposes into controlled modular exponentiation for each exponent bit.\n", + "\n", + "#### Parameters\n", + " - `base`: The integer base of the exponentiation\n", + " - `mod`: The integer modulus\n", + " - `exp_bitsize`: The size of the `exponent` thru-register\n", + " - `x_bitsize`: The size of the `x` right-register \n", + "\n", + "#### Registers\n", + " - `exponent`: The exponent\n", + " - `x [right]`: The output register containing the result of the exponentiation \n", + "\n", + "#### References\n", + " - [How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749). Gidney and Ekerå. 2019.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61062a3d", + "metadata": { + "cq.autogen": "ModExp.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.factoring.rsa import ModExp" + ] + }, + { + "cell_type": "markdown", + "id": "6963ad94", + "metadata": { + "cq.autogen": "ModExp.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56d35af9", + "metadata": { + "cq.autogen": "ModExp.modexp_symb" + }, + "outputs": [], + "source": [ + "g, N, n_e, n_x = sympy.symbols('g N n_e, n_x')\n", + "modexp_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d3d90b4", + "metadata": { + "cq.autogen": "ModExp.modexp_small" + }, + "outputs": [], + "source": [ + "modexp_small = ModExp(base=4, mod=15, exp_bitsize=3, x_bitsize=2048)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3a004d3", + "metadata": { + "cq.autogen": "ModExp.modexp" + }, + "outputs": [], + "source": [ + "modexp = ModExp.make_for_shor(big_n=13 * 17, g=9)" + ] + }, + { + "cell_type": "markdown", + "id": "39422b45", + "metadata": { + "cq.autogen": "ModExp.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83a891e2", + "metadata": { + "cq.autogen": "ModExp.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([modexp_small, modexp, modexp_symb],\n", + " ['`modexp_small`', '`modexp`', '`modexp_symb`'])" + ] + }, + { + "cell_type": "markdown", + "id": "9c271392", + "metadata": { + "cq.autogen": "ModExp.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6d55278", + "metadata": { + "cq.autogen": "ModExp.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "modexp_small_g, modexp_small_sigma = modexp_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(modexp_small_g)\n", + "show_counts_sigma(modexp_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "2603abbd", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.bloq_doc.md" + }, + "source": [ + "## `RSAPhaseEstimate`\n", + "Perform a single phase estimation of the decomposition of Modular Exponentiation for the\n", + "given base.\n", + "\n", + "The constructor requires a pre-set base, see the make_for_shor factory method for picking a\n", + "random, valid base\n", + "\n", + "#### Parameters\n", + " - `n`: The bitsize of the modulus N.\n", + " - `mod`: The modulus N; a part of the public key for RSA.\n", + " - `base`: A base for modular exponentiation. \n", + "\n", + "#### References\n", + " - [Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095). Beauregard. 2003. Fig 1.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b838c20", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.factoring.rsa import RSAPhaseEstimate" + ] + }, + { + "cell_type": "markdown", + "id": "20426b03", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f696c5fd", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.rsa_pe" + }, + "outputs": [], + "source": [ + "n, p, g = sympy.symbols('n p g')\n", + "rsa_pe = RSAPhaseEstimate(n=n, mod=p, base=g)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b16e84a5", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.rsa_pe_small" + }, + "outputs": [], + "source": [ + "rsa_pe_small = RSAPhaseEstimate.make_for_shor(big_n=13 * 17, g=9)" + ] + }, + { + "cell_type": "markdown", + "id": "0c0078d5", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b493a30", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([rsa_pe_small, rsa_pe],\n", + " ['`rsa_pe_small`', '`rsa_pe`'])" + ] + }, + { + "cell_type": "markdown", + "id": "441028fa", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f30cb55", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "rsa_pe_small_g, rsa_pe_small_sigma = rsa_pe_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(rsa_pe_small_g)\n", + "show_counts_sigma(rsa_pe_small_sigma)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d72bc625", + "metadata": { + "cq.autogen": "RSAPhaseEstimate.rsa_pe_shor" + }, + "outputs": [], + "source": [ + "rsa_pe_shor = RSAPhaseEstimate.make_for_shor(big_n=13 * 17, g=9)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qualtran/bloqs/factoring/mod_exp.py b/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py similarity index 78% rename from qualtran/bloqs/factoring/mod_exp.py rename to qualtran/bloqs/factoring/rsa/rsa_mod_exp.py index 129711797..fda33f5fc 100644 --- a/qualtran/bloqs/factoring/mod_exp.py +++ b/qualtran/bloqs/factoring/rsa/rsa_mod_exp.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import random from functools import cached_property from typing import cast, Dict, Optional, Tuple, Union import attrs +import numpy as np import sympy from attrs import frozen @@ -28,17 +28,19 @@ DecomposeTypeError, QUInt, Register, - Side, Signature, Soquet, SoquetT, ) -from qualtran.bloqs.basic_gates import IntState +from qualtran._infra.registers import Side +from qualtran.bloqs.basic_gates.z_basis import IntState from qualtran.bloqs.mod_arithmetic import CModMulK from qualtran.drawing import Text, WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.resource_counting.generalizers import ignore_split_join +from qualtran.simulation.classical_sim import ClassicalValT from qualtran.symbolics import is_symbolic +from qualtran.symbolics.types import SymbolicInt @frozen @@ -66,10 +68,10 @@ class ModExp(Bloq): Gidney and Ekerå. 2019. """ - base: Union[int, sympy.Expr] - mod: Union[int, sympy.Expr] - exp_bitsize: Union[int, sympy.Expr] - x_bitsize: Union[int, sympy.Expr] + base: 'SymbolicInt' + mod: 'SymbolicInt' + exp_bitsize: 'SymbolicInt' + x_bitsize: 'SymbolicInt' def __attrs_post_init__(self): if not is_symbolic(self.base, self.mod): @@ -85,33 +87,47 @@ def signature(self) -> 'Signature': ) @classmethod - def make_for_shor(cls, big_n: int, g: Optional[int] = None): + def make_for_shor( + cls, + big_n: 'SymbolicInt', + g: Optional['SymbolicInt'] = None, + rs: Optional[np.random.RandomState] = None, + ): """Factory method that sets up the modular exponentiation for a factoring run. Args: big_n: The large composite number N. Used to set `mod`. Its bitsize is used to set `x_bitsize` and `exp_bitsize`. g: Optional base of the exponentiation. If `None`, we pick a random base. + rs: Optional random state which can be seeded to make base generation deterministic. """ - if isinstance(big_n, sympy.Expr): + if is_symbolic(big_n): little_n = sympy.ceiling(sympy.log(big_n, 2)) else: little_n = int(math.ceil(math.log2(big_n))) if g is None: - g = random.randint(2, big_n) + if is_symbolic(big_n): + g = sympy.symbols('g') + else: + if rs is None: + rs = np.random.RandomState() + while True: + g = rs.randint(2, int(big_n)) + if math.gcd(g, int(big_n)) == 1: + break return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n) - def _CtrlModMul(self, k: Union[int, sympy.Expr]): + def _CtrlModMul(self, k: 'SymbolicInt'): """Helper method to return a `CModMulK` with attributes forwarded.""" return CModMulK(QUInt(self.x_bitsize), k=k, mod=self.mod) def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[str, 'SoquetT']: - if isinstance(self.exp_bitsize, sympy.Expr): - raise DecomposeTypeError("`exp_bitsize` must be a concrete value.") + if is_symbolic(self.exp_bitsize): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `exp_bitsize`.") + # https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method x = bb.add(IntState(val=1, bitsize=self.x_bitsize)) exponent = bb.split(exponent) - # https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method base = self.base % self.mod for j in range(self.exp_bitsize - 1, 0 - 1, -1): exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x) @@ -121,9 +137,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[st def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': k = ssa.new_symbol('k') - return {IntState(val=1, bitsize=self.x_bitsize): 1, self._CtrlModMul(k=k): self.exp_bitsize} + return {self._CtrlModMul(k=k): self.exp_bitsize, IntState(val=1, bitsize=self.x_bitsize): 1} - def on_classical_vals(self, exponent: int): + def on_classical_vals(self, exponent) -> Dict[str, Union['ClassicalValT', sympy.Expr]]: return {'exponent': exponent, 'x': (self.base**exponent) % self.mod} def wire_symbol( @@ -163,4 +179,4 @@ def _modexp_symb() -> ModExp: return modexp_symb -_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_symb, _modexp_small, _modexp)) +_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb)) diff --git a/qualtran/bloqs/factoring/mod_exp_test.py b/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py similarity index 83% rename from qualtran/bloqs/factoring/mod_exp_test.py rename to qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py index d0b48c4af..e4cabd8c4 100644 --- a/qualtran/bloqs/factoring/mod_exp_test.py +++ b/qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py @@ -21,7 +21,7 @@ from qualtran import Bloq from qualtran.bloqs.bookkeeping import Join, Split -from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_symb, ModExp +from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp from qualtran.bloqs.mod_arithmetic import CModMulK from qualtran.drawing import Text from qualtran.resource_counting import SympySymbolAllocator @@ -40,17 +40,13 @@ def test_mod_exp_consistent_classical(): # Choose an exponent in a range. Set exp_bitsize=ne bit enough to fit. exponent = rs.randint(1, 2**n) - ne = 2 * n - # Choose a base smaller than mod. - base = rs.randint(1, mod) - while np.gcd(base, mod) != 1: - base = rs.randint(1, mod) - - bloq = ModExp(base=base, exp_bitsize=ne, x_bitsize=n, mod=mod) + bloq = ModExp.make_for_shor(big_n=mod, rs=rs) ret1 = bloq.call_classically(exponent=exponent) ret2 = bloq.decompose_bloq().call_classically(exponent=exponent) - assert ret1 == ret2 + assert len(ret1) == len(ret2) + for i in range(len(ret1)): + np.testing.assert_array_equal(ret1[i], ret2[i]) def test_modexp_symb_manual(): @@ -93,19 +89,11 @@ def test_mod_exp_t_complexity(): assert tcomp.t > 0 -def test_modexp(bloq_autotester): - bloq_autotester(_modexp) - - -def test_modexp_symb(bloq_autotester): - bloq_autotester(_modexp_symb) +@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small]) +def test_modexp(bloq_autotester, bloq): + bloq_autotester(bloq) @pytest.mark.notebook def test_intro_notebook(): execute_notebook('factoring-via-modexp') - - -@pytest.mark.notebook -def test_notebook(): - execute_notebook('mod_exp') diff --git a/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py new file mode 100644 index 000000000..611af9de4 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py @@ -0,0 +1,150 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Dict, Optional + +import attrs +import numpy as np +import sympy +from attrs import frozen + +from qualtran import ( + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + DecomposeTypeError, + QUInt, + Signature, + SoquetT, +) +from qualtran.bloqs.basic_gates import IntState, PlusState +from qualtran.bloqs.bookkeeping import Free +from qualtran.bloqs.factoring._factoring_shims import MeasureQFT +from qualtran.bloqs.mod_arithmetic.mod_multiplication import CModMulK +from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.symbolics import is_symbolic, SymbolicInt + + +@frozen +class RSAPhaseEstimate(Bloq): + """Perform a single phase estimation of the decomposition of Modular Exponentiation for the + given base. + + The constructor requires a pre-set base, see the make_for_shor factory method for picking a + random, valid base + + Args: + n: The bitsize of the modulus N. + mod: The modulus N; a part of the public key for RSA. + base: A base for modular exponentiation. + + References: + [Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095). + Beauregard. 2003. Fig 1. + """ + + n: 'SymbolicInt' + mod: 'SymbolicInt' + base: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature([]) + + @classmethod + def make_for_shor( + cls, + big_n: 'SymbolicInt', + g: Optional['SymbolicInt'] = None, + rs: Optional[np.random.RandomState] = None, + ): + """Factory method that sets up the modular exponentiation for a factoring run. + + Args: + big_n: The large composite number N. Used to set `mod`. Its bitsize is used + to set `x_bitsize` and `exp_bitsize`. + g: Optional base of the exponentiation. If `None`, we pick a random base. + rs: Optional random state which can be seeded to make base generation deterministic. + """ + if is_symbolic(big_n): + little_n = sympy.ceiling(sympy.log(big_n, 2)) + else: + little_n = int(math.ceil(math.log2(big_n))) + if g is None: + if is_symbolic(big_n): + g = sympy.symbols('g') + else: + if rs is None: + rs = np.random.RandomState() + while True: + g = rs.randint(2, int(big_n)) + if math.gcd(g, int(big_n)) == 1: + break + return cls(base=g, mod=big_n, n=little_n) + + def __attrs_post_init__(self): + if not is_symbolic(self.n, self.mod): + assert self.n == int(math.ceil(math.log2(self.mod))) + + def _CtrlModMul(self, k: 'SymbolicInt'): + """Helper method to return a `CModMulK` with attributes forwarded.""" + return CModMulK(QUInt(self.n), k=k, mod=self.mod) + + def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + exponent = [bb.add(PlusState()) for _ in range(2 * self.n)] + x = bb.add(IntState(val=1, bitsize=self.n)) + + base = self.base % self.mod + for j in range((2 * self.n) - 1, 0 - 1, -1): + exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x) + base = (base * base) % self.mod + + bb.add(MeasureQFT(n=2 * self.n), x=exponent) + bb.add(Free(QUInt(self.n), dirty=True), reg=x) + return {} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + k = ssa.new_symbol('k') + return {MeasureQFT(n=self.n): 1, self._CtrlModMul(k=k): 2 * self.n} + + +_K = sympy.Symbol('k_exp') + + +def _generalize_k(b: Bloq) -> Optional[Bloq]: + if isinstance(b, CModMulK): + return attrs.evolve(b, k=_K) + + return b + + +@bloq_example +def _rsa_pe() -> RSAPhaseEstimate: + n, p, g = sympy.symbols('n p g') + rsa_pe = RSAPhaseEstimate(n=n, mod=p, base=g) + return rsa_pe + + +@bloq_example +def _rsa_pe_small() -> RSAPhaseEstimate: + rsa_pe_small = RSAPhaseEstimate.make_for_shor(big_n=13 * 17, g=9) + return rsa_pe_small + + +_RSA_PE_BLOQ_DOC = BloqDocSpec(bloq_cls=RSAPhaseEstimate, examples=[_rsa_pe_small, _rsa_pe]) diff --git a/qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py new file mode 100644 index 000000000..4ddedf964 --- /dev/null +++ b/qualtran/bloqs/factoring/rsa/rsa_phase_estimate_test.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import qualtran.testing as qlt_testing +from qualtran.bloqs.factoring.rsa.rsa_phase_estimate import _rsa_pe, _rsa_pe_small + + +@pytest.mark.parametrize('bloq', [_rsa_pe_small, _rsa_pe]) +def test_rsa_pe(bloq_autotester, bloq): + bloq_autotester(bloq) + + +def test_notebook(): + qlt_testing.execute_notebook('rsa') diff --git a/qualtran/cirq_interop/_bloq_to_cirq_test.py b/qualtran/cirq_interop/_bloq_to_cirq_test.py index 95dae1cef..87a76c014 100644 --- a/qualtran/cirq_interop/_bloq_to_cirq_test.py +++ b/qualtran/cirq_interop/_bloq_to_cirq_test.py @@ -21,7 +21,7 @@ from qualtran import Bloq, BloqBuilder, ConnectionT, Signature, Soquet, SoquetT from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.basic_gates import Toffoli, XGate, YGate -from qualtran.bloqs.factoring import ModExp +from qualtran.bloqs.factoring.rsa import ModExp from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd from qualtran.bloqs.state_preparation import PrepareUniformSuperposition from qualtran.cirq_interop._bloq_to_cirq import BloqAsCirqGate, CirqQuregT diff --git a/qualtran/cirq_interop/cirq_interop.ipynb b/qualtran/cirq_interop/cirq_interop.ipynb index 8a2854df8..991f95bd8 100644 --- a/qualtran/cirq_interop/cirq_interop.ipynb +++ b/qualtran/cirq_interop/cirq_interop.ipynb @@ -471,7 +471,7 @@ "metadata": {}, "outputs": [], "source": [ - "from qualtran.bloqs.factoring.mod_exp import ModExp\n", + "from qualtran.bloqs.factoring.rsa import ModExp\n", "from qualtran.drawing import show_bloq\n", "from qualtran.drawing import get_musical_score_data, draw_musical_score\n", "N = 13*17\n", diff --git a/qualtran/serialization/bloq_test.py b/qualtran/serialization/bloq_test.py index 218206d7f..d411f09e3 100644 --- a/qualtran/serialization/bloq_test.py +++ b/qualtran/serialization/bloq_test.py @@ -23,7 +23,7 @@ from qualtran import Bloq, Signature from qualtran._infra.composite_bloq_test import TestTwoCNOT -from qualtran.bloqs.factoring.mod_exp import ModExp +from qualtran.bloqs.factoring.rsa.rsa_mod_exp import ModExp from qualtran.cirq_interop import CirqGateAsBloq from qualtran.cirq_interop._cirq_to_bloq_test import TestCNOT as TestCNOTCirq from qualtran.protos import registers_pb2 diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index f53ea457b..4aa73158d 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -97,7 +97,8 @@ import qualtran.bloqs.data_loading.qroam_clean import qualtran.bloqs.data_loading.qrom import qualtran.bloqs.data_loading.select_swap_qrom -import qualtran.bloqs.factoring.mod_exp +import qualtran.bloqs.factoring._factoring_shims +import qualtran.bloqs.factoring.rsa import qualtran.bloqs.for_testing.atom import qualtran.bloqs.for_testing.casting import qualtran.bloqs.for_testing.interior_alloc @@ -342,7 +343,9 @@ "qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK": qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK, "qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul, "qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul, - "qualtran.bloqs.factoring.mod_exp.ModExp": qualtran.bloqs.factoring.mod_exp.ModExp, + "qualtran.bloqs.factoring._factoring_shims.MeasureQFT": qualtran.bloqs.factoring._factoring_shims.MeasureQFT, + "qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate": qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate, + "qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp": qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp, "qualtran.bloqs.for_testing.atom.TestAtom": qualtran.bloqs.for_testing.atom.TestAtom, "qualtran.bloqs.for_testing.atom.TestGWRAtom": qualtran.bloqs.for_testing.atom.TestGWRAtom, "qualtran.bloqs.for_testing.atom.TestTwoBitOp": qualtran.bloqs.for_testing.atom.TestTwoBitOp, diff --git a/qualtran/surface_code/msft_resource_estimator_interop.ipynb b/qualtran/surface_code/msft_resource_estimator_interop.ipynb index af6337271..2af07a495 100644 --- a/qualtran/surface_code/msft_resource_estimator_interop.ipynb +++ b/qualtran/surface_code/msft_resource_estimator_interop.ipynb @@ -27,7 +27,7 @@ }, "outputs": [], "source": [ - "from qualtran.bloqs.factoring.mod_exp import ModExp\n", + "from qualtran.bloqs.factoring.rsa import ModExp\n", "from qualtran.drawing import show_bloq\n", "\n", "N = 13*17 # integer to factor\n", From 7344830491f7cbac8e43c9ec106be37a39cb94e8 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 18 Oct 2024 23:51:10 +0100 Subject: [PATCH 4/7] Create linear half comparison bloqs (#1408) This PR creates all versions of half a comparison operation in $n$ toffoli complexity. The uncomputation of these bloqs has zero cost. The half comparason operation is needed in the modular inversion bloq. I also modify the OutOfPlaceAdder to allow it to not compute the extra And when it's not needed --- This PR also shows how we can implement all versions of comparison by implementing just one. for example creating a half greater than bloq that use logarithmic depth is enough to implement all the others. --- .../qualtran_dev_tools/notebook_specs.py | 4 + qualtran/bloqs/arithmetic/__init__.py | 4 + qualtran/bloqs/arithmetic/addition.ipynb | 6 +- qualtran/bloqs/arithmetic/addition.py | 28 +- qualtran/bloqs/arithmetic/comparison.ipynb | 452 +++++++++++++++++ qualtran/bloqs/arithmetic/comparison.py | 470 ++++++++++++++++++ qualtran/bloqs/arithmetic/comparison_test.py | 103 ++++ qualtran/serialization/resolver_dict.py | 4 + 8 files changed, 1063 insertions(+), 8 deletions(-) diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index 7c66311fe..6dc00babe 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -435,6 +435,10 @@ qualtran.bloqs.arithmetic.comparison._SQ_CMP_DOC, qualtran.bloqs.arithmetic.comparison._LEQ_DOC, qualtran.bloqs.arithmetic.comparison._CLinearDepthGreaterThan_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_GREATERTHAN_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_GREATERTHANEQUAL_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_LESSTHAN_DOC, + qualtran.bloqs.arithmetic.comparison._LINEAR_DEPTH_HALF_LESSTHANEQUAL_DOC, ], ), NotebookSpecV2( diff --git a/qualtran/bloqs/arithmetic/__init__.py b/qualtran/bloqs/arithmetic/__init__.py index 533d0ee0c..59ca8a5af 100644 --- a/qualtran/bloqs/arithmetic/__init__.py +++ b/qualtran/bloqs/arithmetic/__init__.py @@ -23,6 +23,10 @@ GreaterThanConstant, LessThanConstant, LessThanEqual, + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, SingleQubitCompare, ) from qualtran.bloqs.arithmetic.controlled_addition import CAdd diff --git a/qualtran/bloqs/arithmetic/addition.ipynb b/qualtran/bloqs/arithmetic/addition.ipynb index c9d271e14..4cd90a386 100644 --- a/qualtran/bloqs/arithmetic/addition.ipynb +++ b/qualtran/bloqs/arithmetic/addition.ipynb @@ -186,12 +186,14 @@ "using $4n - 4 T$ gates. Uncomputation requires 0 T-gates.\n", "\n", "#### Parameters\n", - " - `bitsize`: Number of bits used to represent each input integer. The allocated output register is of size `bitsize+1` so it has enough space to hold the sum of `a+b`. \n", + " - `bitsize`: Number of bits used to represent each input integer. The allocated output register is of size `bitsize+1` so it has enough space to hold the sum of `a+b`.\n", + " - `is_adjoint`: Whether this is compute or uncompute version.\n", + " - `include_most_significant_bit`: Whether to add an extra most significant (i.e. carry) bit. \n", "\n", "#### Registers\n", " - `a`: A bitsize-sized input register (register a above).\n", " - `b`: A bitsize-sized input register (register b above).\n", - " - `c`: A bitize+1-sized LEFT/RIGHT register depending on whether the gate adjoint or not. \n", + " - `c`: The LEFT/RIGHT register depending on whether the gate adjoint or not. This register size is either bitsize or bitsize+1 depending on the value of `include_most_significant_bit`. \n", "\n", "#### References\n", " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" diff --git a/qualtran/bloqs/arithmetic/addition.py b/qualtran/bloqs/arithmetic/addition.py index f0b8de6b9..7b208bd36 100644 --- a/qualtran/bloqs/arithmetic/addition.py +++ b/qualtran/bloqs/arithmetic/addition.py @@ -260,11 +260,15 @@ class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[m Args: bitsize: Number of bits used to represent each input integer. The allocated output register is of size `bitsize+1` so it has enough space to hold the sum of `a+b`. + is_adjoint: Whether this is compute or uncompute version. + include_most_significant_bit: Whether to add an extra most significant (i.e. carry) bit. Registers: a: A bitsize-sized input register (register a above). b: A bitsize-sized input register (register b above). - c: A bitize+1-sized LEFT/RIGHT register depending on whether the gate adjoint or not. + c: The LEFT/RIGHT register depending on whether the gate adjoint or not. + This register size is either bitsize or bitsize+1 depending on + the value of `include_most_significant_bit`. References: [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648) @@ -272,6 +276,11 @@ class OutOfPlaceAdder(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[m bitsize: 'SymbolicInt' is_adjoint: bool = False + include_most_significant_bit: bool = True + + @property + def out_bitsize(self): + return self.bitsize + (1 if self.include_most_significant_bit else 0) @property def signature(self): @@ -280,14 +289,14 @@ def signature(self): [ Register('a', QUInt(self.bitsize)), Register('b', QUInt(self.bitsize)), - Register('c', QUInt(self.bitsize + 1), side=side), + Register('c', QUInt(self.out_bitsize), side=side), ] ) def registers(self) -> Sequence[Union[int, Sequence[int]]]: if not isinstance(self.bitsize, int): raise ValueError(f'Symbolic bitsize {self.bitsize} not supported') - return [2] * self.bitsize, [2] * self.bitsize, [2] * (self.bitsize + 1) + return [2] * self.bitsize, [2] * self.bitsize, [2] * self.out_bitsize def apply(self, a: int, b: int, c: int) -> Tuple[int, int, int]: return a, b, c + a + b @@ -307,7 +316,7 @@ def on_classical_vals( return { 'a': a, 'b': b, - 'c': add_ints(int(a), int(b), num_bits=self.bitsize + 1, is_signed=False), + 'c': add_ints(int(a), int(b), num_bits=self.out_bitsize, is_signed=False), } def with_registers(self, *new_registers: Union[int, Sequence[int]]): @@ -328,12 +337,19 @@ def decompose_from_registers( cirq.CX(a[i], c[i + 1]), cirq.CX(b[i], c[i]), ] - for i in range(self.bitsize) + for i in range(self.out_bitsize - 1) ] + if not self.include_most_significant_bit: + # Update c[-1] as c[-1] ^= a[-1]^b[-1] + i = self.bitsize - 1 + optree.append([cirq.CX(a[i], c[i]), cirq.CX(b[i], c[i])]) return cirq.inverse(optree) if self.is_adjoint else optree def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': - return {And(uncompute=self.is_adjoint): self.bitsize, CNOT(): 5 * self.bitsize} + return { + And(uncompute=self.is_adjoint): self.out_bitsize - 1, + CNOT(): 5 * (self.bitsize - 1) + 2 + (3 if self.include_most_significant_bit else 0), + } def __pow__(self, power: int): if power == 1: diff --git a/qualtran/bloqs/arithmetic/comparison.ipynb b/qualtran/bloqs/arithmetic/comparison.ipynb index 4a0c8ad6d..5183ede00 100644 --- a/qualtran/bloqs/arithmetic/comparison.ipynb +++ b/qualtran/bloqs/arithmetic/comparison.ipynb @@ -1020,6 +1020,458 @@ "show_call_graph(equals_g)\n", "show_counts_sigma(equals_sigma)" ] + }, + { + "cell_type": "markdown", + "id": "e2db7c5d", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfGreaterThan`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{b-a}\\ket{a>b}$ using $n$ And gates.\n", + "\n", + "This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a > b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "758a6e35", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfGreaterThan" + ] + }, + { + "cell_type": "markdown", + "id": "3ea5a7bc", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26e4245f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.lineardepthhalfgreaterthan_small" + }, + "outputs": [], + "source": [ + "lineardepthhalfgreaterthan_small = LinearDepthHalfGreaterThan(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "29abac9f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d065b007", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalfgreaterthan_small],\n", + " ['`lineardepthhalfgreaterthan_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "d67e1888", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32952025", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThan.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalfgreaterthan_small_g, lineardepthhalfgreaterthan_small_sigma = lineardepthhalfgreaterthan_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalfgreaterthan_small_g)\n", + "show_counts_sigma(lineardepthhalfgreaterthan_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "9c39992e", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfGreaterThanEqual`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{a-b}\\ket{a \\geq b}$ using $n$ And gates.\n", + "\n", + "This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a >= b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58e6973f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfGreaterThanEqual" + ] + }, + { + "cell_type": "markdown", + "id": "32dae58a", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eded8868", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.lineardepthhalfgreaterthanequal_small" + }, + "outputs": [], + "source": [ + "lineardepthhalfgreaterthanequal_small = LinearDepthHalfGreaterThanEqual(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "0edbf55b", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5326975b", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalfgreaterthanequal_small],\n", + " ['`lineardepthhalfgreaterthanequal_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "eecfbf65", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abfbe19f", + "metadata": { + "cq.autogen": "LinearDepthHalfGreaterThanEqual.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalfgreaterthanequal_small_g, lineardepthhalfgreaterthanequal_small_sigma = lineardepthhalfgreaterthanequal_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalfgreaterthanequal_small_g)\n", + "show_counts_sigma(lineardepthhalfgreaterthanequal_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "19744b75", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfLessThan`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{a-b}\\ket{a b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a < b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1eec63a1", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfLessThan" + ] + }, + { + "cell_type": "markdown", + "id": "75142163", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4759ae6d", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.lineardepthhalflessthan_small" + }, + "outputs": [], + "source": [ + "lineardepthhalflessthan_small = LinearDepthHalfLessThan(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "903efca4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbbefc84", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalflessthan_small],\n", + " ['`lineardepthhalflessthan_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "8862065f", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9be045c4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThan.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalflessthan_small_g, lineardepthhalflessthan_small_sigma = lineardepthhalflessthan_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalflessthan_small_g)\n", + "show_counts_sigma(lineardepthhalflessthan_small_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "8a868719", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.bloq_doc.md" + }, + "source": [ + "## `LinearDepthHalfLessThanEqual`\n", + "Compare two integers while keeping necessary ancillas for zero cost uncomputation.\n", + "\n", + "Implements $\\ket{a}\\ket{b}\\ket{0}\\ket{0} \\rightarrow \\ket{a}\\ket{b}\\ket{b-a}\\ket{a \\leq b}$ using $n$ And gates.\n", + "\n", + "This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We\n", + "implement it by flipping all the bits in b, computing the first half of the addition circuit,\n", + "copying out the carry, and keeping $c$ for the uncomputation.\n", + "\n", + "#### Parameters\n", + " - `dtype`: dtype of the two integers a and b.\n", + " - `uncompute`: whether this bloq uncomputes or computes the comparison. \n", + "\n", + "#### Registers\n", + " - `a`: first input register.\n", + " - `b`: second input register.\n", + " - `c`: ancilla register that will contain $b-a$ and will be used for uncomputation.\n", + " - `target`: A single bit output register to store the result of a <= b. \n", + "\n", + "#### References\n", + " - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b4c9b03", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import LinearDepthHalfLessThanEqual" + ] + }, + { + "cell_type": "markdown", + "id": "bae993de", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6610fd4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.lineardepthhalflessthanequal_small" + }, + "outputs": [], + "source": [ + "lineardepthhalflessthanequal_small = LinearDepthHalfLessThanEqual(QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "72de4f8e", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdd8d4c4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([lineardepthhalflessthanequal_small],\n", + " ['`lineardepthhalflessthanequal_small`'])" + ] + }, + { + "cell_type": "markdown", + "id": "973be9d4", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9efd6db6", + "metadata": { + "cq.autogen": "LinearDepthHalfLessThanEqual.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "lineardepthhalflessthanequal_small_g, lineardepthhalflessthanequal_small_sigma = lineardepthhalflessthanequal_small.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(lineardepthhalflessthanequal_small_g)\n", + "show_counts_sigma(lineardepthhalflessthanequal_small_sigma)" + ] } ], "metadata": { diff --git a/qualtran/bloqs/arithmetic/comparison.py b/qualtran/bloqs/arithmetic/comparison.py index c546aee77..b28269ae4 100644 --- a/qualtran/bloqs/arithmetic/comparison.py +++ b/qualtran/bloqs/arithmetic/comparison.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc from collections import defaultdict from functools import cached_property from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union @@ -51,6 +52,7 @@ from qualtran.drawing import WireSymbol from qualtran.drawing.musical_score import Circle, Text, TextBox from qualtran.resource_counting.generalizers import ignore_split_join +from qualtran.simulation.classical_sim import add_ints from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt if TYPE_CHECKING: @@ -1228,3 +1230,471 @@ def _clineardepthgreaterthan_example() -> CLinearDepthGreaterThan: _CLinearDepthGreaterThan_DOC = BloqDocSpec( bloq_cls=CLinearDepthGreaterThan, examples=[_clineardepthgreaterthan_example] ) + + +@frozen +class _HalfLinearDepthGreaterThan(Bloq): + """A concrete implementation of half-circuit for greater than. + + This bloq can be returned by the _HalfComparisonBase._half_greater_than_bloq abstract property. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a > b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + dtype: Union[QInt, QUInt, QMontgomeryUInt] + uncompute: bool = False + + @cached_property + def signature(self) -> Signature: + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('a', self.dtype), + Register('b', self.dtype), + Register('c', QUInt(bitsize=self.dtype.bitsize + 1), side=side), + Register('target', QBit(), side=side), + ] + ) + + def adjoint(self) -> '_HalfLinearDepthGreaterThan': + return attrs.evolve(self, uncompute=self.uncompute ^ True) + + def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']: + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=b) + else: + a = bb.join(np.concatenate([[bb.allocate(1)], bb.split(a)])) + b = bb.join(np.concatenate([[bb.allocate(1)], bb.split(b)])) + + dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1) + b = bb.add(BitwiseNot(dtype), x=b) # b := -b-1 + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=b) + a, b, c = bb.add( + OutOfPlaceAdder(self.dtype.bitsize + 1, include_most_significant_bit=False), a=a, b=b + ) # c := a - b - 1 + c = bb.add(BitwiseNot(QUInt(dtype.bitsize)), x=c) # c := b - a + + # Update `target` + c_arr = bb.split(c) + target = bb.allocate(1) + c_arr[0], target = bb.add(CNOT(), ctrl=c_arr[0], target=target) + c = bb.join(c_arr) + + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=b) + b = bb.add(BitwiseNot(dtype), x=b) + + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=b) + else: + a_arr = bb.split(a) + a = bb.join(a_arr[1:]) + b_arr = bb.split(b) + b = bb.join(b_arr[1:]) + bb.free(a_arr[0]) + bb.free(b_arr[0]) + return {'a': a, 'b': b, 'c': c, 'target': target} + + def _uncompute( + self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet', c: 'Soquet', target: 'Soquet' + ) -> Dict[str, 'SoquetT']: + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=b) + else: + a = bb.join(np.concatenate([[bb.allocate(1)], bb.split(a)])) + b = bb.join(np.concatenate([[bb.allocate(1)], bb.split(b)])) + + dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1) + b = bb.add(BitwiseNot(dtype), x=b) # b := -b-1 + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)), reg=b) + + c_arr = bb.split(c) + c_arr[0], target = bb.add(CNOT(), ctrl=c_arr[0], target=target) + c = bb.join(c_arr) + + c = bb.add(BitwiseNot(QUInt(dtype.bitsize)), x=c) + a, b = bb.add( + OutOfPlaceAdder(self.dtype.bitsize + 1, include_most_significant_bit=False).adjoint(), + a=a, + b=b, + c=c, + ) + a = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=a) + b = bb.add(Cast(dtype, QUInt(dtype.bitsize)).adjoint(), reg=b) + b = bb.add(BitwiseNot(dtype), x=b) + + if isinstance(self.dtype, QInt): + a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=a) + b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), x=b) + else: + a_arr = bb.split(a) + a = bb.join(a_arr[1:]) + b_arr = bb.split(b) + b = bb.join(b_arr[1:]) + bb.free(a_arr[0]) + bb.free(b_arr[0]) + bb.free(target) + return {'a': a, 'b': b} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + a: 'Soquet', + b: 'Soquet', + c: Optional['Soquet'] = None, + target: Optional['Soquet'] = None, + ) -> Dict[str, 'SoquetT']: + if self.uncompute: + # Uncompute + assert c is not None + assert target is not None + return self._uncompute(bb, a, b, c, target) + else: + assert c is None + assert target is None + return self._compute(bb, a, b) + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1) + counts: 'BloqCountDictT' + if isinstance(self.dtype, QUInt): + counts = {BitwiseNot(dtype): 3} + else: + counts = {BitwiseNot(dtype): 2, BitwiseNot(QUInt(dtype.bitsize)): 1} + + counts[CNOT()] = 1 + + adder = OutOfPlaceAdder(self.dtype.bitsize + 1, include_most_significant_bit=False) + if self.uncompute: + adder = adder.adjoint() + counts[adder] = 1 + + return counts + + +@frozen +class _HalfComparisonBase(Bloq): + """Parent class for the 4 comparison operations (>, >=, <, <=). + + The four comparison operations can be implemented by implementing only one of them + and computing the others either by reversing the input order, flipping the result or both. + + The choice made is to build the four opertions around greater than. Namely the greater than + bloq returned by `._half_greater_than_bloq`; By changing this property we can change + change the properties of the constructed circuit (e.g. complexity, depth, ..etc). + + For example _LinearDepthHalfComparisonBase sets the property to a linear depth construction, + other implementations can set the property to a log depth construction. + + Args: + dtype: dtype of the two integers a and b. + _op_symbol: The symbol of the comparison operation. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a > b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + dtype: Union[QInt, QUInt, QMontgomeryUInt] + _op_symbol: str = attrs.field( + default='>', validator=lambda _, __, s: s in ('>', '<', '>=', '<='), repr=False + ) + uncompute: bool = False + + @cached_property + def signature(self) -> Signature: + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('a', self.dtype), + Register('b', self.dtype), + Register('c', QUInt(bitsize=self.dtype.bitsize + 1), side=side), + Register('target', QBit(), side=side), + ] + ) + + def adjoint(self) -> '_HalfComparisonBase': + return attrs.evolve(self, uncompute=self.uncompute ^ True) + + @cached_property + @abc.abstractmethod + def _half_greater_than_bloq(self) -> Bloq: + raise NotImplementedError() + + def _classical_comparison( + self, a: 'ClassicalValT', b: 'ClassicalValT' + ) -> Union[bool, np.bool_, NDArray[np.bool_]]: + if self._op_symbol == '>': + return a > b + elif self._op_symbol == '<': + return a < b + elif self._op_symbol == '>=': + return a >= b + else: + return a <= b + + def on_classical_vals( + self, + a: 'ClassicalValT', + b: 'ClassicalValT', + c: Optional['ClassicalValT'] = None, + target: Optional['ClassicalValT'] = None, + ) -> Dict[str, 'ClassicalValT']: + if self.uncompute: + assert c == add_ints( + int(a), + int(b), + num_bits=int(self.dtype.bitsize), + is_signed=isinstance(self.dtype, QInt), + ) + assert target == self._classical_comparison(a, b) + return {'a': a, 'b': b} + if self._op_symbol in ('>', '<='): + c = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) + else: + c = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False) + return {'a': a, 'b': b, 'c': c, 'target': int(self._classical_comparison(a, b))} + + def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']: + if self._op_symbol in ('>', '<='): + a, b, c, target = bb.add_from(self._half_greater_than_bloq, a=a, b=b) # type: ignore + else: + b, a, c, target = bb.add_from(self._half_greater_than_bloq, a=b, b=a) # type: ignore + + if self._op_symbol in ('<=', '>='): + target = bb.add(XGate(), q=target) + + return {'a': a, 'b': b, 'c': c, 'target': target} + + def _uncompute( + self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet', c: 'Soquet', target: 'Soquet' + ) -> Dict[str, 'SoquetT']: + if self._op_symbol in ('<=', '>='): + target = bb.add(XGate(), q=target) + + if self._op_symbol in ('>', '<='): + a, b = bb.add_from(self._half_greater_than_bloq.adjoint(), a=a, b=b, c=c, target=target) # type: ignore + else: + a, b = bb.add_from(self._half_greater_than_bloq.adjoint(), a=b, b=a, c=c, target=target) # type: ignore + + return {'a': a, 'b': b} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + a: 'Soquet', + b: 'Soquet', + c: Optional['Soquet'] = None, + target: Optional['Soquet'] = None, + ) -> Dict[str, 'SoquetT']: + if self.uncompute: + assert c is not None + assert target is not None + return self._uncompute(bb, a, b, c, target) + else: + assert c is None + assert target is None + return self._compute(bb, a, b) + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + extra_ops = {} + if isinstance(self.dtype, QInt): + extra_ops = { + SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)): 2, + SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(): 2, + } + if self._op_symbol in ('>=', '<='): + extra_ops[XGate()] = 1 + adder = self._half_greater_than_bloq + if self.uncompute: + adder = adder.adjoint() + adder_call_graph = adder.build_call_graph(ssa) + assert isinstance(adder_call_graph, dict) + counts: defaultdict['Bloq', Union[int, sympy.Expr]] = defaultdict(lambda: 0) + counts.update(adder_call_graph) + for k, v in extra_ops.items(): + counts[k] += v + return counts + + +@frozen +class _LinearDepthHalfComparisonBase(_HalfComparisonBase): + """A wrapper around _HalfComparisonBase that sets ._half_greater_than_bloq property to a construction with linear depth.""" + + @cached_property + def _half_greater_than_bloq(self) -> Bloq: + return _HalfLinearDepthGreaterThan(self.dtype, uncompute=False) + + +@frozen +class LinearDepthHalfGreaterThan(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{b-a}\ket{a>b}$ using $n$ And gates. + + This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a > b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='>', repr=False, init=False) + + +@frozen +class LinearDepthHalfGreaterThanEqual(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{a-b}\ket{a \geq b}$ using $n$ And gates. + + This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a >= b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='>=', repr=False, init=False) + + +@frozen +class LinearDepthHalfLessThan(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{a-b}\ket{a b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a < b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='<', repr=False, init=False) + + +@frozen +class LinearDepthHalfLessThanEqual(_LinearDepthHalfComparisonBase): + r"""Compare two integers while keeping necessary ancillas for zero cost uncomputation. + + Implements $\ket{a}\ket{b}\ket{0}\ket{0} \rightarrow \ket{a}\ket{b}\ket{b-a}\ket{a \leq b}$ using $n$ And gates. + + This comparator relies on the fact that c = (b' + a)' = b - a. If a > b, then b - a < 0. We + implement it by flipping all the bits in b, computing the first half of the addition circuit, + copying out the carry, and keeping $c$ for the uncomputation. + + Args: + dtype: dtype of the two integers a and b. + uncompute: whether this bloq uncomputes or computes the comparison. + + Registers: + a: first input register. + b: second input register. + c: ancilla register that will contain $b-a$ and will be used for uncomputation. + target: A single bit output register to store the result of a <= b. + + References: + [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). + """ + + _op_symbol: str = attrs.field(default='<=', repr=False, init=False) + + +@bloq_example +def _lineardepthhalfgreaterthan_small() -> LinearDepthHalfGreaterThan: + lineardepthhalfgreaterthan_small = LinearDepthHalfGreaterThan(QUInt(3)) + return lineardepthhalfgreaterthan_small + + +@bloq_example +def _lineardepthhalflessthan_small() -> LinearDepthHalfLessThan: + lineardepthhalflessthan_small = LinearDepthHalfLessThan(QUInt(3)) + return lineardepthhalflessthan_small + + +@bloq_example +def _lineardepthhalfgreaterthanequal_small() -> LinearDepthHalfGreaterThanEqual: + lineardepthhalfgreaterthanequal_small = LinearDepthHalfGreaterThanEqual(QUInt(3)) + return lineardepthhalfgreaterthanequal_small + + +@bloq_example +def _lineardepthhalflessthanequal_small() -> LinearDepthHalfLessThanEqual: + lineardepthhalflessthanequal_small = LinearDepthHalfLessThanEqual(QUInt(3)) + return lineardepthhalflessthanequal_small + + +_LINEAR_DEPTH_HALF_GREATERTHAN_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfGreaterThan, examples=[_lineardepthhalfgreaterthan_small] +) + + +_LINEAR_DEPTH_HALF_LESSTHAN_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfLessThan, examples=[_lineardepthhalflessthan_small] +) + + +_LINEAR_DEPTH_HALF_GREATERTHANEQUAL_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfGreaterThanEqual, examples=[_lineardepthhalfgreaterthanequal_small] +) + + +_LINEAR_DEPTH_HALF_LESSTHANEQUAL_DOC = BloqDocSpec( + bloq_cls=LinearDepthHalfLessThanEqual, examples=[_lineardepthhalflessthanequal_small] +) diff --git a/qualtran/bloqs/arithmetic/comparison_test.py b/qualtran/bloqs/arithmetic/comparison_test.py index f755d1eb7..9876537e8 100644 --- a/qualtran/bloqs/arithmetic/comparison_test.py +++ b/qualtran/bloqs/arithmetic/comparison_test.py @@ -28,6 +28,10 @@ _greater_than, _gt_k, _leq_symb, + _lineardepthhalfgreaterthan_small, + _lineardepthhalfgreaterthanequal_small, + _lineardepthhalflessthan_small, + _lineardepthhalflessthanequal_small, _lt_k_symb, BiQubitsMixer, CLinearDepthGreaterThan, @@ -38,10 +42,15 @@ LessThanConstant, LessThanEqual, LinearDepthGreaterThan, + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, SingleQubitCompare, ) from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim +from qualtran.resource_counting import get_cost_value, QECGatesCost from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join @@ -398,3 +407,97 @@ def test_clineardepthgreaterthan_tcomplexity(ctrl, dtype): c = CLinearDepthGreaterThan(dtype(n), ctrl).t_complexity() assert c.t == 4 * (n + 2) assert c.rotations == 0 + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +@pytest.mark.parametrize('bitsize', range(2, 5)) +@pytest.mark.parametrize('uncompute', [True, False]) +def test_linear_half_comparison_decomposition(comp_cls, dtype, bitsize, uncompute): + b = comp_cls(dtype(bitsize), uncompute) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +@pytest.mark.parametrize('bitsize', range(2, 5)) +@pytest.mark.parametrize('uncompute', [True, False]) +def test_linear_half_comparison_bloq_counts(comp_cls, dtype, bitsize, uncompute): + b = comp_cls(dtype(bitsize), uncompute) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +@pytest.mark.parametrize('bitsize', range(2, 5)) +def test_linear_half_comparison_classical_action(comp_cls, dtype, bitsize): + b = comp_cls(dtype(bitsize)) + qlt_testing.assert_consistent_classical_action( + b, a=dtype(bitsize).get_classical_domain(), b=dtype(bitsize).get_classical_domain() + ) + + +@pytest.mark.parametrize( + 'comp_cls', + [ + LinearDepthHalfGreaterThan, + LinearDepthHalfGreaterThanEqual, + LinearDepthHalfLessThan, + LinearDepthHalfLessThanEqual, + ], +) +@pytest.mark.parametrize('dtype', [QInt, QUInt, QMontgomeryUInt]) +def test_linear_half_comparison_symbolic_complexity(comp_cls, dtype): + n = sympy.Symbol('n') + b = comp_cls(dtype(n)) + + cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() + + assert cost['n_t'] == 0 + assert cost['n_ccz'] == n + + # uncomputation has zero cost. + cost = get_cost_value(b.adjoint(), QECGatesCost()).total_t_and_ccz_count() + + assert cost['n_t'] == 0 + assert cost['n_ccz'] == 0 + + +def test_lineardepthhalfgreaterthan_small(bloq_autotester): + bloq_autotester(_lineardepthhalfgreaterthan_small) + + +def test_lineardepthhalflessthan_small(bloq_autotester): + bloq_autotester(_lineardepthhalflessthan_small) + + +def test_lineardepthhalfgreaterthanequal_small(bloq_autotester): + bloq_autotester(_lineardepthhalfgreaterthanequal_small) + + +def test_lineardepthhalflessthanequal_small(bloq_autotester): + bloq_autotester(_lineardepthhalflessthanequal_small) diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index 4aa73158d..bc3be1a65 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -177,6 +177,10 @@ "qualtran.bloqs.arithmetic.comparison.LinearDepthGreaterThan": qualtran.bloqs.arithmetic.comparison.LinearDepthGreaterThan, "qualtran.bloqs.arithmetic.comparison.SingleQubitCompare": qualtran.bloqs.arithmetic.comparison.SingleQubitCompare, "qualtran.bloqs.arithmetic.comparison.CLinearDepthGreaterThan": qualtran.bloqs.arithmetic.comparison.CLinearDepthGreaterThan, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThan": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThan, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThan": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThan, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThanEqual": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfGreaterThanEqual, + "qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThanEqual": qualtran.bloqs.arithmetic.comparison.LinearDepthHalfLessThanEqual, "qualtran.bloqs.arithmetic.controlled_add_or_subtract.ControlledAddOrSubtract": qualtran.bloqs.arithmetic.controlled_add_or_subtract.ControlledAddOrSubtract, "qualtran.bloqs.arithmetic.conversions.contiguous_index.ToContiguousIndex": qualtran.bloqs.arithmetic.conversions.contiguous_index.ToContiguousIndex, "qualtran.bloqs.arithmetic.conversions.ones_complement_to_twos_complement.SignedIntegerToTwosComplement": qualtran.bloqs.arithmetic.conversions.ones_complement_to_twos_complement.SignedIntegerToTwosComplement, From e3aeee0d5e8f72304b2c8a1124196954e764ce0a Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Wed, 23 Oct 2024 14:27:30 -0700 Subject: [PATCH 5/7] Add ECAdd() Bloq (#1425) * Initial commit of ec add waiting on equals to be merged. * Working on tests for ECAdd * ECAdd implementation and tests * remove modmul typo * Fix mypy errors * Better bugfix for ModAdd * Change mod inv classical impl to use monttgomery inv * Fix pytest error * Fix some comments * ECAdd lots of testing * Add comments about bugs to be fixed * Reduce complexity by keeping intermediate values mod p * Stash qmontgomery tests * Address comments * Fix montgomery prod/inv calculations + pylint/mypy --------- Co-authored-by: Noureldin Co-authored-by: Matthew Harrigan --- qualtran/_infra/data_types.py | 44 +- qualtran/_infra/data_types_test.py | 26 + qualtran/bloqs/arithmetic/_shims.py | 22 +- qualtran/bloqs/factoring/ecc/ec_add.ipynb | 36 +- qualtran/bloqs/factoring/ecc/ec_add.py | 1075 ++++++++++++++++- qualtran/bloqs/factoring/ecc/ec_add_test.py | 412 ++++++- qualtran/bloqs/factoring/ecc/ec_point.py | 3 +- qualtran/bloqs/factoring/ecc/ec_point_test.py | 1 + qualtran/bloqs/mod_arithmetic/_shims.py | 55 +- .../mod_arithmetic/mod_multiplication.py | 3 +- qualtran/serialization/resolver_dict.py | 8 + 11 files changed, 1639 insertions(+), 46 deletions(-) diff --git a/qualtran/_infra/data_types.py b/qualtran/_infra/data_types.py index e21eb209f..ee918d506 100644 --- a/qualtran/_infra/data_types.py +++ b/qualtran/_infra/data_types.py @@ -772,9 +772,14 @@ class QMontgomeryUInt(QDType): bitsize: The number of qubits used to represent the integer. References: - [Montgomery modular multiplication](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication) + [Montgomery modular multiplication](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication). + + [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126133 Cat Qubits](https://arxiv.org/abs/2302.06639). + Gouzien et al. 2023. + We follow Montgomery form as described in the above paper; namely, r = 2^bitsize. """ + # TODO(https://github.com/quantumlib/Qualtran/issues/1471): Add modulus p as a class member. bitsize: SymbolicInt @property @@ -810,6 +815,43 @@ def assert_valid_classical_val_array( if np.any(val_array >= 2**self.bitsize): raise ValueError(f"Too-large classical values encountered in {debug_str}") + def montgomery_inverse(self, xm: int, p: int) -> int: + """Returns the modular inverse of an integer in montgomery form. + + Args: + xm: An integer in montgomery form. + p: The modulus of the finite field. + """ + return ((pow(xm, -1, p)) * pow(2, 2 * self.bitsize, p)) % p + + def montgomery_product(self, xm: int, ym: int, p: int) -> int: + """Returns the modular product of two integers in montgomery form. + + Args: + xm: The first montgomery form integer for the product. + ym: The second montgomery form integer for the product. + p: The modulus of the finite field. + """ + return (xm * ym * pow(2, -self.bitsize, p)) % p + + def montgomery_to_uint(self, xm: int, p: int) -> int: + """Converts an integer in montgomery form to a normal form integer. + + Args: + xm: An integer in montgomery form. + p: The modulus of the finite field. + """ + return (xm * pow(2, -self.bitsize, p)) % p + + def uint_to_montgomery(self, x: int, p: int) -> int: + """Converts an integer into montgomery form. + + Args: + x: An integer. + p: The modulus of the finite field. + """ + return (x * pow(2, int(self.bitsize), p)) % p + @attrs.frozen class QGF(QDType): diff --git a/qualtran/_infra/data_types_test.py b/qualtran/_infra/data_types_test.py index 10347b702..65252c5cb 100644 --- a/qualtran/_infra/data_types_test.py +++ b/qualtran/_infra/data_types_test.py @@ -135,6 +135,32 @@ def test_qmontgomeryuint(): assert is_symbolic(QMontgomeryUInt(sympy.Symbol('x'))) +@pytest.mark.parametrize('p', [13, 17, 29]) +@pytest.mark.parametrize('val', [1, 5, 7, 9]) +def test_qmontgomeryuint_operations(val, p): + qmontgomeryuint_8 = QMontgomeryUInt(8) + # Convert value to montgomery form and get the modular inverse. + val_m = qmontgomeryuint_8.uint_to_montgomery(val, p) + mod_inv = qmontgomeryuint_8.montgomery_inverse(val_m, p) + + # Calculate the product in montgomery form and convert back to normal form for assertion. + assert ( + qmontgomeryuint_8.montgomery_to_uint( + qmontgomeryuint_8.montgomery_product(val_m, mod_inv, p), p + ) + == 1 + ) + + +@pytest.mark.parametrize('p', [13, 17, 29]) +@pytest.mark.parametrize('val', [1, 5, 7, 9]) +def test_qmontgomeryuint_conversions(val, p): + qmontgomeryuint_8 = QMontgomeryUInt(8) + assert val == qmontgomeryuint_8.montgomery_to_uint( + qmontgomeryuint_8.uint_to_montgomery(val, p), p + ) + + def test_qgf(): qgf_256 = QGF(characteristic=2, degree=8) assert str(qgf_256) == 'QGF(2**8)' diff --git a/qualtran/bloqs/arithmetic/_shims.py b/qualtran/bloqs/arithmetic/_shims.py index 0d40daba7..d75e3d72f 100644 --- a/qualtran/bloqs/arithmetic/_shims.py +++ b/qualtran/bloqs/arithmetic/_shims.py @@ -22,8 +22,11 @@ from attrs import frozen -from qualtran import Bloq, QBit, QUInt, Register, Signature +from qualtran import Bloq, QBit, QMontgomeryUInt, QUInt, Register, Signature +from qualtran.bloqs.arithmetic.bitwise import BitwiseNot +from qualtran.bloqs.arithmetic.controlled_addition import CAdd from qualtran.bloqs.basic_gates import Toffoli +from qualtran.bloqs.basic_gates.swap import TwoBitCSwap from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -39,6 +42,20 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: return {Toffoli(): self.n - 2} +@frozen +class CSub(Bloq): + n: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [Register('ctrl', QBit()), Register('x', QUInt(self.n)), Register('y', QUInt(self.n))] + ) + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return {CAdd(QMontgomeryUInt(self.n)): 1, BitwiseNot(QMontgomeryUInt(self.n)): 3} + + @frozen class Lt(Bloq): n: int @@ -62,3 +79,6 @@ class CHalf(Bloq): @cached_property def signature(self) -> 'Signature': return Signature([Register('ctrl', QBit()), Register('x', QUInt(self.n))]) + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return {TwoBitCSwap(): self.n} diff --git a/qualtran/bloqs/factoring/ecc/ec_add.ipynb b/qualtran/bloqs/factoring/ecc/ec_add.ipynb index 543458c8c..cbc279bbe 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.ipynb +++ b/qualtran/bloqs/factoring/ecc/ec_add.ipynb @@ -41,16 +41,22 @@ "This takes elliptic curve points given by (a, b) and (x, y)\n", "and outputs the sum (x_r, y_r) in the second pair of registers.\n", "\n", + "Because the decomposition of this Bloq is complex, we split it into six separate parts\n", + "corresponding to the parts described in figure 10 of the Litinski paper cited below. We follow\n", + "the signature from figure 5 and break down the further decompositions based on the steps in\n", + "figure 10.\n", + "\n", "#### Parameters\n", " - `n`: The bitsize of the two registers storing the elliptic curve point\n", - " - `mod`: The modulus of the field in which we do the addition. \n", + " - `mod`: The modulus of the field in which we do the addition.\n", + " - `window_size`: The number of bits in the ModMult window. \n", "\n", "#### Registers\n", - " - `a`: The x component of the first input elliptic curve point of bitsize `n`.\n", - " - `b`: The y component of the first input elliptic curve point of bitsize `n`.\n", - " - `x`: The x component of the second input elliptic curve point of bitsize `n`, which will contain the x component of the resultant curve point.\n", - " - `y`: The y component of the second input elliptic curve point of bitsize `n`, which will contain the y component of the resultant curve point.\n", - " - `lam`: The precomputed lambda slope used in the addition operation. \n", + " - `a`: The x component of the first input elliptic curve point of bitsize `n` in montgomery form.\n", + " - `b`: The y component of the first input elliptic curve point of bitsize `n` in montgomery form.\n", + " - `x`: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the x component of the resultant curve point.\n", + " - `y`: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the y component of the resultant curve point.\n", + " - `lam_r`: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. \n", "\n", "#### References\n", " - [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). Litinski. 2023. Fig 5.\n" @@ -91,6 +97,18 @@ "ec_add = ECAdd(n, mod=p)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "170da165", + "metadata": { + "cq.autogen": "ECAdd.ec_add_small" + }, + "outputs": [], + "source": [ + "ec_add_small = ECAdd(5, mod=7)" + ] + }, { "cell_type": "markdown", "id": "39210af4", @@ -111,8 +129,8 @@ "outputs": [], "source": [ "from qualtran.drawing import show_bloqs\n", - "show_bloqs([ec_add],\n", - " ['`ec_add`'])" + "show_bloqs([ec_add, ec_add_small],\n", + " ['`ec_add`', '`ec_add_small`'])" ] }, { @@ -157,7 +175,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/qualtran/bloqs/factoring/ecc/ec_add.py b/qualtran/bloqs/factoring/ecc/ec_add.py index 74a57706a..5f8216777 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add.py +++ b/qualtran/bloqs/factoring/ecc/ec_add.py @@ -12,12 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cached_property +from typing import Dict, Union +import numpy as np import sympy from attrs import frozen -from qualtran import Bloq, bloq_example, BloqDocSpec, QUInt, Register, Signature -from qualtran.bloqs.arithmetic._shims import MultiCToffoli +from qualtran import ( + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + DecomposeTypeError, + QBit, + QMontgomeryUInt, + Register, + Side, + Signature, + Soquet, + SoquetT, +) +from qualtran.bloqs.arithmetic.comparison import Equals +from qualtran.bloqs.basic_gates import CNOT, IntState, Toffoli, ZeroState +from qualtran.bloqs.bookkeeping import Free +from qualtran.bloqs.mcmt import MultiAnd, MultiControlX, MultiTargetCNOT from qualtran.bloqs.mod_arithmetic import ( CModAdd, CModNeg, @@ -30,6 +48,925 @@ ) from qualtran.bloqs.mod_arithmetic._shims import ModInv from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.simulation.classical_sim import ClassicalValT +from qualtran.symbolics.types import HasLength, is_symbolic + +from .ec_point import ECPoint + + +@frozen +class _ECAddStepOne(Bloq): + r"""Performs step one of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + + Registers: + f1: Flag to set if a = x. + f2: Flag to set if b = -y. + f3: Flag to set if (a, b) = (0, 0). + f4: Flag to set if (x, y) = (0, 0). + ctrl: Flag to set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit(), side=Side.RIGHT), + Register('f2', QBit(), side=Side.RIGHT), + Register('f3', QBit(), side=Side.RIGHT), + Register('f4', QBit(), side=Side.RIGHT), + Register('ctrl', QBit(), side=Side.RIGHT), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, a: 'ClassicalValT', b: 'ClassicalValT', x: 'ClassicalValT', y: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + f1 = int(a == x) + f2 = int(b == (-y % self.mod)) + f3 = int(a == b == 0) + f4 = int(x == y == 0) + ctrl = int(f2 == f3 == f4 == 0) + return { + 'f1': f1, + 'f2': f2, + 'f3': f3, + 'f4': f4, + 'ctrl': ctrl, + 'a': a, + 'b': b, + 'x': x, + 'y': y, + } + + def build_composite_bloq( + self, bb: 'BloqBuilder', a: Soquet, b: Soquet, x: Soquet, y: Soquet + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initialize control flags to 0. + f1 = bb.add(ZeroState()) + f2 = bb.add(ZeroState()) + f3 = bb.add(ZeroState()) + f4 = bb.add(ZeroState()) + ctrl = bb.add(ZeroState()) + + # Set flag 1 if a = x. + a, x, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=a, y=x, target=f1) + + # Set flag 2 if b = -y. + y = bb.add(ModNeg(QMontgomeryUInt(self.n), mod=self.mod), x=y) + b, y, f2 = bb.add(Equals(QMontgomeryUInt(self.n)), x=b, y=y, target=f2) + y = bb.add(ModNeg(QMontgomeryUInt(self.n), mod=self.mod), x=y) + + # Set flag 3 if (a, b) == (0, 0). + ab_arr = np.concatenate([bb.split(a), bb.split(b)]) + ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) + ab_arr = np.split(ab_arr, 2) + a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Set flag 4 if (x, y) == (0, 0). + xy_arr = np.concatenate([bb.split(x), bb.split(y)]) + xy_arr, f4 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=xy_arr, target=f4) + xy_arr = np.split(xy_arr, 2) + x = bb.join(xy_arr[0], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Set ctrl flag if f2, f3, f4 are set. + f_ctrls = [f2, f3, f4] + f_ctrls, ctrl = bb.add(MultiControlX(cvs=[0] * 3), controls=f_ctrls, target=ctrl) + f2 = f_ctrls[0] + f3 = f_ctrls[1] + f4 = f_ctrls[2] + + # Return the output registers. + return { + 'f1': f1, + 'f2': f2, + 'f3': f3, + 'f4': f4, + 'ctrl': ctrl, + 'a': a, + 'b': b, + 'x': x, + 'y': y, + } + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + cvs: Union[list[int], HasLength] + if isinstance(self.n, int): + cvs = [0] * 2 * self.n + else: + cvs = HasLength(2 * self.n) + return { + Equals(QMontgomeryUInt(self.n)): 2, + ModNeg(QMontgomeryUInt(self.n), mod=self.mod): 2, + MultiControlX(cvs=cvs): 2, + MultiControlX(cvs=[0] * 3): 1, + } + + +@frozen +class _ECAddStepTwo(Bloq): + r"""Performs step two of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + f1: Flag set if a = x. + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit()), + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n), side=Side.RIGHT), + Register('lam_r', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + f1: 'ClassicalValT', + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam_r: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + x = (x - a) % self.mod + if ctrl == 1: + y = (y - b) % self.mod + if f1 == 1: + lam = lam_r + f1 = 0 + else: + lam = QMontgomeryUInt(self.n).montgomery_product( + int(y), QMontgomeryUInt(self.n).montgomery_inverse(int(x), self.mod), self.mod + ) + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit + # which flips f1 when lam and lam_r are equal. + if lam == lam_r: + f1 = (f1 + 1) % 2 + else: + lam = 0 + return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + f1: Soquet, + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam_r: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initalize lambda to 0. + lam = bb.add(IntState(bitsize=self.n, val=0)) + + # Perform modular subtraction so that x = (x - a) % p. + a, x = bb.add(ModSub(QMontgomeryUInt(self.n), mod=self.mod), x=a, y=x) + + # Perform controlled modular subtraction so that y = (y - b) % p iff ctrl = 1. + ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) + + # Perform modular inversion s.t. x = (x - a)^-1 % p. + x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + + # Perform modular multiplication z4 = (y / x) % p. + x, y, z4, z3, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=y, + ) + + # If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p. + z4_split = bb.split(z4) + lam_split = bb.split(lam) + for i in range(self.n): + ctrls = [f1, ctrl, z4_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[0, 1, 1]), controls=ctrls, target=lam_split[i] + ) + f1 = ctrls[0] + ctrl = ctrls[1] + z4_split[i] = ctrls[2] + z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) + + # If ctrl = 1 and x = a: lam = lam_r. + lam_r_split = bb.split(lam_r) + for i in range(self.n): + ctrls = [f1, ctrl, lam_r_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i] + ) + f1 = ctrls[0] + ctrl = ctrls[1] + lam_r_split[i] = ctrls[2] + lam_r = bb.join(lam_r_split, dtype=QMontgomeryUInt(self.n)) + lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) + + # If lam = lam_r: return f1 = 0. (If not we will flip f1 to 0 at the end iff x_r = y_r = 0). + lam, lam_r, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=lam, y=lam_r, target=f1) + + # Uncompute the modular multiplication then the modular inversion. + x, y = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=y, + target=z4, + qrom_indices=z3, + reduced=reduced, + ) + x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + + # Return the output registers. + return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + Equals(QMontgomeryUInt(self.n)): 1, + ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModInv(n=self.n, mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + MultiControlX(cvs=[0, 1, 1]): self.n, + MultiControlX(cvs=[1, 1, 1]): self.n, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + ModInv(n=self.n, mod=self.mod).adjoint(): 1, + } + + +@frozen +class _ECAddStepThree(Bloq): + r"""Performs step three of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 1: + x = (x + 3 * a) % self.mod + y = 0 + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Store (x - a) * lam % p in z1 (= (y - b) % p). + x, lam, z1, z2, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=lam, + ) + + # If ctrl: subtract z1 from y (= 0). + ctrl, z1, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=z1, y=y) + + # Uncompute original multiplication. + x, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=lam, + target=z1, + qrom_indices=z2, + reduced=reduced, + ) + + # z1 = a. + z1 = bb.add(IntState(bitsize=self.n, val=0)) + a_split = bb.split(a) + z1_split = bb.split(z1) + for i in range(self.n): + a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) + a = bb.join(a_split, QMontgomeryUInt(self.n)) + z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) + + # z1 = (3 * a) % p. + z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod), x=z1) + a, z1 = bb.add(ModAdd(self.n, mod=self.mod), x=a, y=z1) + + # If ctrl: x = (x + 2 * a) % p. + ctrl, z1, x = bb.add(CModAdd(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=z1, y=x) + + # Uncompute z1. + a, z1 = bb.add(ModAdd(self.n, mod=self.mod).adjoint(), x=a, y=z1) + z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(), x=z1) + a_split = bb.split(a) + z1_split = bb.split(z1) + for i in range(self.n): + a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i]) + a = bb.join(a_split, QMontgomeryUInt(self.n)) + z1 = bb.join(z1_split, QMontgomeryUInt(self.n)) + bb.add(Free(QMontgomeryUInt(self.n)), reg=z1) + + # Return the output registers. + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + CNOT(): 2 * self.n, + ModDbl(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModAdd(self.n, mod=self.mod): 1, + CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModAdd(self.n, mod=self.mod).adjoint(): 1, + ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(): 1, + } + + +@frozen +class _ECAddStepFour(Bloq): + r"""Performs step four of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT' + ) -> Dict[str, 'ClassicalValT']: + x = ( + x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), self.mod) + ) % self.mod + if lam > 0: + y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), self.mod) + return {'x': x, 'y': y, 'lam': lam} + + def build_composite_bloq( + self, bb: 'BloqBuilder', x: Soquet, y: Soquet, lam: Soquet + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Initialize z4 = lam. + z4 = bb.add(IntState(bitsize=self.n, val=0)) + lam_split = bb.split(lam) + z4_split = bb.split(z4) + for i in range(self.n): + lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) + lam = bb.join(lam_split, QMontgomeryUInt(self.n)) + z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) + + # z3 = lam * lam % p. + z4, lam, z3, z2, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=z4, + y=lam, + ) + + # x = a - x_r % p. + z3, x = bb.add(ModSub(QMontgomeryUInt(self.n), mod=self.mod), x=z3, y=x) + + # Uncompute the multiplication and initialization of z4. + z4, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=z4, + y=lam, + target=z3, + qrom_indices=z2, + reduced=reduced, + ) + lam_split = bb.split(lam) + z4_split = bb.split(z4) + for i in range(self.n): + lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i]) + lam = bb.join(lam_split, QMontgomeryUInt(self.n)) + z4 = bb.join(z4_split, QMontgomeryUInt(self.n)) + bb.add(Free(QMontgomeryUInt(self.n)), reg=z4) + + # z3 = lam * x % p. + x, lam, z3, z4, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=lam, + ) + + # y = y_r + b % p. + z3_split = bb.split(z3) + y_split = bb.split(y) + for i in range(self.n): + z3_split[i], y_split[i] = bb.add(CNOT(), ctrl=z3_split[i], target=y_split[i]) + z3 = bb.join(z3_split, QMontgomeryUInt(self.n)) + y = bb.join(y_split, QMontgomeryUInt(self.n)) + + # Uncompute multiplication. + x, lam = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=lam, + target=z3, + qrom_indices=z4, + reduced=reduced, + ) + + # Return the output registers. + return {'x': x, 'y': y, 'lam': lam} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 2, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 2, + CNOT(): 3 * self.n, + } + + +@frozen +class _ECAddStepFive(Bloq): + r"""Performs step five of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. + + Registers: + ctrl: Flag set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + lam: The lambda slope used in the addition operation. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + window_size: int = 1 + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('ctrl', QBit()), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam', QMontgomeryUInt(self.n), side=Side.LEFT), + ] + ) + + def on_classical_vals( + self, + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + lam: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if ctrl == 1: + x = (a - x) % self.mod + y = (y - b) % self.mod + else: + x = (x + a) % self.mod + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + lam: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # x = x ^ -1 % p. + x, z1, z2 = bb.add(ModInv(n=self.n, mod=self.mod), x=x) + + # z4 = x * y % p. + x, y, z4, z3, reduced = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ), + x=x, + y=y, + ) + + # If ctrl: lam = 0. + z4_split = bb.split(z4) + lam_split = bb.split(lam) + for i in range(self.n): + ctrls = [ctrl, z4_split[i]] + ctrls, lam_split[i] = bb.add( + MultiControlX(cvs=[1, 1]), controls=ctrls, target=lam_split[i] + ) + ctrl = ctrls[0] + z4_split[i] = ctrls[1] + z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) + lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit where lambda + # is not set to 0 before being freed. + bb.add(Free(QMontgomeryUInt(self.n), dirty=True), reg=lam) + + # Uncompute multiplication and inverse. + x, y = bb.add( + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(), + x=x, + y=y, + target=z4, + qrom_indices=z3, + reduced=reduced, + ) + x = bb.add(ModInv(n=self.n, mod=self.mod).adjoint(), x=x, garbage1=z1, garbage2=z2) + + # If ctrl: x = x_r - a % p. + ctrl, x = bb.add(CModNeg(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=x) + + # Add a to x (x = x_r). + a, x = bb.add(ModAdd(self.n, mod=self.mod), x=a, y=x) + + # If ctrl: subtract b from y (y = y_r). + ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) + + # Return the output registers. + return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + return { + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + ModInv(n=self.n, mod=self.mod): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ): 1, + DirtyOutOfPlaceMontgomeryModMul( + bitsize=self.n, window_size=self.window_size, mod=self.mod + ).adjoint(): 1, + ModInv(n=self.n, mod=self.mod).adjoint(): 1, + ModAdd(self.n, mod=self.mod): 1, + MultiControlX(cvs=[1, 1]): self.n, + CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1, + } + + +@frozen +class _ECAddStepSix(Bloq): + r"""Performs step six of the ECAdd bloq. + + Args: + n: The bitsize of the two registers storing the elliptic curve point + mod: The modulus of the field in which we do the addition. + + Registers: + f1: Flag to set if a = x. + f2: Flag to set if b = -y. + f3: Flag to set if (a, b) = (0, 0). + f4: Flag to set if (x, y) = (0, 0). + ctrl: Flag to set if neither the input points nor the output point are (0, 0). + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the x component of the resultant curve point. + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which + will contain the y component of the resultant curve point. + + References: + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + Fig 10. + """ + + n: int + mod: int + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('f1', QBit(), side=Side.LEFT), + Register('f2', QBit(), side=Side.LEFT), + Register('f3', QBit(), side=Side.LEFT), + Register('f4', QBit(), side=Side.LEFT), + Register('ctrl', QBit(), side=Side.LEFT), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + ] + ) + + def on_classical_vals( + self, + f1: 'ClassicalValT', + f2: 'ClassicalValT', + f3: 'ClassicalValT', + f4: 'ClassicalValT', + ctrl: 'ClassicalValT', + a: 'ClassicalValT', + b: 'ClassicalValT', + x: 'ClassicalValT', + y: 'ClassicalValT', + ) -> Dict[str, 'ClassicalValT']: + if f4 == 1: + x = a + y = b + if f1 and f2: + x = 0 + y = 0 + return {'a': a, 'b': b, 'x': x, 'y': y} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + f1: Soquet, + f2: Soquet, + f3: Soquet, + f4: Soquet, + ctrl: Soquet, + a: Soquet, + b: Soquet, + x: Soquet, + y: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.n): + raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `n`.") + + # Unset control if f2, f3, and f4 flags are set. + f_ctrls = [f2, f3, f4] + f_ctrls, ctrl = bb.add(MultiControlX(cvs=[0] * 3), controls=f_ctrls, target=ctrl) + f2 = f_ctrls[0] + f3 = f_ctrls[1] + f4 = f_ctrls[2] + + # Set (x, y) to (a, b) if f4 is set. + a_split = bb.split(a) + x_split = bb.split(x) + for i in range(self.n): + toff_ctrl = [f4, a_split[i]] + toff_ctrl, x_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=x_split[i]) + f4 = toff_ctrl[0] + a_split[i] = toff_ctrl[1] + a = bb.join(a_split, QMontgomeryUInt(self.n)) + x = bb.join(x_split, QMontgomeryUInt(self.n)) + b_split = bb.split(b) + y_split = bb.split(y) + for i in range(self.n): + toff_ctrl = [f4, b_split[i]] + toff_ctrl, y_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=y_split[i]) + f4 = toff_ctrl[0] + b_split[i] = toff_ctrl[1] + b = bb.join(b_split, QMontgomeryUInt(self.n)) + y = bb.join(y_split, QMontgomeryUInt(self.n)) + + # Unset f4 if (x, y) = (a, b). + ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n)) + xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) + ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) + ab_split = bb.split(ab) + a = bb.join(ab_split[: self.n], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_split[self.n :], dtype=QMontgomeryUInt(self.n)) + xy_split = bb.split(xy) + x = bb.join(xy_split[: self.n], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_split[self.n :], dtype=QMontgomeryUInt(self.n)) + + # Unset f3 if (a, b) = (0, 0). + ab_arr = np.concatenate([bb.split(a), bb.split(b)]) + ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) + ab_arr = np.split(ab_arr, 2) + a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) + b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) + + # If f1 and f2 are set, subtract a from x and add b to y. + ancilla = bb.add(ZeroState()) + toff_ctrl = [f1, f2] + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + ancilla, a, x = bb.add( + CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ancilla, x=a, y=x + ) + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + f1 = toff_ctrl[0] + f2 = toff_ctrl[1] + bb.add(Free(QBit()), reg=ancilla) + ancilla = bb.add(ZeroState()) + toff_ctrl = [f1, f2] + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + ancilla, b, y = bb.add( + CModAdd(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ancilla, x=b, y=y + ) + toff_ctrl, ancilla = bb.add(Toffoli(), ctrl=toff_ctrl, target=ancilla) + f1 = toff_ctrl[0] + f2 = toff_ctrl[1] + bb.add(Free(QBit()), reg=ancilla) + + # Unset f1 and f2 if (x, y) = (0, 0). + xy_arr = np.concatenate([bb.split(x), bb.split(y)]) + xy_arr, junk, out = bb.add(MultiAnd(cvs=[0] * 2 * self.n), ctrl=xy_arr) + targets = bb.join(np.array([f1, f2])) + out, targets = bb.add(MultiTargetCNOT(2), control=out, targets=targets) + targets = bb.split(targets) + f1 = targets[0] + f2 = targets[1] + xy_arr = bb.add( + MultiAnd(cvs=[0] * 2 * self.n).adjoint(), ctrl=xy_arr, junk=junk, target=out + ) + xy_arr = np.split(xy_arr, 2) + x = bb.join(xy_arr[0], dtype=QMontgomeryUInt(self.n)) + y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) + + # Free all ancilla qubits in the zero state. + # TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bugs in circuit where f1, + # f2, and f4 are freed before being set to 0. + bb.add(Free(QBit(), dirty=True), reg=f1) + bb.add(Free(QBit(), dirty=True), reg=f2) + bb.add(Free(QBit()), reg=f3) + bb.add(Free(QBit(), dirty=True), reg=f4) + bb.add(Free(QBit()), reg=ctrl) + + # Return the output registers. + return {'a': a, 'b': b, 'x': x, 'y': y} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: + cvs: Union[list[int], HasLength] + if isinstance(self.n, int): + cvs = [0] * 2 * self.n + else: + cvs = HasLength(2 * self.n) + return { + MultiControlX(cvs=cvs): 1, + MultiControlX(cvs=[0] * 3): 1, + CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, + CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, + Toffoli(): 2 * self.n + 4, + Equals(QMontgomeryUInt(2 * self.n)): 1, + MultiAnd(cvs=cvs): 1, + MultiTargetCNOT(2): 1, + MultiAnd(cvs=cvs).adjoint(): 1, + } @frozen @@ -39,18 +976,24 @@ class ECAdd(Bloq): This takes elliptic curve points given by (a, b) and (x, y) and outputs the sum (x_r, y_r) in the second pair of registers. + Because the decomposition of this Bloq is complex, we split it into six separate parts + corresponding to the parts described in figure 10 of the Litinski paper cited below. We follow + the signature from figure 5 and break down the further decompositions based on the steps in + figure 10. + Args: n: The bitsize of the two registers storing the elliptic curve point mod: The modulus of the field in which we do the addition. + window_size: The number of bits in the ModMult window. Registers: - a: The x component of the first input elliptic curve point of bitsize `n`. - b: The y component of the first input elliptic curve point of bitsize `n`. - x: The x component of the second input elliptic curve point of bitsize `n`, which + a: The x component of the first input elliptic curve point of bitsize `n` in montgomery form. + b: The y component of the first input elliptic curve point of bitsize `n` in montgomery form. + x: The x component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the x component of the resultant curve point. - y: The y component of the second input elliptic curve point of bitsize `n`, which + y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which will contain the y component of the resultant curve point. - lam: The precomputed lambda slope used in the addition operation. + lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. References: [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). @@ -59,32 +1002,108 @@ class ECAdd(Bloq): n: int mod: int + window_size: int = 1 @cached_property def signature(self) -> 'Signature': return Signature( [ - Register('a', QUInt(self.n)), - Register('b', QUInt(self.n)), - Register('x', QUInt(self.n)), - Register('y', QUInt(self.n)), - Register('lam', QUInt(self.n)), + Register('a', QMontgomeryUInt(self.n)), + Register('b', QMontgomeryUInt(self.n)), + Register('x', QMontgomeryUInt(self.n)), + Register('y', QMontgomeryUInt(self.n)), + Register('lam_r', QMontgomeryUInt(self.n)), ] ) - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: - # litinksi + def build_composite_bloq( + self, bb: 'BloqBuilder', a: Soquet, b: Soquet, x: Soquet, y: Soquet, lam_r: Soquet + ) -> Dict[str, 'SoquetT']: + f1, f2, f3, f4, ctrl, a, b, x, y = bb.add( + _ECAddStepOne(n=self.n, mod=self.mod), a=a, b=b, x=x, y=y + ) + f1, ctrl, a, b, x, y, lam, lam_r = bb.add( + _ECAddStepTwo(n=self.n, mod=self.mod, window_size=self.window_size), + f1=f1, + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam_r=lam_r, + ) + ctrl, a, b, x, y, lam = bb.add( + _ECAddStepThree(n=self.n, mod=self.mod, window_size=self.window_size), + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam=lam, + ) + x, y, lam = bb.add( + _ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size), x=x, y=y, lam=lam + ) + ctrl, a, b, x, y = bb.add( + _ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size), + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + lam=lam, + ) + a, b, x, y = bb.add( + _ECAddStepSix(n=self.n, mod=self.mod), + f1=f1, + f2=f2, + f3=f3, + f4=f4, + ctrl=ctrl, + a=a, + b=b, + x=x, + y=y, + ) + + return {'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r} + + def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]: + curve_a = ( + QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, self.mod) + * 2 + * QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod) + - (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod) ** 2) + ) % self.mod + p1 = ECPoint( + QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod), + mod=self.mod, + curve_a=curve_a, + ) + p2 = ECPoint( + QMontgomeryUInt(self.n).montgomery_to_uint(x, self.mod), + QMontgomeryUInt(self.n).montgomery_to_uint(y, self.mod), + mod=self.mod, + curve_a=curve_a, + ) + result = p1 + p2 + return { + 'a': a, + 'b': b, + 'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, self.mod), + 'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, self.mod), + 'lam_r': lam_r, + } + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': return { - MultiCToffoli(n=self.n): 18, - ModAdd(bitsize=self.n, mod=self.mod): 3, - CModAdd(QUInt(self.n), mod=self.mod): 2, - ModSub(QUInt(self.n), mod=self.mod): 2, - CModSub(QUInt(self.n), mod=self.mod): 4, - ModNeg(QUInt(self.n), mod=self.mod): 2, - CModNeg(QUInt(self.n), mod=self.mod): 1, - ModDbl(QUInt(self.n), mod=self.mod): 2, - DirtyOutOfPlaceMontgomeryModMul(bitsize=self.n, window_size=4, mod=self.mod): 10, - ModInv(n=self.n, mod=self.mod): 4, + _ECAddStepOne(n=self.n, mod=self.mod): 1, + _ECAddStepTwo(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepThree(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size): 1, + _ECAddStepSix(n=self.n, mod=self.mod): 1, } @@ -95,4 +1114,10 @@ def _ec_add() -> ECAdd: return ec_add -_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add]) +@bloq_example +def _ec_add_small() -> ECAdd: + ec_add_small = ECAdd(5, mod=7) + return ec_add_small + + +_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add, _ec_add_small]) diff --git a/qualtran/bloqs/factoring/ecc/ec_add_test.py b/qualtran/bloqs/factoring/ecc/ec_add_test.py index 44a8b77e1..5295316f1 100644 --- a/qualtran/bloqs/factoring/ecc/ec_add_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_add_test.py @@ -12,13 +12,423 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import sympy + import qualtran.testing as qlt_testing -from qualtran.bloqs.factoring.ecc.ec_add import _ec_add +from qualtran._infra.data_types import QMontgomeryUInt +from qualtran.bloqs.factoring.ecc.ec_add import ( + _ec_add, + _ec_add_small, + _ECAddStepFive, + _ECAddStepFour, + _ECAddStepOne, + _ECAddStepSix, + _ECAddStepThree, + _ECAddStepTwo, + ECAdd, +) +from qualtran.resource_counting._bloq_counts import QECGatesCost +from qualtran.resource_counting._costing import get_cost_value +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join + + +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 8) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize('a,b', [(15, 13), (2, 10)]) +@pytest.mark.parametrize('x,y', [(15, 13), (0, 0)]) +def test_ec_add_steps_classical_fast(n, m, a, b, x, y): + p = 17 + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + + a = QMontgomeryUInt(n).uint_to_montgomery(a, p) + b = QMontgomeryUInt(n).uint_to_montgomery(b, p) + x = QMontgomeryUInt(n).uint_to_montgomery(x, p) + y = QMontgomeryUInt(n).uint_to_montgomery(y, p) + lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p + + bloq = _ECAddStepOne(n=n, mod=p) + ret1 = bloq.call_classically(a=a, b=b, x=x, y=y) + ret2 = bloq.decompose_bloq().call_classically(a=a, b=b, x=x, y=y) + assert ret1 == ret2 + + step_1 = _ECAddStepOne(n=n, mod=p).on_classical_vals(a=a, b=b, x=x, y=y) + bloq = _ECAddStepTwo(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + assert ret1 == ret2 + + step_2 = _ECAddStepTwo(n=n, mod=p, window_size=m).on_classical_vals( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + bloq = _ECAddStepThree(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + assert ret1 == ret2 + + step_3 = _ECAddStepThree(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + bloq = _ECAddStepFour(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + ret2 = bloq.decompose_bloq().call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + assert ret1 == ret2 + + step_4 = _ECAddStepFour(n=n, mod=p, window_size=m).on_classical_vals( + x=step_3['x'], y=step_3['y'], lam=step_3['lam'] + ) + bloq = _ECAddStepFive(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + assert ret1 == ret2 + + step_5 = _ECAddStepFive(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + bloq = _ECAddStepSix(n=n, mod=p) + ret1 = bloq.call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + assert ret1 == ret2 + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize( + 'a,b', + [ + (15, 13), + (2, 10), + (8, 3), + (12, 1), + (6, 6), + (5, 8), + (10, 15), + (1, 12), + (3, 0), + (1, 5), + (10, 2), + (0, 0), + ], +) +@pytest.mark.parametrize('x,y', [(15, 13), (5, 8), (10, 15), (1, 12), (3, 0), (1, 5), (10, 2)]) +def test_ec_add_steps_classical(n, m, a, b, x, y): + p = 17 + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + + a = QMontgomeryUInt(n).uint_to_montgomery(a, p) + b = QMontgomeryUInt(n).uint_to_montgomery(b, p) + x = QMontgomeryUInt(n).uint_to_montgomery(x, p) + y = QMontgomeryUInt(n).uint_to_montgomery(y, p) + lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p + + bloq = _ECAddStepOne(n=n, mod=p) + ret1 = bloq.call_classically(a=a, b=b, x=x, y=y) + ret2 = bloq.decompose_bloq().call_classically(a=a, b=b, x=x, y=y) + assert ret1 == ret2 + + step_1 = _ECAddStepOne(n=n, mod=p).on_classical_vals(a=a, b=b, x=x, y=y) + bloq = _ECAddStepTwo(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + assert ret1 == ret2 + + step_2 = _ECAddStepTwo(n=n, mod=p, window_size=m).on_classical_vals( + f1=step_1['f1'], ctrl=step_1['ctrl'], a=a, b=b, x=x, y=y, lam_r=lam_r + ) + bloq = _ECAddStepThree(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + assert ret1 == ret2 + + step_3 = _ECAddStepThree(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_2['ctrl'], + a=step_2['a'], + b=step_2['b'], + x=step_2['x'], + y=step_2['y'], + lam=step_2['lam'], + ) + bloq = _ECAddStepFour(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + ret2 = bloq.decompose_bloq().call_classically(x=step_3['x'], y=step_3['y'], lam=step_3['lam']) + assert ret1 == ret2 + + step_4 = _ECAddStepFour(n=n, mod=p, window_size=m).on_classical_vals( + x=step_3['x'], y=step_3['y'], lam=step_3['lam'] + ) + bloq = _ECAddStepFive(n=n, mod=p, window_size=m) + ret1 = bloq.call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + ret2 = bloq.decompose_bloq().call_classically( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + assert ret1 == ret2 + + step_5 = _ECAddStepFive(n=n, mod=p, window_size=m).on_classical_vals( + ctrl=step_3['ctrl'], + a=step_3['a'], + b=step_3['b'], + x=step_4['x'], + y=step_4['y'], + lam=step_4['lam'], + ) + bloq = _ECAddStepSix(n=n, mod=p) + ret1 = bloq.call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + ret2 = bloq.decompose_bloq().call_classically( + f1=step_2['f1'], + f2=step_1['f2'], + f3=step_1['f3'], + f4=step_1['f4'], + ctrl=step_5['ctrl'], + a=step_5['a'], + b=step_5['b'], + x=step_5['x'], + y=step_5['y'], + ) + assert ret1 == ret2 + + +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 8) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize('a,b', [(15, 13), (2, 10)]) +@pytest.mark.parametrize('x,y', [(15, 13), (0, 0)]) +def test_ec_add_classical_fast(n, m, a, b, x, y): + p = 17 + bloq = ECAdd(n=n, mod=p, window_size=m) + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + ret1 = bloq.call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + ret2 = bloq.decompose_bloq().call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + assert ret1 == ret2 + + +@pytest.mark.slow +@pytest.mark.parametrize( + ['n', 'm'], [(n, m) for n in range(7, 9) for m in range(1, n + 1) if n % m == 0] +) +@pytest.mark.parametrize( + 'a,b', + [ + (15, 13), + (2, 10), + (8, 3), + (12, 1), + (6, 6), + (5, 8), + (10, 15), + (1, 12), + (3, 0), + (1, 5), + (10, 2), + (0, 0), + ], +) +@pytest.mark.parametrize('x,y', [(15, 13), (5, 8), (10, 15), (1, 12), (3, 0), (1, 5), (10, 2)]) +def test_ec_add_classical(n, m, a, b, x, y): + p = 17 + bloq = ECAdd(n=n, mod=p, window_size=m) + lam_num = (3 * a**2) % p + lam_denom = (2 * b) % p + lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p + ret1 = bloq.call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + ret2 = bloq.decompose_bloq().call_classically( + a=QMontgomeryUInt(n).uint_to_montgomery(a, p), + b=QMontgomeryUInt(n).uint_to_montgomery(b, p), + x=QMontgomeryUInt(n).uint_to_montgomery(x, p), + y=QMontgomeryUInt(n).uint_to_montgomery(y, p), + lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p), + ) + assert ret1 == ret2 + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize( + ['n', 'window_size'], + [ + (n, window_size) + for n in range(5, 8) + for window_size in range(1, n + 1) + if n % window_size == 0 + ], +) +def test_ec_add_decomposition(n, window_size, p): + b = ECAdd(n=n, window_size=window_size, mod=p) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize('p', (7, 9, 11)) +@pytest.mark.parametrize( + ['n', 'window_size'], + [ + (n, window_size) + for n in range(5, 8) + for window_size in range(1, n + 1) + if n % window_size == 0 + ], +) +def test_ec_add_bloq_counts(n, window_size, p): + b = ECAdd(n=n, window_size=window_size, mod=p) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +def test_ec_add_symbolic_cost(): + n, m, p = sympy.symbols('n m p', integer=True) + + # In Litinski 2023 https://arxiv.org/abs/2306.08585 a window size of 4 is used. + # The cost function generally has floor/ceil division that disappear for bitsize=0 mod 4. + # This is why instead of using bitsize=n directly, we use bitsize=4*m=n. + b = ECAdd(n=4 * m, window_size=4, mod=p) + cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() + assert cost['n_t'] == 0 + + # Litinski 2023 https://arxiv.org/abs/2306.08585 + # Based on the counts from Figures 3, 5, and 8 the toffoli count for ECAdd is 126.5n^2 + 189n. + # The following formula is 126.5n^2 + 175.5n - 35. We account for the discrepancy in the + # coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, n extra toffolis + # in ModNeg, and 2n extra toffolis to do n 3-controlled toffolis in step 2. The expression is + # written with rationals because sympy comparison fails with floats. + assert isinstance(cost['n_ccz'], sympy.Expr) + assert ( + cost['n_ccz'].subs(m, n / 4).expand() + == sympy.Rational(253, 2) * n**2 + sympy.Rational(351, 2) * n - 35 + ) def test_ec_add(bloq_autotester): bloq_autotester(_ec_add) +def test_ec_add_small(bloq_autotester): + bloq_autotester(_ec_add_small) + + def test_notebook(): qlt_testing.execute_notebook('ec_add') diff --git a/qualtran/bloqs/factoring/ecc/ec_point.py b/qualtran/bloqs/factoring/ecc/ec_point.py index 968ebe0b5..c17ea5957 100644 --- a/qualtran/bloqs/factoring/ecc/ec_point.py +++ b/qualtran/bloqs/factoring/ecc/ec_point.py @@ -69,7 +69,8 @@ def __add__(self, other): return ECPoint(xr, yr, mod=self.mod, curve_a=self.curve_a) def __mul__(self, other): - assert other > 0, other + if other == 0: + return ECPoint.inf(mod=self.mod, curve_a=self.curve_a) x = self for _ in range(other - 1): x = x + self diff --git a/qualtran/bloqs/factoring/ecc/ec_point_test.py b/qualtran/bloqs/factoring/ecc/ec_point_test.py index f65981c80..1ac1b59bd 100644 --- a/qualtran/bloqs/factoring/ecc/ec_point_test.py +++ b/qualtran/bloqs/factoring/ecc/ec_point_test.py @@ -21,6 +21,7 @@ def test_ec_point_overrides(): assert 1 * p == p assert 2 * p == (p + p) assert 3 * p == (p + p + p) + assert 0 * p == ECPoint.inf(mod=17, curve_a=0) def test_ec_point_addition(): diff --git a/qualtran/bloqs/mod_arithmetic/_shims.py b/qualtran/bloqs/mod_arithmetic/_shims.py index 65b61b4ca..7242f6afa 100644 --- a/qualtran/bloqs/mod_arithmetic/_shims.py +++ b/qualtran/bloqs/mod_arithmetic/_shims.py @@ -24,14 +24,17 @@ from functools import cached_property from typing import Dict, Optional, Tuple, TYPE_CHECKING +import attrs from attrs import frozen -from qualtran import Bloq, QUInt, Register, Signature -from qualtran.bloqs.arithmetic import Add, AddK, Negate, Subtract -from qualtran.bloqs.arithmetic._shims import CHalf, Lt, MultiCToffoli +from qualtran import Bloq, QUInt, Register, Side, Signature +from qualtran.bloqs.arithmetic import AddK, Negate +from qualtran.bloqs.arithmetic._shims import CHalf, CSub, Lt, MultiCToffoli +from qualtran.bloqs.arithmetic.controlled_addition import CAdd from qualtran.bloqs.basic_gates import CNOT, CSwap, Swap, Toffoli from qualtran.bloqs.mod_arithmetic.mod_multiplication import ModDbl from qualtran.drawing import Text, TextBox, WireSymbol +from qualtran.simulation.classical_sim import ClassicalValT if TYPE_CHECKING: from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator @@ -57,9 +60,9 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': (CNOT(), 2), (Lt(self.n), 1), (CSwap(self.n), 2), - (Subtract(QUInt(self.n)), 1), - (Add(QUInt(self.n)), 1), - (CNOT(), 1), + (CSub(self.n), 1), + (CAdd(QUInt(self.n)), 1), + (CNOT(), 2), (ModDbl(QUInt(self.n), self.mod), 1), (CHalf(self.n), 1), (CSwap(self.n), 2), @@ -88,10 +91,21 @@ def wire_symbol( class ModInv(Bloq): n: int mod: int + uncompute: bool = False @cached_property def signature(self) -> 'Signature': - return Signature([Register('x', QUInt(self.n)), Register('out', QUInt(self.n))]) + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('x', QUInt(self.n)), + Register('garbage1', QUInt(self.n), side=side), + Register('garbage2', QUInt(self.n), side=side), + ] + ) + + def adjoint(self) -> 'ModInv': + return attrs.evolve(self, uncompute=self.uncompute ^ True) def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': # Roetteler @@ -103,6 +117,29 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': Swap(self.n): 1, } + def on_classical_vals( + self, + x: 'ClassicalValT', + garbage1: Optional['ClassicalValT'] = None, + garbage2: Optional['ClassicalValT'] = None, + ) -> Dict[str, ClassicalValT]: + # TODO(https://github.com/quantumlib/Qualtran/issues/1443): Hacky classical simulation just + # to confirm correctness of ECAdd circuit. + if self.uncompute: + assert garbage1 is not None + assert garbage2 is not None + return {'x': garbage1} + assert garbage1 is None + assert garbage2 is None + + # Store the original x in the garbage registers for the uncompute simulation. + garbage1 = x + garbage2 = x + + x = pow(int(x), self.mod - 2, mod=self.mod) * pow(2, 2 * self.n, self.mod) % self.mod + + return {'x': x, 'garbage1': garbage1, 'garbage2': garbage2} + def wire_symbol( self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple() ) -> 'WireSymbol': @@ -112,4 +149,8 @@ def wire_symbol( return TextBox('x') elif reg.name == 'out': return TextBox('$x^{-1}$') + elif reg.name == 'garbage1': + return TextBox('garbage1') + elif reg.name == 'garbage2': + return TextBox('garbage2') raise ValueError(f'Unrecognized register name {reg.name}') diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index 95f74edc7..9f289add1 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -650,7 +650,8 @@ def on_classical_vals( raise ValueError(f'classical action is not supported for {self}') if self.uncompute: assert ( - target is not None and target == (x * y * pow(2, self.bitsize, self.mod)) % self.mod + target is not None + and target == (x * y * pow(2, self.bitsize * (self.mod - 2), self.mod)) % self.mod ) assert qrom_indices is not None assert reduced is not None diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index bc3be1a65..347c78b61 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -98,6 +98,7 @@ import qualtran.bloqs.data_loading.qrom import qualtran.bloqs.data_loading.select_swap_qrom import qualtran.bloqs.factoring._factoring_shims +import qualtran.bloqs.factoring.ecc.ec_add import qualtran.bloqs.factoring.rsa import qualtran.bloqs.for_testing.atom import qualtran.bloqs.for_testing.casting @@ -348,6 +349,13 @@ "qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul, "qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul, "qualtran.bloqs.factoring._factoring_shims.MeasureQFT": qualtran.bloqs.factoring._factoring_shims.MeasureQFT, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepThree": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepThree, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFour": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFour, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFive": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepFive, + "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepSix": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepSix, + "qualtran.bloqs.factoring.ecc.ec_add.ECAdd": qualtran.bloqs.factoring.ecc.ec_add.ECAdd, "qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate": qualtran.bloqs.factoring.rsa.rsa_phase_estimate.RSAPhaseEstimate, "qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp": qualtran.bloqs.factoring.rsa.rsa_mod_exp.ModExp, "qualtran.bloqs.for_testing.atom.TestAtom": qualtran.bloqs.for_testing.atom.TestAtom, From 7c6715b7e64de12257b3959faa3caff5c3cdfd8b Mon Sep 17 00:00:00 2001 From: Anurudh Peduri <7265746+anurudhp@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:20:03 -0700 Subject: [PATCH 6/7] Improve interface for bloqs with specialized single-qubit-controlled versions (#1451) * get_ctrl_system for bloqs with custom single-qubit-controlled implementations * replace uses of old interface * replace uses of old interface * specialize when `ctrl_spec.num_qubits == 1` * add test example of bloq with a separate controlled bloq * refactor: pass parameters instead of using protocol * `CtrlSpec.get_single_control_bit` to get the correct control bit in the single qubit case. * cleanup * use new single control bit method * test ctrl bit * `control` -> `ctrl` * don't pass `bloq` * use callable * update examples * mypy * add helper method which accepts bloqs instead of a callable * upgrade more usecases * typo * fix bug in adder * rename file to `specialized_ctrl` * rename function to `get_ctrl_system_1bit_cv` * mypy * cleanup design - use a helper bloq that accepts CU to build CCU - do not pass `bloq_without_ctrl` * add exposed helpers with clearer types --------- Co-authored-by: Matthew Harrigan --- qualtran/_infra/controlled.py | 16 ++ qualtran/_infra/controlled_test.py | 20 ++ .../qubitization/select_hubbard.py | 29 +- .../bloqs/chemistry/sparse/select_bloq.py | 31 ++- qualtran/bloqs/chemistry/thc/select_bloq.py | 15 +- .../for_testing/random_select_and_prepare.py | 19 +- qualtran/bloqs/mcmt/controlled_via_and.py | 3 +- qualtran/bloqs/mcmt/specialized_ctrl.py | 255 ++++++++++++++++++ qualtran/bloqs/mcmt/specialized_ctrl_test.py | 203 ++++++++++++++ qualtran/bloqs/multiplexers/apply_lth_bloq.py | 27 +- .../bloqs/multiplexers/select_pauli_lcu.py | 28 +- .../reflections/reflection_using_prepare.py | 27 +- 12 files changed, 650 insertions(+), 23 deletions(-) create mode 100644 qualtran/bloqs/mcmt/specialized_ctrl.py create mode 100644 qualtran/bloqs/mcmt/specialized_ctrl_test.py diff --git a/qualtran/_infra/controlled.py b/qualtran/_infra/controlled.py index e43a2d2b9..f2f946030 100644 --- a/qualtran/_infra/controlled.py +++ b/qualtran/_infra/controlled.py @@ -23,6 +23,7 @@ Sequence, Tuple, TYPE_CHECKING, + TypeAlias, Union, ) @@ -45,6 +46,9 @@ from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT +ControlBit: TypeAlias = int +"""A control bit, either 0 or 1.""" + def _cvs_convert( cvs: Union[ @@ -250,6 +254,18 @@ def from_cirq_cv( bloq_cvs.append(curr_cvs) return CtrlSpec(tuple(qdtypes), tuple(bloq_cvs)) + def get_single_ctrl_bit(self) -> ControlBit: + """If controlled by a single qubit, return the control bit, otherwise raise""" + if self.num_qubits != 1: + raise ValueError(f"expected a single qubit control, got {self.num_qubits}") + + (qdtype,) = self.qdtypes + (cv,) = self.cvs + (idx,) = Register('', qdtype, cv.shape).all_idxs() + (control_bit,) = qdtype.to_bits(cv[idx]) + + return int(control_bit) + class AddControlledT(Protocol): """The signature for the `add_controlled` callback part of `ctrl_system`. diff --git a/qualtran/_infra/controlled_test.py b/qualtran/_infra/controlled_test.py index 34b26def8..77d72432c 100644 --- a/qualtran/_infra/controlled_test.py +++ b/qualtran/_infra/controlled_test.py @@ -100,6 +100,26 @@ def test_ctrl_spec_to_cirq_cv_roundtrip(): assert CtrlSpec.from_cirq_cv(cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.shapes) +@pytest.mark.parametrize( + "ctrl_spec", [CtrlSpec(), CtrlSpec(cvs=[1]), CtrlSpec(cvs=np.atleast_2d([1]))] +) +def test_ctrl_spec_single_bit_one(ctrl_spec: CtrlSpec): + assert ctrl_spec.get_single_ctrl_bit() == 1 + + +@pytest.mark.parametrize( + "ctrl_spec", [CtrlSpec(cvs=0), CtrlSpec(cvs=[0]), CtrlSpec(cvs=np.atleast_2d([0]))] +) +def test_ctrl_spec_single_bit_zero(ctrl_spec: CtrlSpec): + assert ctrl_spec.get_single_ctrl_bit() == 0 + + +@pytest.mark.parametrize("ctrl_spec", [CtrlSpec(cvs=[1, 1]), CtrlSpec(qdtypes=QUInt(2), cvs=0)]) +def test_ctrl_spec_single_bit_raises(ctrl_spec: CtrlSpec): + with pytest.raises(ValueError): + ctrl_spec.get_single_ctrl_bit() + + def _test_cirq_equivalence(bloq: Bloq, gate: 'cirq.Gate'): import cirq diff --git a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py index aa889f91f..f99a56971 100644 --- a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py +++ b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py @@ -20,9 +20,19 @@ import numpy as np from numpy.typing import NDArray -from qualtran import bloq_example, BloqDocSpec, BQUInt, QAny, QBit, Register, Signature +from qualtran import ( + AddControlledT, + Bloq, + bloq_example, + BloqDocSpec, + BQUInt, + CtrlSpec, + QAny, + QBit, + Register, + Signature, +) from qualtran._infra.gate_with_registers import total_bits -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension from qualtran.bloqs.basic_gates import CSwap from qualtran.bloqs.multiplexers.apply_gate_to_lth_target import ApplyGateToLthQubit from qualtran.bloqs.multiplexers.select_base import SelectOracle @@ -30,7 +40,7 @@ @attrs.frozen -class SelectHubbard(SelectOracle, SpecializedSingleQubitControlledExtension): # type: ignore[misc] +class SelectHubbard(SelectOracle): r"""The SELECT operation optimized for the 2D Hubbard model. In contrast to SELECT for an arbitrary chemistry Hamiltonian, we: @@ -180,6 +190,19 @@ def __str__(self): return f'C{s}' return s + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + + return get_ctrl_system_1bit_cv( + self, + ctrl_spec=ctrl_spec, + current_ctrl_bit=self.control_val, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: ( + attrs.evolve(self, control_val=cv), + 'control', + ), + ) + @bloq_example def _sel_hubb() -> SelectHubbard: diff --git a/qualtran/bloqs/chemistry/sparse/select_bloq.py b/qualtran/bloqs/chemistry/sparse/select_bloq.py index 06d9c42c1..1c7fd5595 100644 --- a/qualtran/bloqs/chemistry/sparse/select_bloq.py +++ b/qualtran/bloqs/chemistry/sparse/select_bloq.py @@ -16,11 +16,23 @@ from functools import cached_property from typing import Dict, Optional, Tuple, TYPE_CHECKING +import attrs import cirq from attrs import frozen -from qualtran import bloq_example, BloqBuilder, BloqDocSpec, BQUInt, QAny, QBit, Register, SoquetT -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension +from qualtran import ( + AddControlledT, + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + BQUInt, + CtrlSpec, + QAny, + QBit, + Register, + SoquetT, +) from qualtran.bloqs.basic_gates import SGate from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion @@ -30,7 +42,7 @@ @frozen -class SelectSparse(SpecializedSingleQubitControlledExtension, SelectOracle): # type: ignore[misc] +class SelectSparse(SelectOracle): r"""SELECT oracle for the sparse Hamiltonian. Implements the two applications of Fig. 13. @@ -157,6 +169,19 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': c_maj_y = SelectedMajoranaFermion(sel_pa, target_gate=cirq.Y) return {SGate(): 1, maj_x: 1, c_maj_x: 1, maj_y: 1, c_maj_y: 1} + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + + return get_ctrl_system_1bit_cv( + self, + ctrl_spec=ctrl_spec, + current_ctrl_bit=self.control_val, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: ( + attrs.evolve(self, control_val=cv), + 'control', + ), + ) + @bloq_example def _sel_sparse() -> SelectSparse: diff --git a/qualtran/bloqs/chemistry/thc/select_bloq.py b/qualtran/bloqs/chemistry/thc/select_bloq.py index 5ed430599..a3b825950 100644 --- a/qualtran/bloqs/chemistry/thc/select_bloq.py +++ b/qualtran/bloqs/chemistry/thc/select_bloq.py @@ -20,18 +20,19 @@ from attrs import evolve, frozen from qualtran import ( + AddControlledT, Bloq, bloq_example, BloqBuilder, BloqDocSpec, BQUInt, + CtrlSpec, QAny, QBit, Register, Signature, SoquetT, ) -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension from qualtran.bloqs.basic_gates import CSwap, Toffoli, XGate from qualtran.bloqs.chemistry.black_boxes import ApplyControlledZs from qualtran.bloqs.multiplexers.select_base import SelectOracle @@ -120,7 +121,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': @frozen -class SelectTHC(SpecializedSingleQubitControlledExtension, SelectOracle): # type: ignore[misc] +class SelectTHC(SelectOracle): r"""SELECT for THC Hamiltonian. Args: @@ -313,6 +314,16 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str return out_soqs + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + + return get_ctrl_system_1bit_cv( + self, + ctrl_spec=ctrl_spec, + current_ctrl_bit=self.control_val, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: (evolve(self, control_val=cv), 'control'), + ) + @bloq_example def _thc_sel() -> SelectTHC: diff --git a/qualtran/bloqs/for_testing/random_select_and_prepare.py b/qualtran/bloqs/for_testing/random_select_and_prepare.py index c5eed3569..6e1a607f5 100644 --- a/qualtran/bloqs/for_testing/random_select_and_prepare.py +++ b/qualtran/bloqs/for_testing/random_select_and_prepare.py @@ -14,13 +14,13 @@ from functools import cached_property from typing import Iterator, Optional, Tuple +import attrs import cirq import numpy as np from attrs import frozen from numpy.typing import NDArray -from qualtran import BloqBuilder, BQUInt, QBit, Register, SoquetT -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension +from qualtran import AddControlledT, Bloq, BloqBuilder, BQUInt, CtrlSpec, QBit, Register, SoquetT from qualtran.bloqs.block_encoding.lcu_block_encoding import SelectBlockEncoding from qualtran.bloqs.for_testing.matrix_gate import MatrixGate from qualtran.bloqs.multiplexers.select_base import SelectOracle @@ -84,7 +84,7 @@ def alphas(self): @frozen -class TestPauliSelectOracle(SpecializedSingleQubitControlledExtension, SelectOracle): # type: ignore[misc] +class TestPauliSelectOracle(SelectOracle): # type: ignore[misc] r"""Paulis acting on $m$ qubits, controlled by an $n$-qubit register. Given $2^n$ multi-qubit-Paulis (acting on $m$ qubits) $U_j$, @@ -149,6 +149,19 @@ def decompose_from_registers( op = op.controlled_by(*quregs['control'], control_values=[self.control_val]) yield op + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + + return get_ctrl_system_1bit_cv( + self, + ctrl_spec=ctrl_spec, + current_ctrl_bit=self.control_val, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: ( + attrs.evolve(self, control_val=cv), + 'control', + ), + ) + def random_qubitization_walk_operator( select_bitsize: int, target_bitsize: int, *, random_state: np.random.RandomState diff --git a/qualtran/bloqs/mcmt/controlled_via_and.py b/qualtran/bloqs/mcmt/controlled_via_and.py index b9f7389a0..a0ae58dee 100644 --- a/qualtran/bloqs/mcmt/controlled_via_and.py +++ b/qualtran/bloqs/mcmt/controlled_via_and.py @@ -50,8 +50,7 @@ def _is_single_bit_control(self) -> bool: @cached_property def _single_control_value(self) -> int: - assert self._is_single_bit_control() - return self.ctrl_spec._cvs_tuple[0] + return self.ctrl_spec.get_single_ctrl_bit() def adjoint(self) -> 'ControlledViaAnd': return ControlledViaAnd(self.subbloq.adjoint(), self.ctrl_spec) diff --git a/qualtran/bloqs/mcmt/specialized_ctrl.py b/qualtran/bloqs/mcmt/specialized_ctrl.py new file mode 100644 index 000000000..8c70df3f6 --- /dev/null +++ b/qualtran/bloqs/mcmt/specialized_ctrl.py @@ -0,0 +1,255 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import cached_property +from typing import Callable, cast, Iterable, Optional, Sequence, TYPE_CHECKING + +import attrs +import numpy as np + +from qualtran import Bloq, QBit, Register, Signature + +if TYPE_CHECKING: + from qualtran import AddControlledT, BloqBuilder, CtrlSpec, SoquetT + from qualtran._infra.controlled import ControlBit + from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator + + +@attrs.frozen +class _MultiControlledFromSinglyControlled(Bloq): + """Helper bloq implementing a multi-controlled-U given access to controlled-U. + + This is for internal use only. For reducing multiple controls to a single control, + see :class:`qualtran.bloqs.mcmt.ControlledViaAnd` and + :meth:`qualtran.bloqs.mcmt.specialized_ctrl.get_ctrl_system_1bit_cv`. + + This bloq is used as an intermediate bloq by `get_ctrl_system_1bit_cv` in the + controlled-controlled-bloq case. To cleanly support further controlling this bloq, + the `cvs` attribute accepts a tuple (of at least two controls), and defers to + `ControlledViaAnd` whenever possible, and only extends the `cvs` in the edge cases. + """ + + cvs: tuple[int, ...] + ctrl_bloq: Bloq + ctrl_reg_name: str + + def __attrs_post_init__(self): + assert len(self.cvs) >= 2, f"{self} must have at least 2 controls, got {self.cvs=}" + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [Register(self.ctrl_reg_name, dtype=QBit(), shape=(len(self.cvs),))] + + [reg for reg in self.ctrl_bloq.signature if reg.name != self.ctrl_reg_name] + ) + + @cached_property + def _and_bloq(self) -> Bloq: + from qualtran.bloqs.mcmt import And, MultiAnd + + if len(self.cvs) == 2: + return And(*self.cvs) + else: + return MultiAnd(self.cvs) + + def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str, 'SoquetT']: + and_soqs = bb.add_d(self._and_bloq, ctrl=soqs.pop(self.ctrl_reg_name)) + + soqs |= {self.ctrl_reg_name: and_soqs.pop('target')} + soqs = bb.add_d(self.ctrl_bloq, **soqs) + and_soqs |= {'target': soqs.pop(self.ctrl_reg_name)} + + soqs |= {self.ctrl_reg_name: bb.add(self._and_bloq.adjoint(), **and_soqs)} + + return soqs + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return {self._and_bloq: 1, self.ctrl_bloq: 1, self._and_bloq.adjoint(): 1} + + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: + if ctrl_spec.num_qubits != 1: + return super().get_ctrl_system(ctrl_spec=ctrl_spec) + + ctrl_bloq = attrs.evolve(self, cvs=(ctrl_spec.get_single_ctrl_bit(),) + self.cvs) + + def _adder(bb, ctrl_soqs, in_soqs): + in_soqs[self.ctrl_reg_name] = np.concatenate(ctrl_soqs, in_soqs[self.ctrl_reg_name]) + ctrls, *out_soqs = bb.add_t(ctrl_bloq, **in_soqs) + return ctrls[:1], [*ctrls[1:], *out_soqs] + + return ctrl_bloq, _adder + + def __str__(self): + return f'C[{len(self.cvs)-1}][{self.ctrl_bloq}]' + + +def _get_ctrl_system_1bit_cv( + bloq: 'Bloq', + ctrl_spec: 'CtrlSpec', + *, + current_ctrl_bit: Optional['ControlBit'], + get_ctrl_bloq_and_ctrl_reg_name: Callable[['ControlBit'], Optional[tuple['Bloq', str]]], +) -> tuple['Bloq', 'AddControlledT']: + """Internal method to build the control system for a bloq using single-qubit controlled variants. + + Uses the provided specialized implementation when a singly-controlled variant of the bloq is + requested. When controlled by multiple qubits, the controls are reduced to a single qubit + and the singly-controlled bloq is used. + + The user can provide specializations for the bloq controlled by `1` and (optionally) by `0`. + The specialization for control bit `1` must be provided. + In case a specialization for a control bit `0` is not provided, the default fallback is used + instead, which wraps the bloq using the `Controlled` metabloq. + + Args: + bloq: The current bloq. + ctrl_spec: The control specification + current_ctrl_bit: The control bit of the current bloq, one of `0, 1, None`. + get_ctrl_bloq_and_ctrl_reg_name: A callable that accepts a control bit (`0` or `1`), + and returns the controlled variant of this bloq and the name of the control register. + If the callable returns `None`, then the default fallback is used. + """ + from qualtran import Soquet + from qualtran.bloqs.mcmt import ControlledViaAnd + + def _get_default_fallback(): + return ControlledViaAnd.make_ctrl_system(bloq=bloq, ctrl_spec=ctrl_spec) + + if ctrl_spec.num_qubits != 1: + return _get_default_fallback() + + ctrl_bit = ctrl_spec.get_single_ctrl_bit() + + if current_ctrl_bit is None: + # the easy case: use the controlled bloq + ctrl_bloq_and_ctrl_reg_name = get_ctrl_bloq_and_ctrl_reg_name(ctrl_bit) + if ctrl_bloq_and_ctrl_reg_name is None: + assert ctrl_bit != 1, "invalid usage: controlled-by-1 variant must be provided" + return _get_default_fallback() + + ctrl_bloq, ctrl_reg_name = ctrl_bloq_and_ctrl_reg_name + + def _adder( + bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT'] + ) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]: + (ctrl,) = ctrl_soqs + in_soqs |= {ctrl_reg_name: ctrl} + + out_soqs = bb.add_d(ctrl_bloq, **in_soqs) + + ctrl = out_soqs.pop(ctrl_reg_name) + return [ctrl], out_soqs.values() + + else: + # the difficult case: must combine the two controls into one + ctrl_1_bloq_and_reg_name = get_ctrl_bloq_and_ctrl_reg_name(1) + assert ( + ctrl_1_bloq_and_reg_name is not None + ), "invalid usage: controlled-by-1 variant must be provided" + ctrl_1_bloq, ctrl_reg_name = ctrl_1_bloq_and_reg_name + + ctrl_bloq = _MultiControlledFromSinglyControlled( + cvs=(ctrl_bit, current_ctrl_bit), ctrl_bloq=ctrl_1_bloq, ctrl_reg_name=ctrl_reg_name + ) + + def _adder( + bb: 'BloqBuilder', ctrl_soqs: Sequence['SoquetT'], in_soqs: dict[str, 'SoquetT'] + ) -> tuple[Iterable['SoquetT'], Iterable['SoquetT']]: + # extract the two control bits + (ctrl0,) = ctrl_soqs + ctrl1 = in_soqs.pop(ctrl_reg_name) + + ctrl0 = cast(Soquet, ctrl0) + ctrl1 = cast(Soquet, ctrl1) + + # add the singly controlled bloq + in_soqs |= {ctrl_reg_name: np.array([ctrl0, ctrl1])} + ctrls, *out_soqs = bb.add_t(ctrl_bloq, **in_soqs) + assert isinstance(ctrls, np.ndarray) + ctrl0, ctrl1 = ctrls + + return [ctrl0], [ctrl1, *out_soqs] + + return ctrl_bloq, _adder + + +def get_ctrl_system_1bit_cv( + bloq: 'Bloq', + ctrl_spec: 'CtrlSpec', + *, + current_ctrl_bit: Optional['ControlBit'], + get_ctrl_bloq_and_ctrl_reg_name: Callable[['ControlBit'], tuple['Bloq', str]], +) -> tuple['Bloq', 'AddControlledT']: + """Build the control system for a bloq with specialized single-qubit controlled variants. + + Uses the provided specialized implementation when a singly-controlled variant of the bloq is + requested. When controlled by multiple qubits, the controls are reduced to a single qubit + and the singly-controlled bloq is used. + + The user must provide two specializations for the bloq: controlled by `1` and by `0`. + + When only one specialization (controlled by `1`) is known, use + :meth:`get_ctrl_system_1bit_cv_from_bloqs` instead. + + Args: + bloq: The current bloq. + ctrl_spec: The control specification + current_ctrl_bit: The control bit of the current bloq, one of `0, 1, None`. + get_ctrl_bloq_and_ctrl_reg_name: A callable that accepts a control bit (`0` or `1`), + and returns the controlled variant of this bloq and the name of the control register. + """ + return _get_ctrl_system_1bit_cv( + bloq, + ctrl_spec, + current_ctrl_bit=current_ctrl_bit, + get_ctrl_bloq_and_ctrl_reg_name=get_ctrl_bloq_and_ctrl_reg_name, + ) + + +def get_ctrl_system_1bit_cv_from_bloqs( + bloq: 'Bloq', + ctrl_spec: 'CtrlSpec', + *, + current_ctrl_bit: Optional['ControlBit'], + bloq_with_ctrl: 'Bloq', + ctrl_reg_name: 'str', +) -> tuple['Bloq', 'AddControlledT']: + """Helper to construct the control system given a singly-controlled variant of a bloq. + + Uses the provided specialized implementation when a singly-controlled (by `1`) variant of + the bloq is requested. When controlled by multiple qubits, the controls are reduced to a + single qubit and the singly-controlled bloq is used. + + When specializations for both cases - controlled by `1` and by `0` - are known, use + :meth:`get_ctrl_system_1bit_cv` instead. + + Args: + bloq: The current bloq. + ctrl_spec: The control specification + current_ctrl_bit: The control bit of the current bloq, one of `0, 1, None`. + bloq_with_ctrl: The variant of this bloq controlled by a single qubit in the `1` basis state. + ctrl_reg_name: The name of the control register for the controlled bloq variant(s). + """ + + def get_ctrl_bloq_and_ctrl_reg_name(cv: 'ControlBit') -> Optional[tuple['Bloq', str]]: + if cv == 1: + return bloq_with_ctrl, ctrl_reg_name + else: + return None + + return _get_ctrl_system_1bit_cv( + bloq, + ctrl_spec, + current_ctrl_bit=current_ctrl_bit, + get_ctrl_bloq_and_ctrl_reg_name=get_ctrl_bloq_and_ctrl_reg_name, + ) diff --git a/qualtran/bloqs/mcmt/specialized_ctrl_test.py b/qualtran/bloqs/mcmt/specialized_ctrl_test.py new file mode 100644 index 000000000..01282c1f2 --- /dev/null +++ b/qualtran/bloqs/mcmt/specialized_ctrl_test.py @@ -0,0 +1,203 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from typing import Optional, Sequence, Tuple +from unittest.mock import ANY + +import attrs +import pytest + +from qualtran import ( + AddControlledT, + Bloq, + BloqBuilder, + CtrlSpec, + QAny, + QBit, + Register, + Signature, + SoquetT, +) +from qualtran.bloqs.mcmt import And +from qualtran.bloqs.mcmt.specialized_ctrl import ( + get_ctrl_system_1bit_cv, + get_ctrl_system_1bit_cv_from_bloqs, +) +from qualtran.resource_counting import CostKey, GateCounts, get_cost_value, QECGatesCost + + +@attrs.frozen +class AtomWithSpecializedControl(Bloq): + cv: Optional[int] = None + ctrl_reg_name: str = 'ctrl' + target_reg_name: str = 'q' + + @property + def signature(self) -> 'Signature': + n_ctrl = 1 if self.cv is not None else 0 + reg_name_map = {self.ctrl_reg_name: n_ctrl, self.target_reg_name: 2} + return Signature.build(**reg_name_map) + + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + return get_ctrl_system_1bit_cv( + self, + ctrl_spec, + current_ctrl_bit=self.cv, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: ( + attrs.evolve(self, cv=cv), + self.ctrl_reg_name, + ), + ) + + @staticmethod + def cost_expr_for_cv(cv: Optional[int]): + import sympy + + c_unctrl = sympy.Symbol("_c_target_") + c_ctrl = sympy.Symbol("_c_ctrl_") + + if cv is None: + return c_unctrl + return c_unctrl + c_ctrl + + def my_static_costs(self, cost_key: 'CostKey'): + if cost_key == QECGatesCost(): + r = self.cost_expr_for_cv(self.cv) + return GateCounts(rotation=r) + + return NotImplemented + + +def ON(n: int = 1) -> CtrlSpec: + return CtrlSpec(cvs=[1] * n) + + +def OFF(n: int = 1) -> CtrlSpec: + return CtrlSpec(cvs=[0] * n) + + +@pytest.mark.parametrize( + 'ctrl_specs', + [ + [ON()], + [OFF()], + [OFF(), OFF()], + [OFF(4)], + [OFF(2), OFF(2)], + [ON(), OFF(5)], + [ON(), ON(), ON()], + [OFF(4), ON(3), OFF(5)], + ], +) +@pytest.mark.parametrize('ctrl_reg_name', ['ctrl', 'control']) +def test_custom_controlled(ctrl_specs: Sequence[CtrlSpec], ctrl_reg_name: str): + bloq: Bloq = AtomWithSpecializedControl(ctrl_reg_name=ctrl_reg_name) + for ctrl_spec in ctrl_specs: + bloq = bloq.controlled(ctrl_spec) + n_ctrls = sum(ctrl_spec.num_qubits for ctrl_spec in ctrl_specs) + + gc = get_cost_value(bloq, QECGatesCost()) + assert gc == GateCounts( + and_bloq=n_ctrls - 1, + rotation=AtomWithSpecializedControl.cost_expr_for_cv(1), + clifford=ANY, + measurement=ANY, + ) + + +@attrs.frozen +class TestAtom(Bloq): + tag: str + + @property + def signature(self) -> 'Signature': + return Signature.build(q=2) + + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + return get_ctrl_system_1bit_cv_from_bloqs( + self, + ctrl_spec, + current_ctrl_bit=None, + bloq_with_ctrl=CTestAtom(self.tag), + ctrl_reg_name='ctrl', + ) + + +@attrs.frozen +class CTestAtom(Bloq): + tag: str + + @property + def signature(self) -> 'Signature': + return Signature.build(ctrl=1, q=2) + + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + return get_ctrl_system_1bit_cv_from_bloqs( + self, ctrl_spec, current_ctrl_bit=1, bloq_with_ctrl=self, ctrl_reg_name='ctrl' + ) + + +def test_bloq_with_controlled_bloq(): + assert TestAtom('g').controlled() == CTestAtom('g') + + def _keep_and(b): + # TODO remove this after https://github.com/quantumlib/Qualtran/issues/1346 is resolved. + return isinstance(b, And) + + ctrl_bloq = CTestAtom('g').controlled() + _, sigma = ctrl_bloq.call_graph(keep=_keep_and) + assert sigma == {And(): 1, CTestAtom('g'): 1, And().adjoint(): 1} + + ctrl_bloq = CTestAtom('n').controlled(CtrlSpec(cvs=0)) + _, sigma = ctrl_bloq.call_graph(keep=_keep_and) + assert sigma == {And(0, 1): 1, CTestAtom('n'): 1, And(0, 1).adjoint(): 1} + + ctrl_bloq = TestAtom('nn').controlled(CtrlSpec(cvs=[0, 0])) + _, sigma = ctrl_bloq.call_graph(keep=_keep_and) + assert sigma == {And(0, 0): 1, CTestAtom('nn'): 1, And(0, 0).adjoint(): 1} + + +@attrs.frozen +class TestBloqWithDecompose(Bloq): + ctrl_reg_name: str + target_reg_name: str + + @property + def signature(self) -> 'Signature': + return Signature( + [Register(self.ctrl_reg_name, QBit()), Register(self.target_reg_name, QAny(2))] + ) + + def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str, 'SoquetT']: + for _ in range(2): + soqs = bb.add_d( + AtomWithSpecializedControl( + cv=1, ctrl_reg_name=self.ctrl_reg_name, target_reg_name=self.target_reg_name + ), + **soqs, + ) + return soqs + + +@pytest.mark.parametrize( + ('ctrl_reg_name', 'target_reg_name'), + [ + (ctrl, targ) + for (ctrl, targ) in itertools.product(['ctrl', 'control', 'a', 'b'], repeat=2) + if ctrl != targ + ], +) +def test_get_ctrl_system(ctrl_reg_name: str, target_reg_name: str): + bloq = TestBloqWithDecompose(ctrl_reg_name, target_reg_name).controlled() + _ = bloq.decompose_bloq().flatten() diff --git a/qualtran/bloqs/multiplexers/apply_lth_bloq.py b/qualtran/bloqs/multiplexers/apply_lth_bloq.py index a0de211a6..5ff1e6c74 100644 --- a/qualtran/bloqs/multiplexers/apply_lth_bloq.py +++ b/qualtran/bloqs/multiplexers/apply_lth_bloq.py @@ -17,12 +17,21 @@ import cirq import numpy as np -from attrs import field, frozen +from attrs import evolve, field, frozen from numpy.typing import NDArray -from qualtran import Bloq, bloq_example, BloqDocSpec, BQUInt, QBit, Register, Side +from qualtran import ( + AddControlledT, + Bloq, + bloq_example, + BloqDocSpec, + BQUInt, + CtrlSpec, + QBit, + Register, + Side, +) from qualtran._infra.gate_with_registers import merge_qubits -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate from qualtran.resource_counting import BloqCountT @@ -30,7 +39,7 @@ @frozen -class ApplyLthBloq(UnaryIterationGate, SpecializedSingleQubitControlledExtension, SelectOracle): # type: ignore[misc] +class ApplyLthBloq(UnaryIterationGate, SelectOracle): # type: ignore[misc] r"""A SELECT operation that executes one of a list of bloqs $U_l$ based on a quantum index: $$ @@ -108,6 +117,16 @@ def nth_operation( target_qubits = merge_qubits(bloq.signature, **targets) return bloq.controlled().on(control, *target_qubits) + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + + return get_ctrl_system_1bit_cv( + self, + ctrl_spec=ctrl_spec, + current_ctrl_bit=self.control_val, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: (evolve(self, control_val=cv), 'control'), + ) + @bloq_example def _apply_lth_bloq() -> ApplyLthBloq: diff --git a/qualtran/bloqs/multiplexers/select_pauli_lcu.py b/qualtran/bloqs/multiplexers/select_pauli_lcu.py index 840512a0b..f94b890ac 100644 --- a/qualtran/bloqs/multiplexers/select_pauli_lcu.py +++ b/qualtran/bloqs/multiplexers/select_pauli_lcu.py @@ -22,8 +22,17 @@ import numpy as np from numpy.typing import NDArray -from qualtran import bloq_example, BloqDocSpec, BQUInt, QAny, QBit, Register -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension +from qualtran import ( + AddControlledT, + Bloq, + bloq_example, + BloqDocSpec, + BQUInt, + CtrlSpec, + QAny, + QBit, + Register, +) from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate from qualtran.resource_counting.generalizers import ( @@ -39,7 +48,7 @@ def _to_tuple(x: Iterable[cirq.DensePauliString]) -> Sequence[cirq.DensePauliStr @attrs.frozen -class SelectPauliLCU(SelectOracle, UnaryIterationGate, SpecializedSingleQubitControlledExtension): # type: ignore[misc] +class SelectPauliLCU(SelectOracle, UnaryIterationGate): # type: ignore[misc] r"""A SELECT bloq for selecting and applying operators from an array of `PauliString`s. $$ @@ -117,6 +126,19 @@ def nth_operation( # type: ignore[override] ps = self.select_unitaries[selection].on(*target) return ps.with_coefficient(np.sign(complex(ps.coefficient).real)).controlled_by(control) + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + + return get_ctrl_system_1bit_cv( + self, + ctrl_spec=ctrl_spec, + current_ctrl_bit=self.control_val, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: ( + attrs.evolve(self, control_val=cv), + 'control', + ), + ) + @bloq_example(generalizer=[cirq_to_bloqs, ignore_split_join, ignore_cliffords]) def _select_pauli_lcu() -> SelectPauliLCU: diff --git a/qualtran/bloqs/reflections/reflection_using_prepare.py b/qualtran/bloqs/reflections/reflection_using_prepare.py index 885f9317f..a3cd4c2b3 100644 --- a/qualtran/bloqs/reflections/reflection_using_prepare.py +++ b/qualtran/bloqs/reflections/reflection_using_prepare.py @@ -20,9 +20,17 @@ import numpy as np from numpy.typing import NDArray -from qualtran import Bloq, bloq_example, BloqDocSpec, CtrlSpec, QBit, Register, Signature +from qualtran import ( + AddControlledT, + Bloq, + bloq_example, + BloqDocSpec, + CtrlSpec, + QBit, + Register, + Signature, +) from qualtran._infra.gate_with_registers import GateWithRegisters, merge_qubits, total_bits -from qualtran._infra.single_qubit_controlled import SpecializedSingleQubitControlledExtension from qualtran.bloqs.basic_gates.global_phase import GlobalPhase from qualtran.bloqs.basic_gates.x_basis import XGate from qualtran.bloqs.mcmt import MultiControlZ @@ -41,7 +49,7 @@ @attrs.frozen(cache_hash=True) -class ReflectionUsingPrepare(GateWithRegisters, SpecializedSingleQubitControlledExtension): # type: ignore[misc] +class ReflectionUsingPrepare(GateWithRegisters): r"""Applies reflection around a state prepared by `prepare_gate` Applies $R_{s, g=1} = g (I - 2|s\rangle\langle s|)$ using $R_{s} = @@ -186,6 +194,19 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': def adjoint(self) -> 'ReflectionUsingPrepare': return self + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv + + return get_ctrl_system_1bit_cv( + self, + ctrl_spec=ctrl_spec, + current_ctrl_bit=self.control_val, + get_ctrl_bloq_and_ctrl_reg_name=lambda cv: ( + attrs.evolve(self, control_val=cv), + 'control', + ), + ) + @bloq_example(generalizer=ignore_split_join) def _refl_using_prep() -> ReflectionUsingPrepare: From b3e0748152b8dbd16e0f165e3e47dfb6555c3283 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Wed, 30 Oct 2024 11:41:52 -0700 Subject: [PATCH 7/7] Create KaliskiModInverse (#1464) * Create KaliskiModInverse * format * change signature * free * use half comp * Add documentation * nit * address comments * type * update classical action * nit --------- Co-authored-by: Matthew Harrigan --- .../qualtran_dev_tools/notebook_specs.py | 5 + docs/bloqs/index.rst | 1 + qualtran/bloqs/mod_arithmetic/__init__.py | 1 + .../bloqs/mod_arithmetic/mod_division.ipynb | 170 +++++ qualtran/bloqs/mod_arithmetic/mod_division.py | 666 ++++++++++++++++++ .../bloqs/mod_arithmetic/mod_division_test.py | 87 +++ .../mod_arithmetic/mod_multiplication.py | 2 +- qualtran/serialization/resolver_dict.py | 3 + 8 files changed, 934 insertions(+), 1 deletion(-) create mode 100644 qualtran/bloqs/mod_arithmetic/mod_division.ipynb create mode 100644 qualtran/bloqs/mod_arithmetic/mod_division.py create mode 100644 qualtran/bloqs/mod_arithmetic/mod_division_test.py diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index 6dc00babe..00bde2cb8 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -520,6 +520,11 @@ qualtran.bloqs.mod_arithmetic.mod_multiplication._DIRTY_OUT_OF_PLACE_MONTGOMERY_MOD_MUL_DOC, ], ), + NotebookSpecV2( + title='Modular Divison', + module=qualtran.bloqs.mod_arithmetic.mod_division, + bloq_specs=[qualtran.bloqs.mod_arithmetic.mod_division._KALISKI_MOD_INVERSE_DOC], + ), NotebookSpecV2( title='Factoring RSA', module=qualtran.bloqs.factoring.rsa, diff --git a/docs/bloqs/index.rst b/docs/bloqs/index.rst index 16c591baa..7d7f1a27f 100644 --- a/docs/bloqs/index.rst +++ b/docs/bloqs/index.rst @@ -83,6 +83,7 @@ Bloqs Library mod_arithmetic/mod_addition.ipynb mod_arithmetic/mod_subtraction.ipynb mod_arithmetic/mod_multiplication.ipynb + mod_arithmetic/mod_division.ipynb factoring/rsa/rsa.ipynb factoring/ecc/ec_add.ipynb factoring/ecc/ecc.ipynb diff --git a/qualtran/bloqs/mod_arithmetic/__init__.py b/qualtran/bloqs/mod_arithmetic/__init__.py index ff0d3a3da..0ddb3fc2b 100644 --- a/qualtran/bloqs/mod_arithmetic/__init__.py +++ b/qualtran/bloqs/mod_arithmetic/__init__.py @@ -14,5 +14,6 @@ from ._shims import ModInv from .mod_addition import CModAdd, CModAddK, CtrlScaleModAdd, ModAdd, ModAddK +from .mod_division import KaliskiModInverse from .mod_multiplication import CModMulK, DirtyOutOfPlaceMontgomeryModMul, ModDbl from .mod_subtraction import CModNeg, CModSub, ModNeg, ModSub diff --git a/qualtran/bloqs/mod_arithmetic/mod_division.ipynb b/qualtran/bloqs/mod_arithmetic/mod_division.ipynb new file mode 100644 index 000000000..fd34e136d --- /dev/null +++ b/qualtran/bloqs/mod_arithmetic/mod_division.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1c5f2b28", + "metadata": { + "cq.autogen": "title_cell" + }, + "source": [ + "# Modular Divison" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8751aa36", + "metadata": { + "cq.autogen": "top_imports" + }, + "outputs": [], + "source": [ + "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", + "from qualtran import QBit, QInt, QUInt, QAny\n", + "from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n", + "from typing import *\n", + "import numpy as np\n", + "import sympy\n", + "import cirq" + ] + }, + { + "cell_type": "markdown", + "id": "d680443c", + "metadata": { + "cq.autogen": "KaliskiModInverse.bloq_doc.md" + }, + "source": [ + "## `KaliskiModInverse`\n", + "Compute modular multiplicative inverse -inplace- of numbers in montgomery form.\n", + "\n", + "Applies the transformation\n", + "$$\n", + " \\ket{x} \\ket{0} \\rightarrow \\ket{x^{-1} 2^{2n} \\mod p} \\ket{\\mathrm{garbage}}\n", + "$$\n", + "\n", + "#### Parameters\n", + " - `bitsize`: size of the number.\n", + " - `mod`: The integer modulus.\n", + " - `uncompute`: whether to compute or uncompute. \n", + "\n", + "#### Registers\n", + " - `x`: The register for which we compute the multiplicative inverse.\n", + " - `m`: A 2*bitsize register of intermediate values needed for uncomputation. \n", + "\n", + "#### References\n", + " - [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126 133 Cat Qubits](https://arxiv.org/abs/2302.06639). Appendix C5.\n", + " - [Improved quantum circuits for elliptic curve discrete logarithms](https://arxiv.org/abs/2001.09580). Fig 7(b)\n", + " - [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585). page 8.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5917d72", + "metadata": { + "cq.autogen": "KaliskiModInverse.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.mod_arithmetic import KaliskiModInverse" + ] + }, + { + "cell_type": "markdown", + "id": "d44329eb", + "metadata": { + "cq.autogen": "KaliskiModInverse.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31a37cf6", + "metadata": { + "cq.autogen": "KaliskiModInverse.kaliskimodinverse_example" + }, + "outputs": [], + "source": [ + "kaliskimodinverse_example = KaliskiModInverse(32, 10**9 + 7)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58c697e6", + "metadata": { + "cq.autogen": "KaliskiModInverse.kaliskimodinverse_symbolic" + }, + "outputs": [], + "source": [ + "n, p = sympy.symbols('n p')\n", + "kaliskimodinverse_symbolic = KaliskiModInverse(n, p)" + ] + }, + { + "cell_type": "markdown", + "id": "9bf1e17c", + "metadata": { + "cq.autogen": "KaliskiModInverse.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eca3a706", + "metadata": { + "cq.autogen": "KaliskiModInverse.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([kaliskimodinverse_example, kaliskimodinverse_symbolic],\n", + " ['`kaliskimodinverse_example`', '`kaliskimodinverse_symbolic`'])" + ] + }, + { + "cell_type": "markdown", + "id": "69fd8906", + "metadata": { + "cq.autogen": "KaliskiModInverse.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15c6fabe", + "metadata": { + "cq.autogen": "KaliskiModInverse.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "kaliskimodinverse_example_g, kaliskimodinverse_example_sigma = kaliskimodinverse_example.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(kaliskimodinverse_example_g)\n", + "show_counts_sigma(kaliskimodinverse_example_sigma)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qualtran/bloqs/mod_arithmetic/mod_division.py b/qualtran/bloqs/mod_arithmetic/mod_division.py new file mode 100644 index 000000000..d5bad9fc1 --- /dev/null +++ b/qualtran/bloqs/mod_arithmetic/mod_division.py @@ -0,0 +1,666 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union + +import numpy as np +import sympy +from attrs import frozen + +from qualtran import ( + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + DecomposeTypeError, + QAny, + QBit, + QMontgomeryUInt, + Register, + Side, + Signature, + Soquet, + SoquetT, +) +from qualtran.bloqs.arithmetic.addition import AddK +from qualtran.bloqs.arithmetic.bitwise import BitwiseNot, XorK +from qualtran.bloqs.arithmetic.comparison import LinearDepthHalfGreaterThan +from qualtran.bloqs.arithmetic.controlled_addition import CAdd +from qualtran.bloqs.basic_gates import CNOT, TwoBitCSwap, XGate +from qualtran.bloqs.mcmt import And, MultiAnd +from qualtran.bloqs.mod_arithmetic.mod_multiplication import ModDbl +from qualtran.bloqs.swap_network import CSwapApprox +from qualtran.resource_counting import BloqCountDictT +from qualtran.resource_counting._call_graph import SympySymbolAllocator +from qualtran.symbolics import HasLength, is_symbolic + +if TYPE_CHECKING: + from qualtran.resource_counting import BloqCountDictT + from qualtran.simulation.classical_sim import ClassicalValT + from qualtran.symbolics import SymbolicInt + + +@frozen +class _KaliskiIterationStep1(Bloq): + """The first layer of operations in figure 15 of https://arxiv.org/pdf/2302.06639.""" + + bitsize: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('v', QMontgomeryUInt(self.bitsize)), + Register('m', QBit()), + Register('f', QBit()), + ] + ) + + def on_classical_vals(self, v: int, m: int, f: int) -> Dict[str, 'ClassicalValT']: + m ^= f & (v == 0) + f ^= m + return {'v': v, 'm': m, 'f': f} + + def build_composite_bloq( + self, bb: 'BloqBuilder', v: Soquet, m: Soquet, f: Soquet + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}') + v_arr = bb.split(v) + ctrls = np.concatenate([v_arr, [f]]) + ctrls, junk, target = bb.add(MultiAnd(cvs=[0] * self.bitsize + [1]), ctrl=ctrls) + target, m = bb.add(CNOT(), ctrl=target, target=m) + ctrls = bb.add( + MultiAnd(cvs=[0] * self.bitsize + [1]).adjoint(), ctrl=ctrls, junk=junk, target=target + ) + v_arr = ctrls[:-1] + f = ctrls[-1] + v = bb.join(v_arr) + m, f = bb.add(CNOT(), ctrl=m, target=f) + return {'v': v, 'm': m, 'f': f} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + if is_symbolic(self.bitsize): + cvs: Union[HasLength, List[int]] = HasLength(self.bitsize) + else: + cvs = [0] * int(self.bitsize) + return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 2} + + +@frozen +class _KaliskiIterationStep2(Bloq): + """The second layer of operations in figure 15 of https://arxiv.org/pdf/2302.06639.""" + + bitsize: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('u', QMontgomeryUInt(self.bitsize)), + Register('v', QMontgomeryUInt(self.bitsize)), + Register('b', QBit()), + Register('a', QBit()), + Register('m', QBit()), + Register('f', QBit()), + ] + ) + + def on_classical_vals( + self, u: int, v: int, b: int, a: int, m: int, f: int + ) -> Dict[str, 'ClassicalValT']: + a ^= ((u & 1) == 0) & f + m ^= ((v & 1) == 0) & (a == 0) & f + b ^= a + b ^= m + return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f} + + def build_composite_bloq( + self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet + ) -> Dict[str, 'SoquetT']: + u_arr = bb.split(u) + v_arr = bb.split(v) + + (f, u_arr[-1]), c = bb.add(And(1, 0), ctrl=(f, u_arr[-1])) + c, a = bb.add(CNOT(), ctrl=c, target=a) + f, u_arr[-1] = bb.add(And(1, 0).adjoint(), ctrl=(f, u_arr[-1]), target=c) + + (f, v_arr[-1], a), junk, c = bb.add(MultiAnd(cvs=(1, 0, 0)), ctrl=(f, v_arr[-1], a)) + c, m = bb.add(CNOT(), ctrl=c, target=m) + f, v_arr[-1], a = bb.add( + MultiAnd(cvs=(1, 0, 0)).adjoint(), ctrl=(f, v_arr[-1], a), junk=junk, target=c + ) + + a, b = bb.add(CNOT(), ctrl=a, target=b) + m, b = bb.add(CNOT(), ctrl=m, target=b) + u = bb.join(u_arr) + v = bb.join(v_arr) + return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return { + And(1, 0): 1, + And(1, 0).adjoint(): 1, + CNOT(): 4, + MultiAnd((1, 0, 0)): 1, + MultiAnd((1, 0, 0)).adjoint(): 1, + } + + +@frozen +class _KaliskiIterationStep3(Bloq): + """The third layer of operations in figure 15 of https://arxiv.org/pdf/2302.06639.""" + + bitsize: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('u', QMontgomeryUInt(self.bitsize)), + Register('v', QMontgomeryUInt(self.bitsize)), + Register('b', QBit()), + Register('a', QBit()), + Register('m', QBit()), + Register('f', QBit()), + ] + ) + + def on_classical_vals( + self, u: int, v: int, b: int, a: int, m: int, f: int + ) -> Dict[str, 'ClassicalValT']: + c = (u > v) & (b == 0) & f + a ^= c + m ^= c + return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f} + + def build_composite_bloq( + self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet + ) -> Dict[str, 'SoquetT']: + u, v, junk, greater_than = bb.add( + LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)), a=u, b=v + ) + + (greater_than, f, b), junk, ctrl = bb.add( + MultiAnd(cvs=(1, 1, 0)), ctrl=(greater_than, f, b) + ) + + ctrl, a = bb.add(CNOT(), ctrl=ctrl, target=a) + ctrl, m = bb.add(CNOT(), ctrl=ctrl, target=m) + + greater_than, f, b = bb.add( + MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk, target=ctrl + ) + u, v = bb.add( + LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)).adjoint(), + a=u, + b=v, + c=junk, + target=greater_than, + ) + return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return { + LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)): 1, + LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)).adjoint(): 1, + MultiAnd((1, 1, 0)): 1, + MultiAnd((1, 1, 0)).adjoint(): 1, + CNOT(): 2, + } + + +@frozen +class _KaliskiIterationStep4(Bloq): + """The fourth layer of operations in figure 15 of https://arxiv.org/pdf/2302.06639.""" + + bitsize: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('u', QMontgomeryUInt(self.bitsize)), + Register('v', QMontgomeryUInt(self.bitsize)), + Register('r', QMontgomeryUInt(self.bitsize)), + Register('s', QMontgomeryUInt(self.bitsize)), + Register('a', QBit()), + ] + ) + + def on_classical_vals( + self, u: int, v: int, r: int, s: int, a: int + ) -> Dict[str, 'ClassicalValT']: + if a: + u, v = v, u + r, s = s, r + return {'u': u, 'v': v, 'r': r, 's': s, 'a': a} + + def build_composite_bloq( + self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, a: Soquet + ) -> Dict[str, 'SoquetT']: + # CSwapApprox is a CSWAP with a phase flip. + # Since we are doing two SWAPs the overal phase is correct. + a, u, v = bb.add(CSwapApprox(self.bitsize), ctrl=a, x=u, y=v) + a, r, s = bb.add(CSwapApprox(self.bitsize), ctrl=a, x=r, y=s) + return {'u': u, 'v': v, 'r': r, 's': s, 'a': a} + + def build_call_graph(self, ssa: SympySymbolAllocator) -> 'BloqCountDictT': + return {CSwapApprox(self.bitsize): 2} + + +@frozen +class _KaliskiIterationStep5(Bloq): + """The fifth layer of operations in figure 15 of https://arxiv.org/pdf/2302.06639.""" + + bitsize: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('u', QMontgomeryUInt(self.bitsize)), + Register('v', QMontgomeryUInt(self.bitsize)), + Register('r', QMontgomeryUInt(self.bitsize)), + Register('s', QMontgomeryUInt(self.bitsize)), + Register('b', QBit()), + Register('f', QBit()), + ] + ) + + def on_classical_vals( + self, u: int, v: int, r: int, s: int, b: int, f: int + ) -> Dict[str, 'ClassicalValT']: + if f and b == 0: + v -= u + s += r + return {'u': u, 'v': v, 'r': r, 's': s, 'b': b, 'f': f} + + def build_composite_bloq( + self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, b: Soquet, f: Soquet + ) -> Dict[str, 'SoquetT']: + (f, b), c = bb.add(And(1, 0), ctrl=(f, b)) + v = bb.add(BitwiseNot(QMontgomeryUInt(self.bitsize)), x=v) + c, u, v = bb.add(CAdd(QMontgomeryUInt(self.bitsize)), ctrl=c, a=u, b=v) + v = bb.add(BitwiseNot(QMontgomeryUInt(self.bitsize)), x=v) + c, r, s = bb.add(CAdd(QMontgomeryUInt(self.bitsize)), ctrl=c, a=r, b=s) + f, b = bb.add(And(1, 0).adjoint(), ctrl=(f, b), target=c) + return {'u': u, 'v': v, 'r': r, 's': s, 'b': b, 'f': f} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return { + And(1, 0): 1, + And(1, 0).adjoint(): 1, + BitwiseNot(QMontgomeryUInt(self.bitsize)): 2, + CAdd(QMontgomeryUInt(self.bitsize)): 2, + } + + +@frozen +class _KaliskiIterationStep6(Bloq): + """The sixth layer of operations in figure 15 of https://arxiv.org/pdf/2302.06639.""" + + bitsize: 'SymbolicInt' + mod: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('u', QMontgomeryUInt(self.bitsize)), + Register('v', QMontgomeryUInt(self.bitsize)), + Register('r', QMontgomeryUInt(self.bitsize)), + Register('s', QMontgomeryUInt(self.bitsize)), + Register('b', QBit()), + Register('a', QBit()), + Register('m', QBit()), + Register('f', QBit()), + ] + ) + + def on_classical_vals( + self, u: int, v: int, r: int, s: int, b: int, a: int, m: int, f: int + ) -> Dict[str, 'ClassicalValT']: + b ^= m + b ^= a + if f: + v >>= 1 + r = (2 * r) % self.mod + if a: + r, s = s, r + u, v = v, u + if s % 2 == 0: + a ^= 1 + return {'u': u, 'v': v, 'r': r, 's': s, 'b': b, 'a': a, 'm': m, 'f': f} + + def build_composite_bloq( + self, + bb: 'BloqBuilder', + u: Soquet, + v: Soquet, + r: Soquet, + s: Soquet, + b: Soquet, + a: Soquet, + m: Soquet, + f: Soquet, + ) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize, self.mod): + raise DecomposeTypeError(f'symbolic decomposition is not supported for {self}') + m, b = bb.add(CNOT(), ctrl=m, target=b) + a, b = bb.add(CNOT(), ctrl=a, target=b) + + # Controlled Divison by 2. The control bit is set only iff the number is even so the divison becomes equivalent to a cyclic right shift. + v_arr = bb.split(v) + for i in reversed(range(self.bitsize - 1)): + f, v_arr[i], v_arr[i + 1] = bb.add(TwoBitCSwap(), ctrl=f, x=v_arr[i], y=v_arr[i + 1]) + v = bb.join(v_arr) + + r = bb.add(ModDbl(QMontgomeryUInt(self.bitsize), self.mod), x=r) + + a, u, v = bb.add(CSwapApprox(self.bitsize), ctrl=a, x=u, y=v) + a, r, s = bb.add(CSwapApprox(self.bitsize), ctrl=a, x=r, y=s) + + s_arr = bb.split(s) + s_arr[-1] = bb.add(XGate(), q=s_arr[-1]) + s_arr[-1], a = bb.add(CNOT(), ctrl=s_arr[-1], target=a) + s_arr[-1] = bb.add(XGate(), q=s_arr[-1]) + s = bb.join(s_arr) + + return {'u': u, 'v': v, 'r': r, 's': s, 'b': b, 'a': a, 'm': m, 'f': f} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return { + CNOT(): 4, + XGate(): 2, + ModDbl(QMontgomeryUInt(self.bitsize), self.mod): 1, + CSwapApprox(self.bitsize): 2, + TwoBitCSwap(): self.bitsize - 1, + } + + +@frozen +class _KaliskiIteration(Bloq): + """The single full iteration of Kaliski. see figure 15 of https://arxiv.org/pdf/2302.06639.""" + + bitsize: 'SymbolicInt' + mod: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('u', QMontgomeryUInt(self.bitsize)), + Register('v', QMontgomeryUInt(self.bitsize)), + Register('r', QMontgomeryUInt(self.bitsize)), + Register('s', QMontgomeryUInt(self.bitsize)), + Register('m', QBit()), + Register('f', QBit()), + ] + ) + + def build_composite_bloq( + self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet + ) -> Dict[str, 'SoquetT']: + a = bb.allocate(1) + b = bb.allocate(1) + + v, m, f = bb.add(_KaliskiIterationStep1(self.bitsize), v=v, m=m, f=f) + u, v, b, a, m, f = bb.add( + _KaliskiIterationStep2(self.bitsize), u=u, v=v, b=b, a=a, m=m, f=f + ) + u, v, b, a, m, f = bb.add( + _KaliskiIterationStep3(self.bitsize), u=u, v=v, b=b, a=a, m=m, f=f + ) + u, v, r, s, a = bb.add(_KaliskiIterationStep4(self.bitsize), u=u, v=v, r=r, s=s, a=a) + u, v, r, s, b, f = bb.add( + _KaliskiIterationStep5(self.bitsize), u=u, v=v, r=r, s=s, b=b, f=f + ) + u, v, r, s, b, a, m, f = bb.add( + _KaliskiIterationStep6(self.bitsize, self.mod), u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f + ) + + bb.free(a) + bb.free(b) + return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return { + _KaliskiIterationStep1(self.bitsize): 1, + _KaliskiIterationStep2(self.bitsize): 1, + _KaliskiIterationStep3(self.bitsize): 1, + _KaliskiIterationStep4(self.bitsize): 1, + _KaliskiIterationStep5(self.bitsize): 1, + _KaliskiIterationStep6(self.bitsize, self.mod): 1, + } + + def on_classical_vals( + self, u: int, v: int, r: int, s: int, m: int, f: int + ) -> Dict[str, 'ClassicalValT']: + """This is the Kaliski algorithm as described in Fig7 of https://arxiv.org/pdf/2001.09580. + + The following implementation merges together the pseudocode from Fig7 of https://arxiv.org/pdf/2001.09580 + and the circuit in figure 15 of https://arxiv.org/pdf/2302.06639; This is in order to compute the values + of `f` and `m`. + """ + assert m == 0 + if f == 0: + # When `f = 0` this means that the algorithm is nearly over and that we just need to + # double the value of `r`. + r = (r << 1) % self.mod + elif v == 0: + # `v = 0` is the termination condition of the algorithm and it means that the only + # remaining step is multiplying `r` by 2 raised to the number of remaining iterations. + # Classically this translates into a `r = (r * pow(2, k, p))%p` where k is the number + # of iterations left followed by a break statement. + m = u & 1 + f = 0 + r = (r << 1) % self.mod + else: + m = (u % 2 == 1) & (v % 2 == 0) + # Kaliski iteration as described in Fig7 of https://arxiv.org/pdf/2001.09580. + swap = (u % 2 == 0 and v % 2 == 1) or (u % 2 == 1 and v % 2 == 1 and u > v) + if swap: + u, v = v, u + r, s = s, r + if u % 2 == 1 and v % 2 == 1: + v -= u + s += r + assert v % 2 == 0, f'{u=} {v=} {swap=}' + v >>= 1 + r = (r << 1) % self.mod + if swap: + u, v = v, u + r, s = s, r + return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + + +@frozen +class _KaliskiModInverseImpl(Bloq): + """The full KaliskiIteration algorithm. see C5 https://arxiv.org/pdf/2302.06639""" + + bitsize: 'SymbolicInt' + mod: 'SymbolicInt' + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('u', QMontgomeryUInt(self.bitsize)), + Register('v', QMontgomeryUInt(self.bitsize)), + Register('r', QMontgomeryUInt(self.bitsize)), + Register('s', QMontgomeryUInt(self.bitsize)), + Register('m', QAny(2 * self.bitsize)), + Register('f', QBit()), + ] + ) + + @cached_property + def _kaliski_iteration(self): + return _KaliskiIteration(self.bitsize, self.mod) + + def build_composite_bloq( + self, bb: 'BloqBuilder', u: Soquet, v: Soquet, r: Soquet, s: Soquet, m: Soquet, f: Soquet + ) -> Dict[str, 'SoquetT']: + f = bb.add(XGate(), q=f) + u = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=u) + s = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=s) + + m_arr = bb.split(m) + + for i in range(2 * self.bitsize): + u, v, r, s, m_arr[i], f = bb.add( + self._kaliski_iteration, u=u, v=v, r=r, s=s, m=m_arr[i], f=f + ) + + r = bb.add(BitwiseNot(QMontgomeryUInt(self.bitsize)), x=r) + r = bb.add(AddK(self.bitsize, self.mod + 1, signed=False), x=r) + + u = bb.add(XorK(QMontgomeryUInt(self.bitsize), 1), x=u) + s = bb.add(XorK(QMontgomeryUInt(self.bitsize), self.mod), x=s) + + m = bb.join(m_arr) + return {'u': u, 'v': v, 'r': r, 's': s, 'm': m, 'f': f} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return { + self._kaliski_iteration: 2 * self.bitsize, + BitwiseNot(QMontgomeryUInt(self.bitsize)): 1, + AddK(self.bitsize, self.mod + 1, signed=False): 1, + XGate(): 1, + XorK(QMontgomeryUInt(self.bitsize), self.mod): 2, + XorK(QMontgomeryUInt(self.bitsize), 1): 2, + } + + +@frozen +class KaliskiModInverse(Bloq): + r"""Compute modular multiplicative inverse -inplace- of numbers in montgomery form. + + Applies the transformation + $$ + \ket{x} \ket{0} \rightarrow \ket{x^{-1} 2^{2n} \mod p} \ket{\mathrm{garbage}} + $$ + + Args: + bitsize: size of the number. + mod: The integer modulus. + uncompute: whether to compute or uncompute. + + Registers: + x: The register for which we compute the multiplicative inverse. + m: A 2*bitsize register of intermediate values needed for uncomputation. + + References: + [Performance Analysis of a Repetition Cat Code Architecture: Computing 256-bit Elliptic Curve Logarithm in 9 Hours with 126 133 Cat Qubits](https://arxiv.org/abs/2302.06639) + Appendix C5. + + [Improved quantum circuits for elliptic curve discrete logarithms](https://arxiv.org/abs/2001.09580) + Fig 7(b) + + [How to compute a 256-bit elliptic curve private key with only 50 million Toffoli gates](https://arxiv.org/abs/2306.08585) + page 8. + """ + + bitsize: 'SymbolicInt' + mod: 'SymbolicInt' + uncompute: bool = False + + @cached_property + def signature(self) -> 'Signature': + side = Side.LEFT if self.uncompute else Side.RIGHT + return Signature( + [ + Register('x', QMontgomeryUInt(self.bitsize)), + Register('m', QAny(2 * self.bitsize), side=side), + ] + ) + + def build_composite_bloq( + self, bb: 'BloqBuilder', x: Soquet, m: Optional[Soquet] = None, f: Optional[Soquet] = None + ) -> Dict[str, 'SoquetT']: + u = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) + r = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) + s = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) + f = bb.allocate(1) + + if self.uncompute: + assert m is not None + u, x, r, s, m, f = cast( + Tuple[Soquet, Soquet, Soquet, Soquet, Soquet, Soquet], + bb.add_from( + _KaliskiModInverseImpl(self.bitsize, self.mod).adjoint(), + u=u, + v=r, + r=x, + s=s, + m=m, + f=f, + ), + ) + bb.free(u) + bb.free(r) + bb.free(s) + bb.free(m) + bb.free(f) + return {'x': x} + + m = bb.allocate(2 * self.bitsize) + u, v, x, s, m, f = bb.add_from( + _KaliskiModInverseImpl(self.bitsize, self.mod), u=u, v=x, r=r, s=s, m=m, f=f + ) + + assert isinstance(u, Soquet) + assert isinstance(v, Soquet) + assert isinstance(s, Soquet) + assert isinstance(f, Soquet) + bb.free(u) + bb.free(v) + bb.free(s) + bb.free(f) + return {'x': x, 'm': m} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + return _KaliskiModInverseImpl(self.bitsize, self.mod).build_call_graph(ssa) + + def on_classical_vals(self, x: int, m: int = 0) -> Dict[str, 'ClassicalValT']: + u, v, r, s, f = int(self.mod), x, 0, 1, 1 + iteration = _KaliskiModInverseImpl(self.bitsize, self.mod)._kaliski_iteration + for _ in range(2 * int(self.bitsize)): + u, v, r, s, m_i, f = iteration.call_classically(u=u, v=v, r=r, s=s, m=0, f=f) + m = (m << 1) | m_i + assert u == 1 + assert s == self.mod + assert f == 0 + assert v == 0 + return {'x': self.mod - r, 'm': m} + + +@bloq_example +def _kaliskimodinverse_example() -> KaliskiModInverse: + kaliskimodinverse_example = KaliskiModInverse(32, 10**9 + 7) + return kaliskimodinverse_example + + +@bloq_example +def _kaliskimodinverse_symbolic() -> KaliskiModInverse: + n, p = sympy.symbols('n p') + kaliskimodinverse_symbolic = KaliskiModInverse(n, p) + return kaliskimodinverse_symbolic + + +_KALISKI_MOD_INVERSE_DOC = BloqDocSpec( + bloq_cls=KaliskiModInverse, examples=[_kaliskimodinverse_example, _kaliskimodinverse_symbolic] +) diff --git a/qualtran/bloqs/mod_arithmetic/mod_division_test.py b/qualtran/bloqs/mod_arithmetic/mod_division_test.py new file mode 100644 index 000000000..31c56d394 --- /dev/null +++ b/qualtran/bloqs/mod_arithmetic/mod_division_test.py @@ -0,0 +1,87 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest +import sympy + +import qualtran.testing as qlt_testing +from qualtran import QMontgomeryUInt +from qualtran.bloqs.mod_arithmetic.mod_division import _kaliskimodinverse_example, KaliskiModInverse +from qualtran.resource_counting import get_cost_value, QECGatesCost +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join + + +@pytest.mark.parametrize('bitsize', [5, 6]) +@pytest.mark.parametrize('mod', [3, 5, 7, 11, 13, 15]) +def test_kaliski_mod_inverse_classical_action(bitsize, mod): + blq = KaliskiModInverse(bitsize, mod) + cblq = blq.decompose_bloq() + dtype = QMontgomeryUInt(bitsize) + R = pow(2, bitsize, mod) + for x in range(1, mod): + if math.gcd(x, mod) != 1: + continue + x_montgomery = dtype.uint_to_montgomery(x, mod) + res = blq.call_classically(x=x_montgomery) + assert res == cblq.call_classically(x=x_montgomery) + assert len(res) == 2 + assert res[0] == dtype.montgomery_inverse(x_montgomery, mod) + assert dtype.montgomery_product(int(res[0]), x_montgomery, mod) == R + + +@pytest.mark.parametrize('bitsize', [5, 6]) +@pytest.mark.parametrize('mod', [3, 5, 7, 11, 13, 15]) +def test_kaliski_mod_inverse_decomposition(bitsize, mod): + b = KaliskiModInverse(bitsize, mod) + qlt_testing.assert_valid_bloq_decomposition(b) + + +@pytest.mark.parametrize('bitsize', [5, 6]) +@pytest.mark.parametrize('mod', [3, 5, 7, 11, 13, 15]) +def test_kaliski_mod_bloq_counts(bitsize, mod): + b = KaliskiModInverse(bitsize, mod) + qlt_testing.assert_equivalent_bloq_counts(b, [ignore_alloc_free, ignore_split_join]) + + +def test_kaliski_symbolic_cost(): + n, p = sympy.symbols('n p') + b = KaliskiModInverse(n, p) + cost = get_cost_value(b, QECGatesCost()).total_t_and_ccz_count() + # We have some T gates since we use CSwapApprox instead of n CSWAPs. + total_toff = (cost['n_t'] / 4 + cost['n_ccz']) * sympy.Integer(1) + total_toff = total_toff.expand() + + # The toffoli cost from Litinski https://arxiv.org/abs/2306.08585 is 26n^2 + 2n. + # The cost of Kaliski is 2*n*(cost of an iteration) + (cost of computing $p - x$) + # + # - The cost of of computing $p-x$ in Litinski is 2n (Neg -> Add(p)). In our + # construction this is just $n-1$ (BitwiseNot -> Add(p+1)). + # - The cost of an iteration in Litinski $13n$ since they ignore constants. + # Our construction is exactly the same but we also count the constants + # which amout to $3$. for a total cost of $13n + 3$. + # For example the cost of ModDbl is 2n+1. In their figure 8, they report + # it as just $2n$. ModDbl gets executed within the 2n loop so its contribution + # to the overal cost should be 4n^2 + 2n instead of just 4n^2. + assert total_toff == 26 * n**2 + 7 * n - 1 + + +def test_kaliskimodinverse_example(bloq_autotester): + bloq_autotester(_kaliskimodinverse_example) + + +@pytest.mark.notebook +def test_notebook(): + qlt_testing.execute_notebook('mod_division') diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index 9f289add1..3de2e5b96 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -72,7 +72,7 @@ class ModDbl(Bloq): """ dtype: Union[QUInt, QMontgomeryUInt] - mod: int = attrs.field() + mod: 'SymbolicInt' = attrs.field() @mod.validator def _validate_mod(self, attribute, value): diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index 347c78b61..c9287718a 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -116,6 +116,7 @@ import qualtran.bloqs.mean_estimation.complex_phase_oracle import qualtran.bloqs.mean_estimation.mean_estimation_operator import qualtran.bloqs.mod_arithmetic +import qualtran.bloqs.mod_arithmetic.mod_division import qualtran.bloqs.mod_arithmetic.mod_multiplication import qualtran.bloqs.mod_arithmetic.mod_subtraction import qualtran.bloqs.multiplexers.apply_gate_to_lth_target @@ -348,6 +349,8 @@ "qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK": qualtran.bloqs.mod_arithmetic.mod_multiplication.CModMulK, "qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.DirtyOutOfPlaceMontgomeryModMul, "qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul": qualtran.bloqs.mod_arithmetic.mod_multiplication.SingleWindowModMul, + "qualtran.bloqs.mod_arithmetic.mod_division.KaliskiModInverse": qualtran.bloqs.mod_arithmetic.mod_division.KaliskiModInverse, + "qualtran.bloqs.mod_arithmetic.mod_division._KaliskiIteration": qualtran.bloqs.mod_arithmetic.mod_division._KaliskiIteration, "qualtran.bloqs.factoring._factoring_shims.MeasureQFT": qualtran.bloqs.factoring._factoring_shims.MeasureQFT, "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepOne, "qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo": qualtran.bloqs.factoring.ecc.ec_add._ECAddStepTwo,