Skip to content

Commit

Permalink
Fix empty string column construction (#14052)
Browse files Browse the repository at this point in the history
Fixes #14046 

This PR fixes empty string column construction that arises due to a corner-case in the way pyarrow constructs arrays.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #14052
  • Loading branch information
galipremsagar authored Sep 7, 2023
1 parent 6945c4f commit c9d8821
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
15 changes: 15 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2438,6 +2438,21 @@ def as_column(
from_pandas=True if nan_as_null is None else nan_as_null,
)

if (
isinstance(pyarrow_array, pa.NullArray)
and pa_type is None
and dtype is None
and getattr(arbitrary, "dtype", None)
== cudf.dtype("object")
):
# pa.array constructor returns a NullArray
# for empty arrays, instead of a StringArray.
# This issue is only specific to this dtype,
# all other dtypes, result in their corresponding
# arrow array creation.
dtype = cudf.dtype("str")
pyarrow_array = pyarrow_array.cast(np_to_pa_dtype(dtype))

if (
isinstance(arbitrary, pd.Index)
and arbitrary.dtype == cudf.dtype("object")
Expand Down
5 changes: 1 addition & 4 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7256,10 +7256,7 @@ def test_dataframe_keys(df):
def test_series_keys(ps):
gds = cudf.from_pandas(ps)

if len(ps) == 0 and not isinstance(ps.index, pd.RangeIndex):
assert_eq(ps.keys().astype("float64"), gds.keys())
else:
assert_eq(ps.keys(), gds.keys())
assert_eq(ps.keys(), gds.keys())


@pytest_unmark_spilling
Expand Down
24 changes: 24 additions & 0 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
as_index,
)
from cudf.testing._utils import (
ALL_TYPES,
FLOAT_TYPES,
NUMERIC_TYPES,
OTHER_TYPES,
Expand Down Expand Up @@ -2703,3 +2704,26 @@ def test_index_getitem_time_duration(dtype):
assert gidx[i] is pidx[i]
else:
assert_eq(gidx[i], pidx[i])


@pytest.mark.parametrize("dtype", ALL_TYPES)
def test_index_empty_from_pandas(request, dtype):
request.node.add_marker(
pytest.mark.xfail(
condition=not PANDAS_GE_200
and dtype
in {
"datetime64[ms]",
"datetime64[s]",
"datetime64[us]",
"timedelta64[ms]",
"timedelta64[s]",
"timedelta64[us]",
},
reason="Fixed in pandas-2.0",
)
)
pidx = pd.Index([], dtype=dtype)
gidx = cudf.from_pandas(pidx)

assert_eq(pidx, gidx)

0 comments on commit c9d8821

Please sign in to comment.