diff --git a/qualtran/bloqs/arithmetic.py b/qualtran/bloqs/arithmetic.py index 9882c4f6f..84abbe762 100644 --- a/qualtran/bloqs/arithmetic.py +++ b/qualtran/bloqs/arithmetic.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property +from typing import Dict, Tuple, Union + +import cirq from attrs import frozen +from cirq_ft import LessThanEqualGate as CirqLessThanEqual +from cirq_ft import LessThanGate as CirqLessThanGate from cirq_ft import TComplexity -from qualtran import Bloq, Register, Signature +from qualtran import Bloq, CompositeBloq, Register, Signature +from qualtran.cirq_interop import CirqQuregT, decompose_from_cirq_op @frozen @@ -212,3 +219,74 @@ def t_complexity(self): # See: https://github.com/quantumlib/cirq-qubitization/issues/219 # See: https://github.com/quantumlib/cirq-qubitization/issues/217 return TComplexity(t=8 * self.bitsize) + + +@frozen +class LessThanEqual(Bloq): + r"""Implements $U|x,y,z\rangle = |x, y, z \oplus {x \le y}\rangle$. + + Args: + x_bitsize: bitsize of x register. + y_bitsize: bitsize of y register. + + Registers: + - x, y: Registers to compare against eachother. + - z: Register to hold result of comparison. + """ + + x_bitsize: int + y_bitsize: int + + @cached_property + def signature(self) -> Signature: + return Signature( + [ + Register("x", bitsize=self.x_bitsize), + Register("y", bitsize=self.y_bitsize), + Register("z", bitsize=1), + ] + ) + + def decompose_bloq(self) -> 'CompositeBloq': + return decompose_from_cirq_op(self) + + def as_cirq_op( + self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' + ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: + less_than = CirqLessThanEqual(x_bitsize=self.x_bitsize, y_bitsize=self.y_bitsize) + x = cirq_quregs['x'] + y = cirq_quregs['y'] + z = cirq_quregs['z'] + return (less_than.on(*x, *y, *z), cirq_quregs) + + +@frozen +class LessThanConstant(Bloq): + r"""Implements $U_a|x\rangle = U_a|x\rangle|z\rangle = |x\rangle |z ^ (x < a)\rangle" + + Args: + bitsize: bitsize of x register. + val: integer to compare x against (a above.) + + Registers: + - x: Registers to compare against val. + - z: Register to hold result of comparison. + """ + + bitsize: int + val: int + + @cached_property + def signature(self) -> Signature: + return Signature.build(x=self.bitsize, z=1) + + def decompose_bloq(self) -> 'CompositeBloq': + return decompose_from_cirq_op(self) + + def as_cirq_op( + self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' + ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: + less_than = CirqLessThanGate(bitsize=self.bitsize, less_than_val=self.val) + x = cirq_quregs['x'] + z = cirq_quregs['z'] + return (less_than.on(*x, *z), cirq_quregs) diff --git a/qualtran/bloqs/arithmetic_test.py b/qualtran/bloqs/arithmetic_test.py index 6afdaead6..bb60208cf 100644 --- a/qualtran/bloqs/arithmetic_test.py +++ b/qualtran/bloqs/arithmetic_test.py @@ -12,8 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from cirq_ft.algos import LessThanEqualGate as CirqLessThanEquals +from cirq_ft.algos import LessThanGate as CirqLessThanConstant +from cirq_ft.infra import t_complexity + +import qualtran.testing as qlt_testing from qualtran import BloqBuilder, Register -from qualtran.bloqs.arithmetic import Add, GreaterThan, Product, Square, SumOfSquares +from qualtran.bloqs.arithmetic import ( + Add, + GreaterThan, + LessThanConstant, + LessThanEqual, + Product, + Square, + SumOfSquares, +) from qualtran.testing import execute_notebook @@ -96,5 +109,19 @@ def test_greater_than(): cbloq = bb.finalize(a=q0, b=q1, result=anc) +def test_less_than_equal(): + lte = LessThanEqual(5, 5) + qlt_testing.assert_valid_bloq_decomposition(lte) + cirq_lte = CirqLessThanEquals(5, 5) + assert lte.decompose_bloq().t_complexity() == t_complexity(cirq_lte) + + +def test_less_than_constant(): + ltc = LessThanConstant(5, 7) + qlt_testing.assert_valid_bloq_decomposition(ltc) + cirq_ltc = CirqLessThanConstant(5, 7) + assert ltc.decompose_bloq().t_complexity() == t_complexity(cirq_ltc) + + def test_notebook(): execute_notebook('arithmetic')