Skip to content

Commit

Permalink
Merge branch 'branch-24.02' into bug/mark_kernels_as_static
Browse files Browse the repository at this point in the history
  • Loading branch information
ttnghia authored Jan 17, 2024
2 parents a57144e + 2bead95 commit 0368028
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 13 deletions.
7 changes: 6 additions & 1 deletion python/dask_cudf/dask_cudf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from dask_cudf import sorting
from dask_cudf.accessors import ListMethods, StructMethods
from dask_cudf.sorting import _get_shuffle_method
from dask_cudf.sorting import _deprecate_shuffle_kwarg, _get_shuffle_method


class _Frame(dd.core._Frame, OperatorMethodMixin):
Expand Down Expand Up @@ -111,6 +111,7 @@ def do_apply_rows(df, func, incols, outcols, kwargs):
do_apply_rows, func, incols, outcols, kwargs, meta=meta
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def merge(self, other, shuffle_method=None, **kwargs):
on = kwargs.pop("on", None)
Expand All @@ -123,6 +124,7 @@ def merge(self, other, shuffle_method=None, **kwargs):
**kwargs,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def join(self, other, shuffle_method=None, **kwargs):
# CuDF doesn't support "right" join yet
Expand All @@ -141,6 +143,7 @@ def join(self, other, shuffle_method=None, **kwargs):
**kwargs,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def set_index(
self,
Expand Down Expand Up @@ -216,6 +219,7 @@ def set_index(
**kwargs,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def sort_values(
self,
Expand Down Expand Up @@ -298,6 +302,7 @@ def var(
else:
return _parallel_var(self, meta, skipna, split_every, out)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def shuffle(self, *args, shuffle_method=None, **kwargs):
"""Wraps dask.dataframe DataFrame.shuffle method"""
Expand Down
20 changes: 14 additions & 6 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import cudf
from cudf.utils.nvtx_annotation import _dask_cudf_nvtx_annotate

from dask_cudf.sorting import _deprecate_shuffle_kwarg

# aggregations that are dask-cudf optimized
OPTIMIZED_AGGS = (
"count",
Expand Down Expand Up @@ -189,8 +191,11 @@ def last(self, split_every=None, split_out=1):
split_out,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg, split_every=None, split_out=1, shuffle_method=None
):
if arg == "size":
return self.size()

Expand All @@ -211,15 +216,15 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
sep=self.sep,
sort=self.sort,
as_index=self.as_index,
shuffle_method=shuffle,
shuffle_method=shuffle_method,
**self.dropna,
)

return super().aggregate(
arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
shuffle_method=shuffle_method,
)


Expand Down Expand Up @@ -330,8 +335,11 @@ def last(self, split_every=None, split_out=1):
split_out,
)[self._slice]

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg, split_every=None, split_out=1, shuffle_method=None
):
if arg == "size":
return self.size()

Expand All @@ -342,14 +350,14 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):

if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
return _make_groupby_agg_call(
self, arg, split_every, split_out, shuffle
self, arg, split_every, split_out, shuffle_method
)[self._slice]

return super().aggregate(
arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
shuffle_method=shuffle_method,
)


Expand Down
28 changes: 28 additions & 0 deletions python/dask_cudf/dask_cudf/sorting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import warnings
from collections.abc import Iterator
from functools import wraps

import cupy
import numpy as np
Expand All @@ -21,6 +23,31 @@
_SHUFFLE_SUPPORT = ("tasks", "p2p") # "disk" not supported


def _deprecate_shuffle_kwarg(func):
@wraps(func)
def wrapper(*args, **kwargs):
old_arg_value = kwargs.pop("shuffle", None)

if old_arg_value is not None:
new_arg_value = old_arg_value
msg = (
"the 'shuffle' keyword is deprecated, "
"use 'shuffle_method' instead."
)

warnings.warn(msg, FutureWarning)
if kwargs.get("shuffle_method") is not None:
msg = (
"Can only specify 'shuffle' "
"or 'shuffle_method', not both."
)
raise TypeError(msg)
kwargs["shuffle_method"] = new_arg_value
return func(*args, **kwargs)

return wrapper


@_dask_cudf_nvtx_annotate
def set_index_post(df, index_name, drop, column_dtype):
df2 = df.set_index(index_name, drop=drop)
Expand Down Expand Up @@ -229,6 +256,7 @@ def quantile_divisions(df, by, npartitions):
return divisions


@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def sort_values(
df,
Expand Down
24 changes: 18 additions & 6 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,26 +834,38 @@ def test_groupby_shuffle():

# Sorted aggregation, single-partition output
# (sort=True, split_out=1)
got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=1)
got = gddf.groupby("a", sort=True).agg(
spec, shuffle_method=True, split_out=1
)
dd.assert_eq(expect, got)

# Sorted aggregation, multi-partition output
# (sort=True, split_out=2)
got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=2)
got = gddf.groupby("a", sort=True).agg(
spec, shuffle_method=True, split_out=2
)
dd.assert_eq(expect, got)

# Un-sorted aggregation, single-partition output
# (sort=False, split_out=1)
got = gddf.groupby("a", sort=False).agg(spec, shuffle=True, split_out=1)
got = gddf.groupby("a", sort=False).agg(
spec, shuffle_method=True, split_out=1
)
dd.assert_eq(expect.sort_index(), got.compute().sort_index())

# Un-sorted aggregation, multi-partition output
# (sort=False, split_out=2)
# NOTE: `shuffle=True` should be default
# NOTE: `shuffle_method=True` should be default
got = gddf.groupby("a", sort=False).agg(spec, split_out=2)
dd.assert_eq(expect, got.compute().sort_index())

# Sorted aggregation fails with split_out>1 when shuffle is False
# (sort=True, split_out=2, shuffle=False)
# (sort=True, split_out=2, shuffle_method=False)
with pytest.raises(ValueError):
gddf.groupby("a", sort=True).agg(spec, shuffle=False, split_out=2)
gddf.groupby("a", sort=True).agg(
spec, shuffle_method=False, split_out=2
)

# Check shuffle kwarg deprecation
with pytest.warns(match="'shuffle' keyword is deprecated"):
gddf.groupby("a", sort=True).agg(spec, shuffle=False)

0 comments on commit 0368028

Please sign in to comment.