diff --git a/qualtran/simulation/classical_sim.py b/qualtran/simulation/classical_sim.py index 74eb5319f..48a326461 100644 --- a/qualtran/simulation/classical_sim.py +++ b/qualtran/simulation/classical_sim.py @@ -14,7 +14,19 @@ """Functionality for the `Bloq.call_classically(...)` protocol.""" import itertools -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) import networkx as nx import numpy as np @@ -34,9 +46,41 @@ ) from qualtran._infra.composite_bloq import _binst_to_cxns +if TYPE_CHECKING: + from qualtran import QDType + ClassicalValT = Union[int, np.integer, NDArray[np.integer]] +def _numpy_dtype_from_qdtype(dtype: 'QDType') -> Type: + from qualtran._infra.data_types import QBit, QInt, QUInt + + if isinstance(dtype, QUInt): + if dtype.bitsize <= 8: + return np.uint8 + elif dtype.bitsize <= 16: + return np.uint16 + elif dtype.bitsize <= 32: + return np.uint32 + elif dtype.bitsize <= 64: + return np.uint64 + + if isinstance(dtype, QInt): + if dtype.bitsize <= 8: + return np.int8 + elif dtype.bitsize <= 16: + return np.int16 + elif dtype.bitsize <= 32: + return np.int32 + elif dtype.bitsize <= 64: + return np.int64 + + if isinstance(dtype, QBit): + return np.uint8 + + return object + + def _get_in_vals( binst: Union[DanglingT, BloqInstance], reg: Register, soq_assign: Dict[Soquet, ClassicalValT] ) -> ClassicalValT: @@ -44,19 +88,7 @@ def _get_in_vals( if not reg.shape: return soq_assign[Soquet(binst, reg)] - if reg.bitsize <= 8: - dtype: Type = np.uint8 - elif reg.bitsize <= 16: - dtype = np.uint16 - elif reg.bitsize <= 32: - dtype = np.uint32 - elif reg.bitsize <= 64: - dtype = np.uint64 - else: - raise NotImplementedError( - "We currently only support up to 64-bit " - "multi-dimensional registers in classical simulation." - ) + dtype: Type = _numpy_dtype_from_qdtype(reg.dtype) arg = np.empty(reg.shape, dtype=dtype) for idx in reg.all_idxs(): diff --git a/qualtran/simulation/classical_sim_test.py b/qualtran/simulation/classical_sim_test.py index fafcd3b0d..947bc7887 100644 --- a/qualtran/simulation/classical_sim_test.py +++ b/qualtran/simulation/classical_sim_test.py @@ -20,7 +20,22 @@ from attrs import frozen from numpy.typing import NDArray -from qualtran import Bloq, BloqBuilder, QAny, QBit, Register, Side, Signature, Soquet +from qualtran import ( + Bloq, + BloqBuilder, + BQUInt, + QAny, + QBit, + QDType, + QFxp, + QInt, + QIntOnesComp, + QUInt, + Register, + Side, + Signature, + Soquet, +) from qualtran.bloqs.basic_gates import CNOT from qualtran.simulation.classical_sim import ( _update_assign_from_vals, @@ -148,3 +163,37 @@ def test_add_ints_signed(n_bits: int): @pytest.mark.notebook def test_notebook(): execute_notebook('classical_sim') + + +@frozen +class TestMultiDimensionalReg(Bloq): + dtype: QDType + n: int + + @property + def signature(self): + return Signature( + [ + Register('x', self.dtype, shape=(self.n,), side=Side.LEFT), + Register('y', self.dtype, shape=(self.n,), side=Side.RIGHT), + ] + ) + + def on_classical_vals(self, x): + return {'y': x} + + +@pytest.mark.parametrize( + 'dtype', [QBit(), QInt(5), QUInt(5), QIntOnesComp(5), BQUInt(5, 20), QFxp(5, 3, signed=True)] +) +def test_multidimensional_classical_sim_for_dtypes(dtype: QDType): + x = [*dtype.get_classical_domain()] + bloq = TestMultiDimensionalReg(dtype, len(x)) + np.testing.assert_equal(bloq.call_classically(x=np.array(x))[0], x) + + +def test_multidimensional_classical_sim_for_large_int(): + dtype = QInt(100) + x = [2**88 - 1, 2**12 - 1, 2**54 - 1, 1 - 2**72, 1 - 2**62] + bloq = TestMultiDimensionalReg(dtype, len(x)) + np.testing.assert_equal(bloq.call_classically(x=np.array(x))[0], x)