Skip to content

Commit

Permalink
Raise MixedTypeError when a column of mixed-dtype is being construc…
Browse files Browse the repository at this point in the history
…ted (#14050)

Fixes #14038 

This PR introduces changes that raise an error when a column of `object` dtype is being constructed when the data is not string or bools.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #14050
  • Loading branch information
galipremsagar authored Sep 7, 2023
1 parent dc5f500 commit 6945c4f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
19 changes: 14 additions & 5 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,10 +2062,15 @@ def as_column(
)
else:
pyarrow_array = pa.array(arbitrary, from_pandas=nan_as_null)
if arbitrary.dtype == cudf.dtype("object") and isinstance(
pyarrow_array, (pa.DurationArray, pa.TimestampArray)
if (
arbitrary.dtype == cudf.dtype("object")
and cudf.dtype(pyarrow_array.type.to_pandas_dtype())
!= cudf.dtype(arbitrary.dtype)
and not is_bool_dtype(
cudf.dtype(pyarrow_array.type.to_pandas_dtype())
)
):
raise TypeError("Cannot create column with mixed types")
raise MixedTypeError("Cannot create column with mixed types")
if isinstance(pyarrow_array.type, pa.Decimal128Type):
pyarrow_type = cudf.Decimal128Dtype.from_arrow(
pyarrow_array.type
Expand Down Expand Up @@ -2436,8 +2441,12 @@ def as_column(
if (
isinstance(arbitrary, pd.Index)
and arbitrary.dtype == cudf.dtype("object")
and isinstance(
pyarrow_array, (pa.DurationArray, pa.TimestampArray)
and (
cudf.dtype(pyarrow_array.type.to_pandas_dtype())
!= cudf.dtype(arbitrary.dtype)
and not is_bool_dtype(
cudf.dtype(pyarrow_array.type.to_pandas_dtype())
)
)
):
raise MixedTypeError(
Expand Down
3 changes: 2 additions & 1 deletion python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2676,10 +2676,11 @@ def test_scalar_getitem(self, index_values, i):
12,
20,
],
[1, 2, 3, 4],
],
)
def test_index_mixed_dtype_error(data):
pi = pd.Index(data)
pi = pd.Index(data, dtype="object")
with pytest.raises(TypeError):
cudf.Index(pi)

Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,11 +2374,11 @@ def test_parquet_writer_list_statistics(tmpdir):
for i, col in enumerate(pd_slice):
stats = pq_file.metadata.row_group(rg).column(i).statistics

actual_min = cudf.Series(pd_slice[col].explode().explode()).min()
actual_min = pd_slice[col].explode().explode().dropna().min()
stats_min = stats.min
assert normalized_equals(actual_min, stats_min)

actual_max = cudf.Series(pd_slice[col].explode().explode()).max()
actual_max = pd_slice[col].explode().explode().dropna().max()
stats_max = stats.max
assert normalized_equals(actual_max, stats_max)

Expand Down
6 changes: 5 additions & 1 deletion python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,11 +2187,15 @@ def test_series_init_error():
)


@pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"])
@pytest.mark.parametrize(
"dtype", ["datetime64[ns]", "timedelta64[ns]", "object", "str"]
)
def test_series_mixed_dtype_error(dtype):
ps = pd.concat([pd.Series([1, 2, 3], dtype=dtype), pd.Series([10, 11])])
with pytest.raises(TypeError):
cudf.Series(ps)
with pytest.raises(TypeError):
cudf.Series(ps.array)


@pytest.mark.parametrize("data", [[True, False, None], [10, 200, 300]])
Expand Down

0 comments on commit 6945c4f

Please sign in to comment.