Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permutation Bloq #1110

Merged
merged 18 commits into from
Jul 17, 2024
100 changes: 68 additions & 32 deletions qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import math
from functools import cached_property
from typing import (
cast,
Dict,
Iterable,
Iterator,
Expand Down Expand Up @@ -44,6 +45,7 @@
DecomposeTypeError,
GateWithRegisters,
QBit,
QDType,
QInt,
QMontgomeryUInt,
QUInt,
Expand All @@ -59,12 +61,12 @@
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlX
from qualtran.cirq_interop import decompose_from_cirq_style_method
from qualtran.drawing import directional_text_box, Text
from qualtran.symbolics import is_symbolic, SymbolicInt

if TYPE_CHECKING:
from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import SymbolicInt


@frozen
Expand Down Expand Up @@ -359,6 +361,63 @@ def _cvs_converter(vv):
return tuple(int(v) for v in vv)


@frozen
class XorK(Bloq):
anurudhp marked this conversation as resolved.
Show resolved Hide resolved
r"""Maps |x> to |x \oplus k> for a constant k.

Args:
dtype: Data type of the input register `x`.
k: The classical integer value to be XOR-ed to x.
cvs: A tuple of control values. Each entry specifies whether that control line is a
"positive" control (`cv[i]=1`) or a "negative" control (`cv[i]=0`).

Registers:
x: A quantum register of type `self.dtype` (see above).
ctrls: A sequence of control qubits (only when `cvs` is non-empty).
"""
dtype: QDType
k: SymbolicInt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If k is a SymbolicInt should we restrict dtype to be QUInt / QInt ? Looking at the implementation below, it seems like this would also work for other types when k is not necessarily an integer.

cvs: Tuple[int, ...] = field(converter=_cvs_converter, default=())

@cached_property
def signature(self) -> 'Signature':
return Signature(
((Register('ctrls', QBit(), shape=(len(self.cvs),)),) if len(self.cvs) > 0 else ())
+ (Register('x', self.dtype),)
)

@cached_property
def bitsize(self) -> SymbolicInt:
return self.dtype.num_qubits

def is_symbolic(self):
return is_symbolic(self.k, self.dtype)

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
if self.is_symbolic():
raise DecomposeTypeError(f"cannot decompose symbolic {self}")

xs = bb.split(cast(Soquet, soqs.pop('x')))
ctrls = soqs.pop('ctrls', None)

for i, bit in enumerate(self.dtype.to_bits(self.k)):
if bit == 1:
if len(self.cvs) > 0 and ctrls is not None:
ctrls, xs[i] = bb.add(MultiControlX(cvs=self.cvs), ctrls=ctrls, x=xs[i])
anurudhp marked this conversation as resolved.
Show resolved Hide resolved
else:
xs[i] = bb.add(XGate(), q=xs[i])

soqs['x'] = bb.join(xs)
if ctrls is not None:
soqs['ctrls'] = ctrls
return soqs

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
bit_flip_bloq = MultiControlX(cvs=self.cvs) if len(self.cvs) > 0 else XGate()
num_flips = self.bitsize if self.is_symbolic() else sum(self.dtype.to_bits(self.k))
return {(bit_flip_bloq, num_flips)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't do n MCX gates to do n bit flips when number of controls > 1

Instead, divide it into 3 cases:

  1. Num controls == 0 : Simply do an XGate()
  2. Num controls == 1 : Simply do CNOT() gates
  3. Num controls > 1 : Use a temporary ancilla and do a MultiAnd(cvs=self.cvs) with the temporary ancilla as the target. Then reduce to Case(2) with the temporary ancilla as control.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I'd suggest we only support Case (1) and (2) and reducing Case (3) to Case (1) is a general strategy used for all multi controlled bloq so we can add a separate Bloq that does this reduction (maybe at a later point)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified this to only support a single control_val, as #1131 will handle the CtrlSpec -> QBit part.



@frozen
class AddK(Bloq):
r"""Takes |x> to |x + k> for a classical integer `k`.
Expand Down Expand Up @@ -431,46 +490,23 @@ def build_composite_bloq(
ctrls = None
k = bb.allocate(dtype=x.reg.dtype)

# Get binary representation of k and split k into separate wires.
k_split = bb.split(k)
if self.signed:
binary_rep = QInt(self.bitsize).to_bits(self.k)
xor_k_bloq = XorK(x.reg.dtype, self.k, self.cvs)
if ctrls is not None:
ctrls, k = bb.add(xor_k_bloq, ctrls=ctrls, x=k)
else:
binary_rep = QUInt(self.bitsize).to_bits(self.k)

# Apply XGates to qubits in k where the bitstring has value 1. Apply CNOTs when the gate is
# controlled.
for i in range(self.bitsize):
if binary_rep[i] == 1:
if len(self.cvs) > 0 and ctrls is not None:
ctrls, k_split[i] = bb.add(
MultiControlX(cvs=self.cvs), ctrls=ctrls, x=k_split[i]
)
else:
k_split[i] = bb.add(XGate(), q=k_split[i])
k = bb.add(xor_k_bloq, x=k)

# Rejoin the qubits representing k for in-place addition.
k = bb.join(k_split, dtype=x.reg.dtype)
if not isinstance(x.reg.dtype, (QInt, QUInt, QMontgomeryUInt)):
raise ValueError(
"Only QInt, QUInt and QMontgomerUInt types are supported for composite addition."
)
k, x = bb.add(Add(x.reg.dtype, x.reg.dtype), a=k, b=x)

# Resplit the k qubits in order to undo the original bit flips to go from the binary
# representation back to the zero state.
k_split = bb.split(k)
for i in range(self.bitsize):
if binary_rep[i] == 1:
if len(self.cvs) > 0 and ctrls is not None:
ctrls, k_split[i] = bb.add(
MultiControlX(cvs=self.cvs), ctrls=ctrls, x=k_split[i]
)
else:
k_split[i] = bb.add(XGate(), q=k_split[i])

if ctrls is not None:
ctrls, k = bb.add(xor_k_bloq, ctrls=ctrls, x=k)
else:
k = bb.add(xor_k_bloq, x=k)
# Free the ancilla qubits.
k = bb.join(k_split, dtype=x.reg.dtype)
bb.free(k)

# Return the output registers.
Expand Down
69 changes: 61 additions & 8 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Bloq,
bloq_example,
BloqDocSpec,
DecomposeTypeError,
GateWithRegisters,
QAny,
QBit,
Expand All @@ -48,12 +49,12 @@
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates import CNOT, TGate, XGate
from qualtran.bloqs.basic_gates import CNOT, XGate
from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlX
from qualtran.drawing import WireSymbol
from qualtran.drawing.musical_score import Text, TextBox
from qualtran.symbolics import is_symbolic, SymbolicInt
from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt

if TYPE_CHECKING:
from qualtran import BloqBuilder
Expand Down Expand Up @@ -926,7 +927,7 @@ def _gt_k() -> GreaterThanConstant:

@frozen
class EqualsAConstant(Bloq):
r"""Implements $U_a|x\rangle = U_a|x\rangle|z\rangle = |x\rangle |z \land (x = a)\rangle$
r"""Implements $U_a|x\rangle|z\rangle = |x\rangle |z \oplus (x = a)\rangle$

The bloq_counts and t_complexity are derived from:
https://qualtran.readthedocs.io/en/latest/bloqs/comparison_gates.html#equality-as-a-special-case
Expand All @@ -940,8 +941,8 @@ class EqualsAConstant(Bloq):
target: Register to hold result of comparison.
"""

bitsize: int
val: int
bitsize: SymbolicInt
val: SymbolicInt

@cached_property
def signature(self) -> Signature:
Expand All @@ -956,10 +957,62 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
return TextBox(f"⨁(x = {self.val})")
raise ValueError(f'Unknown register symbol {reg.name}')

def is_symbolic(self):
return is_symbolic(self.bitsize, self.val)

def build_composite_bloq(
self, bb: 'BloqBuilder', x: 'Soquet', target: 'Soquet'
) -> Dict[str, 'SoquetT']:
anurudhp marked this conversation as resolved.
Show resolved Hide resolved
if self.is_symbolic():
raise DecomposeTypeError(f"cannot decompose symbolic {self}")

bits_k = x.reg.dtype.to_bits(self.val)

if self.bitsize == 1:
# Note: when self.val = 0, this is just a negative-control CNOT.
if self.val == 0:
x = bb.add(XGate(), q=x)
x, target = bb.add(CNOT(), ctrl=x, target=target)
if self.val == 0:
x = bb.add(XGate(), q=x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a note that this is just a negative-control CNOT so that in the future, if and when we add a cv attribute to CNOT we can replace it

elif self.bitsize == 2:
and_bloq = And(bits_k[0], bits_k[1])

xs = bb.split(x)
xs, and_xs = bb.add(and_bloq, ctrl=xs)
and_xs, target = bb.add(CNOT(), ctrl=and_xs, target=target)
xs = bb.add(and_bloq.adjoint(), ctrl=xs, target=and_xs)
x = bb.join(xs)
else:
multi_and_bloq = MultiAnd(tuple(bits_k))

xs = bb.split(x)
xs, junk, and_xs = bb.add(multi_and_bloq, ctrl=xs)
and_xs, target = bb.add(CNOT(), ctrl=and_xs, target=target)
xs = bb.add(multi_and_bloq.adjoint(), ctrl=xs, junk=junk, target=and_xs)
x = bb.join(xs)

return {'x': x, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
# See: https://github.com/quantumlib/Qualtran/issues/219
# See: https://github.com/quantumlib/Qualtran/issues/217
return {(TGate(), 4 * (self.bitsize - 1))}
if not self.is_symbolic():
return super().build_call_graph(ssa)

op: Bloq
if not is_symbolic(self.bitsize) and self.bitsize <= 2:
if self.bitsize == 1:
op = XGate()
else:
cv = ssa.new_symbol('cv')
op = And(cv, cv)
else:
op = MultiAnd(HasLength(self.bitsize))

bloq_counts: dict[Bloq, int] = defaultdict(lambda: 0)
bloq_counts[op] += 1
bloq_counts[op.adjoint()] += 1
bloq_counts[CNOT()] += 1
return set(bloq_counts.items())
anurudhp marked this conversation as resolved.
Show resolved Hide resolved


def _make_equals_a_constant():
Expand Down
4 changes: 3 additions & 1 deletion qualtran/bloqs/arithmetic/comparison_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ def test_equals_a_constant():
qlt_testing.assert_wire_symbols_match_expected(
EqualsAConstant(bitsize, 17), ['In(x)', '⨁(x = 17)']
)
assert t_complexity(EqualsAConstant(bitsize, 17)) == TComplexity(t=4 * (bitsize - 1))
assert t_complexity(EqualsAConstant(bitsize, 17)) == TComplexity(
t=4 * (bitsize - 1), clifford=65
)


@pytest.mark.notebook
Expand Down
Loading
Loading