Skip to content

Commit

Permalink
Refactor QPE bloqs (#1297)
Browse files Browse the repository at this point in the history
* Refactor QPE bloqs

* More refactoring, fix failing tests

* Fix formatting

* Fix formatting and imports for notebook

* Fix pylint

* Docstrings and bloq auto testing improvements
  • Loading branch information
tanujkhattar authored Aug 21, 2024
1 parent 78d7b23 commit c0acc8e
Show file tree
Hide file tree
Showing 14 changed files with 484 additions and 263 deletions.
3 changes: 2 additions & 1 deletion dev_tools/autogenerate-bloqs-notebooks-v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@
title='Textbook Quantum Phase Estimation',
module=qualtran.bloqs.phase_estimation.text_book_qpe,
bloq_specs=[
qualtran.bloqs.phase_estimation.text_book_qpe._CC_TEXTBOOK_PHASE_ESTIMATION_DOC
qualtran.bloqs.phase_estimation.qpe_window_state._CC_RECTANGULAR_WINDOW_STATE_DOC,
qualtran.bloqs.phase_estimation.text_book_qpe._CC_TEXTBOOK_PHASE_ESTIMATION_DOC,
],
),
NotebookSpecV2(
Expand Down
1 change: 1 addition & 0 deletions qualtran/bloqs/phase_estimation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# limitations under the License.

from qualtran.bloqs.phase_estimation.lp_resource_state import LPResourceState
from qualtran.bloqs.phase_estimation.qpe_window_state import RectangularWindowState
from qualtran.bloqs.phase_estimation.qubitization_qpe import QubitizationQPE
from qualtran.bloqs.phase_estimation.text_book_qpe import TextbookQPE
143 changes: 80 additions & 63 deletions qualtran/bloqs/phase_estimation/lp_resource_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,23 @@
# limitations under the License.

"""Resource states proposed by A. Luis and J. Peřina (1996) for optimal phase measurements"""
from collections import Counter
from functools import cached_property
from typing import Iterator, Set, Tuple, TYPE_CHECKING, Union
from typing import Dict, Set, TYPE_CHECKING

import attrs
import cirq
import numpy as np
import sympy
from numpy.typing import NDArray

from qualtran import (
Bloq,
bloq_example,
BloqDocSpec,
GateWithRegisters,
QUInt,
Register,
Side,
Signature,
)
from qualtran.bloqs.basic_gates import CZPowGate, GlobalPhase, Hadamard, OnEach, Ry, Rz, XGate
from qualtran.bloqs.mcmt import MultiControlZ

from qualtran import Bloq, bloq_example, BloqDocSpec, GateWithRegisters, QBit, Signature
from qualtran.bloqs.basic_gates import CZ, Hadamard, OnEach, Ry, Rz, XGate
from qualtran.bloqs.phase_estimation.qpe_window_state import QPEWindowStateBase
from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.symbolics import acos, HasLength, is_symbolic, pi, SymbolicInt
from qualtran.symbolics import acos, ceil, is_symbolic, log2, pi, SymbolicFloat, SymbolicInt

if TYPE_CHECKING:
from qualtran import BloqBuilder, Soquet, SoquetT
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator


Expand Down Expand Up @@ -67,32 +59,33 @@ def signature(self) -> 'Signature':
def pretty_name(self) -> str:
return 'LPRS'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> Iterator[cirq.OP_TREE]:
def build_composite_bloq(
self, bb: 'BloqBuilder', *, m: 'SoquetT', anc: 'Soquet'
) -> Dict[str, 'SoquetT']:
if isinstance(self.bitsize, sympy.Expr):
raise ValueError(f'Symbolic bitsize {self.bitsize} not supported')
q, anc = quregs['m'].tolist()[::-1], quregs['anc']
yield [OnEach(self.bitsize, Hadamard()).on(*q), Hadamard().on(*anc)]
m = bb.add(OnEach(self.bitsize, Hadamard()), q=m)
q = bb.split(m)[::-1]
anc = bb.add(Hadamard(), q=anc)
for i in range(self.bitsize):
rz_angle = -2 * np.pi * (2**i) / (2**self.bitsize + 1)
yield Rz(angle=rz_angle).controlled().on(q[i], *anc)
yield Rz(angle=-2 * np.pi / (2**self.bitsize + 1)).on(*anc)
yield Hadamard().on(*anc)
q[i], anc = bb.add(Rz(angle=rz_angle).controlled(), ctrl=q[i], q=anc)
anc = bb.add(Rz(angle=-2 * np.pi / (2**self.bitsize + 1)), q=anc)
anc = bb.add(Hadamard(), q=anc)
return {'m': bb.join(q[::-1]), 'anc': anc}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
rz_angle = -2 * pi(self.bitsize) / (2**self.bitsize + 1)
ret: Set[Tuple[Bloq, SymbolicInt]] = {
(Rz(angle=rz_angle), 1),
(Hadamard(), 2 + self.bitsize),
}
ret: Counter['Bloq'] = Counter()
ret[Rz(angle=rz_angle)] += 1
ret[OnEach(self.bitsize, Hadamard())] += 1
ret[Hadamard()] += 2
if is_symbolic(self.bitsize):
ret |= {(Rz(angle=rz_angle).controlled(), self.bitsize)}
ret[Rz(angle=rz_angle).controlled()] += self.bitsize
else:
ret |= {
(Rz(angle=rz_angle * (2**i)).controlled(), 1) for i in range(int(self.bitsize))
}
return ret
for i in range(self.bitsize):
ret[Rz(angle=rz_angle * (2**i)).controlled()] += 1
return set(ret.items())

def _t_complexity_(self) -> 'TComplexity':
# Uses self.bitsize controlled-Rz rotations which decomposes into
Expand All @@ -102,7 +95,7 @@ def _t_complexity_(self) -> 'TComplexity':


@attrs.frozen
class LPResourceState(GateWithRegisters):
class LPResourceState(QPEWindowStateBase):
r"""Prepares optimal resource state $\chi_{m}$ proposed by A. Luis and J. Peřina (1996)
Uses a single round of amplitude amplification, as described in Ref 2, to prepare the
Expand All @@ -128,53 +121,77 @@ class LPResourceState(GateWithRegisters):

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('m', QUInt(self.bitsize), side=Side.THRU)])
return Signature([self.m_register])

@classmethod
def from_standard_deviation_eps(cls, eps: SymbolicFloat) -> 'LPResourceState':
r"""Estimate the phase $\phi$ with uncertainty in standard deviation bounded by $\epsilon$.
The standard deviation of phase estimation using optimal resource states scales as the
square of Holevo variance $\tan{\frac{\pi}{2^m}}$.
This bound can be used to estimate the size of the phase register s.t. the estimated phase
has a standard deviation of at-most $\epsilon$. See the class docstring for more details.
def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> Iterator[cirq.OP_TREE]:
"""Use the _LPResourceStateHelper and do a single round of amplitude amplification."""
q = quregs['m'].flatten().tolist()
anc, flag = context.qubit_manager.qalloc(2)
$$
m = \lceil\log_2{\pi/\epsilon}\rceil
$$
Args:
eps: Maximum standard deviation of the estimated phase.
"""
return LPResourceState(ceil(log2(pi(eps) / eps)))

@property
def m_bits(self) -> SymbolicInt:
return self.bitsize

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
qpe_reg = bb.allocate(dtype=self.m_register.dtype)
anc, flag = bb.allocate(dtype=QBit()), bb.allocate(dtype=QBit())

flag_angle = np.arccos(1 / (1 + 2**self.bitsize))

# Prepare initial state
yield Ry(angle=flag_angle).on(flag)
yield LPRSInterimPrep(self.bitsize).on(*q, anc)
flag = bb.add(Ry(angle=flag_angle), q=flag)
qpe_reg, anc = bb.add(LPRSInterimPrep(self.bitsize), m=qpe_reg, anc=anc)

# Reflect around the target state
yield CZPowGate().on(flag, anc)
flag, anc = bb.add(CZ(), q1=flag, q2=anc)

# Reflect around the initial state
yield LPRSInterimPrep(self.bitsize).adjoint().on(*q, anc)
yield Ry(angle=-flag_angle).on(flag)

yield XGate().on(flag)
yield MultiControlZ((0,) * (self.bitsize + 1)).on(*q, anc, flag)
yield XGate().on(flag)
qpe_reg, anc = bb.add(LPRSInterimPrep(self.bitsize).adjoint(), m=qpe_reg, anc=anc)
flag = bb.add(Ry(angle=-flag_angle), q=flag)

flag, anc, qpe_reg = bb.add(
ReflectionUsingPrepare.reflection_around_zero([1, 1, self.bitsize], global_phase=1j),
reg0_=flag,
reg1_=anc,
reg2_=qpe_reg,
)

yield LPRSInterimPrep(self.bitsize).on(*q, anc)
yield Ry(angle=flag_angle).on(flag)
qpe_reg, anc = bb.add(LPRSInterimPrep(self.bitsize), m=qpe_reg, anc=anc)
flag = bb.add(Ry(angle=flag_angle), q=flag)

# Reset ancilla to |0> state.
yield [XGate().on(flag), XGate().on(anc)]
yield GlobalPhase(exponent=0.5).on()
context.qubit_manager.qfree([flag, anc])
flag = bb.add(XGate(), q=flag)
anc = bb.add(XGate(), q=anc)
bb.free(flag)
bb.free(anc)
return {'qpe_reg': qpe_reg}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
flag_angle = acos(1 / (1 + 2**self.bitsize))
cvs: Union[HasLength, Tuple[int, ...]] = (
HasLength(self.bitsize + 1) if is_symbolic(self.bitsize) else (0,) * (self.bitsize + 1)
reflection_bloq: 'Bloq' = ReflectionUsingPrepare.reflection_around_zero(
[1, 1, self.bitsize], global_phase=1j
)
return {
(LPRSInterimPrep(self.bitsize), 2),
(LPRSInterimPrep(self.bitsize).adjoint(), 1),
(Ry(angle=flag_angle), 3),
(MultiControlZ(cvs), 1),
(XGate(), 4),
(GlobalPhase(exponent=0.5), 1),
(CZPowGate(), 1),
(Ry(angle=flag_angle), 2),
(Ry(angle=-1 * flag_angle), 1),
(reflection_bloq, 1),
(XGate(), 2),
(CZ(), 1),
}


Expand Down
62 changes: 13 additions & 49 deletions qualtran/bloqs/phase_estimation/lp_resource_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
LPRSInterimPrep,
)
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.cirq_interop.testing import GateHelper
from qualtran.resource_counting.generalizers import (
generalize_rotation_angle,
ignore_alloc_free,
Expand All @@ -42,99 +41,64 @@ def test_lp_resource_state_auto(bloq_autotester):

def test_lp_resource_state_symb():
bloq = _lp_resource_state_symbolic.make()
assert bloq.t_complexity().t == 4 * bloq.bitsize
assert bloq.t_complexity().t == 4 * bloq.bitsize + 4


def get_interim_resource_state(m: int) -> np.ndarray:
N = 2**m
state_vector = np.zeros(2 * N, dtype=np.complex128)
state_vector[:N] = np.cos(np.pi * (1 + np.arange(N)) / (1 + N))
state_vector[N:] = 1j * np.sin(np.pi * (1 + np.arange(N)) / (1 + N))
return np.sqrt(1 / N) * state_vector
state_vector = np.zeros((N, 2), dtype=np.complex128)
state_vector[:, 0] = np.cos(np.pi * (1 + np.arange(N)) / (1 + N))
state_vector[:, 1] = 1j * np.sin(np.pi * (1 + np.arange(N)) / (1 + N))
return np.sqrt(1 / N) * state_vector.reshape(2 * N)


def get_resource_state(m: int) -> np.ndarray:
N = 2**m
return np.sqrt(2 / (1 + N)) * np.sin(np.pi * (1 + np.arange(N)) / (1 + N))


def test_intermediate_resource_state_cirq_quick():
n = 3
bloq = LPRSInterimPrep(n)
state = GateHelper(bloq).circuit.final_state_vector()
np.testing.assert_allclose(state, get_interim_resource_state(n))


def test_intermediate_resource_state_tensor_quick():
n = 3
bloq = LPRSInterimPrep(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
pytest.xfail("https://github.com/quantumlib/Qualtran/issues/1068")
np.testing.assert_allclose(state_vec, get_interim_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_intermediate_resource_state_cirq(n):
def test_intermediate_resource_state(n):
bloq = LPRSInterimPrep(n)
state = GateHelper(bloq).circuit.final_state_vector()
state = initialize_from_zero(bloq).tensor_contract()
np.testing.assert_allclose(state, get_interim_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_intermediate_resource_state_tensor(n):
bloq = LPRSInterimPrep(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
pytest.xfail("https://github.com/quantumlib/Qualtran/issues/1068")
np.testing.assert_allclose(state_vec, get_interim_resource_state(n))


def test_prepares_resource_state_cirq_quick():
def test_prepares_resource_state_quick():
n = 3
bloq = LPResourceState(n)
state = GateHelper(bloq).circuit.final_state_vector()
state = bloq.tensor_contract()
np.testing.assert_allclose(state, get_resource_state(n))


def test_prepares_resource_state_tensor_quick():
n = 3
bloq = LPResourceState(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
np.testing.assert_allclose(state_vec, get_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_prepares_resource_state_cirq(n):
def test_prepares_resource_state(n):
bloq = LPResourceState(n)
state = GateHelper(bloq).circuit.final_state_vector()
state = bloq.tensor_contract()
np.testing.assert_allclose(state, get_resource_state(n))


@pytest.mark.slow
@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_prepares_resource_state_tensor(n):
bloq = LPResourceState(n)
state_prep = initialize_from_zero(bloq)
state_vec = state_prep.tensor_contract()
np.testing.assert_allclose(state_vec, get_resource_state(n))


@pytest.mark.parametrize('n', [*range(1, 14, 2)])
def test_t_complexity(n):
bloq = LPResourceState(n)
qlt_testing.assert_equivalent_bloq_counts(
bloq, [ignore_split_join, ignore_alloc_free, generalize_rotation_angle]
)
lprs_interim_count = 3 * TComplexity(rotations=2 * n + 1, clifford=2 + 3 * n)
multi_control_pauli_count = TComplexity(t=4 * n, clifford=17 * n + 5)
reflection_using_prepare = TComplexity(t=4 * n + 4, clifford=17 * n + 22)
misc_count = TComplexity(rotations=3, clifford=5)

assert bloq.t_complexity() == (lprs_interim_count + multi_control_pauli_count + misc_count)
assert bloq.t_complexity() == (lprs_interim_count + reflection_using_prepare + misc_count)


@pytest.mark.parametrize('bitsize', [*range(1, 14, 2)])
Expand Down
Loading

0 comments on commit c0acc8e

Please sign in to comment.