Skip to content

Commit

Permalink
Add Equals() bloq (#1411)
Browse files Browse the repository at this point in the history
  • Loading branch information
fpapa250 authored Sep 24, 2024
1 parent 21ef526 commit a1df089
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 2 deletions.
1 change: 1 addition & 0 deletions dev_tools/autogenerate-bloqs-notebooks-v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@
qualtran.bloqs.arithmetic.comparison._LT_K_DOC,
qualtran.bloqs.arithmetic.comparison._GREATER_THAN_DOC,
qualtran.bloqs.arithmetic.comparison._GREATER_THAN_K_DOC,
qualtran.bloqs.arithmetic.comparison._EQUALS_DOC,
qualtran.bloqs.arithmetic.comparison._EQUALS_K_DOC,
qualtran.bloqs.arithmetic.comparison._BI_QUBITS_MIXER_DOC,
qualtran.bloqs.arithmetic.comparison._SQ_CMP_DOC,
Expand Down
1 change: 1 addition & 0 deletions qualtran/bloqs/arithmetic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from qualtran.bloqs.arithmetic.comparison import (
BiQubitsMixer,
CLinearDepthGreaterThan,
Equals,
EqualsAConstant,
GreaterThan,
GreaterThanConstant,
Expand Down
102 changes: 102 additions & 0 deletions qualtran/bloqs/arithmetic/comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,108 @@
"show_call_graph(clineardepthgreaterthan_example_g)\n",
"show_counts_sigma(clineardepthgreaterthan_example_sigma)"
]
},
{
"cell_type": "markdown",
"id": "126df51e",
"metadata": {
"cq.autogen": "Equals.bloq_doc.md"
},
"source": [
"## `Equals`\n",
"Implements |x>|y>|t> => |x>|y>|t ⨁ (x = y)> using $n-1$ Toffoli gates.\n",
"\n",
"#### Parameters\n",
" - `dtype`: Data type of the input registers `x` and `y`. \n",
"\n",
"#### Registers\n",
" - `x`: First input register.\n",
" - `y`: Second input register.\n",
" - `target`: Register to hold result of comparison.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef979281",
"metadata": {
"cq.autogen": "Equals.bloq_doc.py"
},
"outputs": [],
"source": [
"from qualtran.bloqs.arithmetic import Equals"
]
},
{
"cell_type": "markdown",
"id": "2ceb3c0c",
"metadata": {
"cq.autogen": "Equals.example_instances.md"
},
"source": [
"### Example Instances"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b311ac0f",
"metadata": {
"cq.autogen": "Equals.equals"
},
"outputs": [],
"source": [
"equals = Equals(QUInt(4))"
]
},
{
"cell_type": "markdown",
"id": "cc0f1db8",
"metadata": {
"cq.autogen": "Equals.graphical_signature.md"
},
"source": [
"#### Graphical Signature"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "088efd0f",
"metadata": {
"cq.autogen": "Equals.graphical_signature.py"
},
"outputs": [],
"source": [
"from qualtran.drawing import show_bloqs\n",
"show_bloqs([equals],\n",
" ['`equals`'])"
]
},
{
"cell_type": "markdown",
"id": "9f8b5abc",
"metadata": {
"cq.autogen": "Equals.call_graph.md"
},
"source": [
"### Call Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7474f07a",
"metadata": {
"cq.autogen": "Equals.call_graph.py"
},
"outputs": [],
"source": [
"from qualtran.resource_counting.generalizers import ignore_split_join\n",
"equals_g, equals_sigma = equals.call_graph(max_depth=1, generalizer=ignore_split_join)\n",
"show_call_graph(equals_g)\n",
"show_counts_sigma(equals_sigma)"
]
}
],
"metadata": {
Expand Down
83 changes: 82 additions & 1 deletion qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GateWithRegisters,
QAny,
QBit,
QDType,
QInt,
QMontgomeryUInt,
QUInt,
Expand All @@ -41,7 +42,7 @@
SoquetT,
)
from qualtran.bloqs.arithmetic.addition import OutOfPlaceAdder
from qualtran.bloqs.arithmetic.bitwise import BitwiseNot
from qualtran.bloqs.arithmetic.bitwise import BitwiseNot, Xor
from qualtran.bloqs.arithmetic.conversions.sign_extension import SignExtend
from qualtran.bloqs.basic_gates import CNOT, XGate
from qualtran.bloqs.bookkeeping import Cast
Expand Down Expand Up @@ -947,6 +948,86 @@ def _gt_k() -> GreaterThanConstant:
_GREATER_THAN_K_DOC = BloqDocSpec(bloq_cls=GreaterThanConstant, examples=[_gt_k])


@frozen
class Equals(Bloq):
r"""Implements |x>|y>|t> => |x>|y>|t ⨁ (x = y)> using $n-1$ Toffoli gates.
Args:
dtype: Data type of the input registers `x` and `y`.
Registers:
x: First input register.
y: Second input register.
target: Register to hold result of comparison.
"""

dtype: QDType

@cached_property
def signature(self) -> Signature:
return Signature.build_from_dtypes(x=self.dtype, y=self.dtype, target=QBit())

@cached_property
def bitsize(self) -> SymbolicInt:
return self.dtype.num_qubits

def is_symbolic(self):
return is_symbolic(self.dtype)

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> WireSymbol:
if reg is None:
return Text("")
if reg.name == 'x':
return TextBox("In(x)")
if reg.name == 'y':
return TextBox("In(y)")
elif reg.name == 'target':
return TextBox("⨁(x = y)")
raise ValueError(f'Unknown register symbol {reg.name}')

def build_composite_bloq(
self, bb: 'BloqBuilder', x: 'Soquet', y: 'Soquet', target: 'Soquet'
) -> Dict[str, 'SoquetT']:
cvs: Union[list[int], HasLength]
if isinstance(self.bitsize, int):
cvs = [0] * self.bitsize
else:
cvs = HasLength(self.bitsize)

x, y = bb.add(Xor(self.dtype), x=x, y=y)
y_split = bb.split(y)
y_split, target = bb.add(MultiControlX(cvs=cvs), controls=y_split, target=target)
y = bb.join(y_split, self.dtype)
x, y = bb.add(Xor(self.dtype), x=x, y=y)

return {'x': x, 'y': y, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
cvs: Union[list[int], HasLength]
if isinstance(self.bitsize, int):
cvs = [0] * self.bitsize
else:
cvs = HasLength(self.bitsize)
return {
Xor(self.dtype): 2,
MultiControlX(cvs=cvs): 1,
MultiAnd(cvs=cvs).adjoint(): 1,
CNOT(): 1,
}

def on_classical_vals(self, x: int, y: int, target: int) -> Dict[str, 'ClassicalValT']:
return {'x': x, 'y': y, 'target': target ^ (x == y)}


@bloq_example
def _equals() -> Equals:
equals = Equals(QUInt(4))
return equals


_EQUALS_DOC = BloqDocSpec(bloq_cls=Equals, examples=[_equals])


@frozen
class EqualsAConstant(Bloq):
r"""Implements $U_a|x\rangle|z\rangle = |x\rangle |z \oplus (x = a)\rangle$
Expand Down
16 changes: 15 additions & 1 deletion qualtran/bloqs/arithmetic/comparison_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
import sympy

import qualtran.testing as qlt_testing
from qualtran import BloqBuilder, QInt, QMontgomeryUInt, QUInt
from qualtran import BloqBuilder, QBit, QInt, QMontgomeryUInt, QUInt
from qualtran.bloqs.arithmetic.comparison import (
_clineardepthgreaterthan_example,
_eq_k,
_equals,
_greater_than,
_gt_k,
_leq_symb,
_lt_k_symb,
BiQubitsMixer,
CLinearDepthGreaterThan,
Equals,
EqualsAConstant,
GreaterThan,
GreaterThanConstant,
Expand Down Expand Up @@ -63,6 +65,10 @@ def test_leq_symb(bloq_autotester):
bloq_autotester(_leq_symb)


def test_equals(bloq_autotester):
bloq_autotester(_equals)


def test_eq_k(bloq_autotester):
bloq_autotester(_eq_k)

Expand Down Expand Up @@ -296,6 +302,14 @@ def test_greater_than_constant():
)


@pytest.mark.parametrize('dtype', [QBit(), QUInt(2), QInt(3), QMontgomeryUInt(4), QUInt(5)])
def test_classical_equals(dtype):
bloq = Equals(dtype)
qlt_testing.assert_consistent_classical_action(
bloq, x=dtype.get_classical_domain(), y=dtype.get_classical_domain(), target=range(2)
)


def test_equals_a_constant():
bb = BloqBuilder()
bitsize = 5
Expand Down
1 change: 1 addition & 0 deletions qualtran/serialization/resolver_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
"qualtran.bloqs.arithmetic.bitwise.Xor": qualtran.bloqs.arithmetic.bitwise.Xor,
"qualtran.bloqs.arithmetic.bitwise.XorK": qualtran.bloqs.arithmetic.bitwise.XorK,
"qualtran.bloqs.arithmetic.comparison.BiQubitsMixer": qualtran.bloqs.arithmetic.comparison.BiQubitsMixer,
"qualtran.bloqs.arithmetic.comparison.Equals": qualtran.bloqs.arithmetic.comparison.Equals,
"qualtran.bloqs.arithmetic.comparison.EqualsAConstant": qualtran.bloqs.arithmetic.comparison.EqualsAConstant,
"qualtran.bloqs.arithmetic.comparison.GreaterThan": qualtran.bloqs.arithmetic.comparison.GreaterThan,
"qualtran.bloqs.arithmetic.comparison.GreaterThanConstant": qualtran.bloqs.arithmetic.comparison.GreaterThanConstant,
Expand Down

0 comments on commit a1df089

Please sign in to comment.