Skip to content

Commit

Permalink
Add unary iteration bloq from Cirq-FT (#399)
Browse files Browse the repository at this point in the history
* Add unary iteration bloq from Cirq-FT

* Remove debug print

* Fix pylint
  • Loading branch information
tanujkhattar authored Oct 13, 2023
1 parent cf25d4f commit 2c173fc
Show file tree
Hide file tree
Showing 9 changed files with 1,516 additions and 4 deletions.
2 changes: 1 addition & 1 deletion qualtran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# Internal imports: none
# External:
# - numpy: multiplying bitsizes, making cirq quregs
from ._infra.registers import Register, Signature, Side
from ._infra.registers import Register, SelectionRegister, Signature, Side

# Internal imports: none
# External imports: none
Expand Down
65 changes: 65 additions & 0 deletions qualtran/_infra/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,71 @@ def total_bits(self) -> int:
return self.bitsize * int(np.product(self.shape))


@frozen
class SelectionRegister(Register):
"""Register used to represent SELECT register for various LCU methods.
`SelectionRegister` extends the `Register` class to store the iteration length
corresponding to that register along with its size.
LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range
of values stored as a superposition over the `SELECT` register. Such (nested) coherent
for-loops can be represented using a `Tuple[SelectionRegister, ...]` where the i'th entry
stores the bitsize and iteration length of i'th nested for-loop.
One useful feature when processing such nested for-loops is to flatten out a composite index,
represented by a tuple of indices (i, j, ...), one for each selection register into a single
integer that can be used to index a flat target register. An example of such a mapping
function is described in Eq.45 of https://arxiv.org/abs/1805.03662. A general version of this
mapping function can be implemented using `numpy.ravel_multi_index` and `numpy.unravel_index`.
For example:
1) We can flatten a 2D for-loop as follows
>>> import numpy as np
>>> N, M = 10, 20
>>> flat_indices = set()
>>> for x in range(N):
... for y in range(M):
... flat_idx = x * M + y
... assert np.ravel_multi_index((x, y), (N, M)) == flat_idx
... assert np.unravel_index(flat_idx, (N, M)) == (x, y)
... flat_indices.add(flat_idx)
>>> assert len(flat_indices) == N * M
2) Similarly, we can flatten a 3D for-loop as follows
>>> import numpy as np
>>> N, M, L = 10, 20, 30
>>> flat_indices = set()
>>> for x in range(N):
... for y in range(M):
... for z in range(L):
... flat_idx = x * M * L + y * L + z
... assert np.ravel_multi_index((x, y, z), (N, M, L)) == flat_idx
... assert np.unravel_index(flat_idx, (N, M, L)) == (x, y, z)
... flat_indices.add(flat_idx)
>>> assert len(flat_indices) == N * M * L
"""

name: str
bitsize: int
iteration_length: int = field()
shape: Tuple[int, ...] = field(
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
)
side: Side = Side.THRU

@iteration_length.default
def _default_iteration_length(self):
return 2**self.bitsize

@iteration_length.validator
def validate_iteration_length(self, attribute, value):
if len(self.shape) != 0:
raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}')
if not (0 <= value <= 2**self.bitsize):
raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]')


def _dedupe(kv_iter: Iterable[Tuple[str, Register]]) -> Dict[str, Register]:
"""Construct a dictionary, but check that there are no duplicate keys."""
# throw ValueError if duplicate keys are provided.
Expand Down
42 changes: 41 additions & 1 deletion qualtran/_infra/registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import cirq
import numpy as np
import pytest

from qualtran import Register, Side, Signature
from qualtran import Register, SelectionRegister, Side, Signature


def test_register():
Expand All @@ -37,6 +38,45 @@ def test_multidim_register():
assert r.total_bits() == 2 * 3


@pytest.mark.parametrize('n, N, m, M', [(4, 10, 5, 19), (4, 16, 5, 32)])
def test_selection_registers_indexing(n, N, m, M):
regs = [SelectionRegister('x', n, N), SelectionRegister('y', m, M)]
for x in range(regs[0].iteration_length):
for y in range(regs[1].iteration_length):
assert np.ravel_multi_index((x, y), (N, M)) == x * M + y
assert np.unravel_index(x * M + y, (N, M)) == (x, y)

assert np.prod(tuple(reg.iteration_length for reg in regs)) == N * M


def test_selection_registers_consistent():
with pytest.raises(ValueError, match="iteration length must be in "):
_ = SelectionRegister('a', 3, 10)

with pytest.raises(ValueError, match="should be flat"):
_ = SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5)

selection_reg = Signature(
[
SelectionRegister('n', bitsize=3, iteration_length=5),
SelectionRegister('m', bitsize=4, iteration_length=12),
]
)
assert selection_reg[0] == SelectionRegister('n', 3, 5)
assert selection_reg[1] == SelectionRegister('m', 4, 12)
assert selection_reg[:1] == tuple([SelectionRegister('n', 3, 5)])


def test_registers_getitem_raises():
g = Signature.build(a=4, b=3, c=2)
with pytest.raises(TypeError, match="indices must be integers or slices"):
_ = g[2.5]

selection_reg = Signature([SelectionRegister('n', bitsize=3, iteration_length=5)])
with pytest.raises(TypeError, match='indices must be integers or slices'):
_ = selection_reg[2.5]


def test_signature():
r1 = Register("r1", 5)
r2 = Register("r2", 2)
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/and_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ def on_classical_vals(self, ctrl: NDArray[np.uint8]) -> Dict[str, NDArray[np.uin
junk, target = accumulate_and[1:-1], accumulate_and[-1]
return {'ctrl': ctrl, 'junk': junk, 'target': target}

def __pow__(self, power: int) -> "And":
def __pow__(self, power: int) -> "MultiAnd":
if power == 1:
return self
if power == -1:
return And(self.cvs, adjoint=self.adjoint ^ True)
return MultiAnd(self.cvs, adjoint=self.adjoint ^ True)
return NotImplemented # pragma: no cover

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
Expand Down
Loading

0 comments on commit 2c173fc

Please sign in to comment.