From 9f09ea0f4efbb693d370902600161a0fb41c9761 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Sun, 22 Dec 2024 20:19:56 +0100 Subject: [PATCH] feat: support std and var with ddof !=1 in pandas-like group by (#1645) --------- Co-authored-by: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> --- narwhals/_arrow/expr.py | 18 +++--- narwhals/_arrow/namespace.py | 7 ++- narwhals/_arrow/selectors.py | 6 +- narwhals/_dask/expr.py | 18 +++--- narwhals/_dask/namespace.py | 21 ++++--- narwhals/_expression_parsing.py | 5 +- narwhals/_pandas_like/expr.py | 28 ++++----- narwhals/_pandas_like/group_by.py | 92 ++++++++++++++++++++++-------- narwhals/_pandas_like/namespace.py | 7 ++- narwhals/_pandas_like/selectors.py | 6 +- narwhals/_pandas_like/series.py | 6 +- narwhals/_pandas_like/utils.py | 43 ++++++++++++-- narwhals/_spark_like/expr.py | 2 +- narwhals/_spark_like/namespace.py | 2 +- narwhals/utils.py | 4 +- tests/group_by_test.py | 75 +++++++++++++++++++++++- 16 files changed, 250 insertions(+), 90 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 6f1d627f5..15afcba1c 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -308,7 +308,7 @@ def alias(self: Self, name: str) -> Self: output_names=[name], backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def null_count(self: Self) -> Self: @@ -436,7 +436,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=self._output_names, backend_version=self._backend_version, version=self._version, - kwargs={"keys": keys}, + kwargs={**self._kwargs, "keys": keys}, ) def mode(self: Self) -> Self: @@ -478,7 +478,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=self._output_names, backend_version=self._backend_version, version=self._version, - kwargs={"function": function, "return_dtype": return_dtype}, + kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype}, ) def is_finite(self: Self) -> Self: @@ -810,7 +810,7 @@ def keep(self: Self) -> ArrowExpr: output_names=root_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: @@ -837,7 +837,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"function": function}, + kwargs={**self._compliant_expr._kwargs, "function": function}, ) def prefix(self: Self, prefix: str) -> ArrowExpr: @@ -862,7 +862,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"prefix": prefix}, + kwargs={**self._compliant_expr._kwargs, "prefix": prefix}, ) def suffix(self: Self, suffix: str) -> ArrowExpr: @@ -888,7 +888,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"suffix": suffix}, + kwargs={**self._compliant_expr._kwargs, "suffix": suffix}, ) def to_lowercase(self: Self) -> ArrowExpr: @@ -914,7 +914,7 @@ def to_lowercase(self: Self) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def to_uppercase(self: Self) -> ArrowExpr: @@ -940,7 +940,7 @@ def to_uppercase(self: Self) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 7dc4db577..cb59d5b71 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -432,7 +432,12 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="concat_str", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), - kwargs={"separator": separator, "ignore_nulls": ignore_nulls}, + kwargs={ + "exprs": exprs, + "more_exprs": more_exprs, + "separator": separator, + "ignore_nulls": ignore_nulls, + }, ) diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 7750bdd03..48e837ec7 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -124,7 +124,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() - other @@ -145,7 +145,7 @@ def call(df: ArrowDataFrame) -> Sequence[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() | other @@ -166,7 +166,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() & other diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 50133ff93..dde6cd9e6 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -177,7 +177,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=self._returns_scalar or returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs=kwargs, + kwargs={**self._kwargs, **kwargs}, ) def alias(self, name: str) -> Self: @@ -194,7 +194,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def __add__(self, other: Any) -> Self: @@ -847,7 +847,7 @@ def func(df: DaskLazyFrame) -> list[Any]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={"keys": keys}, + kwargs={**self._kwargs, "keys": keys}, ) def mode(self: Self) -> Self: @@ -1334,7 +1334,7 @@ def keep(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def map(self: Self, function: Callable[[str], str]) -> DaskExpr: @@ -1362,7 +1362,7 @@ def map(self: Self, function: Callable[[str], str]) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"function": function}, + kwargs={**self._compliant_expr._kwargs, "function": function}, ) def prefix(self: Self, prefix: str) -> DaskExpr: @@ -1388,7 +1388,7 @@ def prefix(self: Self, prefix: str) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"prefix": prefix}, + kwargs={**self._compliant_expr._kwargs, "prefix": prefix}, ) def suffix(self: Self, suffix: str) -> DaskExpr: @@ -1415,7 +1415,7 @@ def suffix(self: Self, suffix: str) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"suffix": suffix}, + kwargs={**self._compliant_expr._kwargs, "suffix": suffix}, ) def to_lowercase(self: Self) -> DaskExpr: @@ -1442,7 +1442,7 @@ def to_lowercase(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def to_uppercase(self: Self) -> DaskExpr: @@ -1469,5 +1469,5 @@ def to_uppercase(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index e0870d242..7c0a6b7eb 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -159,7 +159,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -178,7 +178,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -197,7 +197,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def concat( @@ -279,7 +279,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -301,7 +301,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -323,7 +323,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def when( @@ -388,7 +388,12 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={ + "exprs": exprs, + "more_exprs": more_exprs, + "separator": separator, + "ignore_nulls": ignore_nulls, + }, ) @@ -451,7 +456,7 @@ def then(self, value: DaskExpr | Any) -> DaskThen: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"value": value}, ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 3a9744c9c..4d51eb719 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -225,14 +225,13 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: ): # pragma: no cover msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) - return plx._create_expr_from_callable( # type: ignore[return-value] func, # type: ignore[arg-type] depth=expr._depth + 1, function_name=f"{expr._function_name}->{attr}", root_names=root_names, output_names=output_names, - kwargs=kwargs, + kwargs={**expr._kwargs, **kwargs}, ) @@ -272,7 +271,7 @@ def reuse_series_namespace_implementation( function_name=f"{expr._function_name}->{series_namespace}.{attr}", root_names=expr._root_names, output_names=expr._output_names, - kwargs=kwargs, + kwargs={**expr._kwargs, **kwargs}, ) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 0cf2a3f73..e7e98d991 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -263,16 +263,10 @@ def median(self) -> Self: return reuse_series_implementation(self, "median", returns_scalar=True) def std(self, *, ddof: int) -> Self: - expr = reuse_series_implementation(self, "std", ddof=ddof, returns_scalar=True) - if ddof != 1: - expr._depth += 1 - return expr + return reuse_series_implementation(self, "std", ddof=ddof, returns_scalar=True) def var(self, *, ddof: int) -> Self: - expr = reuse_series_implementation(self, "var", ddof=ddof, returns_scalar=True) - if ddof != 1: - expr._depth += 1 - return expr + return reuse_series_implementation(self, "var", ddof=ddof, returns_scalar=True) def skew(self: Self) -> Self: return reuse_series_implementation(self, "skew", returns_scalar=True) @@ -419,7 +413,7 @@ def alias(self, name: str) -> Self: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def over(self, keys: list[str]) -> Self: @@ -491,7 +485,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"keys": keys}, + kwargs={**self._kwargs, "keys": keys}, ) def is_duplicated(self) -> Self: @@ -568,7 +562,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"function": function, "return_dtype": return_dtype}, + kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype}, ) def is_finite(self: Self) -> Self: @@ -907,7 +901,7 @@ def keep(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: @@ -935,7 +929,7 @@ def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"function": function}, + kwargs={**self._compliant_expr._kwargs, "function": function}, ) def prefix(self: Self, prefix: str) -> PandasLikeExpr: @@ -961,7 +955,7 @@ def prefix(self: Self, prefix: str) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"prefix": prefix}, + kwargs={**self._compliant_expr._kwargs, "prefix": prefix}, ) def suffix(self: Self, suffix: str) -> PandasLikeExpr: @@ -988,7 +982,7 @@ def suffix(self: Self, suffix: str) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={"suffix": suffix}, + kwargs={**self._compliant_expr._kwargs, "suffix": suffix}, ) def to_lowercase(self: Self) -> PandasLikeExpr: @@ -1015,7 +1009,7 @@ def to_lowercase(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) def to_uppercase(self: Self) -> PandasLikeExpr: @@ -1042,7 +1036,7 @@ def to_uppercase(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, - kwargs={}, + kwargs=self._compliant_expr._kwargs, ) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 9c4ffccbb..27d072f8a 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -14,6 +14,7 @@ from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable from narwhals._pandas_like.utils import select_columns_by_name +from narwhals._pandas_like.utils import set_columns from narwhals.utils import Implementation from narwhals.utils import find_stacklevel from narwhals.utils import remove_prefix @@ -168,6 +169,15 @@ def agg_pandas( # noqa: PLR0915 # can pass the `dropna` kwargs. nunique_aggs: dict[str, str] = {} simple_aggs: dict[str, list[str]] = collections.defaultdict(list) + + # ddof to (root_names, output_names) mapping + std_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) + ) + var_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) + ) + expected_old_names: list[str] = [] new_names: list[str] = [] @@ -199,15 +209,27 @@ def agg_pandas( # noqa: PLR0915 function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( function_name, function_name ) + is_n_unique = function_name == "nunique" + is_std = function_name == "std" + is_var = function_name == "var" + ddof = expr._kwargs.get("ddof", 1) for root_name, output_name in zip(expr._root_names, expr._output_names): if is_n_unique: nunique_aggs[output_name] = root_name + elif is_std and ddof != 1: + std_aggs[ddof][0].append(root_name) + std_aggs[ddof][1].append(output_name) + elif is_var and ddof != 1: + var_aggs[ddof][0].append(root_name) + var_aggs[ddof][1].append(output_name) else: new_names.append(output_name) expected_old_names.append(f"{root_name}_{function_name}") simple_aggs[root_name].append(function_name) + result_aggs = [] + if simple_aggs: result_simple_aggs = grouped.agg(simple_aggs) result_simple_aggs.columns = [ @@ -237,43 +259,67 @@ def agg_pandas( # noqa: PLR0915 new_names = [new_names[i] for i in index_map] result_simple_aggs.columns = new_names + result_aggs.append(result_simple_aggs) + if nunique_aggs: result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique( dropna=False ) result_nunique_aggs.columns = list(nunique_aggs.keys()) - if simple_aggs and nunique_aggs: - if ( - set(result_simple_aggs.columns) - .difference(keys) - .intersection(result_nunique_aggs.columns) - ): - msg = ( - "Got two aggregations with the same output name. Please make sure " - "that aggregations have unique output names." - ) + + result_aggs.append(result_nunique_aggs) + + if std_aggs: + result_aggs.extend( + [ + set_columns( + grouped[std_root_names].std(ddof=ddof), + columns=std_output_names, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (std_root_names, std_output_names) in std_aggs.items() + ] + ) + if var_aggs: + result_aggs.extend( + [ + set_columns( + grouped[var_root_names].var(ddof=ddof), + columns=var_output_names, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (var_root_names, var_output_names) in var_aggs.items() + ] + ) + + if result_aggs: + output_names_counter = collections.Counter( + [c for frame in result_aggs for c in frame] + ) + if any(v > 1 for v in output_names_counter.values()): + msg = "" + for key, value in output_names_counter.items(): + if value > 1: + msg += f"\n- '{key}' {value} times" + else: # pragma: no cover + pass + msg = f"Expected unique output names, got:{msg}" raise ValueError(msg) - result_aggs = horizontal_concat( - [result_simple_aggs, result_nunique_aggs], + result = horizontal_concat( + dfs=result_aggs, implementation=implementation, backend_version=backend_version, ) - elif nunique_aggs and not simple_aggs: - result_aggs = result_nunique_aggs - elif simple_aggs and not nunique_aggs: - result_aggs = result_simple_aggs else: # No aggregation provided - result_aggs = native_namespace.DataFrame( - list(grouped.groups.keys()), columns=keys - ) + result = native_namespace.DataFrame(list(grouped.groups.keys()), columns=keys) # Keep inplace=True to avoid making a redundant copy. # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result_aggs.reset_index(inplace=True) # noqa: PD002 + result.reset_index(inplace=True) # noqa: PD002 return from_dataframe( - select_columns_by_name( - result_aggs, output_names, backend_version, implementation - ) + select_columns_by_name(result, output_names, backend_version, implementation) ) if dataframe_is_empty: diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 495173025..3aa4015a0 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -474,7 +474,12 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="concat_str", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), - kwargs={"separator": separator, "ignore_nulls": ignore_nulls}, + kwargs={ + "exprs": exprs, + "more_exprs": more_exprs, + "separator": separator, + "ignore_nulls": ignore_nulls, + }, ) diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 7ef666b96..e7d7fe18d 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -129,7 +129,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() - other @@ -151,7 +151,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() | other @@ -173,7 +173,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - kwargs={"other": other}, + kwargs={**self._kwargs, "other": other}, ) else: return self._to_expr() & other diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index a8b57c5ce..1d895e147 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -17,7 +17,7 @@ from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name -from narwhals._pandas_like.utils import set_axis +from narwhals._pandas_like.utils import set_index from narwhals._pandas_like.utils import to_datetime from narwhals.dependencies import is_numpy_scalar from narwhals.typing import CompliantSeries @@ -211,7 +211,7 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self: # .copy() is necessary in some pre-2.2 versions of pandas to avoid # `values` also getting modified (!) _, values = broadcast_align_and_extract_native(self, values) - values = set_axis( + values = set_index( values.copy(), self._native_series.index[indices], implementation=self._implementation, @@ -1423,7 +1423,7 @@ def len(self: Self) -> PandasLikeSeries: self._compliant_series._implementation is Implementation.PANDAS and self._compliant_series._backend_version < (3, 0) ): # pragma: no cover - native_result = set_axis( + native_result = set_index( rename( native_result, native_series.name, diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2b0da769c..5c523138f 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -138,7 +138,7 @@ def broadcast_align_and_extract_native( if rhs._native_series.index is not lhs_index: return ( lhs._native_series, - set_axis( + set_index( rhs._native_series, lhs_index, implementation=rhs._implementation, @@ -168,7 +168,7 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any: s = other._native_series return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name) if other._native_series.index is not index: - return set_axis( + return set_index( other._native_series, index, implementation=other._implementation, @@ -302,14 +302,17 @@ def native_series_from_iterable( raise TypeError(msg) -def set_axis( +def set_index( obj: T, index: Any, *, implementation: Implementation, backend_version: tuple[int, ...], ) -> T: - """Wrapper around pandas' set_axis so that we can set `copy` / `inplace` based on implementation/version.""" + """Wrapper around pandas' set_axis to set object index. + + We can set `copy` / `inplace` based on implementation/version. + """ if implementation is Implementation.CUDF: # pragma: no cover obj = obj.copy(deep=False) # type: ignore[attr-defined] obj.index = index # type: ignore[attr-defined] @@ -329,6 +332,36 @@ def set_axis( return obj.set_axis(index, axis=0, **kwargs) # type: ignore[attr-defined, no-any-return] +def set_columns( + obj: T, + columns: list[str], + *, + implementation: Implementation, + backend_version: tuple[int, ...], +) -> T: + """Wrapper around pandas' set_axis to set object columns. + + We can set `copy` / `inplace` based on implementation/version. + """ + if implementation is Implementation.CUDF: # pragma: no cover + obj = obj.copy(deep=False) # type: ignore[attr-defined] + obj.columns = columns # type: ignore[attr-defined] + return obj + if implementation is Implementation.PANDAS and ( + backend_version < (1,) + ): # pragma: no cover + kwargs = {"inplace": False} + else: + kwargs = {} + if implementation is Implementation.PANDAS and ( + (1, 5) <= backend_version < (3,) + ): # pragma: no cover + kwargs["copy"] = False + else: # pragma: no cover + pass + return obj.set_axis(columns, axis=1, **kwargs) # type: ignore[attr-defined, no-any-return] + + def rename( obj: T, *args: Any, @@ -654,7 +687,7 @@ def broadcast_series(series: Sequence[PandasLikeSeries]) -> list[Any]: elif s_native.index is not idx: reindexed.append( - set_axis( + set_index( s_native, idx, implementation=s._implementation, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index cbd645298..da25b32e0 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -198,7 +198,7 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs={"name": name}, + kwargs={**self._kwargs, "name": name}, ) def count(self) -> Self: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 639523ed0..d150e7541 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -59,7 +59,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=False, backend_version=self._backend_version, version=self._version, - kwargs={}, + kwargs={"exprs": exprs}, ) def col(self, *column_names: str) -> SparkLikeExpr: diff --git a/narwhals/utils.py b/narwhals/utils.py index bd0c625bb..b6337cb8e 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -548,13 +548,13 @@ def maybe_set_index( df_any._compliant_frame._from_native_frame(native_obj.set_index(keys)) ) elif is_pandas_like_series(native_obj): - from narwhals._pandas_like.utils import set_axis + from narwhals._pandas_like.utils import set_index if column_names: msg = "Cannot set index using column names on a Series" raise ValueError(msg) - native_obj = set_axis( + native_obj = set_index( native_obj, keys, implementation=obj._compliant_series._implementation, # type: ignore[union-attr] diff --git a/tests/group_by_test.py b/tests/group_by_test.py index b7f2d2ea1..a36c90bf8 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -131,6 +131,39 @@ def test_group_by_depth_1_agg( assert_equal_data(result, expected) +@pytest.mark.parametrize( + ("attr", "ddof"), + [ + ("std", 0), + ("var", 0), + ("std", 2), + ("var", 2), + ], +) +def test_group_by_depth_1_std_var( + constructor: Constructor, + attr: str, + ddof: int, + request: pytest.FixtureRequest, +) -> None: + if "dask" in str(constructor): + # Complex aggregation for dask + request.applymarker(pytest.mark.xfail) + + data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} + _pow = 0.5 if attr == "std" else 1 + expected = { + "a": [1, 2], + "b": [ + (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, + (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, + ], + } + expr = getattr(nw.col("b"), attr)(ddof=ddof) + result = nw.from_native(constructor(data)).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + def test_group_by_median(constructor: Constructor) -> None: data = {"a": [1, 1, 1, 2, 2, 2], "b": [5, 4, 6, 7, 3, 2]} result = ( @@ -170,7 +203,7 @@ def test_group_by_same_name_twice() -> None: import pandas as pd df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) - with pytest.raises(ValueError, match="two aggregations with the same"): + with pytest.raises(ValueError, match="Expected unique output names"): nw.from_native(df).group_by("a").agg(nw.col("b").sum(), nw.col("b").n_unique()) @@ -380,3 +413,43 @@ def test_double_same_aggregation( result = df.group_by("a").agg(c=nw.col("b").mean(), d=nw.col("b").mean()).sort("a") expected = {"a": [1, 2], "c": [4.5, 6], "d": [4.5, 6]} assert_equal_data(result, expected) + + +def test_all_kind_of_aggs( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if any(x in str(constructor) for x in ("dask", "cudf", "modin_constructor")): + # bugged in dask https://github.com/dask/dask/issues/11612 + # and modin lol https://github.com/modin-project/modin/issues/7414 + # and cudf https://github.com/rapidsai/cudf/issues/17649 + request.applymarker(pytest.mark.xfail) + if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4): + # Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']] + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor({"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]})) + result = ( + df.group_by("a") + .agg( + c=nw.col("b").mean(), + d=nw.col("b").mean(), + e=nw.col("b").std(ddof=1), + f=nw.col("b").std(ddof=2), + g=nw.col("b").var(ddof=2), + h=nw.col("b").var(ddof=2), + i=nw.col("b").n_unique(), + ) + .sort("a") + ) + + variance_num = sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) + expected = { + "a": [1, 2], + "c": [5, 10 / 3], + "d": [5, 10 / 3], + "e": [1, (variance_num / (3 - 1)) ** 0.5], + "f": [2**0.5, (variance_num) ** 0.5], # denominator is 1 (=3-2) + "g": [2.0, variance_num], # denominator is 1 (=3-2) + "h": [2.0, variance_num], # denominator is 1 (=3-2) + "i": [3, 2], + } + assert_equal_data(result, expected)