diff --git a/python/pylibraft/pylibraft/common/device_ndarray.py b/python/pylibraft/pylibraft/common/device_ndarray.py index f267e0c644..ae7bb2cabf 100644 --- a/python/pylibraft/pylibraft/common/device_ndarray.py +++ b/python/pylibraft/pylibraft/common/device_ndarray.py @@ -31,7 +31,9 @@ def __init__(self, np_ndarray): Parameters ---------- - ndarray : A numpy.ndarray which will be copied and moved to the device + ndarray : Can be numpy.ndarray, array like or even directly + an __array_interface__. Only case it is a numpy.ndarray its + contents will be copied to the device. Examples -------- @@ -58,11 +60,38 @@ def __init__(self, np_ndarray): raft_array = device_ndarray.empty((100, 50)) torch_tensor = torch.as_tensor(raft_array, device='cuda') """ - self.ndarray_ = np_ndarray + + if type(np_ndarray) is np.ndarray: + # np_ndarray IS an actual numpy.ndarray + self.__array_interface__ = np_ndarray.__array_interface__.copy() + self.ndarray_ = np_ndarray + copy = True + elif hasattr(np_ndarray, "__array_interface__"): + # np_ndarray HAS an __array_interface__ + self.__array_interface__ = np_ndarray.__array_interface__.copy() + self.ndarray_ = np_ndarray + copy = False + elif all( + name in np_ndarray for name in {"typestr", "shape", "version"} + ): + # np_ndarray IS an __array_interface__ + self.__array_interface__ = np_ndarray.copy() + self.ndarray_ = None + copy = False + else: + raise ValueError( + "np_ndarray should be or contain __array_interface__" + ) + order = "C" if self.c_contiguous else "F" - self.device_buffer_ = rmm.DeviceBuffer.to_device( - self.ndarray_.tobytes(order=order) - ) + if copy: + self.device_buffer_ = rmm.DeviceBuffer.to_device( + self.ndarray_.tobytes(order=order) + ) + else: + self.device_buffer_ = rmm.DeviceBuffer( + size=np.prod(self.shape) * self.dtype.itemsize + ) @classmethod def empty(cls, shape, dtype=np.float32, order="C"): @@ -82,7 +111,7 @@ def empty(cls, shape, dtype=np.float32, order="C"): or column-major (Fortran-style) order in memory """ arr = np.empty(shape, dtype=dtype, order=order) - return cls(arr) + return cls(arr.__array_interface__.copy()) @property def c_contiguous(self): @@ -104,7 +133,7 @@ def dtype(self): """ Datatype of the current device_ndarray instance """ - array_interface = self.ndarray_.__array_interface__ + array_interface = self.__array_interface__ return np.dtype(array_interface["typestr"]) @property @@ -112,7 +141,7 @@ def shape(self): """ Shape of the current device_ndarray instance """ - array_interface = self.ndarray_.__array_interface__ + array_interface = self.__array_interface__ return array_interface["shape"] @property @@ -120,7 +149,7 @@ def strides(self): """ Strides of the current device_ndarray instance """ - array_interface = self.ndarray_.__array_interface__ + array_interface = self.__array_interface__ return array_interface.get("strides") @property @@ -131,7 +160,7 @@ def __cuda_array_interface__(self): zero-copy semantics. """ device_cai = self.device_buffer_.__cuda_array_interface__ - host_cai = self.ndarray_.__array_interface__.copy() + host_cai = self.__array_interface__.copy() host_cai["data"] = (device_cai["data"][0], device_cai["data"][1]) return host_cai