Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar committed Sep 7, 2023
1 parent e81d79e commit fe25539
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
21 changes: 17 additions & 4 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from cudf.core.column import ColumnBase, column
from cudf.core.column_accessor import ColumnAccessor
from cudf.utils import ioutils
from cudf.utils.dtypes import is_mixed_with_object_dtype
from cudf.utils.dtypes import can_convert_to_column, is_mixed_with_object_dtype
from cudf.utils.utils import _is_same_name


Expand Down Expand Up @@ -608,8 +608,11 @@ def intersection(self, other, sort=False):
(1, 'Blue')],
)
"""
if not can_convert_to_column(other):
raise TypeError("Input must be Index or array-like")

if not isinstance(other, BaseIndex):
other = cudf.Index(other, name=self.name)
other = cudf.Index(other, name=getattr(other, "name", self.name))

if sort not in {None, False}:
raise ValueError(
Expand All @@ -618,9 +621,19 @@ def intersection(self, other, sort=False):
)

if self.equals(other):
dtypes = {self.dtype, other.dtype}
common_dtype = cudf.utils.dtypes.find_common_type(dtypes)
if self.has_duplicates:
return self.unique()._get_reconciled_name_object(other)
return self._get_reconciled_name_object(other)
return (
self.unique()
._get_reconciled_name_object(other)
.astype(common_dtype)
)
return self._get_reconciled_name_object(other).astype(common_dtype)
elif not len(other):
dtypes = {self.dtype, other.dtype}
common_dtype = cudf.utils.dtypes.find_common_type(dtypes)
return other._get_reconciled_name_object(self).astype(common_dtype)

res_name = _get_result_name(self.name, other.name)

Expand Down
23 changes: 21 additions & 2 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,9 @@ def _union(self, other, sort=None):
@_cudf_nvtx_annotate
def _intersection(self, other, sort=False):
if not isinstance(other, RangeIndex):
return super()._intersection(other, sort=sort)
return self._try_reconstruct_range_index(
super()._intersection(other, sort=sort)
)

if not len(self) or not len(other):
return RangeIndex(0)
Expand Down Expand Up @@ -722,7 +724,24 @@ def _intersection(self, other, sort=False):
if sort is None:
new_index = new_index.sort_values()

return new_index
return self._try_reconstruct_range_index(new_index)

def _try_reconstruct_range_index(self, index):
if index.dtype.kind == "f":
return index
# Evenly spaced values can return a
# RangeIndex instead of Int64Index
unique_diffs = (
index.to_frame(name="None").diff()["None"]._column.unique()
)
if len(unique_diffs) == 2 and (
unique_diffs.element_indexing(0) is cudf.NA
and unique_diffs.element_indexing(1) != 0
):
diff = unique_diffs.element_indexing(1)
new_range = range(index[0], index[-1] + diff, diff)
return type(self)(new_range, name=index.name)
return index

def sort_values(
self,
Expand Down
4 changes: 4 additions & 0 deletions python/cudf/cudf/core/join/_join_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def _match_join_keys(
common_type = ltype.categories.dtype
else:
common_type = rtype.categories.dtype
if cudf.get_option(
"mode.pandas_compatible"
) and common_type == cudf.dtype("object"):
common_type = "str"
return lcol.astype(common_type), rcol.astype(common_type)

if is_dtype_equal(ltype, rtype):
Expand Down
26 changes: 19 additions & 7 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2073,25 +2073,37 @@ def test_union_index(idx1, idx2, sort):
(pd.Index([0, 1, 2, 30], name=pd.NA), pd.Index([30, 0, 90, 100])),
(pd.Index([0, 1, 2, 30], name="a"), [90, 100]),
(pd.Index([0, 1, 2, 30]), pd.Index([0, 10, 1.0, 11])),
(pd.Index(["a", "b", "c", "d", "c"]), pd.Index(["a", "c", "z"])),
(
pd.Index(["a", "b", "c", "d", "c"]),
pd.Index(["a", "c", "z"], name="abc"),
),
(
pd.Index(["a", "b", "c", "d", "c"]),
pd.Index(["a", "b", "c", "d", "c"]),
),
(pd.Index([True, False, True, True]), pd.Index([10, 11, 12, 0, 1, 2])),
(pd.Index([True, False, True, True]), pd.Index([True, True])),
(pd.RangeIndex(0, 10, name="a"), pd.Index([5, 6, 7], name="b")),
(pd.Index(["a", "b", "c"], dtype="category"), pd.Index(["a", "b"])),
(pd.Index(["a", "b", "c"], dtype="category"), pd.Index([1, 2, 3])),
(pd.Index([0, 1, 2], dtype="category"), pd.RangeIndex(0, 10)),
(pd.Index(["a", "b", "c"], name="abc"), []),
(pd.Index([], name="abc"), pd.RangeIndex(0, 4)),
(pd.Index([1, 2, 3]), pd.Index([1, 2], dtype="category")),
(pd.Index([]), pd.Index([1, 2], dtype="category")),
],
)
@pytest.mark.parametrize("sort", [None, False])
def test_intersection_index(idx1, idx2, sort):
@pytest.mark.parametrize("pandas_compatible", [True, False])
def test_intersection_index(idx1, idx2, sort, pandas_compatible):
expected = idx1.intersection(idx2, sort=sort)
with cudf.option_context("mode.pandas_compatible", pandas_compatible):
idx1 = cudf.from_pandas(idx1) if isinstance(idx1, pd.Index) else idx1
idx2 = cudf.from_pandas(idx2) if isinstance(idx2, pd.Index) else idx2

idx1 = cudf.from_pandas(idx1) if isinstance(idx1, pd.Index) else idx1
idx2 = cudf.from_pandas(idx2) if isinstance(idx2, pd.Index) else idx2

actual = idx1.intersection(idx2, sort=sort)
actual = idx1.intersection(idx2, sort=sort)

assert_eq(expected, actual, exact=False)
assert_eq(expected, actual, exact=False)


@pytest.mark.parametrize(
Expand Down

0 comments on commit fe25539

Please sign in to comment.