diff --git a/qualtran/__init__.py b/qualtran/__init__.py index 14a904a4e..70ff0e7bf 100644 --- a/qualtran/__init__.py +++ b/qualtran/__init__.py @@ -45,7 +45,7 @@ # Internal imports: none # External: # - numpy: multiplying bitsizes, making cirq quregs -from ._infra.registers import Register, Signature, Side +from ._infra.registers import Register, SelectionRegister, Signature, Side # Internal imports: none # External imports: none diff --git a/qualtran/_infra/registers.py b/qualtran/_infra/registers.py index a83d8c536..cfd79e09b 100644 --- a/qualtran/_infra/registers.py +++ b/qualtran/_infra/registers.py @@ -78,6 +78,71 @@ def total_bits(self) -> int: return self.bitsize * int(np.product(self.shape)) +@frozen +class SelectionRegister(Register): + """Register used to represent SELECT register for various LCU methods. + + `SelectionRegister` extends the `Register` class to store the iteration length + corresponding to that register along with its size. + + LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range + of values stored as a superposition over the `SELECT` register. Such (nested) coherent + for-loops can be represented using a `Tuple[SelectionRegister, ...]` where the i'th entry + stores the bitsize and iteration length of i'th nested for-loop. + + One useful feature when processing such nested for-loops is to flatten out a composite index, + represented by a tuple of indices (i, j, ...), one for each selection register into a single + integer that can be used to index a flat target register. An example of such a mapping + function is described in Eq.45 of https://arxiv.org/abs/1805.03662. A general version of this + mapping function can be implemented using `numpy.ravel_multi_index` and `numpy.unravel_index`. + + For example: + 1) We can flatten a 2D for-loop as follows + >>> import numpy as np + >>> N, M = 10, 20 + >>> flat_indices = set() + >>> for x in range(N): + ... for y in range(M): + ... flat_idx = x * M + y + ... assert np.ravel_multi_index((x, y), (N, M)) == flat_idx + ... assert np.unravel_index(flat_idx, (N, M)) == (x, y) + ... flat_indices.add(flat_idx) + >>> assert len(flat_indices) == N * M + + 2) Similarly, we can flatten a 3D for-loop as follows + >>> import numpy as np + >>> N, M, L = 10, 20, 30 + >>> flat_indices = set() + >>> for x in range(N): + ... for y in range(M): + ... for z in range(L): + ... flat_idx = x * M * L + y * L + z + ... assert np.ravel_multi_index((x, y, z), (N, M, L)) == flat_idx + ... assert np.unravel_index(flat_idx, (N, M, L)) == (x, y, z) + ... flat_indices.add(flat_idx) + >>> assert len(flat_indices) == N * M * L + """ + + name: str + bitsize: int + iteration_length: int = field() + shape: Tuple[int, ...] = field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) + side: Side = Side.THRU + + @iteration_length.default + def _default_iteration_length(self): + return 2**self.bitsize + + @iteration_length.validator + def validate_iteration_length(self, attribute, value): + if len(self.shape) != 0: + raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') + if not (0 <= value <= 2**self.bitsize): + raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]') + + def _dedupe(kv_iter: Iterable[Tuple[str, Register]]) -> Dict[str, Register]: """Construct a dictionary, but check that there are no duplicate keys.""" # throw ValueError if duplicate keys are provided. diff --git a/qualtran/_infra/registers_test.py b/qualtran/_infra/registers_test.py index b01b3e381..8d485e468 100644 --- a/qualtran/_infra/registers_test.py +++ b/qualtran/_infra/registers_test.py @@ -13,9 +13,10 @@ # limitations under the License. import cirq +import numpy as np import pytest -from qualtran import Register, Side, Signature +from qualtran import Register, SelectionRegister, Side, Signature def test_register(): @@ -37,6 +38,45 @@ def test_multidim_register(): assert r.total_bits() == 2 * 3 +@pytest.mark.parametrize('n, N, m, M', [(4, 10, 5, 19), (4, 16, 5, 32)]) +def test_selection_registers_indexing(n, N, m, M): + regs = [SelectionRegister('x', n, N), SelectionRegister('y', m, M)] + for x in range(regs[0].iteration_length): + for y in range(regs[1].iteration_length): + assert np.ravel_multi_index((x, y), (N, M)) == x * M + y + assert np.unravel_index(x * M + y, (N, M)) == (x, y) + + assert np.prod(tuple(reg.iteration_length for reg in regs)) == N * M + + +def test_selection_registers_consistent(): + with pytest.raises(ValueError, match="iteration length must be in "): + _ = SelectionRegister('a', 3, 10) + + with pytest.raises(ValueError, match="should be flat"): + _ = SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5) + + selection_reg = Signature( + [ + SelectionRegister('n', bitsize=3, iteration_length=5), + SelectionRegister('m', bitsize=4, iteration_length=12), + ] + ) + assert selection_reg[0] == SelectionRegister('n', 3, 5) + assert selection_reg[1] == SelectionRegister('m', 4, 12) + assert selection_reg[:1] == tuple([SelectionRegister('n', 3, 5)]) + + +def test_registers_getitem_raises(): + g = Signature.build(a=4, b=3, c=2) + with pytest.raises(TypeError, match="indices must be integers or slices"): + _ = g[2.5] + + selection_reg = Signature([SelectionRegister('n', bitsize=3, iteration_length=5)]) + with pytest.raises(TypeError, match='indices must be integers or slices'): + _ = selection_reg[2.5] + + def test_signature(): r1 = Register("r1", 5) r2 = Register("r2", 2) diff --git a/qualtran/bloqs/and_bloq.py b/qualtran/bloqs/and_bloq.py index 3dd6d7fba..ef0fa1819 100644 --- a/qualtran/bloqs/and_bloq.py +++ b/qualtran/bloqs/and_bloq.py @@ -218,11 +218,11 @@ def on_classical_vals(self, ctrl: NDArray[np.uint8]) -> Dict[str, NDArray[np.uin junk, target = accumulate_and[1:-1], accumulate_and[-1] return {'ctrl': ctrl, 'junk': junk, 'target': target} - def __pow__(self, power: int) -> "And": + def __pow__(self, power: int) -> "MultiAnd": if power == 1: return self if power == -1: - return And(self.cvs, adjoint=self.adjoint ^ True) + return MultiAnd(self.cvs, adjoint=self.adjoint ^ True) return NotImplemented # pragma: no cover def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: diff --git a/qualtran/bloqs/unary_iteration.ipynb b/qualtran/bloqs/unary_iteration.ipynb new file mode 100644 index 000000000..1e9bc1afe --- /dev/null +++ b/qualtran/bloqs/unary_iteration.ipynb @@ -0,0 +1,575 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "e2fa907b", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2023 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "49b5e1e6", + "metadata": {}, + "source": [ + "# Unary Iteration" + ] + }, + { + "cell_type": "markdown", + "id": "fcdb39f2", + "metadata": {}, + "source": [ + "Given an array of potential operations, for example:\n", + "\n", + " ops = [X(i) for i in range(5)]\n", + " \n", + "we would like to select an operation to apply:\n", + "\n", + " n = 4 --> apply ops[4]\n", + " \n", + "If $n$ is a quantum integer, we need to apply the transformation\n", + "\n", + "$$\n", + " |n \\rangle |\\psi\\rangle \\rightarrow |n\\rangle \\, \\mathrm{ops}_n \\cdot |\\psi\\rangle\n", + "$$\n", + "\n", + "The simplest conceptual way to do this is to use a \"total control\" quantum circuit where you introduce a multi-controlled operation for each of the `len(ops)` possible `n` values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0148f529", + "metadata": {}, + "outputs": [], + "source": [ + "import cirq\n", + "from cirq.contrib.svg import SVGCircuit\n", + "import numpy as np\n", + "from typing import *" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32e90969", + "metadata": {}, + "outputs": [], + "source": [ + "import operator\n", + "import cirq._compat\n", + "import itertools" + ] + }, + { + "cell_type": "markdown", + "id": "a6d947da", + "metadata": {}, + "source": [ + "## Total Control\n", + "\n", + "Here, we'll use Sympy's boolean logic to show how total control works. We perform an `And( ... )` for each possible bit pattern. We use an `Xnor` on each selection bit to toggle whether it's a positive or negative control (filled or open circle in quantum circuit diagrams).\n", + "\n", + "In this example, we indeed consider $X_n$ as our potential operations and toggle bits in the `target` register according to the total control." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e61bf03", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "import sympy as S\n", + "import sympy.logic.boolalg as slb\n", + "\n", + "def total_control(selection, target):\n", + " \"\"\"Toggle bits in `target` depending on `selection`.\"\"\"\n", + " print(f\"Selection is {selection}\")\n", + " \n", + " for n, trial in enumerate(itertools.product((0, 1), repeat=len(selection))):\n", + " print(f\"Step {n}, apply total control: {trial}\")\n", + " target[n] ^= slb.And(*[slb.Xnor(s, t) for s, t in zip(selection, trial)])\n", + " \n", + " if target[n] == S.true:\n", + " print(f\" -> At this stage, {n}= and our output bit is set\")\n", + "\n", + " \n", + "selection = [0, 0, 0]\n", + "target = [False]*8\n", + "total_control(selection, target) \n", + "print()\n", + "print(\"Target:\")\n", + "print(target)" + ] + }, + { + "cell_type": "markdown", + "id": "e572a31d", + "metadata": {}, + "source": [ + "Note that our target register shows we have indeed applied $X_\\mathrm{0b010}$. Try changing `selection` to other bit patterns and notice how it changes." + ] + }, + { + "cell_type": "markdown", + "id": "a4a75f61", + "metadata": {}, + "source": [ + "Of course, we don't know what state the selection register will be in. We can use sympy's support for symbolic boolean logic to verify our gadget for all possible selection inputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5df67d45", + "metadata": {}, + "outputs": [], + "source": [ + "selection = [S.Symbol(f's{i}') for i in range(3)]\n", + "target = [S.false for i in range(2**len(selection)) ]\n", + "total_control(selection, target)\n", + "\n", + "print()\n", + "print(\"Target:\")\n", + "for n, t in enumerate(target):\n", + " print(f'{n}= {t}')\n", + " \n", + "tc_target = target.copy()" + ] + }, + { + "cell_type": "markdown", + "id": "deab0553", + "metadata": {}, + "source": [ + "As expected, the \"not pattern\" (where `~` is boolean not) matches the binary representations of `n`." + ] + }, + { + "cell_type": "markdown", + "id": "81b69e70", + "metadata": {}, + "source": [ + "## Unary Iteration with segment trees\n", + "\n", + "A [segment tree](https://en.wikipedia.org/wiki/Segment_tree) is a data structure that allows logrithmic-time querying of intervals. We use a segment tree where each interval is length 1 and comprises all the `n` integers we may select.\n", + "\n", + "It is defined recursively by dividing the input interval into two half-size intervals until the left limit meets the right limit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab998aa4", + "metadata": {}, + "outputs": [], + "source": [ + "def segtree(ctrl, selection, target, depth, left, right):\n", + " \"\"\"Toggle bits in `target` depending on `selection` using a recursive segment tree.\"\"\"\n", + " print(f'depth={depth} left={left} right={right}', end=' ')\n", + " \n", + " if left == (right - 1):\n", + " # Leaf of the recusion.\n", + " print(f'n={n} ctrl={ctrl}')\n", + " target[left] ^= ctrl\n", + " return \n", + " print()\n", + " \n", + " assert depth < len(selection)\n", + " mid = (left + right) >> 1\n", + " \n", + " # Recurse left interval\n", + " new_ctrl = slb.And(ctrl, slb.Not(selection[depth]))\n", + " segtree(ctrl=new_ctrl, selection=selection, target=target, depth=depth+1, left=left, right=mid)\n", + " \n", + " # Recurse right interval\n", + " new_ctrl = slb.And(ctrl, selection[depth])\n", + " segtree(ctrl=new_ctrl, selection=selection, target=target, depth=depth+1, left=mid, right=right)\n", + " \n", + " # Quantum note:\n", + " # instead of throwing away the first value of `new_ctrl` and re-anding\n", + " # with selection, we can just invert the first one (but only if `ctrl` is active)\n", + " # new_ctrl ^= ctrl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a514ee6", + "metadata": {}, + "outputs": [], + "source": [ + "selection = [S.Symbol(f's{i}') for i in range(3)]\n", + "target = [S.false for i in range(2**len(selection)) ]\n", + "segtree(S.true, selection, target, 0, 0, 2**len(selection))\n", + "\n", + "print()\n", + "print(\"Target:\")\n", + "for n, t in enumerate(target):\n", + " print(f'n={n} {slb.simplify_logic(t)}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23d91438", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{'n':3s} | {'segtree':18s} | {'total control':18s} | same?\")\n", + "for n, (t1, t2) in enumerate(zip(target, tc_target)):\n", + " t1 = slb.simplify_logic(t1)\n", + " print(f'{n:3d} | {str(t1):18s} | {str(t2):18s} | {str(t1==t2)}')" + ] + }, + { + "cell_type": "markdown", + "id": "e39448e6", + "metadata": {}, + "source": [ + "## Quantum Circuit\n", + "\n", + "We can translate the boolean logic to reversible, quantum logic. It is instructive to start from the suboptimal total control quantum circuit for comparison purposes. We can build this as in the sympy boolean-logic case by adding controlled X operations to the target signature, with the controls on the selection signature toggled on or off according to the binary representation of the selection index.\n", + "\n", + "Let us first build a GateWithRegisters object to implement the circuit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b37d717", + "metadata": {}, + "outputs": [], + "source": [ + "import cirq\n", + "from cirq._compat import cached_property\n", + "from qualtran import Signature, GateWithRegisters\n", + "from qualtran.cirq_interop.bit_tools import iter_bits\n", + "\n", + "class TotallyControlledNot(GateWithRegisters):\n", + " \n", + " def __init__(self, selection_bitsize: int, target_bitsize: int, control_bitsize: int = 1):\n", + " self._selection_bitsize = selection_bitsize\n", + " self._target_bitsize = target_bitsize\n", + " self._control_bitsize = control_bitsize\n", + "\n", + " @cached_property\n", + " def signature(self) -> Signature:\n", + " return Signature(\n", + " [\n", + " *Signature.build(control=self._control_bitsize),\n", + " *Signature.build(selection=self._selection_bitsize),\n", + " *Signature.build(target=self._target_bitsize)\n", + " ]\n", + " )\n", + "\n", + " def decompose_from_registers(self, **qubit_regs: Sequence[cirq.Qid]) -> cirq.OP_TREE:\n", + " num_controls = self._control_bitsize + self._selection_bitsize\n", + " for target_bit in range(self._target_bitsize):\n", + " bit_pattern = iter_bits(target_bit, self._selection_bitsize)\n", + " control_values = [1]*self._control_bitsize + list(bit_pattern)\n", + " yield cirq.X.controlled(\n", + " num_controls=num_controls,\n", + " control_values=control_values\n", + " ).on(\n", + " *qubit_regs[\"control\"], \n", + " *qubit_regs[\"selection\"],\n", + " qubit_regs[\"target\"][-(target_bit+1)])\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f7b6758", + "metadata": {}, + "outputs": [], + "source": [ + "import qualtran.cirq_interop.testing as cq_testing\n", + "tc_not = TotallyControlledNot(3, 5)\n", + "tc = cq_testing.GateHelper(tc_not)\n", + "cirq.Circuit((cirq.decompose_once(tc.operation)))\n", + "SVGCircuit(cirq.Circuit(cirq.decompose_once(tc.operation)))" + ] + }, + { + "cell_type": "markdown", + "id": "7b28663a", + "metadata": {}, + "source": [ + "## Tests for Correctness\n", + "\n", + "We can use a full statevector simulation to compare the desired statevector to the one generated by the unary iteration circuit for each basis state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "574c5058", + "metadata": {}, + "outputs": [], + "source": [ + "selection_bitsize = 3\n", + "target_bitsize = 5\n", + "for n in range(target_bitsize):\n", + " # Initial qubit values\n", + " qubit_vals = {q: 0 for q in tc.all_qubits}\n", + " # All controls 'on' to activate circuit\n", + " qubit_vals.update({c: 1 for c in tc.quregs['control']})\n", + " # Set selection according to `n`\n", + " qubit_vals.update(zip(tc.quregs['selection'], iter_bits(n, selection_bitsize)))\n", + "\n", + " initial_state = [qubit_vals[x] for x in tc.all_qubits]\n", + " final_state = [qubit_vals[x] for x in tc.all_qubits]\n", + " final_state[-(n+1)] = 1\n", + " cq_testing.assert_circuit_inp_out_cirqsim(\n", + " tc.circuit, tc.all_qubits, initial_state, final_state\n", + " )\n", + " print(f'n={n} checked!')" + ] + }, + { + "cell_type": "markdown", + "id": "d76fcf8f", + "metadata": {}, + "source": [ + "## Towards a segment tree \n", + "\n", + "Next let's see how we can reduce the circuit to the observe the tree structure.\n", + "First let's recall what we are trying to do with the controlled not. Given a\n", + "selection integer (say 3 = 011), we want to toggle the bit in the target\n", + "register to on if the qubit 1 and 2 are set to on in the selection register." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3aca2666", + "metadata": {}, + "outputs": [], + "source": [ + "# The selection bits [1-3] are set according to binary representation of the number 3 (011)\n", + "initial_state = [1, 0, 1, 1, 0, 0, 0, 0, 0]\n", + "final_state = [1, 0, 1, 1, 0, 1, 0, 0, 0]\n", + "actual, should_be = cq_testing.get_circuit_inp_out_cirqsim(\n", + " tc.circuit, tc.all_qubits, initial_state, final_state\n", + " )\n", + "print(\"simulated: \", actual)\n", + "print(\"expected : \", should_be)\n" + ] + }, + { + "cell_type": "markdown", + "id": "4640eeed", + "metadata": {}, + "source": [ + "Now what is important to note is that we can remove many repeated controlled operations by using ancilla qubits to flag what part of the circuit we need to apply, this works because we know the bit pattern of nearby integers is very similar. \n", + "\n", + "A circuit demonstrating this for our example is given below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef853ae7", + "metadata": {}, + "outputs": [], + "source": [ + "from qualtran.bloqs.and_bloq import And\n", + "\n", + "selection_bitsize = 2\n", + "target_bitsize = 4\n", + "qubits = cirq.LineQubit(0).range(1 + selection_bitsize * 2 + target_bitsize)\n", + "circuit = cirq.Circuit()\n", + "circuit.append(\n", + " [\n", + " And(1, 0).on(qubits[0], qubits[1], qubits[2]),\n", + " And(1, 0).on(qubits[2], qubits[3], qubits[4]),\n", + " cirq.CX(qubits[4], qubits[8]),\n", + " cirq.CNOT(qubits[2], qubits[4]),\n", + " cirq.CX(qubits[4], qubits[7]),\n", + " And(adjoint=True).on(qubits[2], qubits[3], qubits[4]),\n", + " cirq.CNOT(qubits[0], qubits[2]),\n", + " And(1, 0).on(qubits[2], qubits[3], qubits[4]),\n", + " cirq.CX(qubits[4], qubits[6]),\n", + " cirq.CNOT(qubits[2], qubits[4]),\n", + " cirq.CX(qubits[4], qubits[5]),\n", + " And(adjoint=True).on(qubits[2], qubits[3], qubits[4]),\n", + " And(adjoint=True).on(qubits[0], qubits[1], qubits[2]),\n", + " ]\n", + ")\n", + "\n", + "SVGCircuit(circuit)" + ] + }, + { + "cell_type": "markdown", + "id": "b9d45d52", + "metadata": {}, + "source": [ + "Reading from left to right we first check the control is set to on and the selection qubit is off, if both these conditions are met then the ancilla qubit is now set to 1. The next control checks if the previous condition was met and likewise the second selection index is also off. At this point if both these conditions are met we must be indexing 0 as the first two qubits are set to off (00), otherwise we know that we want to apply X to qubit 1 so we perform a CNOT operation to flip the bit value in the second ancilla qubit, before returning back up the circuit. Now if the left half of the circuit was not applied (i.e. the first selection register was set to 1) then the CNOT between the control qubit and the first ancilla qubit causes the ancilla qubit to toggle on. This triggers the right side of the circuit, which now performs the previously described operations to figure out if the lowest bit is set. Combining these two then yields the expected controlled X operation. \n", + "\n", + "Below we check the circuit is giving the expected behaviour." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83d1287d", + "metadata": {}, + "outputs": [], + "source": [ + "initial_state = [1, 0, 0, 0, 0, 0, 0, 0, 0]\n", + "target_indx = 3\n", + "sel_bits = list(iter_bits(target_indx, selection_bitsize))\n", + "sel_indices = [i for i in range(1, 2*selection_bitsize+1, 2)]\n", + "initial_state[sel_indices[0]] = sel_bits[0]\n", + "initial_state[sel_indices[1]] = sel_bits[1]\n", + "result = cirq.Simulator(dtype=np.complex128).simulate(\n", + " circuit, initial_state=initial_state\n", + ")\n", + "actual = result.dirac_notation(decimals=2)[1:-1]\n", + "print(\"simulated: {}, index set in string {}\".format(actual, len(qubits)-1-target_indx))" + ] + }, + { + "cell_type": "markdown", + "id": "a86e0d42", + "metadata": {}, + "source": [ + "Extending the above idea to larger ranges of integers is relatively straightforward. For example consider the next simplest case of $L=8 = 2^3$. The circuit above takes care of the last two bits and can be duplicated. For the extra bit we just need to add a additional `AND` operations, and a CNOT to switch between the original range `[0,3]` or the new range `[4,7]` depending on whether the new selection register is off or on respectively. This procedure can be repeated and we can begin to notice the recursive tree-like structure. \n", + "\n", + "This structure is just the segtree described previously for boolean logic and this gives is the basic idea of unary iteration, \n", + "which uses `L-1` `AND` operations. Below the `ApplyXToLthQubit` builds the controlled Not operation using the `UnaryIterationGate` as a base class which defines the `decompose_from_registers` method appropriately to recursively construct the unary iteration circuit.\n", + "\n", + "Note below a different ordering of ancilla and selection qubits is taken to what was used in the simpler `L=4` example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cba52b1", + "metadata": {}, + "outputs": [], + "source": [ + "from cirq_ft import Register, SelectionRegister, UnaryIterationGate\n", + "from cirq._compat import cached_property\n", + "\n", + "class ApplyXToLthQubit(UnaryIterationGate):\n", + " def __init__(self, selection_bitsize: int, target_bitsize: int, control_bitsize: int = 1):\n", + " self._selection_bitsize = selection_bitsize\n", + " self._target_bitsize = target_bitsize\n", + " self._control_bitsize = control_bitsize\n", + "\n", + " @cached_property\n", + " def control_registers(self) -> Tuple[Register, ...]:\n", + " return Register('control', self._control_bitsize),\n", + "\n", + " @cached_property\n", + " def selection_registers(self) -> Tuple[SelectionRegister, ...]:\n", + " return SelectionRegister('selection', self._selection_bitsize, self._target_bitsize),\n", + "\n", + " @cached_property\n", + " def target_registers(self) -> Tuple[Register, ...]:\n", + " return Register('target', self._target_bitsize),\n", + "\n", + " def nth_operation(\n", + " self, context, control: cirq.Qid, selection: int, target: Sequence[cirq.Qid]\n", + " ) -> cirq.OP_TREE:\n", + " return cirq.CNOT(control, target[-(selection + 1)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1e4bafa", + "metadata": {}, + "outputs": [], + "source": [ + "import cirq_ft.infra.testing as cq_testing\n", + "selection_bitsize = 3\n", + "target_bitsize = 5\n", + "\n", + "g = cq_testing.GateHelper(\n", + " ApplyXToLthQubit(selection_bitsize, target_bitsize))\n", + "SVGCircuit(cirq.Circuit(cirq.decompose_once(g.operation)))" + ] + }, + { + "cell_type": "markdown", + "id": "13773620", + "metadata": {}, + "source": [ + "## Tests for Correctness\n", + "\n", + "We can use a full statevector simulation to check again that the optimized circuit produces the expected result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32ae469b", + "metadata": {}, + "outputs": [], + "source": [ + "from cirq_ft.infra.bit_tools import iter_bits\n", + "\n", + "for n in range(target_bitsize):\n", + " # Initial qubit values\n", + " qubit_vals = {q: 0 for q in g.all_qubits}\n", + " # All controls 'on' to activate circuit\n", + " qubit_vals.update({c: 1 for c in g.quregs['control']})\n", + " # Set selection according to `n`\n", + " qubit_vals.update(zip(g.quregs['selection'], iter_bits(n, selection_bitsize)))\n", + "\n", + " initial_state = [qubit_vals[x] for x in g.all_qubits]\n", + " qubit_vals[g.quregs['target'][-(n + 1)]] = 1\n", + " final_state = [qubit_vals[x] for x in g.all_qubits]\n", + " cq_testing.assert_circuit_inp_out_cirqsim(\n", + " g.decomposed_circuit, g.all_qubits, initial_state, final_state\n", + " )\n", + " print(f'n={n} checked!')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/qualtran/bloqs/unary_iteration_bloq.py b/qualtran/bloqs/unary_iteration_bloq.py new file mode 100644 index 000000000..9452288fe --- /dev/null +++ b/qualtran/bloqs/unary_iteration_bloq.py @@ -0,0 +1,453 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Callable, Dict, Iterator, List, Sequence, Tuple + +import cirq +import numpy as np +from cirq._compat import cached_property +from numpy.typing import NDArray + +from qualtran import GateWithRegisters, Register, SelectionRegister, Signature +from qualtran._infra.gate_with_registers import merge_qubits, total_bits +from qualtran.bloqs import and_bloq + + +def _unary_iteration_segtree( + ops: List[cirq.Operation], + control: cirq.Qid, + selection: Sequence[cirq.Qid], + ancilla: Sequence[cirq.Qid], + sl: int, + l: int, + r: int, + l_iter: int, + r_iter: int, + break_early: Callable[[int, int], bool], +) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: + """Constructs a unary iteration circuit by iterating over nodes of an implicit Segment Tree. + + Args: + ops: Operations accumulated so far while traversing the implicit segment tree. The + accumulated ops are yielded and cleared when we reach a leaf node. + control: The control qubit that controls the execution of the entire unary iteration + circuit represented by the current node of the segment tree. + selection: Sequence of selection qubits. The i'th qubit in the list corresponds to the i'th + level in the segment tree.Thus, a total of O(logN) selection qubits are required for a + tree on range `N = (r_iter - l_iter)`. + ancilla: Pre-allocated ancilla qubits to be used for constructing the unary iteration + circuit. + sl: Current depth of the tree. `selection[sl]` gives the selection qubit corresponding to + the current depth. + l: Left index of the range represented by current node of the segment tree. + r: Right index of the range represented by current node of the segment tree. + l_iter: Left index of iteration range over which the segment tree should be constructed. + r_iter: Right index of iteration range over which the segment tree should be constructed. + break_early: For each internal node of the segment tree, `break_early(l, r)` is called to + evaluate whether the unary iteration should terminate early and not recurse in the + subtree of the node representing range `[l, r)`. If True, the internal node is + considered equivalent to a leaf node and the method yields only one tuple + `(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`. + + Yields: + One `Tuple[cirq.OP_TREE, cirq.Qid, int]` for each leaf node in the segment tree. The i'th + yielded element corresponds to the i'th leaf node which represents the `l_iter + i`'th + integer. The tuple corresponds to: + - cirq.OP_TREE: Operations to be inserted in the circuit in between the last leaf node + (or beginning of iteration) to the current leaf node. + - cirq.Qid: The control qubit which can be controlled upon to execute the $U_{l}$ on a + target register when the selection register stores integer $l$. + - int: Integer $l$ which would be stored in the selection register if the control qubit + is set. + """ + if l >= r_iter or l_iter >= r: + # Range corresponding to this node is completely outside of iteration range. + return + if l_iter <= l < r <= r_iter and (l == (r - 1) or break_early(l, r)): + # Reached a leaf node or a "special" internal node; yield the operations. + yield tuple(ops), control, l + ops.clear() + return + assert sl < len(selection) + m = (l + r) >> 1 + if r_iter <= m: + # Yield only left sub-tree. + yield from _unary_iteration_segtree( + ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early + ) + return + if l_iter >= m: + # Yield only right sub-tree + yield from _unary_iteration_segtree( + ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early + ) + return + anc, sq = ancilla[sl], selection[sl] + ops.append(and_bloq.And(1, 0).on(control, sq, anc)) + yield from _unary_iteration_segtree( + ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early + ) + ops.append(cirq.CNOT(control, anc)) + yield from _unary_iteration_segtree( + ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early + ) + ops.append(and_bloq.And(adjoint=True).on(control, sq, anc)) + + +def _unary_iteration_zero_control( + ops: List[cirq.Operation], + selection: Sequence[cirq.Qid], + ancilla: Sequence[cirq.Qid], + l_iter: int, + r_iter: int, + break_early: Callable[[int, int], bool], +) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: + sl, l, r = 0, 0, 2 ** len(selection) + m = (l + r) >> 1 + ops.append(cirq.X(selection[0])) + yield from _unary_iteration_segtree( + ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter, break_early + ) + ops.append(cirq.X(selection[0])) + yield from _unary_iteration_segtree( + ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter, break_early + ) + + +def _unary_iteration_single_control( + ops: List[cirq.Operation], + control: cirq.Qid, + selection: Sequence[cirq.Qid], + ancilla: Sequence[cirq.Qid], + l_iter: int, + r_iter: int, + break_early: Callable[[int, int], bool], +) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: + sl, l, r = 0, 0, 2 ** len(selection) + yield from _unary_iteration_segtree( + ops, control, selection, ancilla, sl, l, r, l_iter, r_iter, break_early + ) + + +def _unary_iteration_multi_controls( + ops: List[cirq.Operation], + controls: Sequence[cirq.Qid], + selection: Sequence[cirq.Qid], + ancilla: Sequence[cirq.Qid], + l_iter: int, + r_iter: int, + break_early: Callable[[int, int], bool], +) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: + num_controls = len(controls) + and_ancilla = ancilla[: num_controls - 2] + and_target = ancilla[num_controls - 2] + if num_controls > 2: + multi_controlled_and = and_bloq.MultiAnd(cvs=(1,) * num_controls).on_registers( + ctrl=np.array(controls).reshape(num_controls, 1), + junk=np.array(and_ancilla).reshape(num_controls - 2, 1), + target=and_target, + ) + else: + multi_controlled_and = and_bloq.And(1, 1).on_registers( + ctrl=np.array(controls).reshape(num_controls, 1), target=and_target + ) + + ops.append(multi_controlled_and) + yield from _unary_iteration_single_control( + ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter, break_early + ) + ops.append(cirq.inverse(multi_controlled_and)) + + +def unary_iteration( + l_iter: int, + r_iter: int, + flanking_ops: List[cirq.Operation], + controls: Sequence[cirq.Qid], + selection: Sequence[cirq.Qid], + qubit_manager: cirq.QubitManager, + break_early: Callable[[int, int], bool] = lambda l, r: False, +) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: + """The method performs unary iteration on `selection` integer in `range(l_iter, r_iter)`. + + Unary iteration is a coherent for loop that can be used to conditionally perform a different + operation on a target register for every integer in the `range(l_iter, r_iter)` stored in the + selection register. + + Users can write multi-dimensional coherent for loops as follows: + + >>> import cirq + >>> from qualtran.bloqs.unary_iteration_gate import unary_iteration + >>> N, M = 5, 7 + >>> target = [[cirq.q(f't({i}, {j})') for j in range(M)] for i in range(N)] + >>> selection = [[cirq.q(f's({i}, {j})') for j in range(3)] for i in range(3)] + >>> circuit = cirq.Circuit() + >>> i_ops = [] + >>> qm = cirq.GreedyQubitManager("ancilla", maximize_reuse=True) + >>> for i_optree, i_ctrl, i in unary_iteration(0, N, i_ops, [], selection[0], qm): + ... circuit.append(i_optree) + ... j_ops = [] + ... for j_optree, j_ctrl, j in unary_iteration(0, M, j_ops, [i_ctrl], selection[1], qm): + ... circuit.append(j_optree) + ... # Conditionally perform operations on target register using `j_ctrl`, `i` & `j`. + ... circuit.append(cirq.CNOT(j_ctrl, target[i][j])) + ... circuit.append(j_ops) + >>> circuit.append(i_ops) + + Note: Unary iteration circuits assume that the selection register stores integers only in the + range `[l, r)` for which the corresponding unary iteration circuit should be built. + + Args: + l_iter: Starting index of the iteration range. + r_iter: Ending index of the iteration range. + flanking_ops: A list of `cirq.Operation`s that represents operations to be inserted in the + circuit before/after the first/last iteration of the unary iteration for loop. Note that + the list is mutated by the function, such that before calling the function, the list + represents operations to be inserted before the first iteration and after the last call + to the function, list represents operations to be inserted at the end of last iteration. + controls: Control register of qubits. + selection: Selection register of qubits. + qubit_manager: A `cirq.QubitManager` to allocate new qubits. + break_early: For each internal node of the segment tree, `break_early(l, r)` is called to + evaluate whether the unary iteration should terminate early and not recurse in the + subtree of the node representing range `[l, r)`. If True, the internal node is + considered equivalent to a leaf node and the method yields only one tuple + `(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`. + + Yields: + (r_iter - l_iter) different tuples, each corresponding to an integer in range + [l_iter, r_iter). + Each returned tuple also corresponds to a unique leaf in the unary iteration tree. + The values of yielded `Tuple[cirq.OP_TREE, cirq.Qid, int]` correspond to: + - cirq.OP_TREE: The op-tree to be inserted in the circuit to get to the current leaf. + - cirq.Qid: Control qubit used to conditionally apply operations on the target conditioned + on the returned integer. + - int: The current integer in the iteration `range(l_iter, r_iter)`. + """ + assert 2 ** len(selection) >= r_iter - l_iter + assert len(selection) > 0 + ancilla = qubit_manager.qalloc(max(0, len(controls) + len(selection) - 1)) + if len(controls) == 0: + yield from _unary_iteration_zero_control( + flanking_ops, selection, ancilla, l_iter, r_iter, break_early + ) + elif len(controls) == 1: + yield from _unary_iteration_single_control( + flanking_ops, controls[0], selection, ancilla, l_iter, r_iter, break_early + ) + else: + yield from _unary_iteration_multi_controls( + flanking_ops, controls, selection, ancilla, l_iter, r_iter, break_early + ) + qubit_manager.qfree(ancilla) + + +class UnaryIterationGate(GateWithRegisters): + """Base class for defining multiplexed gates that can execute a coherent for-loop. + + Unary iteration is a coherent for loop that can be used to conditionally perform a different + operation on a target register for every integer in the `range(l_iter, r_iter)` stored in the + selection register. + + `UnaryIterationGate` leverages the utility method `unary_iteration` to provide + a convenient API for users to define a multi-dimensional multiplexed gate that can execute + indexed operations on a target register depending on the index value stored in a selection + register. + + Note: Unary iteration circuits assume that the selection register stores integers only in the + range `[l, r)` for which the corresponding unary iteration circuit should be built. + + References: + [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity] + (https://arxiv.org/abs/1805.03662). + Babbush et. al. (2018). Section III.A. + """ + + @cached_property + @abc.abstractmethod + def control_registers(self) -> Tuple[Register, ...]: + pass + + @cached_property + @abc.abstractmethod + def selection_registers(self) -> Tuple[SelectionRegister, ...]: + pass + + @cached_property + @abc.abstractmethod + def target_registers(self) -> Tuple[Register, ...]: + pass + + @cached_property + def signature(self) -> Signature: + return Signature( + [*self.control_registers, *self.selection_registers, *self.target_registers] + ) + + @cached_property + def extra_registers(self) -> Tuple[Register, ...]: + return () + + @abc.abstractmethod + def nth_operation( + self, context: cirq.DecompositionContext, control: cirq.Qid, **kwargs + ) -> cirq.OP_TREE: + """Apply nth operation on the target signature when selection signature store `n`. + + The `UnaryIterationGate` class is a mixin that represents a coherent for-loop over + different indices (i.e. selection signature). This method denotes the "body" of the + for-loop, which is executed `self.selection_registers.total_iteration_size` times and each + iteration represents a unique combination of values stored in selection signature. For each + call, the method should return the operations that should be applied to the target + signature, given the values stored in selection signature. + + The derived classes should specify the following arguments as `**kwargs`: + 1) `control: cirq.Qid`: A qubit which can be used as a control to selectively + apply operations when selection register stores specific value. + 2) Register names in `self.selection_registers`: Each argument corresponds to + a selection register and represents the integer value stored in the register. + 3) Register names in `self.target_registers`: Each argument corresponds to a target + register and represents the sequence of qubits that represent the target register. + 4) Register names in `self.extra_regs`: Each argument corresponds to an extra + register and represents the sequence of qubits that represent the extra register. + """ + + def decompose_zero_selection( + self, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type: ignore[type-var] + ) -> cirq.OP_TREE: + """Specify decomposition of the gate when selection register is empty + + By default, if the selection register is empty, the decomposition will raise a + `NotImplementedError`. The derived classes can override this method and specify + a custom decomposition that should be used if the selection register is empty, + i.e. `total_bits(self.selection_registers) == 0`. + + The derived classes should specify the following arguments as `**kwargs`: + 1) Register names in `self.control_registers`: Each argument corresponds to a + control register and represents sequence of qubits that represent the control register. + 2) Register names in `self.target_registers`: Each argument corresponds to a target + register and represents the sequence of qubits that represent the target register. + 3) Register names in `self.extra_regs`: Each argument corresponds to an extra + register and represents the sequence of qubits that represent the extra register. + """ + raise NotImplementedError("Selection register must not be empty.") + + def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int) -> bool: + """Derived classes should override this method to specify an early termination condition. + + For each internal node of the unary iteration segment tree, `break_early(l, r)` is called + to evaluate whether the unary iteration should not recurse in the subtree of the node + representing range `[l, r)`. If True, the internal node is considered equivalent to a leaf + node and thus, `self.nth_operation` will be called for only integer `l` in the range [l, r). + + When the `UnaryIteration` class is constructed using multiple selection signature, i.e. we + wish to perform nested coherent for-loops, a unary iteration segment tree is constructed + corresponding to each nested coherent for-loop. For every such unary iteration segment tree, + the `_break_early` condition is checked by passing the `selection_index_prefix` tuple. + + Args: + selection_index_prefix: To evaluate the early breaking condition for the i'th nested + for-loop, the `selection_index_prefix` contains `i-1` integers corresponding to + the loop variable values for the first `i-1` nested loops. + l: Beginning of range `[l, r)` for internal node of unary iteration segment tree. + r: End (exclusive) of range `[l, r)` for internal node of unary iteration segment tree. + + Returns: + True of the `len(selection_index_prefix)`'th unary iteration should terminate early for + the given parameters. + """ + return False + + def decompose_from_registers( + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] + ) -> cirq.OP_TREE: + if total_bits(self.selection_registers) == 0 or self._break_early( + (), 0, self.selection_registers[0].iteration_length + ): + return self.decompose_zero_selection(context=context, **quregs) + + num_loops = len(self.selection_registers) + target_regs = {reg.name: quregs[reg.name] for reg in self.target_registers} + extra_regs = {reg.name: quregs[reg.name] for reg in self.extra_registers} + + def unary_iteration_loops( + nested_depth: int, + selection_reg_name_to_val: Dict[str, int], + controls: Sequence[cirq.Qid], + ) -> Iterator[cirq.OP_TREE]: + """Recursively write any number of nested coherent for-loops using unary iteration. + + This helper method is useful to write `num_loops` number of nested coherent for-loops by + recursively calling this method `num_loops` times. The ith recursive call of this method + has `nested_depth=i` and represents the body of ith nested for-loop. + + Args: + nested_depth: Integer between `[0, num_loops]` representing the nest-level of + for-loop for which this method implements the body. + selection_reg_name_to_val: A dictionary containing `nested_depth` elements mapping + the selection integer names (i.e. loop variables) to corresponding values; + for each of the `nested_depth` parent for-loops written before. + controls: Control qubits that should be used to conditionally activate the body of + this for-loop. + + Returns: + `cirq.OP_TREE` implementing `num_loops` nested coherent for-loops, with operations + returned by `self.nth_operation` applied conditionally to the target register based + on values of selection signature. + """ + if nested_depth == num_loops: + yield self.nth_operation( + context=context, + control=controls[0], + **selection_reg_name_to_val, + **target_regs, + **extra_regs, + ) + return + # Use recursion to write `num_loops` nested loops using unary_iteration(). + ops: List[cirq.Operation] = [] + selection_index_prefix = tuple(selection_reg_name_to_val.values()) + ith_for_loop = unary_iteration( + l_iter=0, + r_iter=self.selection_registers[nested_depth].iteration_length, + flanking_ops=ops, + controls=controls, + selection=[*quregs[self.selection_registers[nested_depth].name]], + qubit_manager=context.qubit_manager, + break_early=lambda l, r: self._break_early(selection_index_prefix, l, r), + ) + for op_tree, control_qid, n in ith_for_loop: + yield op_tree + selection_reg_name_to_val[self.selection_registers[nested_depth].name] = n + yield from unary_iteration_loops( + nested_depth + 1, selection_reg_name_to_val, (control_qid,) + ) + selection_reg_name_to_val.pop(self.selection_registers[nested_depth].name) + yield ops + + return unary_iteration_loops(0, {}, merge_qubits(self.control_registers, **quregs)) + + def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: + """Basic circuit diagram. + + Descendants are encouraged to override this with more descriptive + circuit diagram information. + """ + wire_symbols = ["@"] * total_bits(self.control_registers) + wire_symbols += ["In"] * total_bits(self.selection_registers) + wire_symbols += [self.__class__.__name__] * total_bits(self.target_registers) + return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/qualtran/bloqs/unary_iteration_bloq_test.py b/qualtran/bloqs/unary_iteration_bloq_test.py new file mode 100644 index 000000000..90c3b4491 --- /dev/null +++ b/qualtran/bloqs/unary_iteration_bloq_test.py @@ -0,0 +1,205 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from typing import Sequence, Tuple + +import cirq +import pytest +from cirq._compat import cached_property + +from qualtran import Register, SelectionRegister, Signature +from qualtran._infra.gate_with_registers import get_named_qubits, total_bits +from qualtran.bloqs.unary_iteration_bloq import unary_iteration, UnaryIterationGate +from qualtran.cirq_interop.bit_tools import iter_bits +from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim, GateHelper +from qualtran.testing import assert_valid_bloq_decomposition, execute_notebook + + +class ApplyXToLthQubit(UnaryIterationGate): + def __init__(self, selection_bitsize: int, target_bitsize: int, control_bitsize: int = 1): + self._selection_bitsize = selection_bitsize + self._target_bitsize = target_bitsize + self._control_bitsize = control_bitsize + + @cached_property + def control_registers(self) -> Tuple[Register, ...]: + return (Register('control', self._control_bitsize),) + + @cached_property + def selection_registers(self) -> Tuple[SelectionRegister, ...]: + return (SelectionRegister('selection', self._selection_bitsize, self._target_bitsize),) + + @cached_property + def target_registers(self) -> Tuple[Register, ...]: + return (Register('target', self._target_bitsize),) + + def nth_operation( # type: ignore[override] + self, + context: cirq.DecompositionContext, + control: cirq.Qid, + selection: int, + target: Sequence[cirq.Qid], + ) -> cirq.OP_TREE: + return cirq.CNOT(control, target[-(selection + 1)]) + + +@pytest.mark.parametrize( + "selection_bitsize, target_bitsize, control_bitsize", [(3, 5, 1), (2, 4, 2), (1, 2, 3)] +) +def test_unary_iteration_gate(selection_bitsize, target_bitsize, control_bitsize): + greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True) + gate = ApplyXToLthQubit(selection_bitsize, target_bitsize, control_bitsize) + g = GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) + assert len(g.all_qubits) <= 2 * (selection_bitsize + control_bitsize) + target_bitsize - 1 + + for n in range(target_bitsize): + # Initial qubit values + qubit_vals = {q: 0 for q in g.operation.qubits} + # All controls 'on' to activate circuit + qubit_vals.update({c: 1 for c in g.quregs['control']}) + # Set selection according to `n` + qubit_vals.update(zip(g.quregs['selection'], iter_bits(n, selection_bitsize))) + + initial_state = [qubit_vals[x] for x in g.operation.qubits] + qubit_vals[g.quregs['target'][-(n + 1)]] = 1 + final_state = [qubit_vals[x] for x in g.operation.qubits] + assert_circuit_inp_out_cirqsim(g.circuit, g.operation.qubits, initial_state, final_state) + + +class ApplyXToIJKthQubit(UnaryIterationGate): + def __init__(self, target_shape: Tuple[int, int, int]): + self._target_shape = target_shape + + @cached_property + def control_registers(self) -> Tuple[Register, ...]: + return () + + @cached_property + def selection_registers(self) -> Tuple[SelectionRegister, ...]: + return tuple( + SelectionRegister( + 'ijk'[i], (self._target_shape[i] - 1).bit_length(), self._target_shape[i] + ) + for i in range(3) + ) + + @cached_property + def target_registers(self) -> Tuple[Register, ...]: + return tuple( + Signature.build( + t1=self._target_shape[0], t2=self._target_shape[1], t3=self._target_shape[2] + ) + ) + + def nth_operation( # type: ignore[override] + self, + context: cirq.DecompositionContext, + control: cirq.Qid, + i: int, + j: int, + k: int, + t1: Sequence[cirq.Qid], + t2: Sequence[cirq.Qid], + t3: Sequence[cirq.Qid], + ) -> cirq.OP_TREE: + yield [cirq.CNOT(control, t1[i]), cirq.CNOT(control, t2[j]), cirq.CNOT(control, t3[k])] + + +@pytest.mark.parametrize("target_shape", [(2, 3, 2), (2, 2, 2)]) +def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, int]): + greedy_mm = cirq.GreedyQubitManager(prefix="_a", maximize_reuse=True) + gate = ApplyXToIJKthQubit(target_shape) + g = GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) + assert ( + len(g.all_qubits) <= total_bits(gate.signature) + total_bits(gate.selection_registers) - 1 + ) + + max_i, max_j, max_k = target_shape + i_len, j_len, k_len = tuple(reg.total_bits() for reg in gate.selection_registers) + for i, j, k in itertools.product(range(max_i), range(max_j), range(max_k)): + qubit_vals = {x: 0 for x in g.operation.qubits} + # Initialize selection bits appropriately: + qubit_vals.update(zip(g.quregs['i'], iter_bits(i, i_len))) + qubit_vals.update(zip(g.quregs['j'], iter_bits(j, j_len))) + qubit_vals.update(zip(g.quregs['k'], iter_bits(k, k_len))) + # Construct initial state + initial_state = [qubit_vals[x] for x in g.operation.qubits] + # Build correct statevector with selection_integer bit flipped in the target register: + for reg_name, idx in zip(['t1', 't2', 't3'], [i, j, k]): + qubit_vals[g.quregs[reg_name][idx]] = 1 + final_state = [qubit_vals[x] for x in g.operation.qubits] + assert_circuit_inp_out_cirqsim(g.circuit, g.operation.qubits, initial_state, final_state) + + +def test_unary_iteration_loop(): + n_range, m_range = (3, 5), (6, 8) + selection_registers = [SelectionRegister('n', 3, 5), SelectionRegister('m', 3, 8)] + selection = get_named_qubits(selection_registers) + target = {(n, m): cirq.q(f't({n}, {m})') for n in range(*n_range) for m in range(*m_range)} + qm = cirq.GreedyQubitManager("ancilla", maximize_reuse=True) + circuit = cirq.Circuit() + i_ops = [] + # Build the unary iteration circuit + for i_optree, i_ctrl, i in unary_iteration( + n_range[0], n_range[1], i_ops, [], selection['n'], qm + ): + circuit.append(i_optree) + j_ops = [] + for j_optree, j_ctrl, j in unary_iteration( + m_range[0], m_range[1], j_ops, [i_ctrl], selection['m'], qm + ): + circuit.append(j_optree) + # Conditionally perform operations on target register using `j_ctrl`, `i` & `j`. + circuit.append(cirq.CNOT(j_ctrl, target[(i, j)])) + circuit.append(j_ops) + circuit.append(i_ops) + all_qubits = sorted(circuit.all_qubits()) + + i_len, j_len = 3, 3 + for i, j in itertools.product(range(*n_range), range(*m_range)): + qubit_vals = {x: 0 for x in all_qubits} + # Initialize selection bits appropriately: + qubit_vals.update(zip(selection['n'], iter_bits(i, i_len))) + qubit_vals.update(zip(selection['m'], iter_bits(j, j_len))) + # Construct initial state + initial_state = [qubit_vals[x] for x in all_qubits] + # Build correct statevector with selection_integer bit flipped in the target register: + qubit_vals[target[(i, j)]] = 1 + final_state = [qubit_vals[x] for x in all_qubits] + assert_circuit_inp_out_cirqsim(circuit, all_qubits, initial_state, final_state) + + +def test_unary_iteration_loop_empty_range(): + qm = cirq.SimpleQubitManager() + assert list(unary_iteration(4, 4, [], [], [cirq.q('s')], qm)) == [] + assert list(unary_iteration(4, 3, [], [], [cirq.q('s')], qm)) == [] + + +@pytest.mark.parametrize( + "selection_bitsize, target_bitsize, control_bitsize", [(3, 5, 1), (2, 4, 2), (1, 2, 3)] +) +def test_bloq_has_consistent_decomposition(selection_bitsize, target_bitsize, control_bitsize): + bloq = ApplyXToLthQubit(selection_bitsize, target_bitsize, control_bitsize) + assert_valid_bloq_decomposition(bloq) + + +@pytest.mark.parametrize("target_shape", [(2, 3, 2), (2, 2, 2)]) +def test_multi_dimensional_bloq_has_consistent_decomposition(target_shape: Tuple[int, int, int]): + bloq = ApplyXToIJKthQubit(target_shape) + assert_valid_bloq_decomposition(bloq) + + +def test_notebook(): + execute_notebook('unary_iteration') diff --git a/qualtran/cirq_interop/bit_tools.py b/qualtran/cirq_interop/bit_tools.py new file mode 100644 index 000000000..c92c5c74c --- /dev/null +++ b/qualtran/cirq_interop/bit_tools.py @@ -0,0 +1,98 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterator, Tuple + +import numpy as np + + +def iter_bits(val: int, width: int, *, signed: bool = False) -> Iterator[int]: + """Iterate over the bits in a binary representation of `val`. + + This uses a big-endian convention where the most significant bit + is yielded first. + + Args: + val: The integer value. Its bitsize must fit within `width` + width: The number of output bits. + signed: If True, the most significant bit represents the sign of + the number (ones complement) which is 1 if val < 0 else 0. + Raises: + ValueError: If `val` is negative or if `val.bit_length()` exceeds `width`. + """ + if val.bit_length() + int(val < 0) > width: + raise ValueError(f"{val} exceeds width {width}.") + if val < 0 and not signed: + raise ValueError(f"{val} is negative.") + if signed: + yield 1 if val < 0 else 0 + width -= 1 + for b in f'{abs(val):0{width}b}': + yield int(b) + + +def iter_bits_twos_complement(val: int, width: int) -> Iterator[int]: + """Iterate over the bits in a binary representation of `val`. + + This uses a big-endian convention where the most significant bit + is yielded first. Allows for negative values and represents these using twos + complement. + + Args: + val: The integer value. Its bitsize must fit within `width` + width: The number of output bits. + + Raises: + ValueError: If `val.bit_length()` exceeds `2 * width + 1`. + """ + if (val.bit_length() - 1) // 2 > width: + raise ValueError(f"{val} exceeds width {width}.") + mask = (1 << width) - 1 + for b in f'{val&mask:0{width}b}': + yield int(b) + + +def iter_bits_fixed_point(val: float, width: int, *, signed: bool = False) -> Iterator[int]: + r"""Represent the floating point number -1 <= val <= 1 using `width` bits. + + $$ + val = \sum_{b=0}^{width - 1} val[b] / 2^{1+b} + $$ + + Args: + val: Floating point number in [-1, 1] + width: The number of output bits in fixed point binary representation of `val`. + signed: If True, the most significant bit represents the sign of + the number (ones complement) which is 1 if val < 0 else 0. + + Raises: + ValueError: If val is not between [0, 1] (signed=False) / [-1, 1] (signed=True). + """ + lb = -1 if signed else 0 + assert lb <= val <= 1, f"{val} must be between [{lb}, 1]" + if signed: + yield 1 if val < 0 else 0 + width -= 1 + val = abs(val) + for _ in range(width): + val = val * 2 + out_bit = np.floor(val) + val = val - out_bit + yield int(out_bit) + + +def float_as_fixed_width_int(val: float, width: int) -> Tuple[int, int]: + """Returns a `width` length fixed point binary representation of `val` where -1 <= val <= 1.""" + bits = [*iter_bits_fixed_point(val, width, signed=True)] + return bits[0], int(''.join(str(b) for b in bits[1:]), 2) diff --git a/qualtran/cirq_interop/bit_tools_test.py b/qualtran/cirq_interop/bit_tools_test.py new file mode 100644 index 000000000..911a034d0 --- /dev/null +++ b/qualtran/cirq_interop/bit_tools_test.py @@ -0,0 +1,76 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random + +import pytest +from cirq_ft.infra.bit_tools import ( + float_as_fixed_width_int, + iter_bits, + iter_bits_fixed_point, + iter_bits_twos_complement, +) + + +def test_iter_bits(): + assert list(iter_bits(0, 2)) == [0, 0] + assert list(iter_bits(0, 3, signed=True)) == [0, 0, 0] + assert list(iter_bits(1, 2)) == [0, 1] + assert list(iter_bits(1, 2, signed=True)) == [0, 1] + assert list(iter_bits(-1, 2, signed=True)) == [1, 1] + assert list(iter_bits(2, 2)) == [1, 0] + assert list(iter_bits(2, 3, signed=True)) == [0, 1, 0] + assert list(iter_bits(-2, 3, signed=True)) == [1, 1, 0] + assert list(iter_bits(3, 2)) == [1, 1] + with pytest.raises(ValueError): + assert list(iter_bits(4, 2)) == [1, 0, 0] + with pytest.raises(ValueError): + _ = list(iter_bits(-3, 4)) + + +def test_iter_bits_twos(): + assert list(iter_bits_twos_complement(0, 4)) == [0, 0, 0, 0] + assert list(iter_bits_twos_complement(1, 4)) == [0, 0, 0, 1] + assert list(iter_bits_twos_complement(-2, 4)) == [1, 1, 1, 0] + assert list(iter_bits_twos_complement(-3, 4)) == [1, 1, 0, 1] + with pytest.raises(ValueError): + _ = list(iter_bits_twos_complement(100, 2)) + + +random.seed(1234) + + +@pytest.mark.parametrize('val', [random.uniform(-1, 1) for _ in range(10)]) +@pytest.mark.parametrize('width', [*range(2, 20, 2)]) +@pytest.mark.parametrize('signed', [True, False]) +def test_iter_bits_fixed_point(val, width, signed): + if (val < 0) and not signed: + with pytest.raises(AssertionError): + _ = [*iter_bits_fixed_point(val, width, signed=signed)] + else: + bits = [*iter_bits_fixed_point(val, width, signed=signed)] + if signed: + sign, bits = bits[0], bits[1:] + assert sign == (1 if val < 0 else 0) + val = abs(val) + approx_val = math.fsum([b * (1 / 2 ** (1 + i)) for i, b in enumerate(bits)]) + unsigned_width = width - 1 if signed else width + assert math.isclose( + val, approx_val, abs_tol=1 / 2**unsigned_width + ), f'{val}:{approx_val}:{width}' + bits_from_int = [ + *iter_bits(float_as_fixed_width_int(val, unsigned_width + 1)[1], unsigned_width) + ] + assert bits == bits_from_int