diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index 2470b4d50f1..e3f4f04eb85 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -20,11 +20,14 @@ from dask.dataframe.dispatch import ( categorical_dtype_dispatch, concat_dispatch, + from_pyarrow_table_dispatch, group_split_dispatch, grouper_dispatch, hash_object_dispatch, is_categorical_dtype_dispatch, make_meta_dispatch, + pyarrow_schema_dispatch, + to_pyarrow_table_dispatch, tolist_dispatch, union_categoricals_dispatch, ) @@ -317,16 +320,6 @@ def get_grouper_cudf(obj): return cudf.core.groupby.Grouper -try: - from dask.dataframe.dispatch import pyarrow_schema_dispatch - - @pyarrow_schema_dispatch.register((cudf.DataFrame,)) - def get_pyarrow_schema_cudf(obj): - return obj.to_arrow().schema - -except ImportError: - pass - try: try: from dask.array.dispatch import percentile_lookup @@ -378,35 +371,37 @@ def percentile_cudf(a, q, interpolation="linear"): except ImportError: pass -try: - # Requires dask>2023.6.0 - from dask.dataframe.dispatch import ( - from_pyarrow_table_dispatch, - to_pyarrow_table_dispatch, - ) - @to_pyarrow_table_dispatch.register(cudf.DataFrame) - def _cudf_to_table(obj, preserve_index=True, **kwargs): - if kwargs: - warnings.warn( - "Ignoring the following arguments to " - f"`to_pyarrow_table_dispatch`: {list(kwargs)}" - ) - return obj.to_arrow(preserve_index=preserve_index) - - @from_pyarrow_table_dispatch.register(cudf.DataFrame) - def _table_to_cudf(obj, table, self_destruct=None, **kwargs): - # cudf ignores self_destruct. - kwargs.pop("self_destruct", None) - if kwargs: - warnings.warn( - f"Ignoring the following arguments to " - f"`from_pyarrow_table_dispatch`: {list(kwargs)}" - ) - return obj.from_arrow(table) +@pyarrow_schema_dispatch.register((cudf.DataFrame,)) +def _get_pyarrow_schema_cudf(obj, preserve_index=True, **kwargs): + if kwargs: + warnings.warn( + "Ignoring the following arguments to " + f"`pyarrow_schema_dispatch`: {list(kwargs)}" + ) + return meta_nonempty(obj).to_arrow(preserve_index=preserve_index).schema -except ImportError: - pass + +@to_pyarrow_table_dispatch.register(cudf.DataFrame) +def _cudf_to_table(obj, preserve_index=True, **kwargs): + if kwargs: + warnings.warn( + "Ignoring the following arguments to " + f"`to_pyarrow_table_dispatch`: {list(kwargs)}" + ) + return obj.to_arrow(preserve_index=preserve_index) + + +@from_pyarrow_table_dispatch.register(cudf.DataFrame) +def _table_to_cudf(obj, table, self_destruct=None, **kwargs): + # cudf ignores self_destruct. + kwargs.pop("self_destruct", None) + if kwargs: + warnings.warn( + f"Ignoring the following arguments to " + f"`from_pyarrow_table_dispatch`: {list(kwargs)}" + ) + return obj.from_arrow(table) @union_categoricals_dispatch.register((cudf.Series, cudf.BaseIndex)) diff --git a/python/dask_cudf/dask_cudf/tests/test_dispatch.py b/python/dask_cudf/dask_cudf/tests/test_dispatch.py index 22cc0f161e2..cf49b1df4f4 100644 --- a/python/dask_cudf/dask_cudf/tests/test_dispatch.py +++ b/python/dask_cudf/dask_cudf/tests/test_dispatch.py @@ -3,9 +3,7 @@ import numpy as np import pandas as pd import pytest -from packaging import version -import dask from dask.base import tokenize from dask.dataframe import assert_eq from dask.dataframe.methods import is_categorical_dtype @@ -24,10 +22,6 @@ def test_is_categorical_dispatch(): assert is_categorical_dtype(cudf.Index([1, 2, 3], dtype="category")) -@pytest.mark.skipif( - version.parse(dask.__version__) <= version.parse("2023.6.0"), - reason="Pyarrow-conversion dispatch requires dask>2023.6.0", -) def test_pyarrow_conversion_dispatch(): from dask.dataframe.dispatch import ( from_pyarrow_table_dispatch, @@ -79,3 +73,18 @@ def test_deterministic_tokenize(index): df2 = df.set_index(["B", "C"], drop=False) assert tokenize(df) != tokenize(df2) assert tokenize(df2) == tokenize(df2) + + +@pytest.mark.parametrize("preserve_index", [True, False]) +def test_pyarrow_schema_dispatch(preserve_index): + from dask.dataframe.dispatch import ( + pyarrow_schema_dispatch, + to_pyarrow_table_dispatch, + ) + + df = cudf.DataFrame(np.random.randn(10, 3), columns=list("abc")) + df["d"] = cudf.Series(["cat", "dog"] * 5) + table = to_pyarrow_table_dispatch(df, preserve_index=preserve_index) + schema = pyarrow_schema_dispatch(df, preserve_index=preserve_index) + + assert schema.equals(table.schema)