diff --git a/narwhals/pandas_like/group_by.py b/narwhals/pandas_like/group_by.py index be0495796..9253683e2 100644 --- a/narwhals/pandas_like/group_by.py +++ b/narwhals/pandas_like/group_by.py @@ -58,7 +58,9 @@ def agg( raise ValueError(msg) output_names.extend(expr._output_names) - if implementation == "pandas" and not os.environ.get("NARWHALS_FORCE_GENERIC"): + if implementation in ("pandas", "modin") and not os.environ.get( + "NARWHALS_FORCE_GENERIC" + ): return agg_pandas( grouped, exprs, @@ -174,7 +176,7 @@ def agg_generic( # noqa: PLR0913 to_remove: list[int] = [] for i, expr in enumerate(exprs): if is_simple_aggregation(expr): - dfs.append(evaluate_simple_aggregation(expr, grouped)) + dfs.append(evaluate_simple_aggregation(expr, grouped, group_by_keys)) to_remove.append(i) exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove] diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index c916263d8..9486e94fb 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -217,7 +217,7 @@ def is_simple_aggregation(expr: PandasExpr) -> bool: ) -def evaluate_simple_aggregation(expr: PandasExpr, grouped: Any) -> Any: +def evaluate_simple_aggregation(expr: PandasExpr, grouped: Any, keys: list[str]) -> Any: """ Use fastpath for simple aggregations if possible. @@ -232,7 +232,14 @@ def evaluate_simple_aggregation(expr: PandasExpr, grouped: Any) -> Any: Returns naive DataFrame. """ if expr._depth == 0: - return grouped.size()["size"].rename(expr._output_names[0]) # type: ignore[index] + # e.g. agg(pl.len()) + df = getattr(grouped, expr._function_name.replace("len", "size"))() + df = ( + df.drop(columns=keys) + if len(df.shape) > 1 + else df.reset_index(drop=True).to_frame("size") + ) + return df.rename(columns={"size": expr._output_names[0]}) # type: ignore[index] if expr._root_names is None or expr._output_names is None: msg = "Expected expr to have root_names and output_names set, but they are None. Please report a bug." raise AssertionError(msg)