From a104310e4cbb06e6d6a5616512d51f58325086cc Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 22 Dec 2024 14:05:31 +0100 Subject: [PATCH 1/6] feat: support std and var with ddof !=1 in pandas-like group by --- narwhals/_arrow/expr.py | 18 +++---- narwhals/_arrow/namespace.py | 7 ++- narwhals/_arrow/selectors.py | 6 +-- narwhals/_dask/expr.py | 18 +++---- narwhals/_dask/namespace.py | 21 +++++--- narwhals/_expression_parsing.py | 5 +- narwhals/_pandas_like/expr.py | 28 ++++------- narwhals/_pandas_like/group_by.py | 80 +++++++++++++++++++++++------- narwhals/_pandas_like/namespace.py | 7 ++- narwhals/_pandas_like/selectors.py | 6 +-- narwhals/_spark_like/expr.py | 2 +- narwhals/_spark_like/namespace.py | 2 +- tests/group_by_test.py | 39 +++++++++++++++ 13 files changed, 164 insertions(+), 75 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 6f1d627f5..15afcba1c 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -308,7 +308,7 @@ def alias(self: Self, name: str) -> Self: output_names=[name], backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def null_count(self: Self) -> Self: @@ -436,7 +436,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=self._output_names, backend_version=self._backend_version, version=self._version, - kwargs={"keys": keys}, + kwargs={**self._kwargs, "keys": keys}, ) def mode(self: Self) -> Self: @@ -478,7 +478,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=self._output_names, backend_version=self._backend_version, version=self._version, - kwargs={"function": function, "return_dtype": return_dtype}, + kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype}, ) def is_finite(self: Self) -> Self: @@ -810,7 +810,7 @@ def keep(self: Self) -> ArrowExpr: output_names=root_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: @@ -837,7 +837,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"function": function}, + kwargs={**self._compliant_expr._kwargs, "function": function}, ) def prefix(self: Self, prefix: str) -> ArrowExpr: @@ -862,7 +862,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"prefix": prefix}, + kwargs={**self._compliant_expr._kwargs, "prefix": prefix}, ) def suffix(self: Self, suffix: str) -> ArrowExpr: @@ -888,7 +888,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"suffix": suffix}, + kwargs={**self._compliant_expr._kwargs, "suffix": suffix}, ) def to_lowercase(self: Self) -> ArrowExpr: @@ -914,7 +914,7 @@ def to_lowercase(self: Self) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def to_uppercase(self: Self) -> ArrowExpr: @@ -940,7 +940,7 @@ def to_uppercase(self: Self) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 7dc4db577..cb59d5b71 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -432,7 +432,12 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="concat_str", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), - kwargs={"separator": separator, "ignore_nulls": ignore_nulls}, + kwargs={ + "exprs": exprs, + "more_exprs": more_exprs, + "separator": separator, + "ignore_nulls": ignore_nulls, + }, ) diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 7750bdd03..48e837ec7 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -124,7 +124,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() - other @@ -145,7 +145,7 @@ def call(df: ArrowDataFrame) -> Sequence[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() | other @@ -166,7 +166,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() & other diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 50133ff93..dde6cd9e6 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -177,7 +177,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=self._returns_scalar or returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs=kwargs, + kwargs={**self._kwargs, **kwargs}, ) def alias(self, name: str) -> Self: @@ -194,7 +194,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def __add__(self, other: Any) -> Self: @@ -847,7 +847,7 @@ def func(df: DaskLazyFrame) -> list[Any]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={"keys": keys}, + kwargs={**self._kwargs, "keys": keys}, ) def mode(self: Self) -> Self: @@ -1334,7 +1334,7 @@ def keep(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def map(self: Self, function: Callable[[str], str]) -> DaskExpr: @@ -1362,7 +1362,7 @@ def map(self: Self, function: Callable[[str], str]) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"function": function}, + kwargs={**self._compliant_expr._kwargs, "function": function}, ) def prefix(self: Self, prefix: str) -> DaskExpr: @@ -1388,7 +1388,7 @@ def prefix(self: Self, prefix: str) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"prefix": prefix}, + kwargs={**self._compliant_expr._kwargs, "prefix": prefix}, ) def suffix(self: Self, suffix: str) -> DaskExpr: @@ -1415,7 +1415,7 @@ def suffix(self: Self, suffix: str) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"suffix": suffix}, + kwargs={**self._compliant_expr._kwargs, "suffix": suffix}, ) def to_lowercase(self: Self) -> DaskExpr: @@ -1442,7 +1442,7 @@ def to_lowercase(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def to_uppercase(self: Self) -> DaskExpr: @@ -1469,5 +1469,5 @@ def to_uppercase(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index e0870d242..7c0a6b7eb 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -159,7 +159,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -178,7 +178,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -197,7 +197,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def concat( @@ -279,7 +279,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -301,7 +301,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -323,7 +323,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def when( @@ -388,7 +388,12 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={ + "exprs": exprs, + "more_exprs": more_exprs, + "separator": separator, + "ignore_nulls": ignore_nulls, + }, ) @@ -451,7 +456,7 @@ def then(self, value: DaskExpr | Any) -> DaskThen: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"value": value}, ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 3a9744c9c..4d51eb719 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -225,14 +225,13 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: ): # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) - return plx._create_expr_from_callable( # type: ignore[return-value] func, # type: ignore[arg-type] depth=expr._depth + 1, function_name=f"{expr._function_name}->{attr}", root_names=root_names, output_names=output_names, - kwargs=kwargs, + kwargs={**expr._kwargs, **kwargs}, ) @@ -272,7 +271,7 @@ def reuse_series_namespace_implementation( function_name=f"{expr._function_name}->{series_namespace}.{attr}", root_names=expr._root_names, output_names=expr._output_names, - kwargs=kwargs, + kwargs={**expr._kwargs, **kwargs}, ) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 0cf2a3f73..e7e98d991 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -263,16 +263,10 @@ def median(self) -> Self: return reuse_series_implementation(self, "median", returns_scalar=True) def std(self, *, ddof: int) -> Self: - expr = reuse_series_implementation(self, "std", ddof=ddof, returns_scalar=True) - if ddof != 1: - expr._depth += 1 - return expr + return reuse_series_implementation(self, "std", ddof=ddof, returns_scalar=True) def var(self, *, ddof: int) -> Self: - expr = reuse_series_implementation(self, "var", ddof=ddof, returns_scalar=True) - if ddof != 1: - expr._depth += 1 - return expr + return reuse_series_implementation(self, "var", ddof=ddof, returns_scalar=True) def skew(self: Self) -> Self: return reuse_series_implementation(self, "skew", returns_scalar=True) @@ -419,7 +413,7 @@ def alias(self, name: str) -> Self: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def over(self, keys: list[str]) -> Self: @@ -491,7 +485,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"keys": keys}, + kwargs={**self._kwargs, "keys": keys}, ) def is_duplicated(self) -> Self: @@ -568,7 +562,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"function": function, "return_dtype": return_dtype}, + kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype}, ) def is_finite(self: Self) -> Self: @@ -907,7 +901,7 @@ def keep(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: @@ -935,7 +929,7 @@ def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"function": function}, + kwargs={**self._compliant_expr._kwargs, "function": function}, ) def prefix(self: Self, prefix: str) -> PandasLikeExpr: @@ -961,7 +955,7 @@ def prefix(self: Self, prefix: str) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"prefix": prefix}, + kwargs={**self._compliant_expr._kwargs, "prefix": prefix}, ) def suffix(self: Self, suffix: str) -> PandasLikeExpr: @@ -988,7 +982,7 @@ def suffix(self: Self, suffix: str) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"suffix": suffix}, + kwargs={**self._compliant_expr._kwargs, "suffix": suffix}, ) def to_lowercase(self: Self) -> PandasLikeExpr: @@ -1015,7 +1009,7 @@ def to_lowercase(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def to_uppercase(self: Self) -> PandasLikeExpr: @@ -1042,7 +1036,7 @@ def to_uppercase(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 9c4ffccbb..292ab5add 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -13,6 +13,7 @@ from narwhals._expression_parsing import parse_into_exprs from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable +from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name from narwhals.utils import Implementation from narwhals.utils import find_stacklevel @@ -168,6 +169,9 @@ def agg_pandas( # noqa: PLR0915 # can pass the `dropna` kwargs. nunique_aggs: dict[str, str] = {} simple_aggs: dict[str, list[str]] = collections.defaultdict(list) + std_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict) + var_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict) + expected_old_names: list[str] = [] new_names: list[str] = [] @@ -199,15 +203,24 @@ def agg_pandas( # noqa: PLR0915 function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( function_name, function_name ) + is_n_unique = function_name == "nunique" + is_std = function_name == "std" + is_var = function_name == "var" + ddof = expr._kwargs.get("ddof", 1) for root_name, output_name in zip(expr._root_names, expr._output_names): if is_n_unique: nunique_aggs[output_name] = root_name + elif is_std and ddof != 1: + std_aggs[ddof].update({output_name: root_name}) + elif is_var and ddof != 1: + var_aggs[ddof].update({output_name: root_name}) else: new_names.append(output_name) expected_old_names.append(f"{root_name}_{function_name}") simple_aggs[root_name].append(function_name) + result_aggs = [] if simple_aggs: result_simple_aggs = grouped.agg(simple_aggs) result_simple_aggs.columns = [ @@ -237,43 +250,72 @@ def agg_pandas( # noqa: PLR0915 new_names = [new_names[i] for i in index_map] result_simple_aggs.columns = new_names + result_aggs.append(result_simple_aggs) + if nunique_aggs: result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique( dropna=False ) result_nunique_aggs.columns = list(nunique_aggs.keys()) - if simple_aggs and nunique_aggs: - if ( - set(result_simple_aggs.columns) - .difference(keys) - .intersection(result_nunique_aggs.columns) - ): + + result_aggs.append(result_nunique_aggs) + + if std_aggs: + result_aggs.extend( + [ + rename( + grouped[list(output_to_root_name_mapping.values())].std( + ddof=ddof + ), + # Invert the dict to have root_name: output_name + # TODO(FBruzzesi): Account for duplicates + columns={v: k for k, v in output_to_root_name_mapping.items()}, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, output_to_root_name_mapping in std_aggs.items() + ] + ) + if var_aggs: + result_aggs.extend( + [ + rename( + grouped[list(output_to_root_name_mapping.values())].var( + ddof=ddof + ), + # Invert the dict to have root_name: output_name + # TODO(FBruzzesi): Account for duplicates + columns={v: k for k, v in output_to_root_name_mapping.items()}, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, output_to_root_name_mapping in var_aggs.items() + ] + ) + + if result_aggs: + output_names_counter = collections.Counter( + [c for frame in result_aggs for c in frame] + ) + if any(v > 1 for v in output_names_counter.values()): msg = ( "Got two aggregations with the same output name. Please make sure " "that aggregations have unique output names." ) raise ValueError(msg) - result_aggs = horizontal_concat( - [result_simple_aggs, result_nunique_aggs], + result = horizontal_concat( + dfs=result_aggs, implementation=implementation, backend_version=backend_version, ) - elif nunique_aggs and not simple_aggs: - result_aggs = result_nunique_aggs - elif simple_aggs and not nunique_aggs: - result_aggs = result_simple_aggs else: # No aggregation provided - result_aggs = native_namespace.DataFrame( - list(grouped.groups.keys()), columns=keys - ) + result = native_namespace.DataFrame(list(grouped.groups.keys()), columns=keys) # Keep inplace=True to avoid making a redundant copy. # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result_aggs.reset_index(inplace=True) # noqa: PD002 + result.reset_index(inplace=True) # noqa: PD002 return from_dataframe( - select_columns_by_name( - result_aggs, output_names, backend_version, implementation - ) + select_columns_by_name(result, output_names, backend_version, implementation) ) if dataframe_is_empty: diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 495173025..3aa4015a0 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -474,7 +474,12 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="concat_str", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), - kwargs={"separator": separator, "ignore_nulls": ignore_nulls}, + kwargs={ + "exprs": exprs, + "more_exprs": more_exprs, + "separator": separator, + "ignore_nulls": ignore_nulls, + }, ) diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 7ef666b96..e7d7fe18d 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -129,7 +129,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() - other @@ -151,7 +151,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() | other @@ -173,7 +173,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() & other diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index cbd645298..da25b32e0 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -198,7 +198,7 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def count(self) -> Self: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 639523ed0..d150e7541 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -59,7 +59,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def col(self, *column_names: str) -> SparkLikeExpr: diff --git a/tests/group_by_test.py b/tests/group_by_test.py index b7f2d2ea1..6843c50a3 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -380,3 +380,42 @@ def test_double_same_aggregation( result = df.group_by("a").agg(c=nw.col("b").mean(), d=nw.col("b").mean()).sort("a") expected = {"a": [1, 2], "c": [4.5, 6], "d": [4.5, 6]} assert_equal_data(result, expected) + + +def test_all_kind_of_aggs( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + from math import sqrt + + if any(x in str(constructor) for x in ("dask", "cudf")): + # bugged in dask https://github.com/dask/dask/issues/11612 + # and modin lol https://github.com/modin-project/modin/issues/7414 + # and cudf https://github.com/rapidsai/cudf/issues/17649 + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor) and PANDAS_VERSION < (1,): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor({"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]})) + result = ( + df.group_by("a") + .agg( + c=nw.col("b").mean(), + d=nw.col("b").mean(), + e=nw.col("b").std(ddof=1), + f=nw.col("b").std(ddof=2), + g=nw.col("b").var(ddof=2), + h=nw.col("b").n_unique(), + ) + .sort("a") + ) + + variance_num = sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) + expected = { + "a": [1, 2], + "c": [5, 10 / 3], + "d": [5, 10 / 3], + "e": [1, sqrt(variance_num / (3 - 1))], + "f": [sqrt(2), sqrt(variance_num)], # denominator is 1 (=3-2) + "g": [2.0, variance_num], # denominator is 1 (=3-2) + "h": [3, 2], + } + assert_equal_data(result, expected) From dfd940c32a07f4741a67c09edc8b9a0fd965a406 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 22 Dec 2024 15:15:06 +0100 Subject: [PATCH 2/6] handle dups --- narwhals/_pandas_like/group_by.py | 48 +++++++++++++---------------- tests/group_by_test.py | 51 ++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 292ab5add..685ae262f 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -13,7 +13,6 @@ from narwhals._expression_parsing import parse_into_exprs from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable -from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name from narwhals.utils import Implementation from narwhals.utils import find_stacklevel @@ -169,8 +168,14 @@ def agg_pandas( # noqa: PLR0915 # can pass the `dropna` kwargs. nunique_aggs: dict[str, str] = {} simple_aggs: dict[str, list[str]] = collections.defaultdict(list) - std_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict) - var_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict) + + # ddof to (root_names, output_names) mapping + std_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) + ) + var_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) + ) expected_old_names: list[str] = [] new_names: list[str] = [] @@ -212,15 +217,18 @@ def agg_pandas( # noqa: PLR0915 if is_n_unique: nunique_aggs[output_name] = root_name elif is_std and ddof != 1: - std_aggs[ddof].update({output_name: root_name}) + std_aggs[ddof][0].append(root_name) + std_aggs[ddof][1].append(output_name) elif is_var and ddof != 1: - var_aggs[ddof].update({output_name: root_name}) + var_aggs[ddof][0].append(root_name) + var_aggs[ddof][1].append(output_name) else: new_names.append(output_name) expected_old_names.append(f"{root_name}_{function_name}") simple_aggs[root_name].append(function_name) result_aggs = [] + if simple_aggs: result_simple_aggs = grouped.agg(simple_aggs) result_simple_aggs.columns = [ @@ -263,33 +271,19 @@ def agg_pandas( # noqa: PLR0915 if std_aggs: result_aggs.extend( [ - rename( - grouped[list(output_to_root_name_mapping.values())].std( - ddof=ddof - ), - # Invert the dict to have root_name: output_name - # TODO(FBruzzesi): Account for duplicates - columns={v: k for k, v in output_to_root_name_mapping.items()}, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, output_to_root_name_mapping in std_aggs.items() + grouped[std_root_names] + .std(ddof=ddof) + .set_axis(std_output_names, axis="columns", copy=False) + for ddof, (std_root_names, std_output_names) in std_aggs.items() ] ) if var_aggs: result_aggs.extend( [ - rename( - grouped[list(output_to_root_name_mapping.values())].var( - ddof=ddof - ), - # Invert the dict to have root_name: output_name - # TODO(FBruzzesi): Account for duplicates - columns={v: k for k, v in output_to_root_name_mapping.items()}, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, output_to_root_name_mapping in var_aggs.items() + grouped[var_root_names] + .var(ddof=ddof) + .set_axis(var_output_names, axis="columns", copy=False) + for ddof, (var_root_names, var_output_names) in var_aggs.items() ] ) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 6843c50a3..01cabcbee 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -131,6 +131,43 @@ def test_group_by_depth_1_agg( assert_equal_data(result, expected) +@pytest.mark.parametrize( + ("attr", "ddof"), + [ + ("std", 0), + ("var", 0), + ("std", 2), + ("var", 2), + ], +) +def test_group_by_depth_1_std_var( + constructor: Constructor, + attr: str, + ddof: int, + request: pytest.FixtureRequest, +) -> None: + if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1): + # Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations" + request.applymarker(pytest.mark.xfail) + + if "dask" in str(constructor): + # Complex aggregation for dask + request.applymarker(pytest.mark.xfail) + + data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} + _pow = 0.5 if attr == "std" else 1 + expected = { + "a": [1, 2], + "b": [ + (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, + (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, + ], + } + expr = getattr(nw.col("b"), attr)(ddof=ddof) + result = nw.from_native(constructor(data)).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + def test_group_by_median(constructor: Constructor) -> None: data = {"a": [1, 1, 1, 2, 2, 2], "b": [5, 4, 6, 7, 3, 2]} result = ( @@ -385,9 +422,7 @@ def test_double_same_aggregation( def test_all_kind_of_aggs( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - from math import sqrt - - if any(x in str(constructor) for x in ("dask", "cudf")): + if any(x in str(constructor) for x in ("dask", "cudf", "modin_constructor")): # bugged in dask https://github.com/dask/dask/issues/11612 # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 @@ -403,7 +438,8 @@ def test_all_kind_of_aggs( e=nw.col("b").std(ddof=1), f=nw.col("b").std(ddof=2), g=nw.col("b").var(ddof=2), - h=nw.col("b").n_unique(), + h=nw.col("b").var(ddof=2), + i=nw.col("b").n_unique(), ) .sort("a") ) @@ -413,9 +449,10 @@ def test_all_kind_of_aggs( "a": [1, 2], "c": [5, 10 / 3], "d": [5, 10 / 3], - "e": [1, sqrt(variance_num / (3 - 1))], - "f": [sqrt(2), sqrt(variance_num)], # denominator is 1 (=3-2) + "e": [1, (variance_num / (3 - 1)) ** 0.5], + "f": [2**0.5, (variance_num) ** 0.5], # denominator is 1 (=3-2) "g": [2.0, variance_num], # denominator is 1 (=3-2) - "h": [3, 2], + "h": [2.0, variance_num], # denominator is 1 (=3-2) + "i": [3, 2], } assert_equal_data(result, expected) From 5cc5f46beaebe642371a2a83cb7717f1f42d786d Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 22 Dec 2024 15:37:03 +0100 Subject: [PATCH 3/6] set_columns --- narwhals/_pandas_like/group_by.py | 19 +++++++++----- narwhals/_pandas_like/series.py | 6 ++--- narwhals/_pandas_like/utils.py | 43 +++++++++++++++++++++++++++---- narwhals/utils.py | 4 +-- 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 685ae262f..ec3e09b4c 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -14,6 +14,7 @@ from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable from narwhals._pandas_like.utils import select_columns_by_name +from narwhals._pandas_like.utils import set_columns from narwhals.utils import Implementation from narwhals.utils import find_stacklevel from narwhals.utils import remove_prefix @@ -271,18 +272,24 @@ def agg_pandas( # noqa: PLR0915 if std_aggs: result_aggs.extend( [ - grouped[std_root_names] - .std(ddof=ddof) - .set_axis(std_output_names, axis="columns", copy=False) + set_columns( + grouped[std_root_names].std(ddof=ddof), + columns=std_output_names, + implementation=implementation, + backend_version=backend_version, + ) for ddof, (std_root_names, std_output_names) in std_aggs.items() ] ) if var_aggs: result_aggs.extend( [ - grouped[var_root_names] - .var(ddof=ddof) - .set_axis(var_output_names, axis="columns", copy=False) + set_columns( + grouped[var_root_names].var(ddof=ddof), + columns=var_output_names, + implementation=implementation, + backend_version=backend_version, + ) for ddof, (var_root_names, var_output_names) in var_aggs.items() ] ) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index a8b57c5ce..1d895e147 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -17,7 +17,7 @@ from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name -from narwhals._pandas_like.utils import set_axis +from narwhals._pandas_like.utils import set_index from narwhals._pandas_like.utils import to_datetime from narwhals.dependencies import is_numpy_scalar from narwhals.typing import CompliantSeries @@ -211,7 +211,7 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self: # .copy() is necessary in some pre-2.2 versions of pandas to avoid # `values` also getting modified (!) _, values = broadcast_align_and_extract_native(self, values) - values = set_axis( + values = set_index( values.copy(), self._native_series.index[indices], implementation=self._implementation, @@ -1423,7 +1423,7 @@ def len(self: Self) -> PandasLikeSeries: self._compliant_series._implementation is Implementation.PANDAS and self._compliant_series._backend_version < (3, 0) ): # pragma: no cover - native_result = set_axis( + native_result = set_index( rename( native_result, native_series.name, diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2b0da769c..5c523138f 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -138,7 +138,7 @@ def broadcast_align_and_extract_native( if rhs._native_series.index is not lhs_index: return ( lhs._native_series, - set_axis( + set_index( rhs._native_series, lhs_index, implementation=rhs._implementation, @@ -168,7 +168,7 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: s = other._native_series return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name) if other._native_series.index is not index: - return set_axis( + return set_index( other._native_series, index, implementation=other._implementation, @@ -302,14 +302,17 @@ def native_series_from_iterable( raise TypeError(msg) -def set_axis( +def set_index( obj: T, index: Any, *, implementation: Implementation, backend_version: tuple[int, ...], ) -> T: - """Wrapper around pandas' set_axis so that we can set `copy` / `inplace` based on implementation/version.""" + """Wrapper around pandas' set_axis to set object index. + + We can set `copy` / `inplace` based on implementation/version. + """ if implementation is Implementation.CUDF: # pragma: no cover obj = obj.copy(deep=False) # type: ignore[attr-defined] obj.index = index # type: ignore[attr-defined] @@ -329,6 +332,36 @@ def set_axis( return obj.set_axis(index, axis=0, **kwargs) # type: ignore[attr-defined, no-any-return] +def set_columns( + obj: T, + columns: list[str], + *, + implementation: Implementation, + backend_version: tuple[int, ...], +) -> T: + """Wrapper around pandas' set_axis to set object columns. + + We can set `copy` / `inplace` based on implementation/version. + """ + if implementation is Implementation.CUDF: # pragma: no cover + obj = obj.copy(deep=False) # type: ignore[attr-defined] + obj.columns = columns # type: ignore[attr-defined] + return obj + if implementation is Implementation.PANDAS and ( + backend_version < (1,) + ): # pragma: no cover + kwargs = {"inplace": False} + else: + kwargs = {} + if implementation is Implementation.PANDAS and ( + (1, 5) <= backend_version < (3,) + ): # pragma: no cover + kwargs["copy"] = False + else: # pragma: no cover + pass + return obj.set_axis(columns, axis=1, **kwargs) # type: ignore[attr-defined, no-any-return] + + def rename( obj: T, *args: Any, @@ -654,7 +687,7 @@ def broadcast_series(series: Sequence[PandasLikeSeries]) -> list[Any]: elif s_native.index is not idx: reindexed.append( - set_axis( + set_index( s_native, idx, implementation=s._implementation, diff --git a/narwhals/utils.py b/narwhals/utils.py index bd0c625bb..b6337cb8e 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -548,13 +548,13 @@ def maybe_set_index( df_any._compliant_frame._from_native_frame(native_obj.set_index(keys)) ) elif is_pandas_like_series(native_obj): - from narwhals._pandas_like.utils import set_axis + from narwhals._pandas_like.utils import set_index if column_names: msg = "Cannot set index using column names on a Series" raise ValueError(msg) - native_obj = set_axis( + native_obj = set_index( native_obj, keys, implementation=obj._compliant_series._implementation, # type: ignore[union-attr] From 638cbc9e7d6e2338c80a004b6c36a62ce37e86da Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:10:39 +0000 Subject: [PATCH 4/6] improve error message, remove unnecessary xfail --- narwhals/_pandas_like/group_by.py | 4 ++++ tests/group_by_test.py | 4 ---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index ec3e09b4c..362a109de 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -303,6 +303,10 @@ def agg_pandas( # noqa: PLR0915 "Got two aggregations with the same output name. Please make sure " "that aggregations have unique output names." ) + for key, value in output_names_counter.items(): + if value > 1: + msg += f"\n- '{key}' {value} times" + msg = f"Expected unique output names, got:{msg}" raise ValueError(msg) result = horizontal_concat( dfs=result_aggs, diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 01cabcbee..6429f6eff 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -146,10 +146,6 @@ def test_group_by_depth_1_std_var( ddof: int, request: pytest.FixtureRequest, ) -> None: - if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1): - # Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations" - request.applymarker(pytest.mark.xfail) - if "dask" in str(constructor): # Complex aggregation for dask request.applymarker(pytest.mark.xfail) From 3b30876a1c4917abc104c7ae0bf615621c209f15 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:23:32 +0000 Subject: [PATCH 5/6] correct xfail --- tests/group_by_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 6429f6eff..c7923f02d 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -423,7 +423,8 @@ def test_all_kind_of_aggs( # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 request.applymarker(pytest.mark.xfail) - if "pandas" in str(constructor) and PANDAS_VERSION < (1,): + if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4): + # Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']] request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]})) result = ( From fe72f9d5f7fe1860ed91334c129503525316177b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 22 Dec 2024 19:11:28 +0000 Subject: [PATCH 6/6] fixup --- narwhals/_pandas_like/group_by.py | 7 +++---- tests/group_by_test.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 362a109de..27d072f8a 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -299,13 +299,12 @@ def agg_pandas( # noqa: PLR0915 [c for frame in result_aggs for c in frame] ) if any(v > 1 for v in output_names_counter.values()): - msg = ( - "Got two aggregations with the same output name. Please make sure " - "that aggregations have unique output names." - ) + msg = "" for key, value in output_names_counter.items(): if value > 1: msg += f"\n- '{key}' {value} times" + else: # pragma: no cover + pass msg = f"Expected unique output names, got:{msg}" raise ValueError(msg) result = horizontal_concat( diff --git a/tests/group_by_test.py b/tests/group_by_test.py index c7923f02d..a36c90bf8 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -203,7 +203,7 @@ def test_group_by_same_name_twice() -> None: import pandas as pd df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) - with pytest.raises(ValueError, match="two aggregations with the same"): + with pytest.raises(ValueError, match="Expected unique output names"): nw.from_native(df).group_by("a").agg(nw.col("b").sum(), nw.col("b").n_unique())