From 1c63e1ee31a07fb4999d7356919280ba3d528741 Mon Sep 17 00:00:00 2001 From: Matthew Murray <41342305+Matt711@users.noreply.github.com> Date: Thu, 15 Aug 2024 21:51:47 -0400 Subject: [PATCH] Initial investigation into NumPy proxying in `cudf.pandas` (#16286) Apart of #15397. Closes #14537. Creates `ProxyNDarray` which inherits from `np.ndarray`. Authors: - Matthew Murray (https://github.com/Matt711) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/16286 --- python/cudf/cudf/pandas/_wrappers/numpy.py | 3 +++ python/cudf/cudf/pandas/fast_slow_proxy.py | 20 +++++++++++++++- python/cudf/cudf/pandas/proxy_base.py | 23 +++++++++++++++++++ .../cudf_pandas_tests/test_cudf_pandas.py | 8 +++++++ 4 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 python/cudf/cudf/pandas/proxy_base.py diff --git a/python/cudf/cudf/pandas/_wrappers/numpy.py b/python/cudf/cudf/pandas/_wrappers/numpy.py index 3b012169676..eabea9713f1 100644 --- a/python/cudf/cudf/pandas/_wrappers/numpy.py +++ b/python/cudf/cudf/pandas/_wrappers/numpy.py @@ -14,6 +14,7 @@ make_final_proxy_type, make_intermediate_proxy_type, ) +from ..proxy_base import ProxyNDarrayBase from .common import ( array_interface, array_method, @@ -111,12 +112,14 @@ def wrap_ndarray(cls, arr: cupy.ndarray | numpy.ndarray, constructor): numpy.ndarray, fast_to_slow=cupy.ndarray.get, slow_to_fast=cupy.asarray, + bases=(ProxyNDarrayBase,), additional_attributes={ "__array__": array_method, # So that pa.array(wrapped-numpy-array) works "__arrow_array__": arrow_array_method, "__cuda_array_interface__": cuda_array_interface, "__array_interface__": array_interface, + "__array_ufunc__": _FastSlowAttribute("__array_ufunc__"), # ndarrays are unhashable "__hash__": None, # iter(cupy-array) produces an iterable of zero-dim device diff --git a/python/cudf/cudf/pandas/fast_slow_proxy.py b/python/cudf/cudf/pandas/fast_slow_proxy.py index bb678fd1efe..61aa6310082 100644 --- a/python/cudf/cudf/pandas/fast_slow_proxy.py +++ b/python/cudf/cudf/pandas/fast_slow_proxy.py @@ -19,6 +19,7 @@ from ..options import _env_get_bool from ..testing import assert_eq from .annotation import nvtx +from .proxy_base import ProxyNDarrayBase def call_operator(fn, args, kwargs): @@ -564,7 +565,11 @@ def _fsproxy_wrap(cls, value, func): _FinalProxy subclasses can override this classmethod if they need particular behaviour when wrapped up. """ - proxy = object.__new__(cls) + base_class = _get_proxy_base_class(cls) + if base_class is object: + proxy = base_class.__new__(cls) + else: + proxy = base_class.__new__(cls, value) proxy._fsproxy_wrapped = value return proxy @@ -1193,6 +1198,19 @@ def is_proxy_object(obj: Any) -> bool: return False +def _get_proxy_base_class(cls): + """Returns the proxy base class if one exists""" + for proxy_class in PROXY_BASE_CLASSES: + if proxy_class in cls.__mro__: + return proxy_class + return object + + +PROXY_BASE_CLASSES: set[type] = { + ProxyNDarrayBase, +} + + NUMPY_TYPES: set[str] = set(np.sctypeDict.values()) diff --git a/python/cudf/cudf/pandas/proxy_base.py b/python/cudf/cudf/pandas/proxy_base.py new file mode 100644 index 00000000000..61d9cde127c --- /dev/null +++ b/python/cudf/cudf/pandas/proxy_base.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import cupy as cp +import numpy as np + + +class ProxyNDarrayBase(np.ndarray): + def __new__(cls, arr): + if isinstance(arr, cp.ndarray): + obj = np.asarray(arr.get()).view(cls) + return obj + elif isinstance(arr, np.ndarray): + obj = np.asarray(arr).view(cls) + return obj + else: + raise TypeError( + "Unsupported array type. Must be numpy.ndarray or cupy.ndarray" + ) + + def __array_finalize__(self, obj): + self._fsproxy_wrapped = getattr(obj, "_fsproxy_wrapped", None) diff --git a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py index 6292022d8e4..e5483fff913 100644 --- a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py +++ b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py @@ -1632,3 +1632,11 @@ def test_change_index_name(index): assert s.index.name == name assert df.index.name == name + + +def test_numpy_ndarray_isinstancecheck(series): + s1, s2 = series + arr1 = s1.values + arr2 = s2.values + assert isinstance(arr1, np.ndarray) + assert isinstance(arr2, np.ndarray)