Skip to content

Commit

Permalink
patch: group by n_unique (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Sep 6, 2024
1 parent 4cf94ce commit ad5616a
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 26 deletions.
27 changes: 22 additions & 5 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import IntoArrowExpr

POLARS_TO_ARROW_AGGREGATIONS = {
"n_unique": "count_distinct",
"std": "stddev",
"var": "variance", # currently unused, we don't have `var` yet
}


class ArrowGroupBy:
def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None:
Expand Down Expand Up @@ -112,16 +118,27 @@ def agg_arrow(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
if function_name != "len":
if function_name == "len":
simple_aggregations[output_name] = (
(root_name, function_name),
f"{root_name}_{function_name}",
(root_name, "count", pc.CountOptions(mode="all")),
f"{root_name}_count",
)
elif function_name == "count_distinct":
simple_aggregations[output_name] = (
(root_name, "count_distinct", pc.CountOptions(mode="all")),
f"{root_name}_count_distinct",
)
elif function_name == "stddev":
simple_aggregations[output_name] = (
(root_name, "stddev", pc.VarianceOptions(ddof=1)),
f"{root_name}_stddev",
)
else:
simple_aggregations[output_name] = (
(root_name, "count", pc.CountOptions(mode="all")),
f"{root_name}_count",
(root_name, function_name),
f"{root_name}_{function_name}",
)

aggs: list[Any] = []
Expand Down
35 changes: 29 additions & 6 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,33 @@
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
import dask.dataframe as dd
import pandas as pd

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.typing import IntoDaskExpr

POLARS_TO_PANDAS_AGGREGATIONS = {

def n_unique() -> dd.Aggregation:
import dask.dataframe as dd # ignore-banned-import

def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s.nunique(dropna=False) # type: ignore[no-any-return]

def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s0.sum() # type: ignore[no-any-return]

return dd.Aggregation(
name="nunique",
chunk=chunk,
agg=agg,
)


POLARS_TO_DASK_AGGREGATIONS = {
"len": "size",
"n_unique": n_unique,
}


Expand Down Expand Up @@ -85,15 +106,15 @@ def agg_dask(
break

if all_simple_aggs:
simple_aggregations: dict[str, tuple[str, str]] = {}
simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {}
for expr in exprs:
if expr._depth == 0:
# e.g. agg(nw.len()) # noqa: ERA001
if expr._output_names is None: # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
for output_name in expr._output_names:
Expand All @@ -108,9 +129,11 @@ def agg_dask(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
function_name, function_name
)
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)

# deal with n_unique case in a "lazy" mode to not depend on dask globally
function_name = function_name() if callable(function_name) else function_name

for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
try:
Expand Down
69 changes: 54 additions & 15 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

POLARS_TO_PANDAS_AGGREGATIONS = {
"len": "size",
"n_unique": "nunique",
}


Expand Down Expand Up @@ -103,7 +104,7 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]:
yield from ((key, self._from_native_frame(sub_df)) for (key, sub_df) in iterator)


def agg_pandas(
def agg_pandas( # noqa: PLR0915
grouped: Any,
exprs: list[PandasLikeExpr],
keys: list[str],
Expand All @@ -120,13 +121,18 @@ def agg_pandas(
- https://github.com/rapidsai/cudf/issues/15118
- https://github.com/rapidsai/cudf/issues/15084
"""
all_simple_aggs = True
all_aggs_are_simple = True
for expr in exprs:
if not is_simple_aggregation(expr):
all_simple_aggs = False
all_aggs_are_simple = False
break

if all_simple_aggs:
# dict of {output_name: root_name} that we count n_unique on
# We need to do this separately from the rest so that we
# can pass the `dropna` kwargs.
nunique_aggs: dict[str, str] = {}

if all_aggs_are_simple:
simple_aggregations: dict[str, tuple[str, str]] = {}
for expr in exprs:
if expr._depth == 0:
Expand Down Expand Up @@ -154,21 +160,54 @@ def agg_pandas(
function_name, function_name
)
for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
if function_name == "nunique":
nunique_aggs[output_name] = root_name
else:
simple_aggregations[output_name] = (root_name, function_name)

aggs = collections.defaultdict(list)
simple_aggs = collections.defaultdict(list)
name_mapping = {}
for output_name, named_agg in simple_aggregations.items():
aggs[named_agg[0]].append(named_agg[1])
simple_aggs[named_agg[0]].append(named_agg[1])
name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name
try:
result_simple = grouped.agg(aggs)
except AttributeError as exc:
msg = "Failed to aggregated - does your aggregation function return a scalar?"
raise RuntimeError(msg) from exc
result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns]
result_simple = result_simple.rename(columns=name_mapping).reset_index()
return from_dataframe(result_simple.loc[:, output_names])
if simple_aggs:
try:
result_simple_aggs = grouped.agg(simple_aggs)
except AttributeError as exc:
msg = "Failed to aggregated - does your aggregation function return a scalar?"
raise RuntimeError(msg) from exc
result_simple_aggs.columns = [
f"{a}_{b}" for a, b in result_simple_aggs.columns
]
result_simple_aggs = result_simple_aggs.rename(
columns=name_mapping
).reset_index()
if nunique_aggs:
result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique(
dropna=False
)
result_nunique_aggs.columns = list(nunique_aggs.keys())
result_nunique_aggs = result_nunique_aggs.reset_index()
if simple_aggs and nunique_aggs:
if (
set(result_simple_aggs.columns)
.difference(keys)
.intersection(result_nunique_aggs.columns)
):
msg = (
"Got two aggregations with the same output name. Please make sure "
"that aggregations have unique output names."
)
raise ValueError(msg)
result_aggs = result_simple_aggs.merge(result_nunique_aggs, on=keys)
elif nunique_aggs and not simple_aggs:
result_aggs = result_nunique_aggs
elif simple_aggs and not nunique_aggs:
result_aggs = result_simple_aggs
else: # pragma: no cover
msg = "Congrats, you entered unreachable code. Please report a bug to https://github.com/narwhals-dev/narwhals/issues."
raise RuntimeError(msg)
return from_dataframe(result_aggs.loc[:, output_names])

if dataframe_is_empty:
# Don't even attempt this, it's way too inconsistent across pandas versions.
Expand Down
51 changes: 51 additions & 0 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,57 @@ def test_group_by_len(constructor: Any) -> None:
compare_dicts(result, expected)


def test_group_by_n_unique(constructor: Any) -> None:
result = (
nw.from_native(constructor(data))
.group_by("a")
.agg(nw.col("b").n_unique())
.sort("a")
)
expected = {"a": [1, 3], "b": [1, 1]}
compare_dicts(result, expected)


def test_group_by_std(constructor: Any) -> None:
data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]}
result = (
nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").std()).sort("a")
)
expected = {"a": [1, 2], "b": [0.707107] * 2}
compare_dicts(result, expected)


def test_group_by_n_unique_w_missing(constructor: Any) -> None:
data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]}
result = (
nw.from_native(constructor(data))
.group_by("a")
.agg(
nw.col("b").n_unique(),
c_n_unique=nw.col("c").n_unique(),
c_n_min=nw.col("b").min(),
d_n_unique=nw.col("d").n_unique(),
)
.sort("a")
)
expected = {
"a": [1, 2],
"b": [2, 1],
"c_n_unique": [1, 1],
"c_n_min": [4, 5],
"d_n_unique": [1, 1],
}
compare_dicts(result, expected)


def test_group_by_same_name_twice() -> None:
import pandas as pd

df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})
with pytest.raises(ValueError, match="two aggregations with the same"):
nw.from_native(df).group_by("a").agg(nw.col("b").sum(), nw.col("b").n_unique())


def test_group_by_empty_result_pandas() -> None:
df_any = pd.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})
df = nw.from_native(df_any, eager_only=True)
Expand Down

0 comments on commit ad5616a

Please sign in to comment.