diff --git a/qualtran/bloqs/arithmetic.py b/qualtran/bloqs/arithmetic.py index 5c2bd5bd1..c47406b7b 100644 --- a/qualtran/bloqs/arithmetic.py +++ b/qualtran/bloqs/arithmetic.py @@ -13,15 +13,29 @@ # limitations under the License. from functools import cached_property -from typing import Dict, Optional, Set, Tuple, TYPE_CHECKING, Union - +from typing import ( + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import cirq import sympy -from attrs import frozen -from cirq_ft.algos.arithmetic_gates import LessThanEqualGate, LessThanGate +from attrs import field, frozen +from numpy.typing import NDArray -from qualtran import Bloq, Register, Side, Signature +from qualtran import Bloq, GateWithRegisters, Register, Side, Signature +from qualtran.bloqs.and_bloq import And, MultiAnd from qualtran.bloqs.basic_gates import TGate from qualtran.bloqs.util_bloqs import ArbitraryClifford +from qualtran.cirq_interop.bit_tools import iter_bits from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity if TYPE_CHECKING: @@ -30,14 +44,427 @@ @frozen -class Add(Bloq): +class LessThanConstant(GateWithRegisters, cirq.ArithmeticGate): + """Applies U_a|x>|z> = |x> |z ^ (x < a)>""" + + bitsize: int + less_than_val: int + + @cached_property + def signature(self) -> Signature: + return Signature.build(x=self.bitsize, target=1) + + def pretty_name(self) -> str: + return f'x < {self.less_than_val}' + + def registers(self) -> Sequence[Union[int, Sequence[int]]]: + return [2] * self.bitsize, self.less_than_val, [2] + + def with_registers(self, *new_registers) -> "LessThanConstant": + return LessThanConstant(len(new_registers[0]), new_registers[1]) + + def apply(self, *register_vals: int) -> Union[int, Iterable[int]]: + input_val, less_than_val, target_register_val = register_vals + return input_val, less_than_val, target_register_val ^ (input_val < less_than_val) + + def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: + wire_symbols = ["In(x)"] * self.bitsize + wire_symbols += [f'+(x < {self.less_than_val})'] + return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) + + def __pow__(self, power: int): + if power in [1, -1]: + return self + return NotImplemented # pragma: no cover + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + """Decomposes the gate into 4N And and And† operations for a T complexity of 4N. + + The decomposition proceeds from the most significant qubit -bit 0- to the least significant + qubit while maintaining whether the qubit sequence is equal to the current prefix of the + `_val` or not. + + The bare-bone logic is: + 1. if ith bit of `_val` is 1 then: + - qubit sequence < `_val` iff they are equal so far and the current qubit is 0. + 2. update `are_equal`: `are_equal := are_equal and (ith bit == ith qubit).` + + This logic is implemented using $n$ `And` & `And†` operations and n+1 clean ancilla where + - one ancilla `are_equal` contains the equality informaiton + - ancilla[i] contain whether the qubits[:i+1] != (i+1)th prefix of `_val` + """ + qubits, (target,) = quregs['x'], quregs['target'] + # Trivial case, self._val is larger than any value the registers could represent + if self.less_than_val >= 2**self.bitsize: + yield cirq.X(target) + return + adjoint = [] + + (are_equal,) = context.qubit_manager.qalloc(1) + + # Initially our belief is that the numbers are equal. + yield cirq.X(are_equal) + adjoint.append(cirq.X(are_equal)) + + # Scan from left to right. + # `are_equal` contains whether the numbers are equal so far. + ancilla = context.qubit_manager.qalloc(self.bitsize) + for b, q, a in zip(iter_bits(self.less_than_val, self.bitsize), qubits, ancilla): + if b: + yield cirq.X(q) + adjoint.append(cirq.X(q)) + + # ancilla[i] = are_equal so far and (q_i != _val[i]). + # = equivalent to: Is the current prefix of qubits < prefix of `_val`? + yield And().on(q, are_equal, a) + adjoint.append(And(adjoint=True).on(q, are_equal, a)) + + # target ^= is the current prefix of the qubit sequence < current prefix of `_val` + yield cirq.CNOT(a, target) + + # If `a=1` (i.e. the current prefixes aren't equal) this means that + # `are_equal` is currently = 1 and q[i] != _val[i] so we need to flip `are_equal`. + yield cirq.CNOT(a, are_equal) + adjoint.append(cirq.CNOT(a, are_equal)) + else: + # ancilla[i] = are_equal so far and (q = 1). + yield And().on(q, are_equal, a) + adjoint.append(And(adjoint=True).on(q, are_equal, a)) + + # if `a=1` then we need to flip `are_equal` since this means that are_equal=1, + # b_i=0, q_i=1 => current prefixes are not equal so we need to flip `are_equal`. + yield cirq.CNOT(a, are_equal) + adjoint.append(cirq.CNOT(a, are_equal)) + + yield from reversed(adjoint) + + def _has_unitary_(self): + return True + + def _t_complexity_(self) -> TComplexity: + n = self.bitsize + if self.less_than_val >= 2**n: + return TComplexity(clifford=1) + return TComplexity(t=4 * n, clifford=15 * n + 3 * bin(self.less_than_val).count("1") + 2) + + +@frozen +class BiQubitsMixer(GateWithRegisters): + """Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + + This gates mixes the values in a way that preserves the result of comparison. + The signature being compared are 2-qubit signature where + x = 2*x_msb + x_lsb + y = 2*y_msb + y_lsb + The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb' + are the final values of x_lsb' and y_lsb'. + + Note that the ancilla qubits are used to reduce the T-count and the user + should clean the qubits at a later point in time with the adjoint gate. + See: https://github.com/quantumlib/Cirq/pull/6313 and + https://github.com/quantumlib/Qualtran/issues/389 + """ # pylint: disable=line-too-long + + adjoint: bool = False + + @cached_property + def signature(self) -> Signature: + one_side = Side.RIGHT if not self.adjoint else Side.LEFT + return Signature( + [Register('x', 2), Register('y', 2), Register('ancilla', 3, side=one_side)] + ) + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla'] + x_msb, x_lsb = x + y_msb, y_lsb = y + + def _cswap(control: cirq.Qid, a: cirq.Qid, b: cirq.Qid, aux: cirq.Qid) -> cirq.OP_TREE: + """A CSWAP with 4T complexity and whose adjoint has 0T complexity. + + A controlled SWAP that swaps `a` and `b` based on `control`. + It uses an extra qubit `aux` so that its adjoint would have + a T complexity of zero. + """ + yield cirq.CNOT(a, b) + yield And(adjoint=self.adjoint).on(control, b, aux) + yield cirq.CNOT(aux, a) + yield cirq.CNOT(a, b) + + def _decomposition(): + # computes the difference of x - y where + # x = 2*x_msb + x_lsb + # y = 2*y_msb + y_lsb + # And stores the result in x_lsb and y_lsb such that + # sign(x - y) = sign(x_lsb - y_lsb) + # This decomposition uses 3 ancilla qubits in order to have a + # T complexity of 8. + yield cirq.X(ancilla[0]) + yield cirq.CNOT(y_msb, x_msb) + yield cirq.CNOT(y_lsb, x_lsb) + yield from _cswap(x_msb, x_lsb, ancilla[0], ancilla[1]) + yield from _cswap(x_msb, y_msb, y_lsb, ancilla[2]) + yield cirq.CNOT(y_lsb, x_lsb) + + if self.adjoint: + yield from reversed(tuple(cirq.flatten_to_ops(_decomposition()))) + else: + yield from _decomposition() + + def __pow__(self, power: int) -> cirq.Gate: + if power == 1: + return self + if power == -1: + return BiQubitsMixer(adjoint=not self.adjoint) + return NotImplemented # pragma: no cover + + def _t_complexity_(self) -> TComplexity: + if self.adjoint: + return TComplexity(clifford=18) + return TComplexity(t=8, clifford=28) + + def _has_unitary_(self): + return not self.adjoint + + +@frozen +class SingleQubitCompare(GateWithRegisters): + """Applies U|a>|b>|0>|0> = |a> |a=b> |(a |(a>b)> + + Source: (FIG. 3) in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + """ # pylint: disable=line-too-long + + adjoint: bool = False + + @cached_property + def signature(self) -> Signature: + one_side = Side.RIGHT if not self.adjoint else Side.LEFT + return Signature( + [ + Register('a', 1), + Register('b', 1), + Register('less_than', 1, side=one_side), + Register('greater_than', 1, side=one_side), + ] + ) + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + a = quregs['a'] + b = quregs['b'] + less_than = quregs['less_than'] + greater_than = quregs['greater_than'] + + def _decomposition() -> Iterator[cirq.Operation]: + yield And(0, 1, adjoint=self.adjoint).on(*a, *b, *less_than) + yield cirq.CNOT(*less_than, *greater_than) + yield cirq.CNOT(*b, *greater_than) + yield cirq.CNOT(*a, *b) + yield cirq.CNOT(*a, *greater_than) + yield cirq.X(*b) + + if self.adjoint: + yield from reversed(tuple(_decomposition())) + else: + yield from _decomposition() + + def __pow__(self, power: int) -> cirq.Gate: + if not isinstance(power, int): + raise ValueError('SingleQubitCompare is only defined for integer powers.') + if power % 2 == 0: + return cirq.IdentityGate(4) + if power < 0: + return SingleQubitCompare(adjoint=not self.adjoint) + return self + + def _t_complexity_(self) -> TComplexity: + if self.adjoint: + return TComplexity(clifford=11) + return TComplexity(t=4, clifford=16) + + +def _equality_with_zero( + context: cirq.DecompositionContext, qubits: Sequence[cirq.Qid], z: cirq.Qid +) -> cirq.OP_TREE: + if len(qubits) == 1: + (q,) = qubits + yield cirq.X(q) + yield cirq.CNOT(q, z) + return + if len(qubits) == 2: + yield And(0, 0).on(*qubits, z) + else: + ancilla = context.qubit_manager.qalloc(len(qubits) - 2) + yield MultiAnd(cvs=[0] * len(qubits)).on(*qubits, *ancilla, z) + + +@frozen +class LessThanEqual(GateWithRegisters, cirq.ArithmeticGate): + """Applies U|x>|y>|z> = |x>|y> |z ^ (x <= y)>""" + + x_bitsize: int + y_bitsize: int + + @cached_property + def signature(self) -> 'Signature': + return Signature.build(x=self.x_bitsize, y=self.y_bitsize, target=1) + + def registers(self) -> Sequence[Union[int, Sequence[int]]]: + return [2] * self.x_bitsize, [2] * self.y_bitsize, [2] + + def with_registers(self, *new_registers) -> "LessThanEqual": + return LessThanEqual(len(new_registers[0]), len(new_registers[1])) + + def apply(self, *register_vals: int) -> Union[int, int, Iterable[int]]: + x_val, y_val, target_val = register_vals + return x_val, y_val, target_val ^ (x_val <= y_val) + + def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: + wire_symbols = ["In(x)"] * self.x_bitsize + wire_symbols += ["In(y)"] * self.y_bitsize + wire_symbols += ['+(x <= y)'] + return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) + + def __pow__(self, power: int): + if power in [1, -1]: + return self + return NotImplemented # pragma: no cover + + def _decompose_via_tree( + self, context: cirq.DecompositionContext, X: Sequence[cirq.Qid], Y: Sequence[cirq.Qid] + ) -> cirq.OP_TREE: + """Returns comparison oracle from https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + + This decomposition follows the tree structure of (FIG. 2) + """ # pylint: disable=line-too-long + if len(X) == 1: + return + if len(X) == 2: + yield BiQubitsMixer().on_registers(x=X, y=Y, ancilla=context.qubit_manager.qalloc(3)) + return + + m = len(X) // 2 + yield self._decompose_via_tree(context, X[:m], Y[:m]) + yield self._decompose_via_tree(context, X[m:], Y[m:]) + yield BiQubitsMixer().on_registers( + x=(X[m - 1], X[-1]), y=(Y[m - 1], Y[-1]), ancilla=context.qubit_manager.qalloc(3) + ) + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + """Decomposes the gate in a T-complexity optimal way. + + The construction can be broken in 4 parts: + 1. In case of differing bitsizes then a multicontrol And Gate + - Section III.A. https://arxiv.org/abs/1805.03662) is used to check whether + the extra prefix is equal to zero: + - result stored in: `prefix_equality` qubit. + 2. The tree structure (FIG. 2) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf + followed by a SingleQubitCompare to compute the result of comparison of + the suffixes of equal length: + - result stored in: `less_than` and `greater_than` with equality in qubits[-2] + 3. The results from the previous two steps are combined to update the target qubit. + 4. The adjoint of the previous operations is added to restore the input qubits + to their original state and clean the ancilla qubits. + """ # pylint: disable=line-too-long + lhs, rhs, (target,) = quregs['x'], quregs['y'], quregs['target'] + + n = min(len(lhs), len(rhs)) + + prefix_equality = None + adjoint: List[cirq.Operation] = [] + + # if one of the registers is longer than the other store equality with |0--0> + # into `prefix_equality` using d = |len(P) - len(Q)| And operations => 4d T. + if len(lhs) != len(rhs): + (prefix_equality,) = context.qubit_manager.qalloc(1) + if len(lhs) > len(rhs): + for op in cirq.flatten_to_ops( + _equality_with_zero(context, lhs[:-n], prefix_equality) + ): + yield op + adjoint.append(cirq.inverse(op)) + else: + for op in cirq.flatten_to_ops( + _equality_with_zero(context, rhs[:-n], prefix_equality) + ): + yield op + adjoint.append(cirq.inverse(op)) + + yield cirq.X(target), cirq.CNOT(prefix_equality, target) + + # compare the remaining suffix of P and Q + lhs = lhs[-n:] + rhs = rhs[-n:] + for op in cirq.flatten_to_ops(self._decompose_via_tree(context, lhs, rhs)): + yield op + adjoint.append(cirq.inverse(op)) + + less_than, greater_than = context.qubit_manager.qalloc(2) + yield SingleQubitCompare().on_registers( + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than + ) + adjoint.append( + SingleQubitCompare(adjoint=True).on_registers( + a=lhs[-1], b=rhs[-1], less_than=less_than, greater_than=greater_than + ) + ) + + if prefix_equality is None: + yield cirq.X(target) + yield cirq.CNOT(greater_than, target) + else: + (less_than_or_equal,) = context.qubit_manager.qalloc(1) + yield And(1, 0).on(prefix_equality, greater_than, less_than_or_equal) + adjoint.append( + And(1, 0, adjoint=True).on(prefix_equality, greater_than, less_than_or_equal) + ) + + yield cirq.CNOT(less_than_or_equal, target) + + yield from reversed(adjoint) + + def _t_complexity_(self) -> TComplexity: + n = min(self.x_bitsize, self.y_bitsize) + d = max(self.x_bitsize, self.y_bitsize) - n + is_second_longer = self.y_bitsize > self.x_bitsize + if d == 0: + # When both registers are of the same size the T complexity is + # 8n - 4 same as in https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf. pylint: disable=line-too-long + return TComplexity(t=8 * n - 4, clifford=46 * n - 17) + else: + # When the registers differ in size and `n` is the size of the smaller one and + # `d` is the difference in size. The T complexity is the sum of the tree + # decomposition as before giving 8n + O(1) and the T complexity of an `And` gate + # over `d` registers giving 4d + O(1) totaling 8n + 4d + O(1). + # From the decomposition we get that the constant is -4 as well as the clifford counts. + if d == 1: + return TComplexity(t=8 * n, clifford=46 * n + 3 + 2 * is_second_longer) + else: + return TComplexity( + t=8 * n + 4 * d - 4, clifford=46 * n + 17 * d - 14 + 2 * is_second_longer + ) + + def _has_unitary_(self): + return True + + +@frozen +class Add(GateWithRegisters, cirq.ArithmeticGate): r"""An n-bit addition gate. Implements $U|a\rangle|b\rangle \rightarrow |a\rangle|a+b\rangle$ using $4n - 4 T$ gates. Args: bitsize: Number of bits used to represent each integer. Must be large - enough to hold the result in the output register of a + b. + enough to hold the result in the output register of a + b, or else it simply + drops the most significant bits. Registers: a: A bitsize-sized input register (register a above). @@ -53,9 +480,65 @@ class Add(Bloq): def signature(self): return Signature.build(a=self.bitsize, b=self.bitsize) + def registers(self) -> Sequence[Union[int, Sequence[int]]]: + return [2] * self.bitsize, [2] * self.bitsize + + def with_registers(self, *new_registers) -> 'Add': + return Add(len(new_registers[0])) + + def apply(self, *register_values: int) -> Union[int, Iterable[int]]: + p, q = register_values + return p, p + q + def pretty_name(self) -> str: return "a + b" + def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: + wire_symbols = ["In(x)"] * self.bitsize + wire_symbols += ["In(y)/Out(x+y)"] * self.bitsize + return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) + + def _has_unitary_(self): + return True + + def _left_building_block(self, inp, out, anc, depth): + if depth == self.bitsize - 1: + return + else: + yield cirq.CX(anc[depth - 1], inp[depth]) + yield cirq.CX(anc[depth - 1], out[depth]) + yield And().on(inp[depth], out[depth], anc[depth]) + yield cirq.CX(anc[depth - 1], anc[depth]) + yield from self._left_building_block(inp, out, anc, depth + 1) + + def _right_building_block(self, inp, out, anc, depth): + if depth == 0: + return + else: + yield cirq.CX(anc[depth - 1], anc[depth]) + yield And(adjoint=True).on(inp[depth], out[depth], anc[depth]) + yield cirq.CX(anc[depth - 1], inp[depth]) + yield cirq.CX(inp[depth], out[depth]) + yield from self._right_building_block(inp, out, anc, depth - 1) + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + input_bits = quregs['a'] + output_bits = quregs['b'] + ancillas = context.qubit_manager.qalloc(self.bitsize - 1) + # Start off the addition by anding into the ancilla + yield And().on(input_bits[0], output_bits[0], ancillas[0]) + # Left part of Fig.2 + yield from self._left_building_block(input_bits, output_bits, ancillas, 1) + yield cirq.CX(ancillas[-1], output_bits[-1]) + yield cirq.CX(input_bits[-1], output_bits[-1]) + # right part of Fig.2 + yield from self._right_building_block(input_bits, output_bits, ancillas, self.bitsize - 2) + yield And(adjoint=True).on(input_bits[0], output_bits[0], ancillas[0]) + yield cirq.CX(input_bits[0], output_bits[0]) + context.qubit_manager.qfree(ancillas) + def t_complexity(self): num_clifford = (self.bitsize - 2) * 19 + 16 num_t_gates = 4 * self.bitsize - 4 @@ -107,6 +590,72 @@ def bloq_counts(self, ssa: Optional['SympySymbolAllocator'] = None) -> Set[Tuple return {(1, Add(self.bitsize)), (self.bitsize, ArbitraryClifford(n=2))} +@frozen(auto_attribs=True) +class AddConstantMod(GateWithRegisters, cirq.ArithmeticGate): + """Applies U_{M}_{add}|x> = |(x + add) % M> if x < M else |x>. + + Applies modular addition to input register `|x>` given parameters `mod` and `add_val` s.t. + 1) If integer `x` < `mod`: output is `|(x + add) % M>` + 2) If integer `x` >= `mod`: output is `|x>`. + + This condition is needed to ensure that the mapping of all input basis states (i.e. input + states |0>, |1>, ..., |2 ** bitsize - 1) to corresponding output states is bijective and thus + the gate is reversible. + + Also supports controlled version of the gate by specifying a per qubit control value as a tuple + of integers passed as `cvs`. + """ + + bitsize: int + mod: int = field() + add_val: int = 1 + cvs: Tuple[int, ...] = field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) + + @mod.validator + def _validate_mod(self, attribute, value): + if not 1 <= value <= 2**self.bitsize: + raise ValueError(f"mod: {value} must be between [1, {2 ** self.bitsize}].") + + @cached_property + def signature(self) -> Signature: + if self.cvs: + return Signature.build(ctrl=len(self.cvs), x=self.bitsize) + return Signature.build(x=self.bitsize) + + def registers(self) -> Sequence[Union[int, Sequence[int]]]: + add_reg = (2,) * self.bitsize + control_reg = (2,) * len(self.cvs) + return (control_reg, add_reg) if control_reg else (add_reg,) + + def with_registers(self, *new_registers: Union[int, Sequence[int]]) -> "AddMod": + raise NotImplementedError() + + def apply(self, *args) -> Union[int, Iterable[int]]: + target_val = args[-1] + if target_val < self.mod: + new_target_val = (target_val + self.add_val) % self.mod + else: + new_target_val = target_val + if self.cvs and args[0] != int(''.join(str(x) for x in self.cvs), 2): + new_target_val = target_val + ret = (args[0], new_target_val) if self.cvs else (new_target_val,) + return ret + + def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: + wire_symbols = ['@' if b else '@(0)' for b in self.cvs] + wire_symbols += [f"Add_{self.add_val}_Mod_{self.mod}"] * self.bitsize + return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) + + def __pow__(self, power: int) -> 'AddConstantMod': + return AddConstantMod(self.bitsize, self.mod, add_val=self.add_val * power, cvs=self.cvs) + + def _t_complexity_(self) -> TComplexity: + # Rough cost as given in https://arxiv.org/abs/1905.09749 + return 5 * Add(self.bitsize).t_complexity() + + @frozen class Square(Bloq): r"""Square an n-bit binary number. @@ -448,7 +997,7 @@ def pretty_name(self) -> str: return "a gt b" def t_complexity(self) -> 'TComplexity': - return t_complexity(LessThanEqualGate(self.bitsize, self.bitsize)) + return t_complexity(LessThanEqual(self.bitsize, self.bitsize)) def bloq_counts( self, ssa: Optional['SympySymbolAllocator'] = None @@ -488,7 +1037,7 @@ def signature(self) -> Signature: return Signature.build(x=self.bitsize, target=1) def t_complexity(self) -> TComplexity: - return t_complexity(LessThanGate(self.bitsize, val=self.val)) + return t_complexity(LessThanConstant(self.bitsize, val=self.val)) def bloq_counts( self, ssa: Optional['SympySymbolAllocator'] = None @@ -577,7 +1126,7 @@ def on_classical_vals( ) -> Dict[str, 'ClassicalValT']: return {'mu': mu, 'nu': nu, 's': nu * (nu + 1) // 2 + mu} - def t_complexity(self) -> 'cirq_ft.TComplexity': + def t_complexity(self) -> 'TComplexity': num_toffoli = self.bitsize**2 + self.bitsize - 1 return TComplexity(t=4 * num_toffoli) diff --git a/qualtran/bloqs/arithmetic_test.py b/qualtran/bloqs/arithmetic_test.py index 0cf0b16d1..ae54379cb 100644 --- a/qualtran/bloqs/arithmetic_test.py +++ b/qualtran/bloqs/arithmetic_test.py @@ -11,13 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools + +import cirq +import numpy as np +import pytest from qualtran import BloqBuilder, Register +from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.arithmetic import ( Add, + AddConstantMod, EqualsAConstant, GreaterThan, GreaterThanConstant, + LessThanConstant, + LessThanEqual, MultiplyTwoReals, OutOfPlaceAdder, Product, @@ -27,7 +36,12 @@ SumOfSquares, ToContiguousIndex, ) -from qualtran.testing import execute_notebook +from qualtran.cirq_interop.bit_tools import iter_bits, iter_bits_twos_complement +from qualtran.cirq_interop.testing import ( + assert_circuit_inp_out_cirqsim, + assert_decompose_is_consistent_with_t_complexity, +) +from qualtran.testing import assert_valid_bloq_decomposition, execute_notebook def _make_add(): @@ -96,6 +110,240 @@ def _make_square_real_number(): return SquareRealNumber(bitsize=10) +def identity_map(n: int): + """Returns a dict of size `2**n` mapping each integer in range [0, 2**n) to itself.""" + return {i: i for i in range(2**n)} + + +def test_less_than_gate(): + qubits = cirq.LineQubit.range(4) + gate = LessThanConstant(3, 5) + op = gate.on(*qubits) + circuit = cirq.Circuit(op) + basis_map = { + 0b_000_0: 0b_000_1, + 0b_000_1: 0b_000_0, + 0b_001_0: 0b_001_1, + 0b_001_1: 0b_001_0, + 0b_010_0: 0b_010_1, + 0b_010_1: 0b_010_0, + 0b_011_0: 0b_011_1, + 0b_011_1: 0b_011_0, + 0b_100_0: 0b_100_1, + 0b_100_1: 0b_100_0, + 0b_101_0: 0b_101_0, + 0b_101_1: 0b_101_1, + 0b_110_0: 0b_110_0, + 0b_110_1: 0b_110_1, + 0b_111_0: 0b_111_0, + 0b_111_1: 0b_111_1, + } + cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit) + circuit += op**-1 + cirq.testing.assert_equivalent_computational_basis_map(identity_map(len(qubits)), circuit) + gate2 = LessThanConstant(4, 10) + assert gate.with_registers(*gate2.registers()) == gate2 + assert cirq.circuit_diagram_info(gate).wire_symbols == ("In(x)",) * 3 + ("+(x < 5)",) + assert (gate**1 is gate) and (gate**-1 is gate) + assert gate.__pow__(2) is NotImplemented + + +@pytest.mark.parametrize("bits", [*range(8)]) +@pytest.mark.parametrize("val", [3, 5, 7, 8, 9]) +def test_decompose_less_than_gate(bits: int, val: int): + qubit_states = list(iter_bits(bits, 3)) + circuit = cirq.Circuit( + cirq.decompose_once( + LessThanConstant(3, val).on_registers(x=cirq.LineQubit.range(3), target=cirq.q(4)) + ) + ) + if val < 8: + initial_state = [0] * 4 + qubit_states + [0] + output_state = [0] * 4 + qubit_states + [int(bits < val)] + else: + # When val >= 2**number_qubits the decomposition doesn't create any ancilla since the + # answer is always 1. + initial_state = [0] + output_state = [1] + assert_circuit_inp_out_cirqsim( + circuit, sorted(circuit.all_qubits()), initial_state, output_state + ) + + +@pytest.mark.parametrize("n", [*range(2, 5)]) +@pytest.mark.parametrize("val", [3, 4, 5, 7, 8, 9]) +def test_less_than_consistent_protocols(n: int, val: int): + g = LessThanConstant(n, val) + assert_decompose_is_consistent_with_t_complexity(g) + # Test the unitary is self-inverse + u = cirq.unitary(g) + np.testing.assert_allclose(u @ u, np.eye(2 ** (n + 1))) + assert_valid_bloq_decomposition(g) + + +def test_multi_in_less_equal_than_gate(): + qubits = cirq.LineQubit.range(7) + op = LessThanEqual(3, 3).on_registers(x=qubits[:3], y=qubits[3:6], target=qubits[-1]) + circuit = cirq.Circuit(op) + basis_map = {} + for in1, in2 in itertools.product(range(2**3), repeat=2): + for target_reg_val in range(2): + target_bin = bin(target_reg_val)[2:] + in1_bin = format(in1, '03b') + in2_bin = format(in2, '03b') + out_bin = bin(target_reg_val ^ (in1 <= in2))[2:] + true_out_int = target_reg_val ^ (in1 <= in2) + input_int = int(in1_bin + in2_bin + target_bin, 2) + output_int = int(in1_bin + in2_bin + out_bin, 2) + assert true_out_int == int(out_bin, 2) + basis_map[input_int] = output_int + + cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit) + circuit += op**-1 + cirq.testing.assert_equivalent_computational_basis_map(identity_map(len(qubits)), circuit) + + +@pytest.mark.parametrize("x_bitsize", [*range(1, 5)]) +@pytest.mark.parametrize("y_bitsize", [*range(1, 5)]) +def test_less_than_equal_consistent_protocols(x_bitsize: int, y_bitsize: int): + g = LessThanEqual(x_bitsize, y_bitsize) + assert_decompose_is_consistent_with_t_complexity(g) + assert_valid_bloq_decomposition(g) + + # Decomposition works even when context is None. + qubits = cirq.LineQid.range(x_bitsize + y_bitsize + 1, dimension=2) + assert cirq.Circuit(g._decompose_with_context_(qubits=qubits)) == cirq.Circuit( + cirq.decompose_once( + g.on(*qubits), context=cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) + ) + ) + + # Test the unitary is self-inverse + assert g**-1 is g + u = cirq.unitary(g) + np.testing.assert_allclose(u @ u, np.eye(2 ** (x_bitsize + y_bitsize + 1))) + # Test diagrams + expected_wire_symbols = ("In(x)",) * x_bitsize + ("In(y)",) * y_bitsize + ("+(x <= y)",) + assert cirq.circuit_diagram_info(g).wire_symbols == expected_wire_symbols + # Test with_registers + assert g.with_registers([2] * 4, [2] * 5, [2]) == LessThanEqual(4, 5) + + +@pytest.mark.parametrize('a,b,num_bits', itertools.product(range(4), range(4), range(3, 5))) +def test_add_decomposition(a: int, b: int, num_bits: int): + num_anc = num_bits - 1 + gate = Add(num_bits) + qubits = cirq.LineQubit.range(2 * num_bits) + op = gate.on_registers(a=qubits[:num_bits], b=qubits[num_bits:]) + greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True) + context = cirq.DecompositionContext(greedy_mm) + circuit = cirq.Circuit(cirq.decompose_once(op, context=context)) + ancillas = sorted(circuit.all_qubits())[-num_anc:] + initial_state = [0] * (2 * num_bits + num_anc) + initial_state[:num_bits] = list(iter_bits(a, num_bits))[::-1] + initial_state[num_bits : 2 * num_bits] = list(iter_bits(b, num_bits))[::-1] + final_state = [0] * (2 * num_bits + num_bits - 1) + final_state[:num_bits] = list(iter_bits(a, num_bits))[::-1] + final_state[num_bits : 2 * num_bits] = list(iter_bits(a + b, num_bits))[::-1] + assert_circuit_inp_out_cirqsim(circuit, qubits + ancillas, initial_state, final_state) + # Test diagrams + expected_wire_symbols = ("In(x)",) * num_bits + ("In(y)/Out(x+y)",) * num_bits + assert cirq.circuit_diagram_info(gate).wire_symbols == expected_wire_symbols + # Test with_registers + assert gate.with_registers([2] * 6, [2] * 6) == Add(6) + + +def test_add_truncated(): + num_bits = 3 + num_anc = num_bits - 1 + gate = Add(num_bits) + qubits = cirq.LineQubit.range(2 * num_bits) + circuit = cirq.Circuit(cirq.decompose_once(gate.on(*qubits))) + ancillas = sorted(circuit.all_qubits() - frozenset(qubits)) + assert len(ancillas) == num_anc + all_qubits = qubits + ancillas + # Corresponds to 2^2 + 2^2 (4 + 4 = 8 = 2^3 (needs num_bits = 4 to work properly)) + initial_state = [0, 0, 1, 0, 0, 1, 0, 0] + # Should be 1000 (or 0001 below) but bit falls off the end + final_state = [0, 0, 1, 0, 0, 0, 0, 0] + # increasing number of bits yields correct value + assert_circuit_inp_out_cirqsim(circuit, all_qubits, initial_state, final_state) + + num_bits = 4 + num_anc = num_bits - 1 + gate = Add(num_bits) + qubits = cirq.LineQubit.range(2 * num_bits) + greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True) + context = cirq.DecompositionContext(greedy_mm) + circuit = cirq.Circuit(cirq.decompose_once(gate.on(*qubits), context=context)) + ancillas = sorted(circuit.all_qubits() - frozenset(qubits)) + assert len(ancillas) == num_anc + all_qubits = qubits + ancillas + initial_state = [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0] + final_state = [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0] + assert_circuit_inp_out_cirqsim(circuit, all_qubits, initial_state, final_state) + + num_bits = 3 + num_anc = num_bits - 1 + gate = Add(num_bits) + qubits = cirq.LineQubit.range(2 * num_bits) + greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True) + context = cirq.DecompositionContext(greedy_mm) + circuit = cirq.Circuit(cirq.decompose_once(gate.on(*qubits), context=context)) + ancillas = sorted(circuit.all_qubits() - frozenset(qubits)) + assert len(ancillas) == num_anc + all_qubits = qubits + ancillas + # Corresponds to 2^2 + (2^2 + 2^1 + 2^0) (4 + 7 = 11 = 1011 (need num_bits=4 to work properly)) + initial_state = [0, 0, 1, 1, 1, 1, 0, 0] + # Should be 1011 (or 1101 below) but last two bits are lost + final_state = [0, 0, 1, 1, 1, 0, 0, 0] + assert_circuit_inp_out_cirqsim(circuit, all_qubits, initial_state, final_state) + + +@pytest.mark.parametrize('a,b,num_bits', itertools.product(range(4), range(4), range(3, 5))) +def test_subtract(a, b, num_bits): + num_anc = num_bits - 1 + gate = Add(num_bits) + qubits = cirq.LineQubit.range(2 * num_bits) + greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True) + context = cirq.DecompositionContext(greedy_mm) + circuit = cirq.Circuit(cirq.decompose_once(gate.on(*qubits), context=context)) + ancillas = sorted(circuit.all_qubits())[-num_anc:] + initial_state = [0] * (2 * num_bits + num_anc) + initial_state[:num_bits] = list(iter_bits_twos_complement(a, num_bits))[::-1] + initial_state[num_bits : 2 * num_bits] = list(iter_bits_twos_complement(-b, num_bits))[::-1] + final_state = [0] * (2 * num_bits + num_bits - 1) + final_state[:num_bits] = list(iter_bits_twos_complement(a, num_bits))[::-1] + final_state[num_bits : 2 * num_bits] = list(iter_bits_twos_complement(a - b, num_bits))[::-1] + all_qubits = qubits + ancillas + assert_circuit_inp_out_cirqsim(circuit, all_qubits, initial_state, final_state) + + +@pytest.mark.parametrize("n", [*range(3, 10)]) +def test_addition_gate_t_complexity(n: int): + g = Add(n) + assert_decompose_is_consistent_with_t_complexity(g) + assert_valid_bloq_decomposition(g) + + +@pytest.mark.parametrize('a,b', itertools.product(range(2**3), repeat=2)) +def test_add_no_decompose(a, b): + num_bits = 5 + qubits = cirq.LineQubit.range(2 * num_bits) + op = Add(num_bits).on(*qubits) + circuit = cirq.Circuit(op) + basis_map = {} + a_bin = format(a, f'0{num_bits}b') + b_bin = format(b, f'0{num_bits}b') + out_bin = format(a + b, f'0{num_bits}b') + true_out_int = a + b + input_int = int(a_bin + b_bin, 2) + output_int = int(a_bin + out_bin, 2) + assert true_out_int == int(out_bin, 2) + basis_map[input_int] = output_int + cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit) + + def test_add(): bb = BloqBuilder() bitsize = 4 @@ -105,6 +353,47 @@ def test_add(): cbloq = bb.finalize(a=a, b=b) +@pytest.mark.parametrize('bitsize', [3]) +@pytest.mark.parametrize('mod', [5, 8]) +@pytest.mark.parametrize('add_val', [1, 2]) +@pytest.mark.parametrize('cvs', [[], [0, 1], [1, 0], [1, 1]]) +def test_add_mod_n(bitsize, mod, add_val, cvs): + gate = AddConstantMod(bitsize, mod, add_val=add_val, cvs=cvs) + basis_map = {} + num_cvs = len(cvs) + for x in range(2**bitsize): + y = (x + add_val) % mod if x < mod else x + if not num_cvs: + basis_map[x] = y + continue + for cb in range(2**num_cvs): + inp = f'0b_{cb:0{num_cvs}b}_{x:0{bitsize}b}' + if tuple(int(x) for x in f'{cb:0{num_cvs}b}') == tuple(cvs): + out = f'0b_{cb:0{num_cvs}b}_{y:0{bitsize}b}' + basis_map[int(inp, 2)] = int(out, 2) + else: + basis_map[int(inp, 2)] = int(inp, 2) + + op = gate.on_registers(**get_named_qubits(gate.signature)) + circuit = cirq.Circuit(op) + cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit) + circuit += op**-1 + cirq.testing.assert_equivalent_computational_basis_map(identity_map(gate.num_qubits()), circuit) + + +def test_add_mod_n_protocols(): + with pytest.raises(ValueError, match="must be between"): + _ = AddConstantMod(3, 10) + add_one = AddConstantMod(3, 5, 1) + add_two = AddConstantMod(3, 5, 2, cvs=[1, 0]) + + assert add_one == AddConstantMod(3, 5, 1) + assert add_one != add_two + assert hash(add_one) != hash(add_two) + assert add_two.cvs == (1, 0) + assert cirq.circuit_diagram_info(add_two).wire_symbols == ('@', '@(0)') + ('Add_2_Mod_5',) * 3 + + def test_out_of_place_adder(): bb = BloqBuilder() bitsize = 4