Skip to content

Commit

Permalink
Controlled-Addition implementation (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
skushnir123 authored Aug 12, 2024
1 parent a4365b2 commit 28487f6
Show file tree
Hide file tree
Showing 2 changed files with 305 additions and 0 deletions.
231 changes: 231 additions & 0 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import math
from typing import Any, Dict, Iterator, Set, TYPE_CHECKING, Union

import cirq
import numpy as np
import sympy
from attrs import field, frozen
from numpy.typing import NDArray

from qualtran import Bloq, CompositeBloq, QBit, QInt, QUInt, Register, Signature, Soquet, SoquetT
from qualtran._infra.data_types import QMontgomeryUInt
from qualtran.bloqs.basic_gates import CNOT
from qualtran.bloqs.mcmt import MultiControlX
from qualtran.bloqs.mcmt.and_bloq import And
from qualtran.cirq_interop import decompose_from_cirq_style_method
from qualtran.cirq_interop.t_complexity_protocol import TComplexity

if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT


@frozen
class CAdd(Bloq):
r"""An n-bit controlled-addition gate.
Args:
a_dtype: Quantum datatype used to represent the integer a.
b_dtype: Quantum datatype used to represent the integer b. Must be large
enough to hold the result in the output register of a + b, or else it simply
drops the most significant bits. If not specified, b_dtype is set to a_dtype.
cv: When controlled=0, this bloq is active when the ctrl register is 0. When
controlled=1, this bloq is active when the ctrl register is 1.
Registers:
ctrl: the control bit for the addition
a: A a_dtype.bitsize-sized input register (register a above).
b: A b_dtype.bitsize-sized input/output register (register b above).
References:
[Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648)
"""

a_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
b_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
cv: int = field(default=1)

@b_dtype.default
def b_dtype_default(self):
return self.a_dtype

@a_dtype.validator
def _a_dtype_validate(self, field, val):
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")
if isinstance(val.num_qubits, sympy.Expr):
return
if val.bitsize > self.b_dtype.bitsize:
raise ValueError("a_dtype bitsize must be less than or equal to b_dtype bitsize")

@b_dtype.validator
def _b_dtype_validate(self, field, val):
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")

@cv.validator
def _controlled_validate(self, field, val):
if val not in (0, 1):
raise ValueError("controlled must be either 0 or 1")

@property
def signature(self):
return Signature(
[Register("ctrl", QBit()), Register("a", self.a_dtype), Register("b", self.b_dtype)]
)

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
import quimb.tensor as qtn

if isinstance(self.a_dtype, QInt) or isinstance(self.b_dtype, QInt):
raise TypeError("Tensor contraction for addition is only supported for unsigned ints.")
N_a = 2**self.a_dtype.bitsize
N_b = 2**self.b_dtype.bitsize
inds = (
incoming['ctrl'],
incoming['a'],
incoming['b'],
outgoing['ctrl'],
outgoing['a'],
outgoing['b'],
)
unitary = np.zeros((2, N_a, N_b, 2, N_a, N_b), dtype=np.complex128)
for c, a, b in itertools.product(range(2), range(N_a), range(N_b)):
if c == self.cv:
unitary[c, a, b, c, a, int(math.fmod(a + b, N_b))] = 1
else:
unitary[c, a, b, c, a, b] = 1

tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.short_name(), tag]))

def decompose_bloq(self) -> 'CompositeBloq':
return decompose_from_cirq_style_method(self)

def on_classical_vals(self, **kwargs) -> Dict[str, 'ClassicalValT']:
a, b = kwargs['a'], kwargs['b']
unsigned = isinstance(self.a_dtype, (QUInt, QMontgomeryUInt))
b_bitsize = self.b_dtype.bitsize
N = 2**b_bitsize if unsigned else 2 ** (b_bitsize - 1)
ctrl = kwargs['ctrl']
if ctrl != self.cv:
return {'ctrl': ctrl, 'a': a, 'b': b}
else:
return {'ctrl': ctrl, 'a': a, 'b': int(math.fmod(a + b, N))}

def short_name(self) -> str:
return "a+b"

def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
from qualtran.drawing import directional_text_box

if soq.reg.name == 'ctrl':
return directional_text_box('ctrl', side=soq.reg.side)
if soq.reg.name == 'a':
return directional_text_box('a', side=soq.reg.side)
elif soq.reg.name == 'b':
return directional_text_box('a+b', side=soq.reg.side)
else:
raise ValueError()

def _left_building_block(self, inp, out, anc, depth):
if depth == self.b_dtype.bitsize - 1:
return
else:
if depth < 1:
raise ValueError(f"{depth=} is not a positive integer")
if depth < len(inp):
yield CNOT().on(anc[depth - 1], inp[depth])
control = inp[depth]
else:
# If inp[depth] doesn't exist, we treat it as a |0>,
# and therefore applying CNOT().on(anc[depth - 1], inp[depth])
# essentially "copies" anc[depth - 1] into inp[depth]
# in the classical basis. So therefore, on future operations,
# we can use anc[depth - 1] in its place.
control = anc[depth - 1]
yield CNOT().on(anc[depth - 1], out[depth])
yield And().on(control, out[depth], anc[depth])
yield CNOT().on(anc[depth - 1], anc[depth])
yield from self._left_building_block(inp, out, anc, depth + 1)

def _right_building_block(self, inp, out, anc, control, depth):
if depth == 0:
return
yield CNOT().on(anc[depth - 1], anc[depth])
if depth < len(inp):
yield And().adjoint().on(inp[depth], out[depth], anc[depth])
yield MultiControlX((1, 1)).on(control, inp[depth], out[depth])
yield CNOT().on(anc[depth - 1], inp[depth])
else:
yield And().adjoint().on(anc[depth - 1], out[depth], anc[depth])
yield MultiControlX((1, 1)).on(control, anc[depth - 1], out[depth])
yield CNOT().on(anc[depth - 1], out[depth])
yield from self._right_building_block(inp, out, anc, control, depth - 1)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> Iterator[cirq.OP_TREE]:
# reverse the order of qubits for big endian-ness.
input_bits = quregs['a'][::-1]
output_bits = quregs['b'][::-1]
ancillas = context.qubit_manager.qalloc(self.b_dtype.bitsize - 1)[::-1]
control = quregs['ctrl'][0]
if self.cv == 0:
yield cirq.X(control)
# Start off the addition by anding into the ancilla
yield And().on(input_bits[0], output_bits[0], ancillas[0])
# Left part of Fig.4
yield from self._left_building_block(input_bits, output_bits, ancillas, 1)
yield CNOT().on(ancillas[-1], output_bits[-1])
if len(input_bits) == len(output_bits):
yield MultiControlX((1, 1)).on(control, input_bits[-1], output_bits[-1])
yield CNOT().on(ancillas[-1], output_bits[-1])
# right part of Fig.4
yield from self._right_building_block(
input_bits, output_bits, ancillas, control, self.b_dtype.bitsize - 2
)
yield And().adjoint().on(input_bits[0], output_bits[0], ancillas[0])
yield MultiControlX((1, 1)).on(control, input_bits[0], output_bits[0])
if self.cv == 0:
yield cirq.X(control)
context.qubit_manager.qfree(ancillas)

def _t_complexity_(self):
n = self.b_dtype.bitsize
num_and = 2 * n - 1
num_clifford = 33 * (n - 2) + 43
return TComplexity(t=4 * num_and, clifford=num_clifford)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
n = self.b_dtype.bitsize
n_cnot = (n - 2) * 6 + 2
return {
(MultiControlX((1, 1)), n),
(And(), n - 1),
(And().adjoint(), n - 1),
(CNOT(), n_cnot),
}
74 changes: 74 additions & 0 deletions qualtran/bloqs/arithmetic/controlled_addition_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq
import numpy as np
import pytest

import qualtran.testing as qlt_testing
from qualtran import QUInt
from qualtran.bloqs.arithmetic.controlled_addition import CAdd
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim
from qualtran.resource_counting.generalizers import ignore_split_join


def iter_bits(x, w):
return [int(b) for b in np.binary_repr(x, width=w)]


@pytest.mark.parametrize('a', [1, 2])
@pytest.mark.parametrize('b', [1, 2, 3])
@pytest.mark.parametrize('num_bits_a', [2, 3])
@pytest.mark.parametrize('num_bits_b', [5])
@pytest.mark.parametrize('controlled_on', [0, 1])
@pytest.mark.parametrize('control', [0, 1])
def test_controlled_addition(a, b, num_bits_a, num_bits_b, controlled_on, control):
num_anc = num_bits_b - 1
gate = CAdd(QUInt(num_bits_a), QUInt(num_bits_b), cv=controlled_on)
qubits = cirq.LineQubit.range(num_bits_a + num_bits_b + 1)
op = gate.on_registers(ctrl=qubits[0], a=qubits[1 : num_bits_a + 1], b=qubits[num_bits_a + 1 :])
greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True)
context = cirq.DecompositionContext(greedy_mm)
circuit = cirq.Circuit(cirq.decompose_once(op, context=context))
circuit0 = cirq.Circuit(op)
ancillas = sorted(circuit.all_qubits())[-num_anc:]
initial_state = [0] * (num_bits_a + num_bits_b + num_anc + 1)
initial_state[0] = control
initial_state[1 : num_bits_a + 1] = list(iter_bits(a, num_bits_a))
initial_state[num_bits_a + 1 : num_bits_a + num_bits_b + 1] = list(iter_bits(b, num_bits_b))
final_state = [0] * (num_bits_a + num_bits_b + num_anc + 1)
final_state[0] = control
final_state[1 : num_bits_a + 1] = list(iter_bits(a, num_bits_a))
if control == controlled_on:
final_state[num_bits_a + 1 : num_bits_a + num_bits_b + 1] = list(
iter_bits(a + b, num_bits_b)
)
else:
final_state[num_bits_a + 1 : num_bits_a + num_bits_b + 1] = list(iter_bits(b, num_bits_b))
assert_circuit_inp_out_cirqsim(circuit, qubits + ancillas, initial_state, final_state)
assert_circuit_inp_out_cirqsim(
circuit0, qubits, initial_state[:-num_anc], final_state[:-num_anc]
)


@pytest.mark.parametrize("n", [*range(3, 10)])
def test_addition_gate_counts_controlled(n: int):
add = CAdd(QUInt(n), cv=1)
num_and = 2 * n - 1
t_count = 4 * num_and

qlt_testing.assert_valid_bloq_decomposition(add)
assert add.t_complexity() == add.decompose_bloq().t_complexity()
assert add.bloq_counts() == add.decompose_bloq().bloq_counts(generalizer=ignore_split_join)
assert add.t_complexity().t == t_count

0 comments on commit 28487f6

Please sign in to comment.