Skip to content

Commit

Permalink
fix: group by no aggregation (#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Sep 11, 2024
1 parent 90476a2 commit cc1696b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
6 changes: 6 additions & 0 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def agg(
output_names.extend(expr._output_names)

return agg_dask(
self._df,
self._grouped,
exprs,
self._keys,
Expand All @@ -88,6 +89,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame:


def agg_dask(
df: DaskLazyFrame,
grouped: Any,
exprs: list[DaskExpr],
keys: list[str],
Expand All @@ -99,6 +101,10 @@ def agg_dask(
- https://github.com/rapidsai/cudf/issues/15118
- https://github.com/rapidsai/cudf/issues/15084
"""
if not exprs:
# No aggregation provided
return df.select(*keys).unique(subset=keys)

all_simple_aggs = True
for expr in exprs:
if not is_simple_aggregation(expr):
Expand Down
8 changes: 5 additions & 3 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def agg(
dataframe_is_empty=self._df._native_frame.empty,
implementation=implementation,
backend_version=self._df._backend_version,
native_namespace=self._df.__native_namespace__(),
)

def _from_native_frame(self, df: PandasLikeDataFrame) -> PandasLikeDataFrame:
Expand Down Expand Up @@ -114,6 +115,7 @@ def agg_pandas( # noqa: PLR0915
implementation: Any,
backend_version: tuple[int, ...],
dataframe_is_empty: bool,
native_namespace: Any,
) -> PandasLikeDataFrame:
"""
This should be the fastpath, but cuDF is too far behind to use it.
Expand Down Expand Up @@ -204,9 +206,9 @@ def agg_pandas( # noqa: PLR0915
result_aggs = result_nunique_aggs
elif simple_aggs and not nunique_aggs:
result_aggs = result_simple_aggs
else: # pragma: no cover
msg = "Congrats, you entered unreachable code. Please report a bug to https://github.com/narwhals-dev/narwhals/issues."
raise RuntimeError(msg)
else:
# No aggregation provided
result_aggs = native_namespace.DataFrame(grouped.groups.keys(), columns=keys)
return from_dataframe(result_aggs.loc[:, output_names])

if dataframe_is_empty:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,10 @@ def test_key_with_nulls(constructor: Any, request: Any) -> None:
)
expected = {"b": [4.0, 5, float("nan")], "len": [1, 1, 1], "a": [1, 2, 3]}
compare_dicts(result, expected)


def test_no_agg(constructor: Any) -> None:
result = nw.from_native(constructor(data)).group_by(["a", "b"]).agg().sort("a", "b")

expected = {"a": [1, 3], "b": [4, 6]}
compare_dicts(result, expected)

0 comments on commit cc1696b

Please sign in to comment.