Skip to content

Commit

Permalink
[FEA] Add support for cudf.NamedAgg (#16744)
Browse files Browse the repository at this point in the history
Closes #15118

Authors:
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #16744
  • Loading branch information
Matt711 authored Sep 11, 2024
1 parent c3d323d commit 4cdb1bf
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
ListDtype,
StructDtype,
)
from cudf.core.groupby import Grouper
from cudf.core.groupby import Grouper, NamedAgg
from cudf.core.index import (
BaseIndex,
CategoricalIndex,
Expand Down
5 changes: 3 additions & 2 deletions python/cudf/cudf/core/groupby/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

from cudf.core.groupby.groupby import GroupBy, Grouper
from cudf.core.groupby.groupby import GroupBy, Grouper, NamedAgg

__all__ = [
"GroupBy",
"Grouper",
"NamedAgg",
]
46 changes: 43 additions & 3 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,34 @@ def _is_row_of(chunk, obj):
)


NamedAgg = pd.NamedAgg


NamedAgg.__doc__ = """
Helper for column specific aggregation with control over output column names.
Subclass of typing.NamedTuple.
Parameters
----------
column : Hashable
Column label in the DataFrame to apply aggfunc.
aggfunc : function or str
Function to apply to the provided column.
Examples
--------
>>> df = cudf.DataFrame({"key": [1, 1, 2], "a": [-1, 0, 1], 1: [10, 11, 12]})
>>> agg_a = cudf.NamedAgg(column="a", aggfunc="min")
>>> agg_1 = cudf.NamedAgg(column=1, aggfunc=lambda x: x.mean())
>>> df.groupby("key").agg(result_a=agg_a, result_1=agg_1)
result_a result_1
key
1 -1 10.5
2 1 12.0
"""


groupby_doc_template = textwrap.dedent(
"""Group using a mapper or by a Series of columns.
Expand Down Expand Up @@ -1296,9 +1324,21 @@ def _normalize_aggs(
columns = values._columns
aggs_per_column = (aggs,) * len(columns)
elif not aggs and kwargs:
column_names, aggs_per_column = kwargs.keys(), kwargs.values()
columns = tuple(self.obj._data[x[0]] for x in kwargs.values())
aggs_per_column = tuple(x[1] for x in kwargs.values())
column_names = kwargs.keys()

def _raise_invalid_type(x):
raise TypeError(
f"Invalid keyword argument {x} of type {type(x)} was passed to agg"
)

columns, aggs_per_column = zip(
*(
(self.obj._data[x[0]], x[1])
if isinstance(x, tuple)
else _raise_invalid_type(x)
for x in kwargs.values()
)
)
else:
raise TypeError("Must provide at least one aggregation function.")

Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/tests/groupby/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,19 @@ def test_dataframe_agg(attr, func):
)

assert_eq(agg, pd_agg)

agg = getattr(df.groupby("a"), attr)(
foo=cudf.NamedAgg(column="b", aggfunc=func),
bar=cudf.NamedAgg(column="a", aggfunc=func),
)
pd_agg = getattr(pdf.groupby(["a"]), attr)(
foo=("b", func), bar=("a", func)
)

assert_eq(agg, pd_agg)


def test_dataframe_agg_with_invalid_kwarg():
with pytest.raises(TypeError, match="Invalid keyword argument"):
df = cudf.DataFrame({"a": [1, 2, 1, 2], "b": [0, 0, 0, 0]})
df.groupby("a").agg(foo=set())

0 comments on commit 4cdb1bf

Please sign in to comment.