diff --git a/dev_tools/qualtran_dev_tools/notebook_specs.py b/dev_tools/qualtran_dev_tools/notebook_specs.py index a35c468ce..8289e00ce 100644 --- a/dev_tools/qualtran_dev_tools/notebook_specs.py +++ b/dev_tools/qualtran_dev_tools/notebook_specs.py @@ -37,6 +37,7 @@ import qualtran.bloqs.arithmetic.controlled_add_or_subtract import qualtran.bloqs.arithmetic.controlled_addition import qualtran.bloqs.arithmetic.conversions +import qualtran.bloqs.arithmetic.lists import qualtran.bloqs.arithmetic.multiplication import qualtran.bloqs.arithmetic.negate import qualtran.bloqs.arithmetic.permutation @@ -488,6 +489,15 @@ module=qualtran.bloqs.arithmetic.trigonometric, bloq_specs=[qualtran.bloqs.arithmetic.trigonometric.arcsin._ARCSIN_DOC], ), + NotebookSpecV2( + title='List Functions', + module=qualtran.bloqs.arithmetic.lists, + bloq_specs=[ + qualtran.bloqs.arithmetic.lists.sort_in_place._SORT_IN_PLACE_DOC, + qualtran.bloqs.arithmetic.lists.symmetric_difference._SYMMETRIC_DIFFERENCE_DOC, + qualtran.bloqs.arithmetic.lists.has_duplicates._HAS_DUPLICATES_DOC, + ], + ), ] MOD_ARITHMETIC = [ diff --git a/docs/bloqs/index.rst b/docs/bloqs/index.rst index d827dfaad..f886288fd 100644 --- a/docs/bloqs/index.rst +++ b/docs/bloqs/index.rst @@ -75,6 +75,7 @@ Bloqs Library arithmetic/permutation.ipynb arithmetic/bitwise.ipynb arithmetic/trigonometric/trigonometric.ipynb + arithmetic/lists/lists.ipynb .. toctree:: :maxdepth: 2 diff --git a/qualtran/bloqs/arithmetic/lists/__init__.py b/qualtran/bloqs/arithmetic/lists/__init__.py new file mode 100644 index 000000000..7ed84457f --- /dev/null +++ b/qualtran/bloqs/arithmetic/lists/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .has_duplicates import HasDuplicates +from .sort_in_place import SortInPlace +from .symmetric_difference import SymmetricDifference diff --git a/qualtran/bloqs/arithmetic/lists/has_duplicates.py b/qualtran/bloqs/arithmetic/lists/has_duplicates.py new file mode 100644 index 000000000..4da8d2a61 --- /dev/null +++ b/qualtran/bloqs/arithmetic/lists/has_duplicates.py @@ -0,0 +1,176 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import Counter +from typing import Union + +import attrs +import numpy as np +from attrs import frozen + +from qualtran import ( + AddControlledT, + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + CtrlSpec, + QBit, + QInt, + QUInt, + Register, + Signature, + Soquet, + SoquetT, +) +from qualtran.bloqs.arithmetic import LinearDepthHalfLessThan +from qualtran.bloqs.basic_gates import CNOT, XGate +from qualtran.bloqs.mcmt import MultiControlX +from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.simulation.classical_sim import ClassicalValT +from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt + + +@frozen +class HasDuplicates(Bloq): + r"""Given a sorted list of `l` numbers, check if it contains any duplicates. + + Produces a single qubit which is `1` if there are duplicates, and `0` if all are disjoint. + It compares every adjacent pair, and therefore uses `l - 1` comparisons. + It then uses a single MCX on `l - 1` bits gate to compute the flag. + + Args: + l: number of elements in the list + dtype: type of each element to store `[n]`. + + Registers: + xs: a list of `l` registers of `dtype`. + flag: single qubit. Value is flipped if the input list has duplicates, otherwise stays same. + + References: + [Quartic quantum speedups for planted inference](https://arxiv.org/abs/2406.19378v1) + Lemma 4.12. Eq. 122. + """ + + l: SymbolicInt + dtype: Union[QUInt, QInt] + is_controlled: bool = False + + @property + def signature(self) -> 'Signature': + registers = [Register('xs', self.dtype, shape=(self.l,)), Register('flag', QBit())] + if self.is_controlled: + registers.append(Register('ctrl', QBit())) + return Signature(registers) + + @property + def _le_bloq(self) -> LinearDepthHalfLessThan: + return LinearDepthHalfLessThan(self.dtype) + + def build_composite_bloq( + self, bb: 'BloqBuilder', xs: 'SoquetT', flag: 'Soquet', **extra_soqs: 'SoquetT' + ) -> dict[str, 'SoquetT']: + assert not is_symbolic(self.l) + assert isinstance(xs, np.ndarray) + + cs = [] + oks = [] + if self.is_controlled: + oks = [extra_soqs.pop('ctrl')] + assert not extra_soqs + + for i in range(1, self.l): + xs[i - 1], xs[i], c, ok = bb.add(self._le_bloq, a=xs[i - 1], b=xs[i]) + cs.append(c) + oks.append(ok) + + oks, flag = bb.add(MultiControlX((1,) * len(oks)), controls=np.array(oks), target=flag) + if not self.is_controlled: + flag = bb.add(XGate(), q=flag) + else: + oks[0], flag = bb.add(CNOT(), ctrl=oks[0], target=flag) + + oks = list(oks) + for i in reversed(range(1, self.l)): + xs[i - 1], xs[i] = bb.add( + self._le_bloq.adjoint(), a=xs[i - 1], b=xs[i], c=cs.pop(), target=oks.pop() + ) + + if self.is_controlled: + extra_soqs = {'ctrl': oks.pop()} + assert not oks + + return {'xs': xs, 'flag': flag} | extra_soqs + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: + counts = Counter[Bloq]() + + counts[self._le_bloq] += self.l - 1 + counts[self._le_bloq.adjoint()] += self.l - 1 + + n_ctrls = self.l - (1 if not self.is_controlled else 0) + counts[MultiControlX(HasLength(n_ctrls))] += 1 + + counts[XGate() if not self.is_controlled else CNOT()] += 1 + + return counts + + def on_classical_vals(self, **vals: 'ClassicalValT') -> dict[str, 'ClassicalValT']: + xs = np.asarray(vals['xs']) + assert np.all(xs == np.sort(xs)) + if np.any(xs[:-1] == xs[1:]): + vals['flag'] ^= 1 + return vals + + def adjoint(self) -> 'HasDuplicates': + return self + + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs + + return get_ctrl_system_1bit_cv_from_bloqs( + self, + ctrl_spec, + current_ctrl_bit=1 if self.is_controlled else None, + bloq_with_ctrl=attrs.evolve(self, is_controlled=True), + ctrl_reg_name='ctrl', + ) + + +@bloq_example +def _has_duplicates() -> HasDuplicates: + has_duplicates = HasDuplicates(4, QUInt(3)) + return has_duplicates + + +@bloq_example +def _has_duplicates_symb() -> HasDuplicates: + import sympy + + n = sympy.Symbol("n") + has_duplicates_symb = HasDuplicates(4, QUInt(n)) + return has_duplicates_symb + + +@bloq_example +def _has_duplicates_symb_len() -> HasDuplicates: + import sympy + + l, n = sympy.symbols("l n") + has_duplicates_symb_len = HasDuplicates(l, QUInt(n)) + return has_duplicates_symb_len + + +_HAS_DUPLICATES_DOC = BloqDocSpec( + bloq_cls=HasDuplicates, examples=[_has_duplicates_symb, _has_duplicates] +) diff --git a/qualtran/bloqs/arithmetic/lists/has_duplicates_test.py b/qualtran/bloqs/arithmetic/lists/has_duplicates_test.py new file mode 100644 index 000000000..0a02fa83a --- /dev/null +++ b/qualtran/bloqs/arithmetic/lists/has_duplicates_test.py @@ -0,0 +1,56 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools + +import numpy as np +import pytest + +import qualtran.testing as qlt_testing +from qualtran import QInt, QUInt +from qualtran.bloqs.arithmetic.lists.has_duplicates import ( + _has_duplicates, + _has_duplicates_symb, + _has_duplicates_symb_len, + HasDuplicates, +) + + +@pytest.mark.parametrize( + "bloq_ex", + [_has_duplicates, _has_duplicates_symb, _has_duplicates_symb_len], + ids=lambda b: b.name, +) +def test_examples(bloq_autotester, bloq_ex): + bloq_autotester(bloq_ex) + + +@pytest.mark.parametrize("bloq_ex", [_has_duplicates, _has_duplicates_symb], ids=lambda b: b.name) +def test_counts(bloq_ex): + qlt_testing.assert_equivalent_bloq_counts(bloq_ex()) + + +@pytest.mark.parametrize("l", [2, 3, pytest.param(4, marks=pytest.mark.slow)]) +@pytest.mark.parametrize( + "dtype", [QUInt(2), QInt(2), pytest.param(QUInt(3), marks=pytest.mark.slow)] +) +def test_classical_action(l, dtype): + bloq = HasDuplicates(l, dtype) + cbloq = bloq.decompose_bloq() + + for xs_t in itertools.product(dtype.get_classical_domain(), repeat=l): + xs = np.sort(xs_t) + for flag in [0, 1]: + np.testing.assert_equal( + cbloq.call_classically(xs=xs, flag=flag), bloq.call_classically(xs=xs, flag=flag) + ) diff --git a/qualtran/bloqs/arithmetic/lists/lists.ipynb b/qualtran/bloqs/arithmetic/lists/lists.ipynb new file mode 100644 index 000000000..2251b288d --- /dev/null +++ b/qualtran/bloqs/arithmetic/lists/lists.ipynb @@ -0,0 +1,342 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "46306919", + "metadata": { + "cq.autogen": "title_cell" + }, + "source": [ + "# List Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47761dec", + "metadata": { + "cq.autogen": "top_imports" + }, + "outputs": [], + "source": [ + "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", + "from qualtran import QBit, QInt, QUInt, QAny\n", + "from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n", + "from typing import *\n", + "import numpy as np\n", + "import sympy\n", + "import cirq" + ] + }, + { + "cell_type": "markdown", + "id": "19f11879", + "metadata": { + "cq.autogen": "SortInPlace.bloq_doc.md" + }, + "source": [ + "## `SortInPlace`\n", + "Sort a list of $\\ell$ numbers in place using $\\ell \\log \\ell$ ancilla bits.\n", + "\n", + "Applies the map:\n", + "$$\n", + " |x_1, x_2, \\ldots, x_l\\rangle\n", + " |0^{\\ell \\log \\ell}\\rangle\n", + " \\mapsto\n", + " |x_{\\pi_1}, x_{\\pi_2}, \\ldots, x_{\\pi_\\ell})\\rangle\n", + " |\\pi_1, \\pi_2, \\ldots, \\pi_\\ell\\rangle\n", + "$$\n", + "where $x_{\\pi_1} \\le x_{\\pi_2} \\ldots \\le x_{\\pi_\\ell}$ is the sorted list,\n", + "and the ancilla are entangled.\n", + "\n", + "To apply this, we first use any sorting algorithm to output the sorted list\n", + "in a clean register. And then use the following algorithm from Lemma 4.12 of Ref [1]\n", + "that applies the map:\n", + "\n", + "$$\n", + " |x_1, ..., x_l\\rangle|x_{\\pi(1)}, ..., x_{\\pi(l)})\\rangle\n", + " \\mapsto\n", + " |x_l, ..., x_l\\rangle|\\pi(1), ..., \\pi(l))\\rangle\n", + "$$\n", + "\n", + "where $x_i \\in [n]$ and $\\pi(i) \\in [l]$.\n", + "This second algorithm (Lemma 4.12) has two steps, each with $l^2$ comparisons:\n", + "1. compute `pi(1) ... pi(l)` given `x_1 ... x_l` and `x_{pi(1)} ... x{pi(l)}`.\n", + "1. (un)compute `x_{pi(1)} ... x{pi(l)}` using `pi(1) ... pi(l)` given `x_1 ... x_l`.\n", + "\n", + "#### Parameters\n", + " - `l`: number of elements in the list\n", + " - `dtype`: type of each element to store `[n]`. \n", + "\n", + "#### Registers\n", + " - `input`: the entire input as a single register\n", + " - `ancilla`: the generated (entangled) register storing `pi`. \n", + "\n", + "#### References\n", + " - [Quartic quantum speedups for planted inference](https://arxiv.org/abs/2406.19378v1). Lemma 4.12. Eq. 122.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce69b6e8", + "metadata": { + "cq.autogen": "SortInPlace.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic.lists import SortInPlace" + ] + }, + { + "cell_type": "markdown", + "id": "e84df89b", + "metadata": { + "cq.autogen": "SymmetricDifference.bloq_doc.md" + }, + "source": [ + "## `SymmetricDifference`\n", + "Given two sorted sets $S, T$ of unique elements, compute their symmetric difference.\n", + "\n", + "This accepts an integer `n_diff`, and marks a flag qubit if the symmetric difference\n", + "set is of size exactly `n_diff`. If the flag is marked (1), then the output of `n_diff`\n", + "numbers is the symmetric difference, otherwise it may be arbitrary.\n", + "\n", + "#### Parameters\n", + " - `n_lhs`: number of elements in $S$\n", + " - `n_rhs`: number of elements in $T$\n", + " - `n_diff`: expected number of elements in the difference $S \\Delta T$.\n", + " - `dtype`: type of each element. \n", + "\n", + "#### Registers\n", + " - `S`: list of `n_lhs` numbers.\n", + " - `T`: list of `n_rhs` numbers.\n", + " - `diff`: output register of `n_diff` numbers.\n", + " - `flag`: 1 if there are duplicates, 0 if all are unique. \n", + "\n", + "#### References\n", + " - [Quartic quantum speedups for planted inference](https://arxiv.org/abs/2406.19378v1). Theorem 4.17, proof para 3, page 38.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3a9d9ea", + "metadata": { + "cq.autogen": "SymmetricDifference.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic.lists import SymmetricDifference" + ] + }, + { + "cell_type": "markdown", + "id": "f58e0eba", + "metadata": { + "cq.autogen": "SymmetricDifference.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29fc34f0", + "metadata": { + "cq.autogen": "SymmetricDifference.symm_diff" + }, + "outputs": [], + "source": [ + "dtype = QUInt(4)\n", + "symm_diff = SymmetricDifference(n_lhs=4, n_rhs=2, n_diff=4, dtype=dtype)" + ] + }, + { + "cell_type": "markdown", + "id": "70608811", + "metadata": { + "cq.autogen": "SymmetricDifference.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9df9334c", + "metadata": { + "cq.autogen": "SymmetricDifference.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([symm_diff],\n", + " ['`symm_diff`'])" + ] + }, + { + "cell_type": "markdown", + "id": "476c580a", + "metadata": { + "cq.autogen": "SymmetricDifference.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6e6f812", + "metadata": { + "cq.autogen": "SymmetricDifference.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "symm_diff_g, symm_diff_sigma = symm_diff.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(symm_diff_g)\n", + "show_counts_sigma(symm_diff_sigma)" + ] + }, + { + "cell_type": "markdown", + "id": "15e88a40", + "metadata": { + "cq.autogen": "HasDuplicates.bloq_doc.md" + }, + "source": [ + "## `HasDuplicates`\n", + "Given a sorted list of `l` numbers, check if it contains any duplicates.\n", + "\n", + "Produces a single qubit which is `1` if there are duplicates, and `0` if all are disjoint.\n", + "It compares every adjacent pair, and therefore uses `l - 1` comparisons.\n", + "It then uses a single MCX on `l - 1` bits gate to compute the flag.\n", + "\n", + "#### Parameters\n", + " - `l`: number of elements in the list\n", + " - `dtype`: type of each element to store `[n]`. \n", + "\n", + "#### Registers\n", + " - `xs`: a list of `l` registers of `dtype`.\n", + " - `flag`: single qubit. Value is flipped if the input list has duplicates, otherwise stays same. \n", + "\n", + "#### References\n", + " - [Quartic quantum speedups for planted inference](https://arxiv.org/abs/2406.19378v1). Lemma 4.12. Eq. 122.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "101f6899", + "metadata": { + "cq.autogen": "HasDuplicates.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic.lists import HasDuplicates" + ] + }, + { + "cell_type": "markdown", + "id": "6ceded7e", + "metadata": { + "cq.autogen": "HasDuplicates.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5f5efa7", + "metadata": { + "cq.autogen": "HasDuplicates.has_duplicates_symb" + }, + "outputs": [], + "source": [ + "import sympy\n", + "\n", + "n = sympy.Symbol(\"n\")\n", + "has_duplicates_symb = HasDuplicates(4, QUInt(n))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5412e0d", + "metadata": { + "cq.autogen": "HasDuplicates.has_duplicates" + }, + "outputs": [], + "source": [ + "has_duplicates = HasDuplicates(4, QUInt(3))" + ] + }, + { + "cell_type": "markdown", + "id": "7c1c62b4", + "metadata": { + "cq.autogen": "HasDuplicates.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1198bb5d", + "metadata": { + "cq.autogen": "HasDuplicates.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([has_duplicates_symb, has_duplicates],\n", + " ['`has_duplicates_symb`', '`has_duplicates`'])" + ] + }, + { + "cell_type": "markdown", + "id": "f6b8cde0", + "metadata": { + "cq.autogen": "HasDuplicates.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75afba07", + "metadata": { + "cq.autogen": "HasDuplicates.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "has_duplicates_symb_g, has_duplicates_symb_sigma = has_duplicates_symb.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(has_duplicates_symb_g)\n", + "show_counts_sigma(has_duplicates_symb_sigma)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qualtran/bloqs/arithmetic/lists/sort_in_place.py b/qualtran/bloqs/arithmetic/lists/sort_in_place.py new file mode 100644 index 000000000..d4da3727a --- /dev/null +++ b/qualtran/bloqs/arithmetic/lists/sort_in_place.py @@ -0,0 +1,99 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import Counter + +from attrs import frozen + +from qualtran import Bloq, BloqDocSpec, BQUInt, QDType, Register, Signature +from qualtran.bloqs.arithmetic import Xor +from qualtran.bloqs.arithmetic.sorting import Comparator +from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.symbolics import ceil, log2, SymbolicInt + + +@frozen +class SortInPlace(Bloq): + r"""Sort a list of $\ell$ numbers in place using $\ell \log \ell$ ancilla bits. + + Applies the map: + $$ + |x_1, x_2, \ldots, x_l\rangle + |0^{\ell \log \ell}\rangle + \mapsto + |x_{\pi_1}, x_{\pi_2}, \ldots, x_{\pi_\ell})\rangle + |\pi_1, \pi_2, \ldots, \pi_\ell\rangle + $$ + where $x_{\pi_1} \le x_{\pi_2} \ldots \le x_{\pi_\ell}$ is the sorted list, + and the ancilla are entangled. + + To apply this, we first use any sorting algorithm to output the sorted list + in a clean register. And then use the following algorithm from Lemma 4.12 of Ref [1] + that applies the map: + + $$ + |x_1, ..., x_l\rangle|x_{\pi(1)}, ..., x_{\pi(l)})\rangle + \mapsto + |x_l, ..., x_l\rangle|\pi(1), ..., \pi(l))\rangle + $$ + + where $x_i \in [n]$ and $\pi(i) \in [l]$. + This second algorithm (Lemma 4.12) has two steps, each with $l^2$ comparisons: + 1. compute `pi(1) ... pi(l)` given `x_1 ... x_l` and `x_{pi(1)} ... x{pi(l)}`. + 1. (un)compute `x_{pi(1)} ... x{pi(l)}` using `pi(1) ... pi(l)` given `x_1 ... x_l`. + + Args: + l: number of elements in the list + dtype: type of each element to store `[n]`. + + Registers: + input: the entire input as a single register + ancilla (RIGHT): the generated (entangled) register storing `pi`. + + References: + [Quartic quantum speedups for planted inference](https://arxiv.org/abs/2406.19378v1) + Lemma 4.12. Eq. 122. + """ + + l: SymbolicInt + dtype: QDType + + @property + def signature(self) -> 'Signature': + return Signature( + [ + Register('xs', self.dtype, shape=(self.l,)), + Register('pi', self.index_dtype, shape=(self.l,)), + ] + ) + + @property + def index_dtype(self) -> QDType: + """dtype to represent an index in range `[l]`""" + bitsize = ceil(log2(self.l)) + return BQUInt(bitsize, self.l) + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: + compare = Comparator(self.dtype.num_qubits) + n_ops = 3 * self.l**2 + + counts = Counter[Bloq]() + + counts[compare] += n_ops + counts[compare.adjoint()] += n_ops + counts[Xor(self.dtype)] += n_ops + + return counts + + +_SORT_IN_PLACE_DOC = BloqDocSpec(bloq_cls=SortInPlace, examples=[]) diff --git a/qualtran/bloqs/arithmetic/lists/symmetric_difference.py b/qualtran/bloqs/arithmetic/lists/symmetric_difference.py new file mode 100644 index 000000000..72d13f822 --- /dev/null +++ b/qualtran/bloqs/arithmetic/lists/symmetric_difference.py @@ -0,0 +1,135 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import Counter + +from attrs import frozen + +from qualtran import Bloq, bloq_example, BloqDocSpec, QBit, QDType, QUInt, Register, Signature +from qualtran.bloqs.arithmetic import Equals, EqualsAConstant, HammingWeightCompute, Xor +from qualtran.bloqs.arithmetic.sorting import BitonicMerge +from qualtran.bloqs.basic_gates import CNOT +from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator +from qualtran.symbolics import bit_length, is_symbolic, SymbolicInt + + +@frozen +class SymmetricDifference(Bloq): + r"""Given two sorted sets $S, T$ of unique elements, compute their symmetric difference. + + This accepts an integer `n_diff`, and marks a flag qubit if the symmetric difference + set is of size exactly `n_diff`. If the flag is marked (1), then the output of `n_diff` + numbers is the symmetric difference, otherwise it may be arbitrary. + + Args: + n_lhs: number of elements in $S$ + n_rhs: number of elements in $T$ + n_diff: expected number of elements in the difference $S \Delta T$. + dtype: type of each element. + + Registers: + S: list of `n_lhs` numbers. + T: list of `n_rhs` numbers. + diff: output register of `n_diff` numbers. + flag: 1 if there are duplicates, 0 if all are unique. + + References: + [Quartic quantum speedups for planted inference](https://arxiv.org/abs/2406.19378v1) + Theorem 4.17, proof para 3, page 38. + """ + + n_lhs: SymbolicInt + n_rhs: SymbolicInt + n_diff: SymbolicInt + dtype: QDType + + def __attrs_post_init__(self): + if not is_symbolic(self.n_lhs, self.n_rhs): + assert self.n_lhs >= self.n_rhs, "lhs must be the larger set" + + @property + def signature(self) -> 'Signature': + return Signature( + [ + Register('S', self.dtype, shape=(self.n_lhs,)), + Register('T', self.dtype, shape=(self.n_rhs,)), + Register('diff', self.dtype, shape=(self.n_diff,)), + Register('flag', QBit()), + ] + ) + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: + # the forward pass, i.e. all bloqs that must be uncomputed + counts_forward = Counter[Bloq]() + + # merge the lists + counts_forward[BitonicMerge(self.n_lhs, self.dtype.num_qubits)] += 1 + # compare adjacents + counts_forward[Equals(self.dtype)] += self.n_lhs + self.n_rhs - 1 + # compute number of equal adjacents + counts_forward[HammingWeightCompute(self.n_lhs + self.n_rhs - 1)] += 1 + # check: 2 * n_equal = n_lhs + n_rhs - n_diff + # (note: the above eq holds as we assume all input elements are unique) + counts_forward[ + EqualsAConstant( + bit_length(self.n_lhs + self.n_rhs - 1), + (self.n_lhs + self.n_rhs - self.n_diff) // 2, + ) + ] += 1 + + # all bloqs + counts = Counter[Bloq]() + + # copy the first n_diff numbers and flag + counts[Xor(self.dtype)] += self.n_diff + counts[CNOT()] += 1 + + for bloq, n in counts_forward.items(): + counts[bloq] += n + counts[bloq.adjoint()] += n + + return counts + + +@bloq_example +def _symm_diff() -> SymmetricDifference: + dtype = QUInt(4) + symm_diff = SymmetricDifference(n_lhs=4, n_rhs=2, n_diff=4, dtype=dtype) + return symm_diff + + +@bloq_example +def _symm_diff_symb() -> SymmetricDifference: + import sympy + + from qualtran.symbolics import bit_length + + n, k, c = sympy.symbols("n k c", positive=True, integer=True) + dtype = QUInt(bit_length(n - 1)) + symm_diff_symb = SymmetricDifference(n_lhs=c * k, n_rhs=k, n_diff=c * k, dtype=dtype) + return symm_diff_symb + + +@bloq_example +def _symm_diff_equal_size_symb() -> SymmetricDifference: + import sympy + + from qualtran.symbolics import bit_length + + n, k, c = sympy.symbols("n k c", positive=True, integer=True) + dtype = QUInt(bit_length(n - 1)) + symm_diff_equal_size_symb = SymmetricDifference(n_lhs=c * k, n_rhs=c * k, n_diff=k, dtype=dtype) + return symm_diff_equal_size_symb + + +_SYMMETRIC_DIFFERENCE_DOC = BloqDocSpec(bloq_cls=SymmetricDifference, examples=[_symm_diff]) diff --git a/qualtran/bloqs/arithmetic/lists/symmetric_difference_test.py b/qualtran/bloqs/arithmetic/lists/symmetric_difference_test.py new file mode 100644 index 000000000..0ad5a356b --- /dev/null +++ b/qualtran/bloqs/arithmetic/lists/symmetric_difference_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import ANY + +import pytest + +import qualtran.testing as qlt_testing +from qualtran.bloqs.arithmetic.lists.symmetric_difference import ( + _symm_diff, + _symm_diff_equal_size_symb, + _symm_diff_symb, +) +from qualtran.resource_counting import big_O, GateCounts, get_cost_value, QECGatesCost +from qualtran.symbolics import ceil, log2 + + +@pytest.mark.parametrize("bloq_ex", [_symm_diff, _symm_diff_symb, _symm_diff_equal_size_symb]) +def test_examples(bloq_autotester, bloq_ex): + bloq_autotester(bloq_ex) + + +@pytest.mark.parametrize("bloq_ex", [_symm_diff_symb, _symm_diff_equal_size_symb]) +def test_cost(bloq_ex): + bloq = bloq_ex() + gc = get_cost_value(bloq, QECGatesCost()) + + l, r = bloq.n_lhs, bloq.n_rhs # assumption l >= r + logn = bloq.dtype.num_qubits + logl = ceil(log2(l)) + assert gc == GateCounts( + cswap=2 * l * logn * (logl + 1), + and_bloq=( + 2 * l * (2 * logn + 1) * (logl + 1) + + l + + r + + 2 * ((logn - 1) * (l + r - 1)) + + 2 * ceil(log2(l + r)) + - 4 + ), + clifford=ANY, + measurement=ANY, + ) + + # \tilde{O}(l log n) + # Page 38, Thm 4.17, proof para 3, 3rd last line. + assert gc.total_t_count() in big_O(l * logn * logl**2) + + +@pytest.mark.notebook +def test_notebook(): + qlt_testing.execute_notebook('lists') diff --git a/qualtran/bloqs/arithmetic/sorting.py b/qualtran/bloqs/arithmetic/sorting.py index d18db924e..a6e9937cc 100644 --- a/qualtran/bloqs/arithmetic/sorting.py +++ b/qualtran/bloqs/arithmetic/sorting.py @@ -23,6 +23,7 @@ bloq_example, BloqBuilder, BloqDocSpec, + DecomposeNotImplementedError, DecomposeTypeError, QBit, QUInt, @@ -222,8 +223,6 @@ def __attrs_post_init__(self): k = self.half_length if not is_symbolic(k): assert k >= 1, "length of input lists must be positive" - # TODO(#1090) support non-power-of-two input lengths - assert (k & (k - 1)) == 0, "length of input lists must be a power of 2" @cached_property def signature(self) -> 'Signature': @@ -249,14 +248,16 @@ def is_symbolic(self): def build_composite_bloq( self, bb: 'BloqBuilder', xs: 'SoquetT', ys: 'SoquetT' ) -> dict[str, 'SoquetT']: - if is_symbolic(self.half_length): + k = self.half_length + if is_symbolic(k): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") + if (k & (k - 1)) != 0: + # TODO(#1090) support non-power-of-two input lengths + raise DecomposeNotImplementedError("length of input lists must be a power of 2") assert isinstance(xs, np.ndarray) assert isinstance(ys, np.ndarray) - k = self.half_length - first_round_junk = [] for i in range(k): xs[i], ys[k - 1 - i], anc = bb.add(Comparator(self.bitsize), a=xs[i], b=ys[k - 1 - i]) diff --git a/qualtran/conftest.py b/qualtran/conftest.py index 28d79a89f..439ae6145 100644 --- a/qualtran/conftest.py +++ b/qualtran/conftest.py @@ -141,6 +141,9 @@ def assert_bloq_example_serializes_for_pytest(bloq_ex: BloqExample): 'ctrl_on_symbolic_cv', # cannot serialize Shaped 'ctrl_on_symbolic_cv_multi', # cannot serialize Shaped 'ctrl_on_symbolic_n_ctrls', # cannot serialize Shaped + 'has_duplicates_symb_len', # cannot serialize HasLength + 'symm_diff_symb', # round trip fail: sympy assumptions not serialized + 'symm_diff_equal_size_symb', # round trip fail: sympy assumptions not serialized ]: pytest.xfail("Skipping serialization test for bloq examples that cannot yet be serialized.") diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index cd01026fb..0075dd2b4 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -23,6 +23,7 @@ import qualtran.bloqs.arithmetic.conversions.ones_complement_to_twos_complement import qualtran.bloqs.arithmetic.conversions.sign_extension import qualtran.bloqs.arithmetic.hamming_weight +import qualtran.bloqs.arithmetic.lists import qualtran.bloqs.arithmetic.multiplication import qualtran.bloqs.arithmetic.negate import qualtran.bloqs.arithmetic.permutation @@ -190,6 +191,9 @@ "qualtran.bloqs.arithmetic.conversions.sign_extension.SignExtend": qualtran.bloqs.arithmetic.conversions.sign_extension.SignExtend, "qualtran.bloqs.arithmetic.conversions.sign_extension.SignTruncate": qualtran.bloqs.arithmetic.conversions.sign_extension.SignTruncate, "qualtran.bloqs.arithmetic.hamming_weight.HammingWeightCompute": qualtran.bloqs.arithmetic.hamming_weight.HammingWeightCompute, + "qualtran.bloqs.arithmetic.lists.has_duplicates.HasDuplicates": qualtran.bloqs.arithmetic.lists.has_duplicates.HasDuplicates, + "qualtran.bloqs.arithmetic.lists.sort_in_place.SortInPlace": qualtran.bloqs.arithmetic.lists.sort_in_place.SortInPlace, + "qualtran.bloqs.arithmetic.lists.symmetric_difference.SymmetricDifference": qualtran.bloqs.arithmetic.lists.symmetric_difference.SymmetricDifference, "qualtran.bloqs.arithmetic.multiplication.InvertRealNumber": qualtran.bloqs.arithmetic.multiplication.InvertRealNumber, "qualtran.bloqs.arithmetic.multiplication.MultiplyTwoReals": qualtran.bloqs.arithmetic.multiplication.MultiplyTwoReals, "qualtran.bloqs.arithmetic.multiplication.PlusEqualProduct": qualtran.bloqs.arithmetic.multiplication.PlusEqualProduct,