From dfd940c32a07f4741a67c09edc8b9a0fd965a406 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 22 Dec 2024 15:15:06 +0100 Subject: [PATCH] handle dups --- narwhals/_pandas_like/group_by.py | 48 +++++++++++++---------------- tests/group_by_test.py | 51 ++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 292ab5add..685ae262f 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -13,7 +13,6 @@ from narwhals._expression_parsing import parse_into_exprs from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable -from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name from narwhals.utils import Implementation from narwhals.utils import find_stacklevel @@ -169,8 +168,14 @@ def agg_pandas( # noqa: PLR0915 # can pass the `dropna` kwargs. nunique_aggs: dict[str, str] = {} simple_aggs: dict[str, list[str]] = collections.defaultdict(list) - std_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict) - var_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict) + + # ddof to (root_names, output_names) mapping + std_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) + ) + var_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) + ) expected_old_names: list[str] = [] new_names: list[str] = [] @@ -212,15 +217,18 @@ def agg_pandas( # noqa: PLR0915 if is_n_unique: nunique_aggs[output_name] = root_name elif is_std and ddof != 1: - std_aggs[ddof].update({output_name: root_name}) + std_aggs[ddof][0].append(root_name) + std_aggs[ddof][1].append(output_name) elif is_var and ddof != 1: - var_aggs[ddof].update({output_name: root_name}) + var_aggs[ddof][0].append(root_name) + var_aggs[ddof][1].append(output_name) else: new_names.append(output_name) expected_old_names.append(f"{root_name}_{function_name}") simple_aggs[root_name].append(function_name) result_aggs = [] + if simple_aggs: result_simple_aggs = grouped.agg(simple_aggs) result_simple_aggs.columns = [ @@ -263,33 +271,19 @@ def agg_pandas( # noqa: PLR0915 if std_aggs: result_aggs.extend( [ - rename( - grouped[list(output_to_root_name_mapping.values())].std( - ddof=ddof - ), - # Invert the dict to have root_name: output_name - # TODO(FBruzzesi): Account for duplicates - columns={v: k for k, v in output_to_root_name_mapping.items()}, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, output_to_root_name_mapping in std_aggs.items() + grouped[std_root_names] + .std(ddof=ddof) + .set_axis(std_output_names, axis="columns", copy=False) + for ddof, (std_root_names, std_output_names) in std_aggs.items() ] ) if var_aggs: result_aggs.extend( [ - rename( - grouped[list(output_to_root_name_mapping.values())].var( - ddof=ddof - ), - # Invert the dict to have root_name: output_name - # TODO(FBruzzesi): Account for duplicates - columns={v: k for k, v in output_to_root_name_mapping.items()}, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, output_to_root_name_mapping in var_aggs.items() + grouped[var_root_names] + .var(ddof=ddof) + .set_axis(var_output_names, axis="columns", copy=False) + for ddof, (var_root_names, var_output_names) in var_aggs.items() ] ) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 6843c50a3..01cabcbee 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -131,6 +131,43 @@ def test_group_by_depth_1_agg( assert_equal_data(result, expected) +@pytest.mark.parametrize( + ("attr", "ddof"), + [ + ("std", 0), + ("var", 0), + ("std", 2), + ("var", 2), + ], +) +def test_group_by_depth_1_std_var( + constructor: Constructor, + attr: str, + ddof: int, + request: pytest.FixtureRequest, +) -> None: + if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1): + # Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations" + request.applymarker(pytest.mark.xfail) + + if "dask" in str(constructor): + # Complex aggregation for dask + request.applymarker(pytest.mark.xfail) + + data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} + _pow = 0.5 if attr == "std" else 1 + expected = { + "a": [1, 2], + "b": [ + (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, + (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, + ], + } + expr = getattr(nw.col("b"), attr)(ddof=ddof) + result = nw.from_native(constructor(data)).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected) + + def test_group_by_median(constructor: Constructor) -> None: data = {"a": [1, 1, 1, 2, 2, 2], "b": [5, 4, 6, 7, 3, 2]} result = ( @@ -385,9 +422,7 @@ def test_double_same_aggregation( def test_all_kind_of_aggs( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - from math import sqrt - - if any(x in str(constructor) for x in ("dask", "cudf")): + if any(x in str(constructor) for x in ("dask", "cudf", "modin_constructor")): # bugged in dask https://github.com/dask/dask/issues/11612 # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 @@ -403,7 +438,8 @@ def test_all_kind_of_aggs( e=nw.col("b").std(ddof=1), f=nw.col("b").std(ddof=2), g=nw.col("b").var(ddof=2), - h=nw.col("b").n_unique(), + h=nw.col("b").var(ddof=2), + i=nw.col("b").n_unique(), ) .sort("a") ) @@ -413,9 +449,10 @@ def test_all_kind_of_aggs( "a": [1, 2], "c": [5, 10 / 3], "d": [5, 10 / 3], - "e": [1, sqrt(variance_num / (3 - 1))], - "f": [sqrt(2), sqrt(variance_num)], # denominator is 1 (=3-2) + "e": [1, (variance_num / (3 - 1)) ** 0.5], + "f": [2**0.5, (variance_num) ** 0.5], # denominator is 1 (=3-2) "g": [2.0, variance_num], # denominator is 1 (=3-2) - "h": [3, 2], + "h": [2.0, variance_num], # denominator is 1 (=3-2) + "i": [3, 2], } assert_equal_data(result, expected)