-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Controlled-Addition implementation (#864)
- Loading branch information
1 parent
a4365b2
commit 28487f6
Showing
2 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import itertools | ||
import math | ||
from typing import Any, Dict, Iterator, Set, TYPE_CHECKING, Union | ||
|
||
import cirq | ||
import numpy as np | ||
import sympy | ||
from attrs import field, frozen | ||
from numpy.typing import NDArray | ||
|
||
from qualtran import Bloq, CompositeBloq, QBit, QInt, QUInt, Register, Signature, Soquet, SoquetT | ||
from qualtran._infra.data_types import QMontgomeryUInt | ||
from qualtran.bloqs.basic_gates import CNOT | ||
from qualtran.bloqs.mcmt import MultiControlX | ||
from qualtran.bloqs.mcmt.and_bloq import And | ||
from qualtran.cirq_interop import decompose_from_cirq_style_method | ||
from qualtran.cirq_interop.t_complexity_protocol import TComplexity | ||
|
||
if TYPE_CHECKING: | ||
import quimb.tensor as qtn | ||
|
||
from qualtran.drawing import WireSymbol | ||
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator | ||
from qualtran.simulation.classical_sim import ClassicalValT | ||
|
||
|
||
@frozen | ||
class CAdd(Bloq): | ||
r"""An n-bit controlled-addition gate. | ||
Args: | ||
a_dtype: Quantum datatype used to represent the integer a. | ||
b_dtype: Quantum datatype used to represent the integer b. Must be large | ||
enough to hold the result in the output register of a + b, or else it simply | ||
drops the most significant bits. If not specified, b_dtype is set to a_dtype. | ||
cv: When controlled=0, this bloq is active when the ctrl register is 0. When | ||
controlled=1, this bloq is active when the ctrl register is 1. | ||
Registers: | ||
ctrl: the control bit for the addition | ||
a: A a_dtype.bitsize-sized input register (register a above). | ||
b: A b_dtype.bitsize-sized input/output register (register b above). | ||
References: | ||
[Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648) | ||
""" | ||
|
||
a_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field() | ||
b_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field() | ||
cv: int = field(default=1) | ||
|
||
@b_dtype.default | ||
def b_dtype_default(self): | ||
return self.a_dtype | ||
|
||
@a_dtype.validator | ||
def _a_dtype_validate(self, field, val): | ||
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)): | ||
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.") | ||
if isinstance(val.num_qubits, sympy.Expr): | ||
return | ||
if val.bitsize > self.b_dtype.bitsize: | ||
raise ValueError("a_dtype bitsize must be less than or equal to b_dtype bitsize") | ||
|
||
@b_dtype.validator | ||
def _b_dtype_validate(self, field, val): | ||
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)): | ||
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.") | ||
|
||
@cv.validator | ||
def _controlled_validate(self, field, val): | ||
if val not in (0, 1): | ||
raise ValueError("controlled must be either 0 or 1") | ||
|
||
@property | ||
def signature(self): | ||
return Signature( | ||
[Register("ctrl", QBit()), Register("a", self.a_dtype), Register("b", self.b_dtype)] | ||
) | ||
|
||
def add_my_tensors( | ||
self, | ||
tn: 'qtn.TensorNetwork', | ||
tag: Any, | ||
*, | ||
incoming: Dict[str, 'SoquetT'], | ||
outgoing: Dict[str, 'SoquetT'], | ||
): | ||
import quimb.tensor as qtn | ||
|
||
if isinstance(self.a_dtype, QInt) or isinstance(self.b_dtype, QInt): | ||
raise TypeError("Tensor contraction for addition is only supported for unsigned ints.") | ||
N_a = 2**self.a_dtype.bitsize | ||
N_b = 2**self.b_dtype.bitsize | ||
inds = ( | ||
incoming['ctrl'], | ||
incoming['a'], | ||
incoming['b'], | ||
outgoing['ctrl'], | ||
outgoing['a'], | ||
outgoing['b'], | ||
) | ||
unitary = np.zeros((2, N_a, N_b, 2, N_a, N_b), dtype=np.complex128) | ||
for c, a, b in itertools.product(range(2), range(N_a), range(N_b)): | ||
if c == self.cv: | ||
unitary[c, a, b, c, a, int(math.fmod(a + b, N_b))] = 1 | ||
else: | ||
unitary[c, a, b, c, a, b] = 1 | ||
|
||
tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.short_name(), tag])) | ||
|
||
def decompose_bloq(self) -> 'CompositeBloq': | ||
return decompose_from_cirq_style_method(self) | ||
|
||
def on_classical_vals(self, **kwargs) -> Dict[str, 'ClassicalValT']: | ||
a, b = kwargs['a'], kwargs['b'] | ||
unsigned = isinstance(self.a_dtype, (QUInt, QMontgomeryUInt)) | ||
b_bitsize = self.b_dtype.bitsize | ||
N = 2**b_bitsize if unsigned else 2 ** (b_bitsize - 1) | ||
ctrl = kwargs['ctrl'] | ||
if ctrl != self.cv: | ||
return {'ctrl': ctrl, 'a': a, 'b': b} | ||
else: | ||
return {'ctrl': ctrl, 'a': a, 'b': int(math.fmod(a + b, N))} | ||
|
||
def short_name(self) -> str: | ||
return "a+b" | ||
|
||
def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol': | ||
from qualtran.drawing import directional_text_box | ||
|
||
if soq.reg.name == 'ctrl': | ||
return directional_text_box('ctrl', side=soq.reg.side) | ||
if soq.reg.name == 'a': | ||
return directional_text_box('a', side=soq.reg.side) | ||
elif soq.reg.name == 'b': | ||
return directional_text_box('a+b', side=soq.reg.side) | ||
else: | ||
raise ValueError() | ||
|
||
def _left_building_block(self, inp, out, anc, depth): | ||
if depth == self.b_dtype.bitsize - 1: | ||
return | ||
else: | ||
if depth < 1: | ||
raise ValueError(f"{depth=} is not a positive integer") | ||
if depth < len(inp): | ||
yield CNOT().on(anc[depth - 1], inp[depth]) | ||
control = inp[depth] | ||
else: | ||
# If inp[depth] doesn't exist, we treat it as a |0>, | ||
# and therefore applying CNOT().on(anc[depth - 1], inp[depth]) | ||
# essentially "copies" anc[depth - 1] into inp[depth] | ||
# in the classical basis. So therefore, on future operations, | ||
# we can use anc[depth - 1] in its place. | ||
control = anc[depth - 1] | ||
yield CNOT().on(anc[depth - 1], out[depth]) | ||
yield And().on(control, out[depth], anc[depth]) | ||
yield CNOT().on(anc[depth - 1], anc[depth]) | ||
yield from self._left_building_block(inp, out, anc, depth + 1) | ||
|
||
def _right_building_block(self, inp, out, anc, control, depth): | ||
if depth == 0: | ||
return | ||
yield CNOT().on(anc[depth - 1], anc[depth]) | ||
if depth < len(inp): | ||
yield And().adjoint().on(inp[depth], out[depth], anc[depth]) | ||
yield MultiControlX((1, 1)).on(control, inp[depth], out[depth]) | ||
yield CNOT().on(anc[depth - 1], inp[depth]) | ||
else: | ||
yield And().adjoint().on(anc[depth - 1], out[depth], anc[depth]) | ||
yield MultiControlX((1, 1)).on(control, anc[depth - 1], out[depth]) | ||
yield CNOT().on(anc[depth - 1], out[depth]) | ||
yield from self._right_building_block(inp, out, anc, control, depth - 1) | ||
|
||
def decompose_from_registers( | ||
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var] | ||
) -> Iterator[cirq.OP_TREE]: | ||
# reverse the order of qubits for big endian-ness. | ||
input_bits = quregs['a'][::-1] | ||
output_bits = quregs['b'][::-1] | ||
ancillas = context.qubit_manager.qalloc(self.b_dtype.bitsize - 1)[::-1] | ||
control = quregs['ctrl'][0] | ||
if self.cv == 0: | ||
yield cirq.X(control) | ||
# Start off the addition by anding into the ancilla | ||
yield And().on(input_bits[0], output_bits[0], ancillas[0]) | ||
# Left part of Fig.4 | ||
yield from self._left_building_block(input_bits, output_bits, ancillas, 1) | ||
yield CNOT().on(ancillas[-1], output_bits[-1]) | ||
if len(input_bits) == len(output_bits): | ||
yield MultiControlX((1, 1)).on(control, input_bits[-1], output_bits[-1]) | ||
yield CNOT().on(ancillas[-1], output_bits[-1]) | ||
# right part of Fig.4 | ||
yield from self._right_building_block( | ||
input_bits, output_bits, ancillas, control, self.b_dtype.bitsize - 2 | ||
) | ||
yield And().adjoint().on(input_bits[0], output_bits[0], ancillas[0]) | ||
yield MultiControlX((1, 1)).on(control, input_bits[0], output_bits[0]) | ||
if self.cv == 0: | ||
yield cirq.X(control) | ||
context.qubit_manager.qfree(ancillas) | ||
|
||
def _t_complexity_(self): | ||
n = self.b_dtype.bitsize | ||
num_and = 2 * n - 1 | ||
num_clifford = 33 * (n - 2) + 43 | ||
return TComplexity(t=4 * num_and, clifford=num_clifford) | ||
|
||
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: | ||
n = self.b_dtype.bitsize | ||
n_cnot = (n - 2) * 6 + 2 | ||
return { | ||
(MultiControlX((1, 1)), n), | ||
(And(), n - 1), | ||
(And().adjoint(), n - 1), | ||
(CNOT(), n_cnot), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import cirq | ||
import numpy as np | ||
import pytest | ||
|
||
import qualtran.testing as qlt_testing | ||
from qualtran import QUInt | ||
from qualtran.bloqs.arithmetic.controlled_addition import CAdd | ||
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim | ||
from qualtran.resource_counting.generalizers import ignore_split_join | ||
|
||
|
||
def iter_bits(x, w): | ||
return [int(b) for b in np.binary_repr(x, width=w)] | ||
|
||
|
||
@pytest.mark.parametrize('a', [1, 2]) | ||
@pytest.mark.parametrize('b', [1, 2, 3]) | ||
@pytest.mark.parametrize('num_bits_a', [2, 3]) | ||
@pytest.mark.parametrize('num_bits_b', [5]) | ||
@pytest.mark.parametrize('controlled_on', [0, 1]) | ||
@pytest.mark.parametrize('control', [0, 1]) | ||
def test_controlled_addition(a, b, num_bits_a, num_bits_b, controlled_on, control): | ||
num_anc = num_bits_b - 1 | ||
gate = CAdd(QUInt(num_bits_a), QUInt(num_bits_b), cv=controlled_on) | ||
qubits = cirq.LineQubit.range(num_bits_a + num_bits_b + 1) | ||
op = gate.on_registers(ctrl=qubits[0], a=qubits[1 : num_bits_a + 1], b=qubits[num_bits_a + 1 :]) | ||
greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True) | ||
context = cirq.DecompositionContext(greedy_mm) | ||
circuit = cirq.Circuit(cirq.decompose_once(op, context=context)) | ||
circuit0 = cirq.Circuit(op) | ||
ancillas = sorted(circuit.all_qubits())[-num_anc:] | ||
initial_state = [0] * (num_bits_a + num_bits_b + num_anc + 1) | ||
initial_state[0] = control | ||
initial_state[1 : num_bits_a + 1] = list(iter_bits(a, num_bits_a)) | ||
initial_state[num_bits_a + 1 : num_bits_a + num_bits_b + 1] = list(iter_bits(b, num_bits_b)) | ||
final_state = [0] * (num_bits_a + num_bits_b + num_anc + 1) | ||
final_state[0] = control | ||
final_state[1 : num_bits_a + 1] = list(iter_bits(a, num_bits_a)) | ||
if control == controlled_on: | ||
final_state[num_bits_a + 1 : num_bits_a + num_bits_b + 1] = list( | ||
iter_bits(a + b, num_bits_b) | ||
) | ||
else: | ||
final_state[num_bits_a + 1 : num_bits_a + num_bits_b + 1] = list(iter_bits(b, num_bits_b)) | ||
assert_circuit_inp_out_cirqsim(circuit, qubits + ancillas, initial_state, final_state) | ||
assert_circuit_inp_out_cirqsim( | ||
circuit0, qubits, initial_state[:-num_anc], final_state[:-num_anc] | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("n", [*range(3, 10)]) | ||
def test_addition_gate_counts_controlled(n: int): | ||
add = CAdd(QUInt(n), cv=1) | ||
num_and = 2 * n - 1 | ||
t_count = 4 * num_and | ||
|
||
qlt_testing.assert_valid_bloq_decomposition(add) | ||
assert add.t_complexity() == add.decompose_bloq().t_complexity() | ||
assert add.bloq_counts() == add.decompose_bloq().bloq_counts(generalizer=ignore_split_join) | ||
assert add.t_complexity().t == t_count |