Skip to content

Commit

Permalink
feat: Improve error message when trying to do unsupported group-by ag…
Browse files Browse the repository at this point in the history
…gregation (#1584)
  • Loading branch information
MarcoGorelli authored Dec 13, 2024
1 parent 44cf9e3 commit b4f7b96
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ body:
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
render: shell
render: shell
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/doc_issue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ body:
attributes:
value: >
### If you'd be interested in opening a pull request to fix this, please let us know!
2 changes: 1 addition & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ updates:
interval: "monthly"
commit-message:
prefix: "skip changelog" # So this PR will not be added to release-drafter
include: "scope" # List of the updated dependencies in the commit will be added
include: "scope" # List of the updated dependencies in the commit will be added
2 changes: 1 addition & 1 deletion .github/workflows/release-drafter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ jobs:
config-name: release-drafter.yml
# disable-autolabeler: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,7 @@ repos:
rev: v5.0.0
hooks:
- id: name-tests-test
exclude: ^tests/utils\.py
exclude: ^tests/utils\.py
- id: no-commit-to-branch
- id: end-of-file-fixer
exclude: .svg$
2 changes: 1 addition & 1 deletion docs/css/extra.css
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.md-typeset ol li,
.md-typeset ul li {
margin-bottom: 0.1em !important;
}
}
2 changes: 1 addition & 1 deletion docs/javascripts/extra.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ function cleanupClipboardText(targetSelector) {
document$.subscribe(function () {
setCopyText();
});


2 changes: 1 addition & 1 deletion docs/javascripts/katex.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ document$.subscribe(({ body }) => {
{ left: "\\[", right: "\\]", display: true }
],
})
})
})
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ extra_javascript:

extra_css:
- https://unpkg.com/katex@0/dist/katex.min.css
- css/extra.css
- css/extra.css
19 changes: 14 additions & 5 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
from narwhals._arrow.typing import IntoArrowExpr

POLARS_TO_ARROW_AGGREGATIONS = {
"len": "count",
"sum": "sum",
"mean": "mean",
"median": "approximate_median",
"n_unique": "count_distinct",
"max": "max",
"min": "min",
"std": "stddev",
"var": "variance", # currently unused, we don't have `var` yet
"len": "count",
"n_unique": "count_distinct",
"count": "count",
}


Expand All @@ -36,7 +41,7 @@ def get_function_name_option(
import pyarrow.compute as pc

function_name_to_options = {
"count": pc.CountOptions(mode="all"),
"count": pc.CountOptions(mode="only_valid"),
"count_distinct": pc.CountOptions(mode="all"),
"stddev": pc.VarianceOptions(ddof=1),
"variance": pc.VarianceOptions(ddof=1),
Expand Down Expand Up @@ -131,7 +136,11 @@ def agg_arrow(

all_simple_aggs = True
for expr in exprs:
if not is_simple_aggregation(expr):
if not (
is_simple_aggregation(expr)
and remove_prefix(expr._function_name, "col->")
in POLARS_TO_ARROW_AGGREGATIONS
):
all_simple_aggs = False
break

Expand Down Expand Up @@ -185,7 +194,7 @@ def agg_arrow(
return from_dataframe(result_simple)

msg = (
"Non-trivial complex found.\n\n"
"Non-trivial complex aggregation found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"pyarrow table.\n"
"Please rewrite your query such that group-by aggregations "
Expand Down
21 changes: 14 additions & 7 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,16 @@ def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:


POLARS_TO_DASK_AGGREGATIONS = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"std": "std",
"var": "var",
"len": "size",
"n_unique": n_unique,
"count": "count",
}


Expand Down Expand Up @@ -108,7 +116,10 @@ def agg_dask(

all_simple_aggs = True
for expr in exprs:
if not is_simple_aggregation(expr):
if not (
is_simple_aggregation(expr)
and remove_prefix(expr._function_name, "col->") in POLARS_TO_DASK_AGGREGATIONS
):
all_simple_aggs = False
break

Expand Down Expand Up @@ -143,15 +154,11 @@ def agg_dask(

for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
try:
result_simple = grouped.agg(**simple_aggregations)
except ValueError as exc:
msg = "Failed to aggregated - does your aggregation function return a scalar?"
raise RuntimeError(msg) from exc
result_simple = grouped.agg(**simple_aggregations)
return from_dataframe(result_simple.reset_index())

msg = (
"Non-trivial complex found.\n\n"
"Non-trivial complex aggregation found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"dask dataframe.\n"
"Please rewrite your query such that group-by aggregations "
Expand Down
26 changes: 19 additions & 7 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals._pandas_like.utils import rename
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.utils import Implementation
from narwhals.utils import find_stacklevel
from narwhals.utils import remove_prefix
from narwhals.utils import tupleify

Expand All @@ -23,8 +24,16 @@
from narwhals._pandas_like.typing import IntoPandasLikeExpr

POLARS_TO_PANDAS_AGGREGATIONS = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"std": "std",
"var": "var",
"len": "size",
"n_unique": "nunique",
"count": "count",
}


Expand Down Expand Up @@ -144,7 +153,11 @@ def agg_pandas( # noqa: PLR0915
"""
all_aggs_are_simple = True
for expr in exprs:
if not is_simple_aggregation(expr):
if not (
is_simple_aggregation(expr)
and remove_prefix(expr._function_name, "col->")
in POLARS_TO_PANDAS_AGGREGATIONS
):
all_aggs_are_simple = False
break

Expand Down Expand Up @@ -193,11 +206,7 @@ def agg_pandas( # noqa: PLR0915
simple_aggs[named_agg[0]].append(named_agg[1])
name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name
if simple_aggs:
try:
result_simple_aggs = grouped.agg(simple_aggs)
except AttributeError as exc:
msg = "Failed to aggregated - does your aggregation function return a scalar?"
raise RuntimeError(msg) from exc
result_simple_aggs = grouped.agg(simple_aggs)
result_simple_aggs.columns = [
f"{a}_{b}" for a, b in result_simple_aggs.columns
]
Expand Down Expand Up @@ -264,14 +273,17 @@ def agg_pandas( # noqa: PLR0915
"pandas API. If you can, please rewrite your query such that group-by aggregations "
"are simple (e.g. mean, std, min, max, ...).",
UserWarning,
stacklevel=2,
stacklevel=find_stacklevel(),
)

def func(df: Any) -> Any:
out_group = []
out_names = []
for expr in exprs:
results_keys = expr._call(from_dataframe(df))
if not all(len(x) == 1 for x in results_keys):
msg = f"Aggregation '{expr._function_name}' failed to aggregate - does your aggregation function return a scalar?"
raise ValueError(msg)
for result_keys in results_keys:
out_group.append(result_keys._native_series.iloc[0])
out_names.append(result_keys.name)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def agg_pyspark(
for expr in exprs:
if not is_simple_aggregation(expr): # pragma: no cover
msg = (
"Non-trivial complex found.\n\n"
"Non-trivial complex aggregation found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"dask dataframe.\n"
"Please rewrite your query such that group-by aggregations "
Expand Down
33 changes: 29 additions & 4 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_invalid_group_by_dask() -> None:

df_dask = dd.from_pandas(df_pandas)

with pytest.raises(ValueError, match=r"Non-trivial complex found"):
with pytest.raises(ValueError, match=r"Non-trivial complex aggregation found"):
nw.from_native(df_dask).group_by("a").agg(nw.col("b").mean().min())

with pytest.raises(RuntimeError, match="does your"):
with pytest.raises(ValueError, match="Non-trivial complex aggregation"):
nw.from_native(df_dask).group_by("a").agg(nw.col("b"))

with pytest.raises(
Expand All @@ -56,9 +56,10 @@ def test_invalid_group_by_dask() -> None:
nw.from_native(df_dask).group_by("a").agg(nw.all().mean())


@pytest.mark.filterwarnings("ignore:Found complex group-by expression:UserWarning")
def test_invalid_group_by() -> None:
df = nw.from_native(df_pandas)
with pytest.raises(RuntimeError, match="does your"):
with pytest.raises(ValueError, match="does your"):
df.group_by("a").agg(nw.col("b"))
with pytest.raises(
ValueError, match=r"Anonymous expressions are not supported in group_by\.agg"
Expand All @@ -68,7 +69,7 @@ def test_invalid_group_by() -> None:
ValueError, match=r"Anonymous expressions are not supported in group_by\.agg"
):
nw.from_native(pa.table({"a": [1, 2, 3]})).group_by("a").agg(nw.all().mean())
with pytest.raises(ValueError, match=r"Non-trivial complex found"):
with pytest.raises(ValueError, match=r"Non-trivial complex aggregation found"):
nw.from_native(pa.table({"a": [1, 2, 3]})).group_by("a").agg(
nw.col("b").mean().min()
)
Expand Down Expand Up @@ -340,3 +341,27 @@ def test_group_by_categorical(
.sort("x")
)
assert_equal_data(result, data)


@pytest.mark.filterwarnings("ignore:Found complex group-by expression:UserWarning")
def test_group_by_shift_raises(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "polars" in str(constructor):
# Polars supports all kinds of crazy group-by aggregations, so
# we don't check that it errors here.
request.applymarker(pytest.mark.xfail)
df_native = {"a": [1, 2, 3], "b": [1, 1, 2]}
df = nw.from_native(constructor(df_native))
with pytest.raises(
ValueError, match=".*(failed to aggregate|Non-trivial complex aggregation found)"
):
df.group_by("b").agg(nw.col("a").shift(1))


def test_group_by_count(constructor: Constructor) -> None:
data = {"a": [1, 1, 1, 2], "b": [1, None, 2, 3]}
df = nw.from_native(constructor(data))
result = df.group_by("a").agg(nw.col("b").count()).sort("a")
expected = {"a": [1, 2], "b": [2, 1]}
assert_equal_data(result, expected)

0 comments on commit b4f7b96

Please sign in to comment.