From c0b6dbe6c38d53529a26e04f84489220b650e12e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 21 Dec 2024 10:16:05 +0000 Subject: [PATCH] perf: avoid merge in pandas groupby --- narwhals/_pandas_like/group_by.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index f9ea86dc0..06190bf3e 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -11,6 +11,7 @@ from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs +from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable from narwhals._pandas_like.utils import select_columns_by_name from narwhals.utils import Implementation @@ -236,18 +237,11 @@ def agg_pandas( # noqa: PLR0915 new_names = [new_names[i] for i in index_map] result_simple_aggs.columns = new_names - # Keep inplace=True to avoid making a redundant copy. - # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result_simple_aggs.reset_index(inplace=True) # noqa: PD002 - if nunique_aggs: result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique( dropna=False ) result_nunique_aggs.columns = list(nunique_aggs.keys()) - # Keep inplace=True to avoid making a redundant copy. - # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result_nunique_aggs.reset_index(inplace=True) # noqa: PD002 if simple_aggs and nunique_aggs: if ( set(result_simple_aggs.columns) @@ -259,7 +253,11 @@ def agg_pandas( # noqa: PLR0915 "that aggregations have unique output names." ) raise ValueError(msg) - result_aggs = result_simple_aggs.merge(result_nunique_aggs, on=keys) + result_aggs = horizontal_concat( + [result_simple_aggs, result_nunique_aggs], + implementation=implementation, + backend_version=backend_version, + ) elif nunique_aggs and not simple_aggs: result_aggs = result_nunique_aggs elif simple_aggs and not nunique_aggs: @@ -269,6 +267,9 @@ def agg_pandas( # noqa: PLR0915 result_aggs = native_namespace.DataFrame( list(grouped.groups.keys()), columns=keys ) + # Keep inplace=True to avoid making a redundant copy. + # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files + result_aggs.reset_index(inplace=True) # noqa: PD002 return from_dataframe( select_columns_by_name( result_aggs, output_names, backend_version, implementation