From 35db018ace8af0399dcaa067c9ca09ec8d0920fc Mon Sep 17 00:00:00 2001 From: Anurudh Peduri <7265746+anurudhp@users.noreply.github.com> Date: Thu, 22 Aug 2024 17:25:56 -0700 Subject: [PATCH] Simplify GQSP call graph (#1328) * simplify GQSP call graph * simplify HamiltonianSimulation call graph * fix notebook test * make hamsim a bloq * address feedback is_zero tests * link issue about tensor vs cirq unitary * format --- .../hamiltonian_simulation_by_gqsp.py | 30 ++++++------- .../hamiltonian_simulation_by_gqsp_test.py | 5 ++- qualtran/bloqs/qsp/generalized_qsp.py | 43 +++++++++---------- qualtran/bloqs/qsp/generalized_qsp_test.py | 21 ++++++--- qualtran/symbolics/__init__.py | 1 + qualtran/symbolics/math_funcs.py | 12 ++++++ qualtran/symbolics/math_funcs_test.py | 15 +++++++ 7 files changed, 82 insertions(+), 45 deletions(-) diff --git a/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp.py b/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp.py index 999638a27..0f8db837c 100644 --- a/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp.py +++ b/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import Counter from functools import cached_property from typing import cast, Dict, Set, Tuple, TYPE_CHECKING, Union @@ -18,7 +19,8 @@ from attrs import field, frozen from numpy.typing import NDArray -from qualtran import bloq_example, BloqDocSpec, GateWithRegisters, Signature, Soquet +from qualtran import Bloq, bloq_example, BloqDocSpec, CtrlSpec, Signature, Soquet +from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate from qualtran.bloqs.qsp.generalized_qsp import GeneralizedQSP from qualtran.bloqs.qubitization.qubitization_walk_operator import QubitizationWalkOperator from qualtran.linalg.polynomial.jacobi_anger_approximations import ( @@ -34,7 +36,7 @@ @frozen -class HamiltonianSimulationByGQSP(GateWithRegisters): +class HamiltonianSimulationByGQSP(Bloq): r"""Hamiltonian simulation using Generalized QSP given a qubitized quantum walk operator. Given the Szegedy Quantum Walk Operator for a Hamiltonian $H$ constructed from SELECT and PREPARE oracles, @@ -161,7 +163,6 @@ def __add_prepare( return gqsp_soqs, prepare_out_soqs def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']: - # TODO open issue: alloc/free does not work with cirq api state_prep_ancilla: Dict[str, 'SoquetT'] = { reg.name: bb.allocate(reg.total_bits()) for reg in self.walk_operator.prepare.junk_registers @@ -182,19 +183,16 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str return soqs def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: - if self.is_symbolic(): - from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate - - d = self.degree - return { - (self.walk_operator.prepare, 1), - (self.walk_operator.prepare.adjoint(), 1), - (self.walk_operator.controlled(control_values=[0]), d), - (self.walk_operator.adjoint().controlled(), d), - (SU2RotationGate.arbitrary(ssa), 2 * d + 1), - } - - return super().build_call_graph(ssa) + counts = Counter[Bloq]() + + d = self.degree + counts[self.walk_operator.prepare] += 1 + counts[self.walk_operator.prepare.adjoint()] += 1 + counts[self.walk_operator.controlled(ctrl_spec=CtrlSpec(cvs=0))] += d + counts[self.walk_operator.adjoint().controlled()] += d + counts[SU2RotationGate.arbitrary(ssa)] += 2 * d + 1 + + return set(counts.items()) @bloq_example diff --git a/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp_test.py b/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp_test.py index d2e74ce4b..2a706d49d 100644 --- a/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp_test.py +++ b/qualtran/bloqs/hamiltonian_simulation/hamiltonian_simulation_by_gqsp_test.py @@ -33,6 +33,7 @@ verify_generalized_qsp, ) from qualtran.bloqs.qubitization.qubitization_walk_operator import QubitizationWalkOperator +from qualtran.cirq_interop import BloqAsCirqGate from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.resource_counting import big_O, BloqCount, get_cost_value from qualtran.symbolics import Shaped @@ -75,7 +76,9 @@ def verify_hamiltonian_simulation_by_gqsp( N = H.shape[0] W_e_iHt = HamiltonianSimulationByGQSP(W, t=t, precision=precision) - result_unitary = cirq.unitary(W_e_iHt) + # TODO This cirq.unitary call is 4-5x faster than tensor_contract. + # https://github.com/quantumlib/Qualtran/issues/1336 + result_unitary = cirq.unitary(BloqAsCirqGate(W_e_iHt)) expected_top_left = scipy.linalg.expm(-1j * H * t) actual_top_left = result_unitary[:N, :N] diff --git a/qualtran/bloqs/qsp/generalized_qsp.py b/qualtran/bloqs/qsp/generalized_qsp.py index 4bfd95b64..e8b71e5b7 100644 --- a/qualtran/bloqs/qsp/generalized_qsp.py +++ b/qualtran/bloqs/qsp/generalized_qsp.py @@ -21,8 +21,10 @@ from numpy.typing import NDArray from qualtran import ( + Bloq, bloq_example, BloqDocSpec, + CtrlSpec, DecomposeTypeError, GateWithRegisters, QBit, @@ -31,7 +33,16 @@ ) from qualtran.bloqs.basic_gates.su2_rotation import SU2RotationGate from qualtran.linalg.polynomial.qsp_testing import assert_is_qsp_polynomial -from qualtran.symbolics import is_symbolic, Shaped, slen, smax, smin, SymbolicFloat, SymbolicInt +from qualtran.symbolics import ( + is_symbolic, + is_zero, + Shaped, + slen, + smax, + smin, + SymbolicFloat, + SymbolicInt, +) if TYPE_CHECKING: import cirq @@ -359,29 +370,17 @@ def is_symbolic(self) -> bool: return is_symbolic(self.P, self.Q, self.negative_power) def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: - if isinstance(self.P, Shaped) or self.is_symbolic(): - degree = slen(self.P) - 1 - - return { - (self.U.controlled(control_values=[0]), smax(0, degree - self.negative_power)), - (self.U.adjoint(), smax(0, self.negative_power - degree)), - (self.U.adjoint().controlled(), smin(degree, self.negative_power)), - (SU2RotationGate.arbitrary(ssa), degree + 1), - } - - degree = len(self.P) - 1 + counts = Counter[Bloq]() - counts: Set['BloqCountT'] = set(Counter(self.signal_rotations).items()) - - if degree > self.negative_power: - counts.add((self.U.controlled(control_values=[0]), degree - self.negative_power)) - elif self.negative_power > degree: - counts.add((self.U.adjoint(), self.negative_power - degree)) - - if isinstance(self.negative_power, int) and self.negative_power > 0: - counts.add((self.U.adjoint().controlled(), min(degree, self.negative_power))) + degree = slen(self.P) - 1 + counts[SU2RotationGate.arbitrary(ssa)] += degree + 1 + counts[self.U.controlled(ctrl_spec=CtrlSpec(cvs=0))] += smax( + 0, degree - self.negative_power + ) + counts[self.U.adjoint()] += smax(0, self.negative_power - degree) + counts[self.U.adjoint().controlled()] += smin(degree, self.negative_power) - return counts + return set((bloq, count) for bloq, count in counts.items() if not is_zero(count)) @bloq_example diff --git a/qualtran/bloqs/qsp/generalized_qsp_test.py b/qualtran/bloqs/qsp/generalized_qsp_test.py index 284e626aa..c64899b19 100644 --- a/qualtran/bloqs/qsp/generalized_qsp_test.py +++ b/qualtran/bloqs/qsp/generalized_qsp_test.py @@ -43,6 +43,7 @@ from qualtran.linalg.testing import assert_matrices_almost_equal from qualtran.resource_counting import SympySymbolAllocator from qualtran.symbolics import Shaped +from qualtran.testing import execute_notebook def test_gqsp_example(bloq_autotester): @@ -198,12 +199,15 @@ def catch_rotations(bloq: Bloq) -> Bloq: _, sigma = gqsp.call_graph(max_depth=1, generalizer=catch_rotations) - assert sigma == { - arbitrary_rotation: degree + 1, - Controlled(U, CtrlSpec(cvs=0)): max(0, degree - negative_power), - Controlled(U.adjoint(), CtrlSpec()): min(degree, negative_power), - U.adjoint(): max(0, negative_power - degree), - } + expected_sigma: dict[Bloq, int] = {arbitrary_rotation: degree + 1} + if degree > negative_power: + expected_sigma[Controlled(U, CtrlSpec(cvs=0))] = degree - negative_power + if negative_power > 0: + expected_sigma[Controlled(U.adjoint(), CtrlSpec())] = min(degree, negative_power) + if negative_power > degree: + expected_sigma[U.adjoint()] = negative_power - degree + + assert sigma == expected_sigma @define(slots=False) @@ -304,3 +308,8 @@ def test_complementary_polynomials_for_jacobi_anger_approximations(t: float, pre list(P), Q, random_state=random_state, rtol=precision ) verify_generalized_qsp(MatrixGate.random(1, random_state=random_state), list(P), Q) + + +@pytest.mark.notebook +def test_notebook(): + execute_notebook('generalized_qsp') diff --git a/qualtran/symbolics/__init__.py b/qualtran/symbolics/__init__.py index 243232041..b09ebba27 100644 --- a/qualtran/symbolics/__init__.py +++ b/qualtran/symbolics/__init__.py @@ -18,6 +18,7 @@ bit_length, ceil, floor, + is_zero, ln, log2, pi, diff --git a/qualtran/symbolics/math_funcs.py b/qualtran/symbolics/math_funcs.py index 6d51133a1..d588386d2 100644 --- a/qualtran/symbolics/math_funcs.py +++ b/qualtran/symbolics/math_funcs.py @@ -313,3 +313,15 @@ def shape(x: Shaped) -> Tuple[SymbolicInt, ...]: def shape(x: Union[np.ndarray, Shaped]): return x.shape + + +def is_zero(x: SymbolicInt) -> bool: + """check if a symbolic integer is zero + + If it returns True, then the value is definitely 0. + If it returns False, then the value is either non-zero, + or could not be symbolically symplified to a zero. + """ + if is_symbolic(x): + return x.equals(0) + return x == 0 diff --git a/qualtran/symbolics/math_funcs_test.py b/qualtran/symbolics/math_funcs_test.py index 729e5c927..993006d83 100644 --- a/qualtran/symbolics/math_funcs_test.py +++ b/qualtran/symbolics/math_funcs_test.py @@ -22,6 +22,7 @@ bit_length, ceil, is_symbolic, + is_zero, log2, sarg, sexp, @@ -137,3 +138,17 @@ def test_shaped(shape: tuple[int, ...]): shaped = Shaped(shape=shape) assert is_symbolic(shaped) assert slen(shaped) == shape[0] + + +def test_is_zero(): + assert is_zero(0) + assert not is_zero(1) + + n = sympy.Symbol("n") + assert not is_zero(n) + assert is_zero(n - n) + assert is_zero(n * 0) + assert is_zero(n * 2 - n - n) + + assert is_zero(sympy.sympify("0")) + assert not is_zero(sympy.sympify("1"))