diff --git a/qualtran/_infra/controlled.py b/qualtran/_infra/controlled.py index dca518e93..bcfc1ee6f 100644 --- a/qualtran/_infra/controlled.py +++ b/qualtran/_infra/controlled.py @@ -31,6 +31,7 @@ import numpy as np from numpy.typing import NDArray +from ..symbolics import is_symbolic, prod, Shaped, SymbolicInt from .bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError from .data_types import QBit, QDType from .gate_with_registers import GateWithRegisters @@ -55,18 +56,21 @@ def _cvs_convert( int, np.integer, NDArray[np.integer], + Shaped, Sequence[Union[int, np.integer]], Sequence[Sequence[Union[int, np.integer]]], - Sequence[NDArray[np.integer]], + Sequence[Union[NDArray[np.integer], Shaped]], ] -) -> Tuple[NDArray[np.integer], ...]: +) -> Tuple[Union[NDArray[np.integer], Shaped], ...]: + if isinstance(cvs, Shaped): + return (cvs,) if isinstance(cvs, (int, np.integer)): return (np.array(cvs),) if isinstance(cvs, np.ndarray): return (cvs,) if all(isinstance(cv, (int, np.integer)) for cv in cvs): return (np.asarray(cvs),) - return tuple(np.asarray(cv) for cv in cvs) + return tuple(cv if isinstance(cv, Shaped) else np.asarray(cv) for cv in cvs) @attrs.frozen(eq=False) @@ -115,7 +119,9 @@ class CtrlSpec: qdtypes: Tuple[QDType, ...] = attrs.field( default=QBit(), converter=lambda qt: (qt,) if isinstance(qt, QDType) else tuple(qt) ) - cvs: Tuple[NDArray[np.integer], ...] = attrs.field(default=1, converter=_cvs_convert) + cvs: Tuple[Union[NDArray[np.integer], Shaped], ...] = attrs.field( + default=1, converter=_cvs_convert + ) def __attrs_post_init__(self): assert len(self.qdtypes) == len(self.cvs) @@ -125,19 +131,29 @@ def num_ctrl_reg(self) -> int: return len(self.qdtypes) @cached_property - def shapes(self) -> Tuple[Tuple[int, ...], ...]: + def shapes(self) -> Tuple[Tuple[SymbolicInt, ...], ...]: """Tuple of shapes of control registers represented by this CtrlSpec.""" return tuple(cv.shape for cv in self.cvs) @cached_property - def num_qubits(self) -> int: + def concrete_shapes(self) -> tuple[tuple[int, ...], ...]: + """Tuple of shapes of control registers represented by this CtrlSpec.""" + shapes = self.shapes + if is_symbolic(*shapes): + raise ValueError(f"cannot get concrete shapes: found symbolic {self.shapes}") + return shapes # type: ignore + + @cached_property + def num_qubits(self) -> SymbolicInt: """Total number of qubits required for control registers represented by this CtrlSpec.""" return sum( - dtype.num_qubits * int(np.prod(shape)) - for dtype, shape in zip(self.qdtypes, self.shapes) + dtype.num_qubits * prod(shape) for dtype, shape in zip(self.qdtypes, self.shapes) ) - def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[int, ...]]]: + def is_symbolic(self): + return is_symbolic(*self.qdtypes) or is_symbolic(*self.cvs) + + def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[SymbolicInt, ...]]]: """The data types that serve as input to the 'activation function'. The activation function takes in (quantum) inputs of these types and shapes and determines @@ -165,6 +181,8 @@ def is_active(self, *vals: 'ClassicalValT') -> bool: Returns: True if the specific input values evaluate to `True` for this CtrlSpec. """ + if self.is_symbolic(): + raise ValueError(f"Cannot compute activation for symbolic {self}") if len(vals) != self.num_ctrl_reg: raise ValueError(f"Incorrect number of inputs for {self}: {len(vals)}.") @@ -180,19 +198,31 @@ def is_active(self, *vals: 'ClassicalValT') -> bool: return True def wire_symbol(self, i: int, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': - # Return a circle for bits; a box otherwise. from qualtran.drawing import Circle, TextBox + cvs = self.cvs[i] + + if is_symbolic(cvs): + # control value is not given + return TextBox('ctrl') + + # Return a circle for bits; a box otherwise. + cv = cvs[idx] if reg.bitsize == 1: - cv = self.cvs[i][idx] return Circle(filled=(cv == 1)) - - cv = self.cvs[i][idx] - return TextBox(f'{cv}') + else: + return TextBox(f'{cv}') @cached_property - def _cvs_tuple(self) -> Tuple[int, ...]: - return tuple(cv for cvs in self.cvs for cv in tuple(cvs.reshape(-1))) + def __cvs_tuple(self) -> Tuple[Union[tuple[int, ...], Shaped], ...]: + """Serialize the control values for hashing and equality checking.""" + + def _serialize(cvs) -> Union[tuple[int, ...], Shaped]: + if isinstance(cvs, Shaped): + return cvs + return tuple(cvs.reshape(-1)) + + return tuple(_serialize(cvs) for cvs in self.cvs) def __eq__(self, other: Any) -> bool: if not isinstance(other, CtrlSpec): @@ -201,18 +231,22 @@ def __eq__(self, other: Any) -> bool: return ( other.qdtypes == self.qdtypes and other.shapes == self.shapes - and other._cvs_tuple == self._cvs_tuple + and other.__cvs_tuple == self.__cvs_tuple ) def __hash__(self): - return hash((self.qdtypes, self.shapes, self._cvs_tuple)) + return hash((self.qdtypes, self.shapes, self.__cvs_tuple)) def to_cirq_cv(self) -> 'cirq.SumOfProducts': """Convert CtrlSpec to cirq.SumOfProducts representation of control values.""" import cirq + if self.is_symbolic(): + raise ValueError(f"Cannot convert symbolic {self} to cirq control values.") + cirq_cv = [] for qdtype, cv in zip(self.qdtypes, self.cvs): + assert isinstance(cv, np.ndarray) for idx in Register('', qdtype, cv.shape).all_idxs(): cirq_cv += [*qdtype.to_bits(cv[idx])] return cirq.SumOfProducts([tuple(cirq_cv)]) @@ -256,11 +290,14 @@ def from_cirq_cv( def get_single_ctrl_bit(self) -> ControlBit: """If controlled by a single qubit, return the control bit, otherwise raise""" + if self.is_symbolic(): + raise ValueError(f"cannot get ctrl bit for symbolic {self}") if self.num_qubits != 1: raise ValueError(f"expected a single qubit control, got {self.num_qubits}") (qdtype,) = self.qdtypes (cv,) = self.cvs + assert isinstance(cv, np.ndarray) (idx,) = Register('', qdtype, cv.shape).all_idxs() (control_bit,) = qdtype.to_bits(cv[idx]) diff --git a/qualtran/_infra/controlled_test.py b/qualtran/_infra/controlled_test.py index 77d72432c..fcb9f207c 100644 --- a/qualtran/_infra/controlled_test.py +++ b/qualtran/_infra/controlled_test.py @@ -16,6 +16,7 @@ import attrs import numpy as np import pytest +import sympy import qualtran.testing as qlt_testing from qualtran import ( @@ -24,6 +25,7 @@ CompositeBloq, Controlled, CtrlSpec, + DecomposeTypeError, QBit, QInt, QUInt, @@ -52,6 +54,7 @@ from qualtran.drawing import get_musical_score_data from qualtran.drawing.musical_score import Circle, SoqData, TextBox from qualtran.simulation.tensor import cbloq_to_quimb, get_right_and_left_inds +from qualtran.symbolics import Shaped if TYPE_CHECKING: import cirq @@ -73,8 +76,10 @@ def test_ctrl_spec(): cspec3 = CtrlSpec(QInt(64), cvs=np.int64(234234)) assert cspec3 != cspec1 assert cspec3.qdtypes[0].num_qubits == 64 - assert cspec3.cvs[0] == 234234 - assert cspec3.cvs[0][tuple()] == 234234 + (cvs,) = cspec3.cvs + assert isinstance(cvs, np.ndarray) + assert cvs == 234234 + assert cvs[tuple()] == 234234 def test_ctrl_spec_shape(): @@ -97,7 +102,9 @@ def test_ctrl_spec_to_cirq_cv_roundtrip(): for ctrl_spec in ctrl_specs: assert ctrl_spec.to_cirq_cv() == cirq_cv.expand() - assert CtrlSpec.from_cirq_cv(cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.shapes) + assert CtrlSpec.from_cirq_cv( + cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.concrete_shapes + ) @pytest.mark.parametrize( @@ -120,6 +127,32 @@ def test_ctrl_spec_single_bit_raises(ctrl_spec: CtrlSpec): ctrl_spec.get_single_ctrl_bit() +@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)]) +def test_ctrl_spec_symbolic_cvs(shape: tuple[int, ...]): + ctrl_spec = CtrlSpec(cvs=Shaped(shape)) + assert ctrl_spec.is_symbolic() + assert ctrl_spec.num_qubits == np.prod(shape) + assert ctrl_spec.shapes == (shape,) + + +@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)]) +def test_ctrl_spec_symbolic_dtype(shape: tuple[int, ...]): + n = sympy.Symbol("n") + dtype = QUInt(n) + + ctrl_spec = CtrlSpec(qdtypes=dtype, cvs=Shaped(shape)) + + assert ctrl_spec.is_symbolic() + assert ctrl_spec.num_qubits == n * np.prod(shape) + assert ctrl_spec.shapes == (shape,) + + +def test_ctrl_spec_symbolic_wire_symbol(): + ctrl_spec = CtrlSpec(cvs=Shaped((10,))) + reg = Register('q', QBit()) + assert ctrl_spec.wire_symbol(0, reg) == TextBox('ctrl') + + def _test_cirq_equivalence(bloq: Bloq, gate: 'cirq.Gate'): import cirq @@ -431,11 +464,15 @@ def signature(self) -> 'Signature': return Signature([Register('x', QBit(), shape=(3,), side=Side.RIGHT)]) def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']: + if self.ctrl_spec.is_symbolic(): + raise DecomposeTypeError(f"cannot decompose {self} with symbolic {self.ctrl_spec=}") + one_or_zero = [ZeroState(), OneState()] ctrl_bloq = Controlled(And(*self.and_ctrl), ctrl_spec=self.ctrl_spec) ctrl_soqs = {} for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs): + assert isinstance(cvs, np.ndarray) soqs = np.empty(shape=reg.shape, dtype=object) for idx in reg.all_idxs(): soqs[idx] = bb.add(IntState(val=cvs[idx], bitsize=reg.dtype.num_qubits)) @@ -447,6 +484,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']: out_soqs = np.asarray([*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')]) # type: ignore[misc] for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs): + assert isinstance(cvs, np.ndarray) for idx in reg.all_idxs(): ctrl_soq = np.asarray(ctrl_soqs[reg.name])[idx] bb.add(IntEffect(val=cvs[idx], bitsize=reg.dtype.num_qubits), val=ctrl_soq) diff --git a/qualtran/bloqs/mcmt/ctrl_spec_and.ipynb b/qualtran/bloqs/mcmt/ctrl_spec_and.ipynb index e6feec19b..e96d6fa66 100644 --- a/qualtran/bloqs/mcmt/ctrl_spec_and.ipynb +++ b/qualtran/bloqs/mcmt/ctrl_spec_and.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "0beec492", + "id": "209f4989", "metadata": { "cq.autogen": "title_cell" }, @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d31094e", + "id": "373927c3", "metadata": { "cq.autogen": "top_imports" }, @@ -30,7 +30,7 @@ }, { "cell_type": "markdown", - "id": "12aba876", + "id": "d7dd05fa", "metadata": { "cq.autogen": "CtrlSpecAnd.bloq_doc.md" }, @@ -66,7 +66,7 @@ { "cell_type": "code", "execution_count": null, - "id": "842307af", + "id": "837cc8f9", "metadata": { "cq.autogen": "CtrlSpecAnd.bloq_doc.py" }, @@ -77,7 +77,7 @@ }, { "cell_type": "markdown", - "id": "76f2965f", + "id": "760fb1b0", "metadata": { "cq.autogen": "CtrlSpecAnd.example_instances.md" }, @@ -88,7 +88,7 @@ { "cell_type": "code", "execution_count": null, - "id": "68a43214", + "id": "8486e383", "metadata": { "cq.autogen": "CtrlSpecAnd.ctrl_on_int" }, @@ -102,7 +102,7 @@ { "cell_type": "code", "execution_count": null, - "id": "945fa4a4", + "id": "1510fd98", "metadata": { "cq.autogen": "CtrlSpecAnd.ctrl_on_bits" }, @@ -116,7 +116,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5e6374b2", + "id": "489cf199", "metadata": { "cq.autogen": "CtrlSpecAnd.ctrl_on_nd_bits" }, @@ -132,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2047fccf", + "id": "bc189810", "metadata": { "cq.autogen": "CtrlSpecAnd.ctrl_on_multiple_values" }, @@ -145,9 +145,39 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fa2c166", + "metadata": { + "cq.autogen": "CtrlSpecAnd.ctrl_on_symbolic_cv" + }, + "outputs": [], + "source": [ + "from qualtran import CtrlSpec\n", + "from qualtran.symbolics import Shaped\n", + "\n", + "ctrl_on_symbolic_cv = CtrlSpecAnd(CtrlSpec(cvs=Shaped((2,))))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0d78907", + "metadata": { + "cq.autogen": "CtrlSpecAnd.ctrl_on_symbolic_cv_multi" + }, + "outputs": [], + "source": [ + "from qualtran import CtrlSpec\n", + "from qualtran.symbolics import Shaped\n", + "\n", + "ctrl_on_symbolic_cv_multi = CtrlSpecAnd(CtrlSpec(cvs=Shaped((3,))))" + ] + }, { "cell_type": "markdown", - "id": "55581cc0", + "id": "60a9fcc0", "metadata": { "cq.autogen": "CtrlSpecAnd.graphical_signature.md" }, @@ -158,20 +188,20 @@ { "cell_type": "code", "execution_count": null, - "id": "5ea9e6e5", + "id": "f1b9274c", "metadata": { "cq.autogen": "CtrlSpecAnd.graphical_signature.py" }, "outputs": [], "source": [ "from qualtran.drawing import show_bloqs\n", - "show_bloqs([ctrl_on_int, ctrl_on_bits, ctrl_on_nd_bits, ctrl_on_multiple_values],\n", - " ['`ctrl_on_int`', '`ctrl_on_bits`', '`ctrl_on_nd_bits`', '`ctrl_on_multiple_values`'])" + "show_bloqs([ctrl_on_int, ctrl_on_bits, ctrl_on_nd_bits, ctrl_on_multiple_values, ctrl_on_symbolic_cv, ctrl_on_symbolic_cv_multi],\n", + " ['`ctrl_on_int`', '`ctrl_on_bits`', '`ctrl_on_nd_bits`', '`ctrl_on_multiple_values`', '`ctrl_on_symbolic_cv`', '`ctrl_on_symbolic_cv_multi`'])" ] }, { "cell_type": "markdown", - "id": "3f5bb7d6", + "id": "704a2863", "metadata": { "cq.autogen": "CtrlSpecAnd.call_graph.md" }, @@ -182,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f83e1715", + "id": "1c2ab375", "metadata": { "cq.autogen": "CtrlSpecAnd.call_graph.py" }, diff --git a/qualtran/bloqs/mcmt/ctrl_spec_and.py b/qualtran/bloqs/mcmt/ctrl_spec_and.py index 409bb5307..713a94047 100644 --- a/qualtran/bloqs/mcmt/ctrl_spec_and.py +++ b/qualtran/bloqs/mcmt/ctrl_spec_and.py @@ -14,6 +14,8 @@ from functools import cached_property from typing import Optional, TYPE_CHECKING, Union +import numpy as np +import sympy from attrs import frozen from qualtran import ( @@ -90,7 +92,7 @@ def signature(self) -> Signature: return Signature( [ *self.control_registers, - *self.junk_registers(), + *self.junk_registers, Register('target', QBit(), side=Side.RIGHT), ] ) @@ -102,6 +104,7 @@ def control_registers(self) -> tuple[Register, ...]: for i, (dtype, shape) in enumerate(self.ctrl_spec.activation_function_dtypes()) ) + @cached_property def junk_registers(self) -> tuple[Register, ...]: if not is_symbolic(self.n_ctrl_qubits) and self.n_ctrl_qubits == 2: return () @@ -123,6 +126,7 @@ def _flat_cvs(self) -> Union[tuple[int, ...], HasLength]: flat_cvs: list[int] = [] for reg, cv in zip(self.control_registers, self.ctrl_spec.cvs): + assert isinstance(cv, np.ndarray) flat_cvs.extend(reg.dtype.to_bits_array(cv.ravel()).ravel()) return tuple(flat_cvs) @@ -138,8 +142,10 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str # Compute the single control qubit `target` if self.n_ctrl_qubits == 2: - assert isinstance(self._flat_cvs, tuple) - cv1, cv2 = self._flat_cvs + if isinstance(self._flat_cvs, tuple): + cv1, cv2 = self._flat_cvs + else: + cv1, cv2 = sympy.symbols("cv1, cv2") ctrl_qubits, target = bb.add(And(cv1, cv2), ctrl=ctrl_qubits) junk = None else: @@ -175,8 +181,10 @@ def wire_symbol(self, reg: Optional[Register], idx: tuple[int, ...] = tuple()) - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': if not is_symbolic(self.n_ctrl_qubits) and self.n_ctrl_qubits == 2: - assert isinstance(self._flat_cvs, tuple) - cv1, cv2 = self._flat_cvs + if isinstance(self._flat_cvs, tuple): + cv1, cv2 = self._flat_cvs + else: + cv1, cv2 = sympy.symbols("cv1, cv2") return {And(cv1, cv2): 1} return {MultiAnd(self._flat_cvs): 1} @@ -218,7 +226,42 @@ def _ctrl_on_multiple_values() -> CtrlSpecAnd: return ctrl_on_multiple_values +@bloq_example(generalizer=ignore_split_join) +def _ctrl_on_symbolic_cv() -> CtrlSpecAnd: + from qualtran import CtrlSpec + from qualtran.symbolics import Shaped + + ctrl_on_symbolic_cv = CtrlSpecAnd(CtrlSpec(cvs=Shaped((2,)))) + return ctrl_on_symbolic_cv + + +@bloq_example(generalizer=ignore_split_join) +def _ctrl_on_symbolic_cv_multi() -> CtrlSpecAnd: + from qualtran import CtrlSpec + from qualtran.symbolics import Shaped + + ctrl_on_symbolic_cv_multi = CtrlSpecAnd(CtrlSpec(cvs=Shaped((3,)))) + return ctrl_on_symbolic_cv_multi + + +@bloq_example(generalizer=ignore_split_join) +def _ctrl_on_symbolic_n_ctrls() -> CtrlSpecAnd: + from qualtran import CtrlSpec + from qualtran.symbolics import Shaped + + n = sympy.Symbol("n") + ctrl_on_symbolic_cv_multi = CtrlSpecAnd(CtrlSpec(cvs=Shaped((n,)))) + return ctrl_on_symbolic_cv_multi + + _CTRLSPEC_AND_DOC = BloqDocSpec( bloq_cls=CtrlSpecAnd, - examples=(_ctrl_on_int, _ctrl_on_bits, _ctrl_on_nd_bits, _ctrl_on_multiple_values), + examples=( + _ctrl_on_int, + _ctrl_on_bits, + _ctrl_on_nd_bits, + _ctrl_on_multiple_values, + _ctrl_on_symbolic_cv, + _ctrl_on_symbolic_cv_multi, + ), ) diff --git a/qualtran/bloqs/mcmt/ctrl_spec_and_test.py b/qualtran/bloqs/mcmt/ctrl_spec_and_test.py index eb3ce38d8..f0f8638db 100644 --- a/qualtran/bloqs/mcmt/ctrl_spec_and_test.py +++ b/qualtran/bloqs/mcmt/ctrl_spec_and_test.py @@ -14,19 +14,33 @@ import numpy as np import pytest +import qualtran.testing as qlt_testing from qualtran import CtrlSpec, QUInt from qualtran.bloqs.mcmt.ctrl_spec_and import ( _ctrl_on_bits, _ctrl_on_int, _ctrl_on_multiple_values, _ctrl_on_nd_bits, + _ctrl_on_symbolic_cv, + _ctrl_on_symbolic_cv_multi, + _ctrl_on_symbolic_n_ctrls, CtrlSpecAnd, ) from qualtran.simulation.classical_sim import get_classical_truth_table @pytest.mark.parametrize( - "example", [_ctrl_on_bits, _ctrl_on_nd_bits, _ctrl_on_int, _ctrl_on_multiple_values] + "example", + [ + _ctrl_on_bits, + _ctrl_on_nd_bits, + _ctrl_on_int, + _ctrl_on_multiple_values, + _ctrl_on_symbolic_cv, + _ctrl_on_symbolic_cv_multi, + _ctrl_on_symbolic_n_ctrls, + ], + ids=lambda ex: ex.name, ) def test_examples(bloq_autotester, example): bloq_autotester(example) @@ -51,3 +65,8 @@ def test_truth_table_using_classical_sim(ctrl_spec: CtrlSpec): # check: target bit (last output value) matches `is_active` assert out_vals[-1] == ctrl_spec.is_active(*in_vals) + + +@pytest.mark.notebook +def test_notebook(): + qlt_testing.execute_notebook('ctrl_spec_and') diff --git a/qualtran/conftest.py b/qualtran/conftest.py index ac6137d71..28d79a89f 100644 --- a/qualtran/conftest.py +++ b/qualtran/conftest.py @@ -138,6 +138,9 @@ def assert_bloq_example_serializes_for_pytest(bloq_ex: BloqExample): 'black_box_select', # cannot serialize AutoPartition 'black_box_prepare', # cannot serialize AutoPartition 'kaiser_window_state_symbolic', # Split cannot have a symbolic data type. + 'ctrl_on_symbolic_cv', # cannot serialize Shaped + 'ctrl_on_symbolic_cv_multi', # cannot serialize Shaped + 'ctrl_on_symbolic_n_ctrls', # cannot serialize Shaped ]: pytest.xfail("Skipping serialization test for bloq examples that cannot yet be serialized.") diff --git a/qualtran/serialization/ctrl_spec.py b/qualtran/serialization/ctrl_spec.py index 301faf2fc..30895b51e 100644 --- a/qualtran/serialization/ctrl_spec.py +++ b/qualtran/serialization/ctrl_spec.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from qualtran import CtrlSpec from qualtran.protos import ctrl_spec_pb2 from qualtran.serialization import args, data_types +from qualtran.symbolics import Shaped def ctrl_spec_from_proto(spec: ctrl_spec_pb2.CtrlSpec) -> CtrlSpec: @@ -25,7 +25,12 @@ def ctrl_spec_from_proto(spec: ctrl_spec_pb2.CtrlSpec) -> CtrlSpec: def ctrl_spec_to_proto(spec: CtrlSpec) -> ctrl_spec_pb2.CtrlSpec: + def cvs_to_proto(cvs): + if isinstance(cvs, Shaped): + raise ValueError("cannot serialize Shaped") + return args.ndarray_to_proto(cvs) + return ctrl_spec_pb2.CtrlSpec( qdtypes=[data_types.data_type_to_proto(dtype) for dtype in spec.qdtypes], - cvs=[args.ndarray_to_proto(cvs) for cvs in spec.cvs], + cvs=[cvs_to_proto(cvs) for cvs in spec.cvs], ) diff --git a/qualtran/simulation/tensor/_tensor_data_manipulation.py b/qualtran/simulation/tensor/_tensor_data_manipulation.py index 97f8852e2..5029ae22e 100644 --- a/qualtran/simulation/tensor/_tensor_data_manipulation.py +++ b/qualtran/simulation/tensor/_tensor_data_manipulation.py @@ -68,11 +68,15 @@ def active_space_for_ctrl_spec( Returns a tuple of indices/slices that can be used to address into the ndarray, representing tensor data of shape `tensor_shape_from_signature(signature)`, and access the active subspace. """ + if ctrl_spec.is_symbolic(): + raise ValueError(f"cannot compute active space for symbolic {ctrl_spec=}") + out_ind, inp_ind = tensor_out_inp_shape_from_signature(signature) data_shape = out_ind + inp_ind active_idx: List[Union[int, slice]] = [slice(x) for x in data_shape] ctrl_idx = 0 for cv in ctrl_spec.cvs: + assert isinstance(cv, np.ndarray) for idx in itertools.product(*[range(sh) for sh in cv.shape]): active_idx[ctrl_idx] = int(cv[idx]) active_idx[ctrl_idx + len(out_ind)] = int(cv[idx]) diff --git a/qualtran/symbolics/__init__.py b/qualtran/symbolics/__init__.py index 9bee6b1fe..95cf224f0 100644 --- a/qualtran/symbolics/__init__.py +++ b/qualtran/symbolics/__init__.py @@ -28,8 +28,6 @@ sarg, sconj, sexp, - shape, - slen, smax, smin, ssqrt, @@ -38,7 +36,9 @@ from qualtran.symbolics.types import ( HasLength, is_symbolic, + shape, Shaped, + slen, SymbolicComplex, SymbolicFloat, SymbolicInt, diff --git a/qualtran/symbolics/math_funcs.py b/qualtran/symbolics/math_funcs.py index 19473ad20..5768042de 100644 --- a/qualtran/symbolics/math_funcs.py +++ b/qualtran/symbolics/math_funcs.py @@ -11,19 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Iterable, overload, Sized, Tuple, TypeVar, Union +from typing import cast, Iterable, overload, TypeVar import numpy as np import sympy -from qualtran.symbolics.types import ( - HasLength, - is_symbolic, - Shaped, - SymbolicComplex, - SymbolicFloat, - SymbolicInt, -) +from qualtran.symbolics.types import is_symbolic, SymbolicComplex, SymbolicFloat, SymbolicInt def pi(*args) -> SymbolicFloat: @@ -261,34 +254,6 @@ def sconj(x: SymbolicComplex) -> SymbolicComplex: return sympy.conjugate(x) if is_symbolic(x) else np.conjugate(x) -@overload -def slen(x: Sized) -> int: ... - - -@overload -def slen(x: Union[Shaped, HasLength]) -> sympy.Expr: ... - - -def slen(x: Union[Sized, Shaped, HasLength]) -> SymbolicInt: - if isinstance(x, Shaped): - return x.shape[0] - if isinstance(x, HasLength): - return x.n - return len(x) - - -@overload -def shape(x: np.ndarray) -> Tuple[int, ...]: ... - - -@overload -def shape(x: Shaped) -> Tuple[SymbolicInt, ...]: ... - - -def shape(x: Union[np.ndarray, Shaped]): - return x.shape - - def is_zero(x: SymbolicInt) -> bool: """check if a symbolic integer is zero diff --git a/qualtran/symbolics/math_funcs_test.py b/qualtran/symbolics/math_funcs_test.py index 993006d83..78be4eac0 100644 --- a/qualtran/symbolics/math_funcs_test.py +++ b/qualtran/symbolics/math_funcs_test.py @@ -18,19 +18,7 @@ import sympy from sympy.codegen.cfunctions import log2 as sympy_log2 -from qualtran.symbolics import ( - bit_length, - ceil, - is_symbolic, - is_zero, - log2, - sarg, - sexp, - Shaped, - slen, - smax, - smin, -) +from qualtran.symbolics import bit_length, ceil, is_zero, log2, sarg, sexp, smax, smin def test_log2(): @@ -130,16 +118,6 @@ def test_bit_length_symbolic_simplify(): assert b.subs({N: 2**n}) == n -@pytest.mark.parametrize( - "shape", - [(4,), (1, 2), (1, 2, 3), (sympy.Symbol('n'),), (sympy.Symbol('n'), sympy.Symbol('m'), 100)], -) -def test_shaped(shape: tuple[int, ...]): - shaped = Shaped(shape=shape) - assert is_symbolic(shaped) - assert slen(shaped) == shape[0] - - def test_is_zero(): assert is_zero(0) assert not is_zero(1) diff --git a/qualtran/symbolics/types.py b/qualtran/symbolics/types.py index 18714b6f8..e971db857 100644 --- a/qualtran/symbolics/types.py +++ b/qualtran/symbolics/types.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import overload, TypeVar, Union +from typing import overload, Sized, TypeVar, Union +import numpy as np import sympy from attrs import field, frozen, validators from typing_extensions import TypeIs @@ -31,13 +32,20 @@ class Shaped: """Symbolic value for an object that has a shape. - A Shaped object can be used as a symbolic replacement for any object that has an attribute `shape`, - for example numpy NDArrays. - Each dimension can be either an positive integer value or a sympy expression. + A Shaped object can be used as a symbolic replacement for any object that has an + attribute `shape`, for example numpy `NDArrays`. Each dimension can be either + a positive integer value or a sympy expression. - This is useful to do symbolic analysis of Bloqs whose call graph only depends on the shape of the input, - but not on the actual values. - For example, T-cost of the `QROM` Bloq depends only on the iteration length (shape) and not on actual data values. + For the symbolic variant of a tuple or sequence of values, see `HasLength`. + + This is useful to do symbolic analysis of Bloqs whose call graph only depends on the shape + of the input, but not on the actual values. For example, T-cost of the `QROM` Bloq depends + only on the iteration length (shape) and not on actual data values. In this case, for the + bloq attribute `data`, we can use the type: + + ```py + data: Union[NDArray, Shaped] + ``` """ shape: tuple[SymbolicInt, ...] = field(validator=validators.instance_of(tuple)) @@ -50,6 +58,15 @@ def is_symbolic(self): class HasLength: """Symbolic value for an object that has a length. + This is used as a "symbolic" tuple. The length can either be a positive integer + or a sympy expression. For example, if a bloq attribute is a tuple of ints, + we can use the type: + + ```py + values: Union[tuple, HasLength] + ``` + + For the symbolic variant of a NDArray, see `Shaped`. Note that we cannot override __len__ and return a sympy symbol because Python has special treatment for __len__ and expects you to return a non-negative integers. @@ -63,6 +80,34 @@ def is_symbolic(self): return True +@overload +def slen(x: Sized) -> int: ... + + +@overload +def slen(x: Union[Shaped, HasLength]) -> sympy.Expr: ... + + +def slen(x: Union[Sized, Shaped, HasLength]) -> SymbolicInt: + if isinstance(x, Shaped): + return x.shape[0] + if isinstance(x, HasLength): + return x.n + return len(x) + + +@overload +def shape(x: np.ndarray) -> tuple[int, ...]: ... + + +@overload +def shape(x: Shaped) -> tuple[SymbolicInt, ...]: ... + + +def shape(x: Union[np.ndarray, Shaped]): + return x.shape + + T = TypeVar('T') diff --git a/qualtran/symbolics/types_test.py b/qualtran/symbolics/types_test.py new file mode 100644 index 000000000..8e6656119 --- /dev/null +++ b/qualtran/symbolics/types_test.py @@ -0,0 +1,28 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sympy + +from qualtran.symbolics import is_symbolic, Shaped, slen + + +@pytest.mark.parametrize( + "shape", + [(4,), (1, 2), (1, 2, 3), (sympy.Symbol('n'),), (sympy.Symbol('n'), sympy.Symbol('m'), 100)], +) +def test_shaped(shape: tuple[int, ...]): + shaped = Shaped(shape=shape) + assert is_symbolic(shaped) + assert slen(shaped) == shape[0]