Skip to content

Commit

Permalink
Remove pandas shim and use result_type
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Jan 26, 2024
1 parent 80090d7 commit 481ea9c
Showing 1 changed file with 1 addition and 29 deletions.
30 changes: 1 addition & 29 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,34 +507,6 @@ def get_allowed_combinations_for_operator(dtype_l, dtype_r, op):
raise error


def np_find_common_type(*dtypes: np.dtype) -> np.dtype:
"""
np.find_common_type implementation pre-1.25 deprecation using np.result_type
https://github.com/pandas-dev/pandas/pull/49569#issuecomment-1308300065
Parameters
----------
dtypes : np.dtypes
Returns
-------
np.dtype
"""
# TODO: possibly raise the TypeError. Coercing to np.dtype("O") (string)
# might not make sense in cudf
try:
common_dtype = np.result_type(*dtypes)
if common_dtype.kind in "mMSU":
# NumPy promotion currently (1.25) misbehaves for for times and strings,
# so fall back to object (find_common_dtype did unless there
# was only one dtype)
common_dtype = np.dtype("O")

except TypeError:
common_dtype = np.dtype("O")
return common_dtype


def find_common_type(dtypes):
"""
Wrapper over np.find_common_type to handle special cases
Expand Down Expand Up @@ -642,7 +614,7 @@ def find_common_type(dtypes):
dtypes = dtypes - td_dtypes
dtypes.add(np.result_type(*td_dtypes))

common_dtype = np_find_common_type(*dtypes)
common_dtype = np.result_type(*dtypes)
if common_dtype == np.dtype("float16"):
return cudf.dtype("float32")
return cudf.dtype(common_dtype)
Expand Down

0 comments on commit 481ea9c

Please sign in to comment.