Skip to content

Commit

Permalink
handle dups
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Dec 22, 2024
1 parent a104310 commit dfd940c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 34 deletions.
48 changes: 21 additions & 27 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from narwhals._expression_parsing import parse_into_exprs
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import rename
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.utils import Implementation
from narwhals.utils import find_stacklevel
Expand Down Expand Up @@ -169,8 +168,14 @@ def agg_pandas( # noqa: PLR0915
# can pass the `dropna` kwargs.
nunique_aggs: dict[str, str] = {}
simple_aggs: dict[str, list[str]] = collections.defaultdict(list)
std_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict)
var_aggs: dict[int, dict[str, str]] = collections.defaultdict(dict)

# ddof to (root_names, output_names) mapping
std_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict(
lambda: ([], [])
)
var_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict(
lambda: ([], [])
)

expected_old_names: list[str] = []
new_names: list[str] = []
Expand Down Expand Up @@ -212,15 +217,18 @@ def agg_pandas( # noqa: PLR0915
if is_n_unique:
nunique_aggs[output_name] = root_name
elif is_std and ddof != 1:
std_aggs[ddof].update({output_name: root_name})
std_aggs[ddof][0].append(root_name)
std_aggs[ddof][1].append(output_name)
elif is_var and ddof != 1:
var_aggs[ddof].update({output_name: root_name})
var_aggs[ddof][0].append(root_name)
var_aggs[ddof][1].append(output_name)
else:
new_names.append(output_name)
expected_old_names.append(f"{root_name}_{function_name}")
simple_aggs[root_name].append(function_name)

result_aggs = []

if simple_aggs:
result_simple_aggs = grouped.agg(simple_aggs)
result_simple_aggs.columns = [
Expand Down Expand Up @@ -263,33 +271,19 @@ def agg_pandas( # noqa: PLR0915
if std_aggs:
result_aggs.extend(
[
rename(
grouped[list(output_to_root_name_mapping.values())].std(
ddof=ddof
),
# Invert the dict to have root_name: output_name
# TODO(FBruzzesi): Account for duplicates
columns={v: k for k, v in output_to_root_name_mapping.items()},
implementation=implementation,
backend_version=backend_version,
)
for ddof, output_to_root_name_mapping in std_aggs.items()
grouped[std_root_names]
.std(ddof=ddof)
.set_axis(std_output_names, axis="columns", copy=False)
for ddof, (std_root_names, std_output_names) in std_aggs.items()
]
)
if var_aggs:
result_aggs.extend(
[
rename(
grouped[list(output_to_root_name_mapping.values())].var(
ddof=ddof
),
# Invert the dict to have root_name: output_name
# TODO(FBruzzesi): Account for duplicates
columns={v: k for k, v in output_to_root_name_mapping.items()},
implementation=implementation,
backend_version=backend_version,
)
for ddof, output_to_root_name_mapping in var_aggs.items()
grouped[var_root_names]
.var(ddof=ddof)
.set_axis(var_output_names, axis="columns", copy=False)
for ddof, (var_root_names, var_output_names) in var_aggs.items()
]
)

Expand Down
51 changes: 44 additions & 7 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,43 @@ def test_group_by_depth_1_agg(
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("attr", "ddof"),
[
("std", 0),
("var", 0),
("std", 2),
("var", 2),
],
)
def test_group_by_depth_1_std_var(
constructor: Constructor,
attr: str,
ddof: int,
request: pytest.FixtureRequest,
) -> None:
if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1):
# Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations"
request.applymarker(pytest.mark.xfail)

if "dask" in str(constructor):
# Complex aggregation for dask
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}
_pow = 0.5 if attr == "std" else 1
expected = {
"a": [1, 2],
"b": [
(sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow,
(sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow,
],
}
expr = getattr(nw.col("b"), attr)(ddof=ddof)
result = nw.from_native(constructor(data)).group_by("a").agg(expr).sort("a")
assert_equal_data(result, expected)


def test_group_by_median(constructor: Constructor) -> None:
data = {"a": [1, 1, 1, 2, 2, 2], "b": [5, 4, 6, 7, 3, 2]}
result = (
Expand Down Expand Up @@ -385,9 +422,7 @@ def test_double_same_aggregation(
def test_all_kind_of_aggs(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
from math import sqrt

if any(x in str(constructor) for x in ("dask", "cudf")):
if any(x in str(constructor) for x in ("dask", "cudf", "modin_constructor")):
# bugged in dask https://github.com/dask/dask/issues/11612
# and modin lol https://github.com/modin-project/modin/issues/7414
# and cudf https://github.com/rapidsai/cudf/issues/17649
Expand All @@ -403,7 +438,8 @@ def test_all_kind_of_aggs(
e=nw.col("b").std(ddof=1),
f=nw.col("b").std(ddof=2),
g=nw.col("b").var(ddof=2),
h=nw.col("b").n_unique(),
h=nw.col("b").var(ddof=2),
i=nw.col("b").n_unique(),
)
.sort("a")
)
Expand All @@ -413,9 +449,10 @@ def test_all_kind_of_aggs(
"a": [1, 2],
"c": [5, 10 / 3],
"d": [5, 10 / 3],
"e": [1, sqrt(variance_num / (3 - 1))],
"f": [sqrt(2), sqrt(variance_num)], # denominator is 1 (=3-2)
"e": [1, (variance_num / (3 - 1)) ** 0.5],
"f": [2**0.5, (variance_num) ** 0.5], # denominator is 1 (=3-2)
"g": [2.0, variance_num], # denominator is 1 (=3-2)
"h": [3, 2],
"h": [2.0, variance_num], # denominator is 1 (=3-2)
"i": [3, 2],
}
assert_equal_data(result, expected)

0 comments on commit dfd940c

Please sign in to comment.