Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unary iteration bloq from Cirq-FT #399

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qualtran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# Internal imports: none
# External:
# - numpy: multiplying bitsizes, making cirq quregs
from ._infra.registers import Register, Signature, Side
from ._infra.registers import Register, SelectionRegister, Signature, Side

# Internal imports: none
# External imports: none
Expand Down
65 changes: 65 additions & 0 deletions qualtran/_infra/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,71 @@ def total_bits(self) -> int:
return self.bitsize * int(np.product(self.shape))


@frozen
class SelectionRegister(Register):
"""Register used to represent SELECT register for various LCU methods.

`SelectionRegister` extends the `Register` class to store the iteration length
corresponding to that register along with its size.

LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range
of values stored as a superposition over the `SELECT` register. Such (nested) coherent
for-loops can be represented using a `Tuple[SelectionRegister, ...]` where the i'th entry
stores the bitsize and iteration length of i'th nested for-loop.

One useful feature when processing such nested for-loops is to flatten out a composite index,
represented by a tuple of indices (i, j, ...), one for each selection register into a single
integer that can be used to index a flat target register. An example of such a mapping
function is described in Eq.45 of https://arxiv.org/abs/1805.03662. A general version of this
mapping function can be implemented using `numpy.ravel_multi_index` and `numpy.unravel_index`.

For example:
1) We can flatten a 2D for-loop as follows
>>> import numpy as np
>>> N, M = 10, 20
>>> flat_indices = set()
>>> for x in range(N):
... for y in range(M):
... flat_idx = x * M + y
... assert np.ravel_multi_index((x, y), (N, M)) == flat_idx
... assert np.unravel_index(flat_idx, (N, M)) == (x, y)
... flat_indices.add(flat_idx)
>>> assert len(flat_indices) == N * M

2) Similarly, we can flatten a 3D for-loop as follows
>>> import numpy as np
>>> N, M, L = 10, 20, 30
>>> flat_indices = set()
>>> for x in range(N):
... for y in range(M):
... for z in range(L):
... flat_idx = x * M * L + y * L + z
... assert np.ravel_multi_index((x, y, z), (N, M, L)) == flat_idx
... assert np.unravel_index(flat_idx, (N, M, L)) == (x, y, z)
... flat_indices.add(flat_idx)
>>> assert len(flat_indices) == N * M * L
"""

name: str
bitsize: int
iteration_length: int = field()
shape: Tuple[int, ...] = field(
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
)
side: Side = Side.THRU

@iteration_length.default
def _default_iteration_length(self):
return 2**self.bitsize

@iteration_length.validator
def validate_iteration_length(self, attribute, value):
if len(self.shape) != 0:
raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}')
if not (0 <= value <= 2**self.bitsize):
raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]')


def _dedupe(kv_iter: Iterable[Tuple[str, Register]]) -> Dict[str, Register]:
"""Construct a dictionary, but check that there are no duplicate keys."""
# throw ValueError if duplicate keys are provided.
Expand Down
42 changes: 41 additions & 1 deletion qualtran/_infra/registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import cirq
import numpy as np
import pytest

from qualtran import Register, Side, Signature
from qualtran import Register, SelectionRegister, Side, Signature


def test_register():
Expand All @@ -37,6 +38,45 @@ def test_multidim_register():
assert r.total_bits() == 2 * 3


@pytest.mark.parametrize('n, N, m, M', [(4, 10, 5, 19), (4, 16, 5, 32)])
def test_selection_registers_indexing(n, N, m, M):
regs = [SelectionRegister('x', n, N), SelectionRegister('y', m, M)]
for x in range(regs[0].iteration_length):
for y in range(regs[1].iteration_length):
assert np.ravel_multi_index((x, y), (N, M)) == x * M + y
assert np.unravel_index(x * M + y, (N, M)) == (x, y)

assert np.prod(tuple(reg.iteration_length for reg in regs)) == N * M


def test_selection_registers_consistent():
with pytest.raises(ValueError, match="iteration length must be in "):
_ = SelectionRegister('a', 3, 10)

with pytest.raises(ValueError, match="should be flat"):
_ = SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5)

selection_reg = Signature(
[
SelectionRegister('n', bitsize=3, iteration_length=5),
SelectionRegister('m', bitsize=4, iteration_length=12),
]
)
assert selection_reg[0] == SelectionRegister('n', 3, 5)
assert selection_reg[1] == SelectionRegister('m', 4, 12)
assert selection_reg[:1] == tuple([SelectionRegister('n', 3, 5)])


def test_registers_getitem_raises():
g = Signature.build(a=4, b=3, c=2)
with pytest.raises(TypeError, match="indices must be integers or slices"):
_ = g[2.5]

selection_reg = Signature([SelectionRegister('n', bitsize=3, iteration_length=5)])
with pytest.raises(TypeError, match='indices must be integers or slices'):
_ = selection_reg[2.5]


def test_signature():
r1 = Register("r1", 5)
r2 = Register("r2", 2)
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/and_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ def on_classical_vals(self, ctrl: NDArray[np.uint8]) -> Dict[str, NDArray[np.uin
junk, target = accumulate_and[1:-1], accumulate_and[-1]
return {'ctrl': ctrl, 'junk': junk, 'target': target}

def __pow__(self, power: int) -> "And":
def __pow__(self, power: int) -> "MultiAnd":
if power == 1:
return self
if power == -1:
return And(self.cvs, adjoint=self.adjoint ^ True)
return MultiAnd(self.cvs, adjoint=self.adjoint ^ True)
return NotImplemented # pragma: no cover

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
Expand Down
Loading
Loading