diff --git a/python/cudf/cudf/_lib/CMakeLists.txt b/python/cudf/cudf/_lib/CMakeLists.txt index 410fd57691e..da4faabf189 100644 --- a/python/cudf/cudf/_lib/CMakeLists.txt +++ b/python/cudf/cudf/_lib/CMakeLists.txt @@ -12,7 +12,7 @@ # the License. # ============================================================================= -set(cython_sources column.pyx groupby.pyx scalar.pyx strings_udf.pyx types.pyx utils.pyx) +set(cython_sources column.pyx scalar.pyx strings_udf.pyx types.pyx utils.pyx) set(linked_libraries cudf::cudf) rapids_cython_create_modules( diff --git a/python/cudf/cudf/_lib/__init__.py b/python/cudf/cudf/_lib/__init__.py index 6b5a7814e48..10f9d813ccc 100644 --- a/python/cudf/cudf/_lib/__init__.py +++ b/python/cudf/cudf/_lib/__init__.py @@ -1,10 +1,7 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. import numpy as np -from . import ( - groupby, - strings_udf, -) +from . import strings_udf MAX_COLUMN_SIZE = np.iinfo(np.int32).max MAX_COLUMN_SIZE_STR = "INT32_MAX" diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx deleted file mode 100644 index 80a77ef2267..00000000000 --- a/python/cudf/cudf/_lib/groupby.pyx +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. -from functools import singledispatch - -from pandas.errors import DataError - -from cudf.api.types import _is_categorical_dtype, is_string_dtype -from cudf.core.buffer import acquire_spill_lock -from cudf.core.dtypes import ( - CategoricalDtype, - DecimalDtype, - IntervalDtype, - ListDtype, - StructDtype, -) - -from cudf._lib.scalar cimport DeviceScalar -from cudf._lib.utils cimport columns_from_pylibcudf_table - -from cudf._lib.scalar import as_device_scalar - -import pylibcudf - -from cudf.core._internals.aggregation import make_aggregation - -# The sets below define the possible aggregations that can be performed on -# different dtypes. These strings must be elements of the AggregationKind enum. -# The libcudf infrastructure exists for "COLLECT" support on -# categoricals, but the dtype support in python does not. -_CATEGORICAL_AGGS = {"COUNT", "NUNIQUE", "SIZE", "UNIQUE"} -_STRING_AGGS = { - "COLLECT", - "COUNT", - "MAX", - "MIN", - "NTH", - "NUNIQUE", - "SIZE", - "UNIQUE", -} -_LIST_AGGS = {"COLLECT"} -_STRUCT_AGGS = {"COLLECT", "CORRELATION", "COVARIANCE"} -_INTERVAL_AGGS = {"COLLECT"} -_DECIMAL_AGGS = { - "ARGMIN", - "ARGMAX", - "COLLECT", - "COUNT", - "MAX", - "MIN", - "NTH", - "NUNIQUE", - "SUM", -} - - -@singledispatch -def get_valid_aggregation(dtype): - if is_string_dtype(dtype): - return _STRING_AGGS - return "ALL" - - -@get_valid_aggregation.register -def _(dtype: ListDtype): - return _LIST_AGGS - - -@get_valid_aggregation.register -def _(dtype: CategoricalDtype): - return _CATEGORICAL_AGGS - - -@get_valid_aggregation.register -def _(dtype: ListDtype): - return _LIST_AGGS - - -@get_valid_aggregation.register -def _(dtype: StructDtype): - return _STRUCT_AGGS - - -@get_valid_aggregation.register -def _(dtype: IntervalDtype): - return _INTERVAL_AGGS - - -@get_valid_aggregation.register -def _(dtype: DecimalDtype): - return _DECIMAL_AGGS - - -cdef class GroupBy: - cdef dict __dict__ - - def __init__(self, keys, dropna=True): - with acquire_spill_lock() as spill_lock: - self._groupby = pylibcudf.groupby.GroupBy( - pylibcudf.table.Table([c.to_pylibcudf(mode="read") for c in keys]), - pylibcudf.types.NullPolicy.EXCLUDE if dropna - else pylibcudf.types.NullPolicy.INCLUDE - ) - - # We spill lock the columns while this GroupBy instance is alive. - self._spill_lock = spill_lock - - def groups(self, list values): - """ - Perform a sort groupby, using the keys used to construct the Groupby as the key - columns and ``values`` as the value columns. - - Parameters - ---------- - values: list of Columns - The value columns - - Returns - ------- - offsets: list of integers - Integer offsets such that offsets[i+1] - offsets[i] - represents the size of group `i`. - grouped_keys: list of Columns - The grouped key columns - grouped_values: list of Columns - The grouped value columns - """ - offsets, grouped_keys, grouped_values = self._groupby.get_groups( - pylibcudf.table.Table([c.to_pylibcudf(mode="read") for c in values]) - if values else None - ) - - return ( - offsets, - columns_from_pylibcudf_table(grouped_keys), - ( - columns_from_pylibcudf_table(grouped_values) - if grouped_values is not None else [] - ), - ) - - def aggregate(self, values, aggregations): - """ - Parameters - ---------- - values : Frame - aggregations - A dict mapping column names in `Frame` to a list of aggregations - to perform on that column - - Each aggregation may be specified as: - - a string (e.g., "max") - - a lambda/function - - Returns - ------- - Frame of aggregated values - """ - included_aggregations = [] - column_included = [] - requests = [] - for i, (col, aggs) in enumerate(zip(values, aggregations)): - valid_aggregations = get_valid_aggregation(col.dtype) - included_aggregations_i = [] - col_aggregations = [] - for agg in aggs: - str_agg = str(agg) - if ( - is_string_dtype(col) - and agg not in _STRING_AGGS - and - ( - str_agg in {"cumsum", "cummin", "cummax"} - or not ( - any(a in str_agg for a in { - "count", - "max", - "min", - "first", - "last", - "nunique", - "unique", - "nth" - }) - or (agg is list) - ) - ) - ): - raise TypeError( - f"function is not supported for this dtype: {agg}" - ) - elif ( - _is_categorical_dtype(col) - and agg not in _CATEGORICAL_AGGS - and ( - str_agg in {"cumsum", "cummin", "cummax"} - or - not ( - any(a in str_agg for a in {"count", "max", "min", "unique"}) - ) - ) - ): - raise TypeError( - f"{col.dtype} type does not support {agg} operations" - ) - - agg_obj = make_aggregation(agg) - if valid_aggregations == "ALL" or agg_obj.kind in valid_aggregations: - included_aggregations_i.append((agg, agg_obj.kind)) - col_aggregations.append(agg_obj.c_obj) - included_aggregations.append(included_aggregations_i) - if col_aggregations: - requests.append(pylibcudf.groupby.GroupByRequest( - col.to_pylibcudf(mode="read"), col_aggregations - )) - column_included.append(i) - - if not requests and any(len(v) > 0 for v in aggregations): - raise DataError("All requested aggregations are unsupported.") - - keys, results = self._groupby.scan(requests) if \ - _is_all_scan_aggregate(aggregations) else self._groupby.aggregate(requests) - - result_columns = [[] for _ in range(len(values))] - for i, result in zip(column_included, results): - result_columns[i] = columns_from_pylibcudf_table(result) - - return result_columns, columns_from_pylibcudf_table(keys), included_aggregations - - def shift(self, list values, int periods, list fill_values): - keys, shifts = self._groupby.shift( - pylibcudf.table.Table([c.to_pylibcudf(mode="read") for c in values]), - [periods] * len(values), - [ - ( as_device_scalar(val, dtype=col.dtype)).c_value - for val, col in zip(fill_values, values) - ], - ) - - return columns_from_pylibcudf_table(shifts), columns_from_pylibcudf_table(keys) - - def replace_nulls(self, list values, object method): - _, replaced = self._groupby.replace_nulls( - pylibcudf.table.Table([c.to_pylibcudf(mode="read") for c in values]), - [ - pylibcudf.replace.ReplacePolicy.PRECEDING - if method == 'ffill' else pylibcudf.replace.ReplacePolicy.FOLLOWING - ] * len(values), - ) - - return columns_from_pylibcudf_table(replaced) - - -_GROUPBY_SCANS = {"cumcount", "cumsum", "cummin", "cummax", "cumprod", "rank"} - - -def _is_all_scan_aggregate(all_aggs): - """ - Returns true if all are scan aggregations. - Raises - ------ - NotImplementedError - If both reduction aggregations and scan aggregations are present. - """ - - def get_name(agg): - return agg.__name__ if callable(agg) else agg - - all_scan = all( - get_name(agg_name) in _GROUPBY_SCANS for aggs in all_aggs - for agg_name in aggs - ) - any_scan = any( - get_name(agg_name) in _GROUPBY_SCANS for aggs in all_aggs - for agg_name in aggs - ) - - if not all_scan and any_scan: - raise NotImplementedError( - "Cannot perform both aggregation and scan in one operation" - ) - return all_scan and any_scan diff --git a/python/cudf/cudf/core/_internals/aggregation.py b/python/cudf/cudf/core/_internals/aggregation.py index fe8ea5a947a..1d21d34b1bf 100644 --- a/python/cudf/cudf/core/_internals/aggregation.py +++ b/python/cudf/cudf/core/_internals/aggregation.py @@ -29,11 +29,11 @@ class Aggregation: def __init__(self, agg: plc.aggregation.Aggregation) -> None: - self.c_obj = agg + self.plc_obj = agg @property def kind(self) -> str: - name = self.c_obj.kind().name + name = self.plc_obj.kind().name return _agg_name_map.get(name, name) @classmethod diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index cccafaeba88..75b9070b53f 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1605,7 +1605,7 @@ def scan(self, scan_op: str, inclusive: bool, **kwargs) -> Self: return type(self).from_pylibcudf( # type: ignore[return-value] plc.reduce.scan( self.to_pylibcudf(mode="read"), - aggregation.make_aggregation(scan_op, kwargs).c_obj, + aggregation.make_aggregation(scan_op, kwargs).plc_obj, plc.reduce.ScanType.INCLUSIVE if inclusive else plc.reduce.ScanType.EXCLUSIVE, @@ -1637,7 +1637,7 @@ def reduce(self, reduction_op: str, dtype=None, **kwargs) -> ScalarLike: with acquire_spill_lock(): plc_scalar = plc.reduce.reduce( self.to_pylibcudf(mode="read"), - aggregation.make_aggregation(reduction_op, kwargs).c_obj, + aggregation.make_aggregation(reduction_op, kwargs).plc_obj, dtype_to_pylibcudf_type(col_dtype), ) result_col = type(self).from_pylibcudf( diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 6cd8e11695f..badb5c79a47 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -4,9 +4,10 @@ import copy import itertools import textwrap +import types import warnings from collections import abc -from functools import cached_property +from functools import cached_property, singledispatch from typing import TYPE_CHECKING, Any, Literal import cupy as cp @@ -18,17 +19,27 @@ import cudf import cudf.core._internals from cudf import _lib as libcudf -from cudf._lib import groupby as libgroupby from cudf._lib.types import size_type_dtype from cudf.api.extensions import no_default -from cudf.api.types import is_list_like, is_numeric_dtype +from cudf.api.types import ( + is_list_like, + is_numeric_dtype, + is_string_dtype, +) from cudf.core._compat import PANDAS_LT_300 -from cudf.core._internals import sorting +from cudf.core._internals import aggregation, sorting from cudf.core.abc import Serializable from cudf.core.buffer import acquire_spill_lock -from cudf.core.column.column import ColumnBase, StructDtype, as_column +from cudf.core.column.column import ColumnBase, as_column from cudf.core.column_accessor import ColumnAccessor from cudf.core.copy_types import GatherMap +from cudf.core.dtypes import ( + CategoricalDtype, + DecimalDtype, + IntervalDtype, + ListDtype, + StructDtype, +) from cudf.core.join._join_helpers import _match_join_keys from cudf.core.mixins import Reducible, Scannable from cudf.core.multiindex import MultiIndex @@ -37,7 +48,7 @@ from cudf.utils.utils import GetAttrGetItemMixin if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Generator, Iterable from cudf._typing import ( AggType, @@ -46,6 +57,152 @@ ScalarLike, ) +# The sets below define the possible aggregations that can be performed on +# different dtypes. These strings must be elements of the AggregationKind enum. +# The libcudf infrastructure exists for "COLLECT" support on +# categoricals, but the dtype support in python does not. +_CATEGORICAL_AGGS = {"COUNT", "NUNIQUE", "SIZE", "UNIQUE"} +_STRING_AGGS = { + "COLLECT", + "COUNT", + "MAX", + "MIN", + "NTH", + "NUNIQUE", + "SIZE", + "UNIQUE", +} +_LIST_AGGS = {"COLLECT"} +_STRUCT_AGGS = {"COLLECT", "CORRELATION", "COVARIANCE"} +_INTERVAL_AGGS = {"COLLECT"} +_DECIMAL_AGGS = { + "ARGMIN", + "ARGMAX", + "COLLECT", + "COUNT", + "MAX", + "MIN", + "NTH", + "NUNIQUE", + "SUM", +} + + +@singledispatch +def get_valid_aggregation(dtype): + if is_string_dtype(dtype): + return _STRING_AGGS + return "ALL" + + +@get_valid_aggregation.register +def _(dtype: ListDtype): + return _LIST_AGGS + + +@get_valid_aggregation.register +def _(dtype: CategoricalDtype): + return _CATEGORICAL_AGGS + + +@get_valid_aggregation.register +def _(dtype: ListDtype): + return _LIST_AGGS + + +@get_valid_aggregation.register +def _(dtype: StructDtype): + return _STRUCT_AGGS + + +@get_valid_aggregation.register +def _(dtype: IntervalDtype): + return _INTERVAL_AGGS + + +@get_valid_aggregation.register +def _(dtype: DecimalDtype): + return _DECIMAL_AGGS + + +@singledispatch +def _is_unsupported_agg_for_type(dtype, str_agg: str) -> bool: + return False + + +@_is_unsupported_agg_for_type.register +def _(dtype: np.dtype, str_agg: str) -> bool: + # string specifically + cumulative_agg = str_agg in {"cumsum", "cummin", "cummax"} + basic_agg = any( + a in str_agg + for a in ( + "count", + "max", + "min", + "first", + "last", + "nunique", + "unique", + "nth", + ) + ) + return ( + dtype.kind == "O" + and str_agg not in _STRING_AGGS + and (cumulative_agg or not (basic_agg or str_agg == "")) + ) + + +@_is_unsupported_agg_for_type.register +def _(dtype: CategoricalDtype, str_agg: str) -> bool: + cumulative_agg = str_agg in {"cumsum", "cummin", "cummax"} + not_basic_agg = not any( + a in str_agg for a in ("count", "max", "min", "unique") + ) + return str_agg not in _CATEGORICAL_AGGS and ( + cumulative_agg or not_basic_agg + ) + + +def _is_all_scan_aggregate(all_aggs: list[list[str]]) -> bool: + """ + Returns True if all are scan aggregations. + + Raises + ------ + NotImplementedError + If both reduction aggregations and scan aggregations are present. + """ + groupby_scans = { + "cumcount", + "cumsum", + "cummin", + "cummax", + "cumprod", + "rank", + } + + def get_name(agg): + return agg.__name__ if callable(agg) else agg + + all_scan = all( + get_name(agg_name) in groupby_scans + for aggs in all_aggs + for agg_name in aggs + ) + any_scan = any( + get_name(agg_name) in groupby_scans + for aggs in all_aggs + for agg_name in aggs + ) + + if not all_scan and any_scan: + raise NotImplementedError( + "Cannot perform both aggregation and scan in one operation" + ) + return all_scan and any_scan + def _deprecate_collect(): warnings.warn( @@ -423,7 +580,7 @@ def indices(self) -> dict[ScalarLike, cp.ndarray]: >>> df.groupby(by=["a"]).indices {10: array([0, 1]), 40: array([2])} """ - offsets, group_keys, (indices,) = self._groupby.groups( + offsets, group_keys, (indices,) = self._groups( [ cudf.core.column.as_column( range(len(self.obj)), dtype=size_type_dtype @@ -582,11 +739,137 @@ def rank(x): return result @cached_property - def _groupby(self): - return libgroupby.GroupBy( - [*self.grouping.keys._columns], dropna=self._dropna + def _groupby(self) -> types.SimpleNamespace: + with acquire_spill_lock() as spill_lock: + plc_groupby = plc.groupby.GroupBy( + plc.Table( + [ + col.to_pylibcudf(mode="read") + for col in self.grouping.keys._columns + ] + ), + plc.types.NullPolicy.EXCLUDE + if self._dropna + else plc.types.NullPolicy.INCLUDE, + ) + # Do we need this because we just check _spill_locks in test_spillable_df_groupby? + return types.SimpleNamespace( + plc_groupby=plc_groupby, _spill_locks=spill_lock + ) + + def _groups( + self, values: Iterable[ColumnBase] + ) -> tuple[list[int], list[ColumnBase], list[ColumnBase]]: + plc_columns = [col.to_pylibcudf(mode="read") for col in values] + if not plc_columns: + plc_table = None + else: + plc_table = plc.Table(plc_columns) + offsets, grouped_keys, grouped_values = ( + self._groupby.plc_groupby.get_groups(plc_table) + ) + + return ( + offsets, + [ColumnBase.from_pylibcudf(col) for col in grouped_keys.columns()], + ( + [ + ColumnBase.from_pylibcudf(col) + for col in grouped_values.columns() + ] + if grouped_values is not None + else [] + ), + ) + + def _aggregate( + self, values: tuple[ColumnBase, ...], aggregations + ) -> tuple[ + list[list[ColumnBase]], + list[ColumnBase], + list[list[tuple[str, str]]], + ]: + included_aggregations = [] + column_included = [] + requests = [] + result_columns: list[list[ColumnBase]] = [] + for i, (col, aggs) in enumerate(zip(values, aggregations)): + valid_aggregations = get_valid_aggregation(col.dtype) + included_aggregations_i = [] + col_aggregations = [] + for agg in aggs: + str_agg = str(agg) + if _is_unsupported_agg_for_type(col.dtype, str_agg): + raise TypeError( + f"{col.dtype} type does not support {agg} operations" + ) + agg_obj = aggregation.make_aggregation(agg) + if ( + valid_aggregations == "ALL" + or agg_obj.kind in valid_aggregations + ): + included_aggregations_i.append((agg, agg_obj.kind)) + col_aggregations.append(agg_obj.plc_obj) + included_aggregations.append(included_aggregations_i) + result_columns.append([]) + if col_aggregations: + requests.append( + plc.groupby.GroupByRequest( + col.to_pylibcudf(mode="read"), col_aggregations + ) + ) + column_included.append(i) + + if not requests and any(len(v) > 0 for v in aggregations): + raise pd.errors.DataError( + "All requested aggregations are unsupported." + ) + + keys, results = ( + self._groupby.plc_groupby.scan(requests) + if _is_all_scan_aggregate(aggregations) + else self._groupby.plc_groupby.aggregate(requests) ) + for i, result in zip(column_included, results): + result_columns[i] = [ + ColumnBase.from_pylibcudf(col) for col in result.columns() + ] + + return ( + result_columns, + [ColumnBase.from_pylibcudf(key) for key in keys.columns()], + included_aggregations, + ) + + def _shift( + self, values: tuple[ColumnBase, ...], periods: int, fill_values: list + ) -> Generator[ColumnBase]: + _, shifts = self._groupby.plc_groupby.shift( + plc.table.Table([col.to_pylibcudf(mode="read") for col in values]), + [periods] * len(values), + [ + cudf.Scalar(val, dtype=col.dtype).device_value.c_value + for val, col in zip(fill_values, values) + ], + ) + return (ColumnBase.from_pylibcudf(col) for col in shifts.columns()) + + def _replace_nulls( + self, values: tuple[ColumnBase, ...], method: str + ) -> Generator[ColumnBase]: + _, replaced = self._groupby.plc_groupby.replace_nulls( + plc.Table([col.to_pylibcudf(mode="read") for col in values]), + [ + plc.replace.ReplacePolicy.PRECEDING + if method == "ffill" + else plc.replace.ReplacePolicy.FOLLOWING + ] + * len(values), + ) + + return (ColumnBase.from_pylibcudf(col) for col in replaced.columns()) + @_performance_tracking def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): """ @@ -702,7 +985,7 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): result_columns, grouped_key_cols, included_aggregations, - ) = self._groupby.aggregate(columns, normalized_aggs) + ) = self._aggregate(columns, normalized_aggs) result_index = self.grouping.keys._from_columns_like_self( grouped_key_cols, @@ -761,7 +1044,7 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): else: if cudf.get_option( "mode.pandas_compatible" - ) and not libgroupby._is_all_scan_aggregate(normalized_aggs): + ) and not _is_all_scan_aggregate(normalized_aggs): # Even with `sort=False`, pandas guarantees that # groupby preserves the order of rows within each group. left_cols = list(self.grouping.keys.drop_duplicates()._columns) @@ -810,7 +1093,7 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): if not self._as_index: result = result.reset_index() - if libgroupby._is_all_scan_aggregate(normalized_aggs): + if _is_all_scan_aggregate(normalized_aggs): # Scan aggregations return rows in original index order return self._mimic_pandas_order(result) @@ -920,7 +1203,7 @@ def _head_tail(self, n, *, take_head: bool, preserve_order: bool): # Can't use _mimic_pandas_order because we need to # subsample the gather map from the full input ordering, # rather than permuting the gather map of the output. - _, _, (ordering,) = self._groupby.groups( + _, _, (ordering,) = self._groups( [as_column(range(0, len(self.obj)))] ) # Invert permutation from original order to groups on the @@ -1312,8 +1595,8 @@ def deserialize(cls, header, frames): return cls(obj, grouping, **kwargs) def _grouped(self, *, include_groups: bool = True): - offsets, grouped_key_cols, grouped_value_cols = self._groupby.groups( - [*self.obj.index._columns, *self.obj._columns] + offsets, grouped_key_cols, grouped_value_cols = self._groups( + itertools.chain(self.obj.index._columns, self.obj._columns) ) grouped_keys = cudf.core.index._index_from_data( dict(enumerate(grouped_key_cols)) @@ -1945,7 +2228,7 @@ def transform( "Currently, `transform()` supports only aggregations." ) from e # If the aggregation is a scan, don't broadcast - if libgroupby._is_all_scan_aggregate([[func]]): + if _is_all_scan_aggregate([[func]]): if len(result) != len(self.obj): raise AssertionError( "Unexpected result length for scan transform" @@ -2409,7 +2692,7 @@ def _scan_fill(self, method: str, limit: int) -> DataFrameOrSeries: dict( zip( values._column_names, - self._groupby.replace_nulls([*values._columns], method), + self._replace_nulls(values._columns, method), ) ) ) @@ -2513,7 +2796,7 @@ def fillna( @_performance_tracking def shift( self, - periods=1, + periods: int = 1, freq=None, axis=0, fill_value=None, @@ -2560,7 +2843,7 @@ def shift( if freq is not None: raise NotImplementedError("Parameter freq is unsupported.") - if not axis == 0: + if axis != 0: raise NotImplementedError("Only axis=0 is supported.") if suffix is not None: @@ -2568,20 +2851,18 @@ def shift( values = self.grouping.values if is_list_like(fill_value): - if len(fill_value) != len(values._data): + if len(fill_value) != values._num_columns: raise ValueError( "Mismatched number of columns and values to fill." ) else: - fill_value = [fill_value] * len(values._data) + fill_value = [fill_value] * values._num_columns result = self.obj.__class__._from_data( dict( zip( values._column_names, - self._groupby.shift( - [*values._columns], periods, fill_value - )[0], + self._shift(values._columns, periods, fill_value), ) ) ) @@ -2680,9 +2961,7 @@ def _mimic_pandas_order( # result coming back from libcudf has null_count few rows than # the input, so we must produce an ordering from the full # input range. - _, _, (ordering,) = self._groupby.groups( - [as_column(range(0, len(self.obj)))] - ) + _, _, (ordering,) = self._groups([as_column(range(0, len(self.obj)))]) if self._dropna and any( c.has_nulls(include_nan=True) > 0 for c in self.grouping._key_columns diff --git a/python/cudf/cudf/core/window/rolling.py b/python/cudf/cudf/core/window/rolling.py index a580c35ccbf..2f8a6d9e5e7 100644 --- a/python/cudf/cudf/core/window/rolling.py +++ b/python/cudf/cudf/core/window/rolling.py @@ -315,7 +315,7 @@ def _apply_agg_column(self, source_column, agg_name): {"dtype": source_column.dtype} if callable(agg_name) else self.agg_params, - ).c_obj, + ).plc_obj, ) ) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index d8a2528230e..db4f3cd3c9f 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -3960,8 +3960,8 @@ def test_group_by_value_counts_with_count_column(): def test_groupby_internal_groups_empty(gdf): # test that we don't segfault when calling the internal # .groups() method with an empty list: - gb = gdf.groupby("y")._groupby - _, _, grouped_vals = gb.groups([]) + gb = gdf.groupby("y") + _, _, grouped_vals = gb._groups([]) assert grouped_vals == []