Skip to content

Commit

Permalink
Support arbitrary dtypes in classical sim for multi dimensional numpy…
Browse files Browse the repository at this point in the history
… arrays (#1418)
  • Loading branch information
tanujkhattar authored Sep 24, 2024
1 parent cee6954 commit 57b6e3d
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 15 deletions.
60 changes: 46 additions & 14 deletions qualtran/simulation/classical_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@

"""Functionality for the `Bloq.call_classically(...)` protocol."""
import itertools
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import networkx as nx
import numpy as np
Expand All @@ -34,29 +46,49 @@
)
from qualtran._infra.composite_bloq import _binst_to_cxns

if TYPE_CHECKING:
from qualtran import QDType

ClassicalValT = Union[int, np.integer, NDArray[np.integer]]


def _numpy_dtype_from_qdtype(dtype: 'QDType') -> Type:
from qualtran._infra.data_types import QBit, QInt, QUInt

if isinstance(dtype, QUInt):
if dtype.bitsize <= 8:
return np.uint8
elif dtype.bitsize <= 16:
return np.uint16
elif dtype.bitsize <= 32:
return np.uint32
elif dtype.bitsize <= 64:
return np.uint64

if isinstance(dtype, QInt):
if dtype.bitsize <= 8:
return np.int8
elif dtype.bitsize <= 16:
return np.int16
elif dtype.bitsize <= 32:
return np.int32
elif dtype.bitsize <= 64:
return np.int64

if isinstance(dtype, QBit):
return np.uint8

return object


def _get_in_vals(
binst: Union[DanglingT, BloqInstance], reg: Register, soq_assign: Dict[Soquet, ClassicalValT]
) -> ClassicalValT:
"""Pluck out the correct values from `soq_assign` for `reg` on `binst`."""
if not reg.shape:
return soq_assign[Soquet(binst, reg)]

if reg.bitsize <= 8:
dtype: Type = np.uint8
elif reg.bitsize <= 16:
dtype = np.uint16
elif reg.bitsize <= 32:
dtype = np.uint32
elif reg.bitsize <= 64:
dtype = np.uint64
else:
raise NotImplementedError(
"We currently only support up to 64-bit "
"multi-dimensional registers in classical simulation."
)
dtype: Type = _numpy_dtype_from_qdtype(reg.dtype)

arg = np.empty(reg.shape, dtype=dtype)
for idx in reg.all_idxs():
Expand Down
51 changes: 50 additions & 1 deletion qualtran/simulation/classical_sim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@
from attrs import frozen
from numpy.typing import NDArray

from qualtran import Bloq, BloqBuilder, QAny, QBit, Register, Side, Signature, Soquet
from qualtran import (
Bloq,
BloqBuilder,
BQUInt,
QAny,
QBit,
QDType,
QFxp,
QInt,
QIntOnesComp,
QUInt,
Register,
Side,
Signature,
Soquet,
)
from qualtran.bloqs.basic_gates import CNOT
from qualtran.simulation.classical_sim import (
_update_assign_from_vals,
Expand Down Expand Up @@ -148,3 +163,37 @@ def test_add_ints_signed(n_bits: int):
@pytest.mark.notebook
def test_notebook():
execute_notebook('classical_sim')


@frozen
class TestMultiDimensionalReg(Bloq):
dtype: QDType
n: int

@property
def signature(self):
return Signature(
[
Register('x', self.dtype, shape=(self.n,), side=Side.LEFT),
Register('y', self.dtype, shape=(self.n,), side=Side.RIGHT),
]
)

def on_classical_vals(self, x):
return {'y': x}


@pytest.mark.parametrize(
'dtype', [QBit(), QInt(5), QUInt(5), QIntOnesComp(5), BQUInt(5, 20), QFxp(5, 3, signed=True)]
)
def test_multidimensional_classical_sim_for_dtypes(dtype: QDType):
x = [*dtype.get_classical_domain()]
bloq = TestMultiDimensionalReg(dtype, len(x))
np.testing.assert_equal(bloq.call_classically(x=np.array(x))[0], x)


def test_multidimensional_classical_sim_for_large_int():
dtype = QInt(100)
x = [2**88 - 1, 2**12 - 1, 2**54 - 1, 1 - 2**72, 1 - 2**62]
bloq = TestMultiDimensionalReg(dtype, len(x))
np.testing.assert_equal(bloq.call_classically(x=np.array(x))[0], x)

0 comments on commit 57b6e3d

Please sign in to comment.