From a1df089a450c62ca55f957629a884f0d31bbf761 Mon Sep 17 00:00:00 2001 From: Frankie Papa Date: Mon, 23 Sep 2024 23:44:18 -0700 Subject: [PATCH] Add Equals() bloq (#1411) --- dev_tools/autogenerate-bloqs-notebooks-v2.py | 1 + qualtran/bloqs/arithmetic/__init__.py | 1 + qualtran/bloqs/arithmetic/comparison.ipynb | 102 +++++++++++++++++++ qualtran/bloqs/arithmetic/comparison.py | 83 ++++++++++++++- qualtran/bloqs/arithmetic/comparison_test.py | 16 ++- qualtran/serialization/resolver_dict.py | 1 + 6 files changed, 202 insertions(+), 2 deletions(-) diff --git a/dev_tools/autogenerate-bloqs-notebooks-v2.py b/dev_tools/autogenerate-bloqs-notebooks-v2.py index 64eed620c..54e26e140 100644 --- a/dev_tools/autogenerate-bloqs-notebooks-v2.py +++ b/dev_tools/autogenerate-bloqs-notebooks-v2.py @@ -443,6 +443,7 @@ qualtran.bloqs.arithmetic.comparison._LT_K_DOC, qualtran.bloqs.arithmetic.comparison._GREATER_THAN_DOC, qualtran.bloqs.arithmetic.comparison._GREATER_THAN_K_DOC, + qualtran.bloqs.arithmetic.comparison._EQUALS_DOC, qualtran.bloqs.arithmetic.comparison._EQUALS_K_DOC, qualtran.bloqs.arithmetic.comparison._BI_QUBITS_MIXER_DOC, qualtran.bloqs.arithmetic.comparison._SQ_CMP_DOC, diff --git a/qualtran/bloqs/arithmetic/__init__.py b/qualtran/bloqs/arithmetic/__init__.py index 694b36bd2..533d0ee0c 100644 --- a/qualtran/bloqs/arithmetic/__init__.py +++ b/qualtran/bloqs/arithmetic/__init__.py @@ -17,6 +17,7 @@ from qualtran.bloqs.arithmetic.comparison import ( BiQubitsMixer, CLinearDepthGreaterThan, + Equals, EqualsAConstant, GreaterThan, GreaterThanConstant, diff --git a/qualtran/bloqs/arithmetic/comparison.ipynb b/qualtran/bloqs/arithmetic/comparison.ipynb index 7001178f8..537c80f36 100644 --- a/qualtran/bloqs/arithmetic/comparison.ipynb +++ b/qualtran/bloqs/arithmetic/comparison.ipynb @@ -918,6 +918,108 @@ "show_call_graph(clineardepthgreaterthan_example_g)\n", "show_counts_sigma(clineardepthgreaterthan_example_sigma)" ] + }, + { + "cell_type": "markdown", + "id": "126df51e", + "metadata": { + "cq.autogen": "Equals.bloq_doc.md" + }, + "source": [ + "## `Equals`\n", + "Implements |x>|y>|t> => |x>|y>|t ⨁ (x = y)> using $n-1$ Toffoli gates.\n", + "\n", + "#### Parameters\n", + " - `dtype`: Data type of the input registers `x` and `y`. \n", + "\n", + "#### Registers\n", + " - `x`: First input register.\n", + " - `y`: Second input register.\n", + " - `target`: Register to hold result of comparison.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef979281", + "metadata": { + "cq.autogen": "Equals.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import Equals" + ] + }, + { + "cell_type": "markdown", + "id": "2ceb3c0c", + "metadata": { + "cq.autogen": "Equals.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b311ac0f", + "metadata": { + "cq.autogen": "Equals.equals" + }, + "outputs": [], + "source": [ + "equals = Equals(QUInt(4))" + ] + }, + { + "cell_type": "markdown", + "id": "cc0f1db8", + "metadata": { + "cq.autogen": "Equals.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "088efd0f", + "metadata": { + "cq.autogen": "Equals.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([equals],\n", + " ['`equals`'])" + ] + }, + { + "cell_type": "markdown", + "id": "9f8b5abc", + "metadata": { + "cq.autogen": "Equals.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7474f07a", + "metadata": { + "cq.autogen": "Equals.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "equals_g, equals_sigma = equals.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(equals_g)\n", + "show_counts_sigma(equals_sigma)" + ] } ], "metadata": { diff --git a/qualtran/bloqs/arithmetic/comparison.py b/qualtran/bloqs/arithmetic/comparison.py index ff4986262..3c7cc76b2 100644 --- a/qualtran/bloqs/arithmetic/comparison.py +++ b/qualtran/bloqs/arithmetic/comparison.py @@ -31,6 +31,7 @@ GateWithRegisters, QAny, QBit, + QDType, QInt, QMontgomeryUInt, QUInt, @@ -41,7 +42,7 @@ SoquetT, ) from qualtran.bloqs.arithmetic.addition import OutOfPlaceAdder -from qualtran.bloqs.arithmetic.bitwise import BitwiseNot +from qualtran.bloqs.arithmetic.bitwise import BitwiseNot, Xor from qualtran.bloqs.arithmetic.conversions.sign_extension import SignExtend from qualtran.bloqs.basic_gates import CNOT, XGate from qualtran.bloqs.bookkeeping import Cast @@ -947,6 +948,86 @@ def _gt_k() -> GreaterThanConstant: _GREATER_THAN_K_DOC = BloqDocSpec(bloq_cls=GreaterThanConstant, examples=[_gt_k]) +@frozen +class Equals(Bloq): + r"""Implements |x>|y>|t> => |x>|y>|t ⨁ (x = y)> using $n-1$ Toffoli gates. + + Args: + dtype: Data type of the input registers `x` and `y`. + + Registers: + x: First input register. + y: Second input register. + target: Register to hold result of comparison. + """ + + dtype: QDType + + @cached_property + def signature(self) -> Signature: + return Signature.build_from_dtypes(x=self.dtype, y=self.dtype, target=QBit()) + + @cached_property + def bitsize(self) -> SymbolicInt: + return self.dtype.num_qubits + + def is_symbolic(self): + return is_symbolic(self.dtype) + + def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol: + if reg is None: + return Text("") + if reg.name == 'x': + return TextBox("In(x)") + if reg.name == 'y': + return TextBox("In(y)") + elif reg.name == 'target': + return TextBox("⨁(x = y)") + raise ValueError(f'Unknown register symbol {reg.name}') + + def build_composite_bloq( + self, bb: 'BloqBuilder', x: 'Soquet', y: 'Soquet', target: 'Soquet' + ) -> Dict[str, 'SoquetT']: + cvs: Union[list[int], HasLength] + if isinstance(self.bitsize, int): + cvs = [0] * self.bitsize + else: + cvs = HasLength(self.bitsize) + + x, y = bb.add(Xor(self.dtype), x=x, y=y) + y_split = bb.split(y) + y_split, target = bb.add(MultiControlX(cvs=cvs), controls=y_split, target=target) + y = bb.join(y_split, self.dtype) + x, y = bb.add(Xor(self.dtype), x=x, y=y) + + return {'x': x, 'y': y, 'target': target} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': + cvs: Union[list[int], HasLength] + if isinstance(self.bitsize, int): + cvs = [0] * self.bitsize + else: + cvs = HasLength(self.bitsize) + return { + Xor(self.dtype): 2, + MultiControlX(cvs=cvs): 1, + MultiAnd(cvs=cvs).adjoint(): 1, + CNOT(): 1, + } + + def on_classical_vals(self, x: int, y: int, target: int) -> Dict[str, 'ClassicalValT']: + return {'x': x, 'y': y, 'target': target ^ (x == y)} + + +@bloq_example +def _equals() -> Equals: + equals = Equals(QUInt(4)) + return equals + + +_EQUALS_DOC = BloqDocSpec(bloq_cls=Equals, examples=[_equals]) + + @frozen class EqualsAConstant(Bloq): r"""Implements $U_a|x\rangle|z\rangle = |x\rangle |z \oplus (x = a)\rangle$ diff --git a/qualtran/bloqs/arithmetic/comparison_test.py b/qualtran/bloqs/arithmetic/comparison_test.py index 368d45d3c..23267e1d6 100644 --- a/qualtran/bloqs/arithmetic/comparison_test.py +++ b/qualtran/bloqs/arithmetic/comparison_test.py @@ -20,16 +20,18 @@ import sympy import qualtran.testing as qlt_testing -from qualtran import BloqBuilder, QInt, QMontgomeryUInt, QUInt +from qualtran import BloqBuilder, QBit, QInt, QMontgomeryUInt, QUInt from qualtran.bloqs.arithmetic.comparison import ( _clineardepthgreaterthan_example, _eq_k, + _equals, _greater_than, _gt_k, _leq_symb, _lt_k_symb, BiQubitsMixer, CLinearDepthGreaterThan, + Equals, EqualsAConstant, GreaterThan, GreaterThanConstant, @@ -63,6 +65,10 @@ def test_leq_symb(bloq_autotester): bloq_autotester(_leq_symb) +def test_equals(bloq_autotester): + bloq_autotester(_equals) + + def test_eq_k(bloq_autotester): bloq_autotester(_eq_k) @@ -296,6 +302,14 @@ def test_greater_than_constant(): ) +@pytest.mark.parametrize('dtype', [QBit(), QUInt(2), QInt(3), QMontgomeryUInt(4), QUInt(5)]) +def test_classical_equals(dtype): + bloq = Equals(dtype) + qlt_testing.assert_consistent_classical_action( + bloq, x=dtype.get_classical_domain(), y=dtype.get_classical_domain(), target=range(2) + ) + + def test_equals_a_constant(): bb = BloqBuilder() bitsize = 5 diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index 45518ad08..dad2907ea 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -167,6 +167,7 @@ "qualtran.bloqs.arithmetic.bitwise.Xor": qualtran.bloqs.arithmetic.bitwise.Xor, "qualtran.bloqs.arithmetic.bitwise.XorK": qualtran.bloqs.arithmetic.bitwise.XorK, "qualtran.bloqs.arithmetic.comparison.BiQubitsMixer": qualtran.bloqs.arithmetic.comparison.BiQubitsMixer, + "qualtran.bloqs.arithmetic.comparison.Equals": qualtran.bloqs.arithmetic.comparison.Equals, "qualtran.bloqs.arithmetic.comparison.EqualsAConstant": qualtran.bloqs.arithmetic.comparison.EqualsAConstant, "qualtran.bloqs.arithmetic.comparison.GreaterThan": qualtran.bloqs.arithmetic.comparison.GreaterThan, "qualtran.bloqs.arithmetic.comparison.GreaterThanConstant": qualtran.bloqs.arithmetic.comparison.GreaterThanConstant,