Skip to content

Commit

Permalink
Update pyarrow-related dispatch logic in dask_cudf (#14069)
Browse files Browse the repository at this point in the history
Updates `dask_cudf` dispatch logic to avoid breakage from dask/dask#10500.
Also removes stale `try`/`except` logic.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)
  - Ray Douglass (https://github.com/raydouglass)
  - gpuCI (https://github.com/GPUtester)
  - Mike Wendt (https://github.com/mike-wendt)
  - AJ Schmidt (https://github.com/ajschmidt8)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #14069
  • Loading branch information
rjzamora authored Sep 18, 2023
1 parent 3b691f4 commit 4ca568e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 43 deletions.
69 changes: 32 additions & 37 deletions python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
from dask.dataframe.dispatch import (
categorical_dtype_dispatch,
concat_dispatch,
from_pyarrow_table_dispatch,
group_split_dispatch,
grouper_dispatch,
hash_object_dispatch,
is_categorical_dtype_dispatch,
make_meta_dispatch,
pyarrow_schema_dispatch,
to_pyarrow_table_dispatch,
tolist_dispatch,
union_categoricals_dispatch,
)
Expand Down Expand Up @@ -317,16 +320,6 @@ def get_grouper_cudf(obj):
return cudf.core.groupby.Grouper


try:
from dask.dataframe.dispatch import pyarrow_schema_dispatch

@pyarrow_schema_dispatch.register((cudf.DataFrame,))
def get_pyarrow_schema_cudf(obj):
return obj.to_arrow().schema

except ImportError:
pass

try:
try:
from dask.array.dispatch import percentile_lookup
Expand Down Expand Up @@ -378,35 +371,37 @@ def percentile_cudf(a, q, interpolation="linear"):
except ImportError:
pass

try:
# Requires dask>2023.6.0
from dask.dataframe.dispatch import (
from_pyarrow_table_dispatch,
to_pyarrow_table_dispatch,
)

@to_pyarrow_table_dispatch.register(cudf.DataFrame)
def _cudf_to_table(obj, preserve_index=True, **kwargs):
if kwargs:
warnings.warn(
"Ignoring the following arguments to "
f"`to_pyarrow_table_dispatch`: {list(kwargs)}"
)
return obj.to_arrow(preserve_index=preserve_index)

@from_pyarrow_table_dispatch.register(cudf.DataFrame)
def _table_to_cudf(obj, table, self_destruct=None, **kwargs):
# cudf ignores self_destruct.
kwargs.pop("self_destruct", None)
if kwargs:
warnings.warn(
f"Ignoring the following arguments to "
f"`from_pyarrow_table_dispatch`: {list(kwargs)}"
)
return obj.from_arrow(table)
@pyarrow_schema_dispatch.register((cudf.DataFrame,))
def _get_pyarrow_schema_cudf(obj, preserve_index=True, **kwargs):
if kwargs:
warnings.warn(
"Ignoring the following arguments to "
f"`pyarrow_schema_dispatch`: {list(kwargs)}"
)
return meta_nonempty(obj).to_arrow(preserve_index=preserve_index).schema

except ImportError:
pass

@to_pyarrow_table_dispatch.register(cudf.DataFrame)
def _cudf_to_table(obj, preserve_index=True, **kwargs):
if kwargs:
warnings.warn(
"Ignoring the following arguments to "
f"`to_pyarrow_table_dispatch`: {list(kwargs)}"
)
return obj.to_arrow(preserve_index=preserve_index)


@from_pyarrow_table_dispatch.register(cudf.DataFrame)
def _table_to_cudf(obj, table, self_destruct=None, **kwargs):
# cudf ignores self_destruct.
kwargs.pop("self_destruct", None)
if kwargs:
warnings.warn(
f"Ignoring the following arguments to "
f"`from_pyarrow_table_dispatch`: {list(kwargs)}"
)
return obj.from_arrow(table)


@union_categoricals_dispatch.register((cudf.Series, cudf.BaseIndex))
Expand Down
21 changes: 15 additions & 6 deletions python/dask_cudf/dask_cudf/tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import numpy as np
import pandas as pd
import pytest
from packaging import version

import dask
from dask.base import tokenize
from dask.dataframe import assert_eq
from dask.dataframe.methods import is_categorical_dtype
Expand All @@ -24,10 +22,6 @@ def test_is_categorical_dispatch():
assert is_categorical_dtype(cudf.Index([1, 2, 3], dtype="category"))


@pytest.mark.skipif(
version.parse(dask.__version__) <= version.parse("2023.6.0"),
reason="Pyarrow-conversion dispatch requires dask>2023.6.0",
)
def test_pyarrow_conversion_dispatch():
from dask.dataframe.dispatch import (
from_pyarrow_table_dispatch,
Expand Down Expand Up @@ -79,3 +73,18 @@ def test_deterministic_tokenize(index):
df2 = df.set_index(["B", "C"], drop=False)
assert tokenize(df) != tokenize(df2)
assert tokenize(df2) == tokenize(df2)


@pytest.mark.parametrize("preserve_index", [True, False])
def test_pyarrow_schema_dispatch(preserve_index):
from dask.dataframe.dispatch import (
pyarrow_schema_dispatch,
to_pyarrow_table_dispatch,
)

df = cudf.DataFrame(np.random.randn(10, 3), columns=list("abc"))
df["d"] = cudf.Series(["cat", "dog"] * 5)
table = to_pyarrow_table_dispatch(df, preserve_index=preserve_index)
schema = pyarrow_schema_dispatch(df, preserve_index=preserve_index)

assert schema.equals(table.schema)

0 comments on commit 4ca568e

Please sign in to comment.