Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make proxy NumPy arrays pass isinstance check in cudf.pandas #16286

Merged
merged 25 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
365e1b0
merge conflict
Matt711 Jul 16, 2024
b21a47a
Add __array_finalize__ to ProxyNDarrayBase
Matt711 Jul 16, 2024
f1e6670
Add __array_ufunc__ to proxy array
Matt711 Jul 17, 2024
49dec4a
Merge branch 'branch-24.08' into feature/numpy-proxying
Matt711 Jul 17, 2024
5d4f170
Merge branch 'branch-24.08' of github.com:rapidsai/cudf into feature/…
Matt711 Jul 19, 2024
108895b
ensure __array_ufunc__ returns a real numpy array
Matt711 Jul 19, 2024
51c5b1a
Merge branch 'branch-24.08' into feature/numpy-proxying
Matt711 Jul 19, 2024
67989ad
Merge branch 'branch-24.08' into feature/numpy-proxying
galipremsagar Jul 22, 2024
a11526c
Merge branch 'branch-24.08' into feature/numpy-proxying
Matt711 Jul 24, 2024
e47a9d1
Merge branch 'branch-24.10' of github.com:rapidsai/cudf into feature/…
Matt711 Jul 24, 2024
84ff28e
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Jul 25, 2024
3dccf7d
Merge branch 'branch-24.10' into feature/numpy-proxying
galipremsagar Jul 25, 2024
1d74061
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Jul 30, 2024
ab90aef
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Aug 1, 2024
df0c19d
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Aug 1, 2024
a88ad71
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Aug 2, 2024
fc3b7d3
Merge branch 'branch-24.10' into feature/numpy-proxying
galipremsagar Aug 7, 2024
6d49d35
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Aug 8, 2024
81810b3
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Aug 13, 2024
4bc0d95
Add a test and TODO
Matt711 Aug 13, 2024
6d7d61b
Merge branch 'branch-24.10' into feature/numpy-proxying
Matt711 Aug 13, 2024
b265aaa
Merge branch 'branch-24.10' into feature/numpy-proxying
galipremsagar Aug 14, 2024
f4fe303
Merge branch 'branch-24.10' into feature/numpy-proxying
galipremsagar Aug 15, 2024
32c03c0
Merge branch 'branch-24.10' into feature/numpy-proxying
galipremsagar Aug 15, 2024
23365d2
Address review
Matt711 Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/cudf/cudf/pandas/_wrappers/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
make_final_proxy_type,
make_intermediate_proxy_type,
)
from ..proxy_base import ProxyNDarrayBase
from .common import (
array_interface,
array_method,
Expand Down Expand Up @@ -111,6 +112,7 @@ def wrap_ndarray(cls, arr: cupy.ndarray | numpy.ndarray, constructor):
numpy.ndarray,
fast_to_slow=cupy.ndarray.get,
slow_to_fast=cupy.asarray,
bases=(ProxyNDarrayBase,),
additional_attributes={
"__array__": array_method,
# So that pa.array(wrapped-numpy-array) works
Expand Down
27 changes: 25 additions & 2 deletions python/cudf/cudf/pandas/fast_slow_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..options import _env_get_bool
from ..testing import assert_eq
from .annotation import nvtx
from .proxy_base import ProxyNDarrayBase


def call_operator(fn, args, kwargs):
Expand Down Expand Up @@ -564,8 +565,13 @@ def _fsproxy_wrap(cls, value, func):
_FinalProxy subclasses can override this classmethod if they
need particular behaviour when wrapped up.
"""
proxy = object.__new__(cls)
proxy._fsproxy_wrapped = value
# TODO: Use _has_proxy_base_class to perform the check
mroeschke marked this conversation as resolved.
Show resolved Hide resolved
if np.ndarray in cls.__mro__:
proxy = ProxyNDarrayBase.__new__(cls, value)
proxy._fsproxy_wrapped = value
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
else:
proxy = object.__new__(cls)
proxy._fsproxy_wrapped = value
return proxy

def __reduce__(self):
Expand Down Expand Up @@ -1193,6 +1199,23 @@ def is_proxy_object(obj: Any) -> bool:
return False


def _has_proxy_base_class(cls):
"""Determine if an object is proxy object

Parameters
----------
cls : type
The type to check.

"""
return any(base in cls.__mro__ for base in PROXY_BASE_CLASSES)


PROXY_BASE_CLASSES: set[type] = {
ProxyNDarrayBase,
}


NUMPY_TYPES: set[str] = set(np.sctypeDict.values())


Expand Down
27 changes: 27 additions & 0 deletions python/cudf/cudf/pandas/proxy_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import cupy as cp
import numpy as np


class ProxyNDarrayBase(np.ndarray):
def __new__(cls, arr):
if isinstance(arr, cp.ndarray):
obj = np.asarray(arr.get()).view(cls)
return obj
elif isinstance(arr, np.ndarray):
obj = np.asarray(arr).view(cls)
return obj
else:
raise TypeError(
"Unsupported array type. Must be numpy.ndarray or cupy.ndarray"
)

def __array_finalize__(self, obj):
self._fsproxy_wrapped = getattr(obj, "_fsproxy_wrapped", None)

def __array_ufunc__(self, *args, **kwargs):
mroeschke marked this conversation as resolved.
Show resolved Hide resolved
args = (args[0], args[1], np.asarray(args[2]), np.asarray(args[3]))
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
mroeschke marked this conversation as resolved.
Show resolved Hide resolved
return super().__array_ufunc__(*args, **kwargs)
8 changes: 8 additions & 0 deletions python/cudf/cudf_pandas_tests/test_cudf_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,3 +1632,11 @@ def test_change_index_name(index):

assert s.index.name == name
assert df.index.name == name


def test_numpy_ndarray_isinstancecheck(series):
s1, s2 = series
arr1 = s1.values
arr2 = s2.values
assert isinstance(arr1, np.ndarray)
assert isinstance(arr2, np.ndarray)
Loading