From 74e8f95cc698535e4ddd226405cc08e774219903 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Sat, 28 Dec 2024 16:57:44 +0100 Subject: [PATCH] fix: pyspark group by with kwargs (#1665) --------- Co-authored-by: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> --- narwhals/_spark_like/expr.py | 38 +++++++++++++++--------- narwhals/_spark_like/group_by.py | 49 +++++++++++++++++++++++-------- narwhals/_spark_like/utils.py | 40 +++++++++++++++++++++++++ narwhals/typing.py | 1 + tests/spark_like_test.py | 50 ++++++++++++++++++++++++++++++++ 5 files changed, 152 insertions(+), 26 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index da25b32e0..3d09a2427 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -233,24 +233,34 @@ def _min(_input: Column) -> Column: return self._from_call(_min, "min", returns_scalar=True) - def std(self, ddof: int) -> Self: + def std(self: Self, ddof: int) -> Self: + from functools import partial + import numpy as np # ignore-banned-import - def _std(_input: Column, ddof: int) -> Column: # pragma: no cover - if self._backend_version < (3, 5) or parse_version(np.__version__) > (2, 0): - from pyspark.sql import functions as F # noqa: N812 + from narwhals._spark_like.utils import _std + + func = partial( + _std, + ddof=ddof, + backend_version=self._backend_version, + np_version=parse_version(np.__version__), + ) + + return self._from_call(func, "std", returns_scalar=True, ddof=ddof) - if ddof == 1: - return F.stddev_samp(_input) + def var(self: Self, ddof: int) -> Self: + from functools import partial - n_rows = F.count(_input) - return F.stddev_samp(_input) * F.sqrt((n_rows - 1) / (n_rows - ddof)) + import numpy as np # ignore-banned-import - from pyspark.pandas.spark.functions import stddev + from narwhals._spark_like.utils import _var - return stddev(_input, ddof=ddof) + func = partial( + _var, + ddof=ddof, + backend_version=self._backend_version, + np_version=parse_version(np.__version__), + ) - expr = self._from_call(_std, "std", returns_scalar=True, ddof=ddof) - if ddof != 1: - expr._depth += 1 - return expr + return self._from_call(func, "var", returns_scalar=True, ddof=ddof) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index ecd9f235d..f07101a7e 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -1,6 +1,7 @@ from __future__ import annotations from copy import copy +from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -8,6 +9,9 @@ from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs +from narwhals._spark_like.utils import _std +from narwhals._spark_like.utils import _var +from narwhals.utils import parse_version from narwhals.utils import remove_prefix if TYPE_CHECKING: @@ -18,10 +22,8 @@ from narwhals._spark_like.typing import IntoSparkLikeExpr from narwhals.typing import CompliantExpr -POLARS_TO_PYSPARK_AGGREGATIONS = { - "len": "count", - "std": "stddev", -} + +POLARS_TO_PYSPARK_AGGREGATIONS = {"len": "count"} class SparkLikeLazyGroupBy: @@ -77,7 +79,18 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: ) -def get_spark_function(function_name: str) -> Column: +def get_spark_function( + function_name: str, backend_version: tuple[int, ...], **kwargs: Any +) -> Column: + if function_name in {"std", "var"}: + import numpy as np # ignore-banned-import + + return partial( + _std if function_name == "std" else _var, + ddof=kwargs.get("ddof", 1), + backend_version=backend_version, + np_version=parse_version(np.__version__), + ) from pyspark.sql import functions as F # noqa: N812 return getattr(F, function_name) @@ -114,9 +127,12 @@ def agg_pyspark( function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get( expr._function_name, expr._function_name ) - for output_name in expr._output_names: - agg_func = get_spark_function(function_name) - simple_aggregations[output_name] = agg_func(keys[0]) + agg_func = get_spark_function( + function_name, backend_version=expr._backend_version, **expr._kwargs + ) + simple_aggregations.update( + {output_name: agg_func(keys[0]) for output_name in expr._output_names} + ) continue # e.g. agg(nw.mean('a')) # noqa: ERA001 @@ -127,11 +143,20 @@ def agg_pyspark( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") - function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(function_name, function_name) + pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get( + function_name, function_name + ) + agg_func = get_spark_function( + pyspark_function, backend_version=expr._backend_version, **expr._kwargs + ) + + simple_aggregations.update( + { + output_name: agg_func(root_name) + for root_name, output_name in zip(expr._root_names, expr._output_names) + } + ) - for root_name, output_name in zip(expr._root_names, expr._output_names): - agg_func = get_spark_function(function_name) - simple_aggregations[output_name] = agg_func(root_name) agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()] try: result_simple = grouped.agg(*agg_columns) diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 032522475..12892ca0c 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -118,3 +118,43 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any: return column_result.over(Window.partitionBy(F.lit(1))) return column_result return obj + + +def _std( + _input: Column, + ddof: int, + backend_version: tuple[int, ...], + np_version: tuple[int, ...], +) -> Column: + if backend_version < (3, 5) or np_version > (2, 0): + from pyspark.sql import functions as F # noqa: N812 + + if ddof == 1: + return F.stddev_samp(_input) + + n_rows = F.count(_input) + return F.stddev_samp(_input) * F.sqrt((n_rows - 1) / (n_rows - ddof)) + + from pyspark.pandas.spark.functions import stddev + + return stddev(_input, ddof=ddof) + + +def _var( + _input: Column, + ddof: int, + backend_version: tuple[int, ...], + np_version: tuple[int, ...], +) -> Column: + if backend_version < (3, 5) or np_version > (2, 0): + from pyspark.sql import functions as F # noqa: N812 + + if ddof == 1: + return F.var_samp(_input) + + n_rows = F.count(_input) + return F.var_samp(_input) * (n_rows - 1) / (n_rows - ddof) + + from pyspark.pandas.spark.functions import var + + return var(_input, ddof=ddof) diff --git a/narwhals/typing.py b/narwhals/typing.py index ec440751a..ff29cb57e 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -68,6 +68,7 @@ def __narwhals_namespace__(self) -> Any: ... class CompliantExpr(Protocol, Generic[CompliantSeriesT_co]): _implementation: Implementation + _backend_version: tuple[int, ...] _output_names: list[str] | None _root_names: list[str] | None _depth: int diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index c4eb040c3..0d13edefd 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -369,6 +369,27 @@ def test_std(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +# copied from tests/expr_and_series/var_test.py +def test_var(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2, None], "b": [4, 4, 6, None], "z": [7.0, 8, 9, None]} + + expected_results = { + "a_ddof_1": [1.0], + "a_ddof_0": [0.6666666666666666], + "b_ddof_2": [2.666666666666667], + "z_ddof_0": [0.6666666666666666], + } + + df = nw.from_native(pyspark_constructor(data)) + result = df.select( + nw.col("a").var(ddof=1).alias("a_ddof_1"), + nw.col("a").var(ddof=0).alias("a_ddof_0"), + nw.col("b").var(ddof=2).alias("b_ddof_2"), + nw.col("z").var(ddof=0).alias("z_ddof_0"), + ) + assert_equal_data(result, expected_results) + + # copied from tests/group_by_test.py def test_group_by_std(pyspark_constructor: Constructor) -> None: data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]} @@ -441,3 +462,32 @@ def test_group_by_multiple_keys(pyspark_constructor: Constructor) -> None: "c_max": [7, 1], } assert_equal_data(result, expected) + + +# copied from tests/group_by_test.py +@pytest.mark.parametrize( + ("attr", "ddof"), + [ + ("std", 0), + ("var", 0), + ("std", 2), + ("var", 2), + ], +) +def test_group_by_depth_1_std_var( + pyspark_constructor: Constructor, + attr: str, + ddof: int, +) -> None: + data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} + _pow = 0.5 if attr == "std" else 1 + expected = { + "a": [1, 2], + "b": [ + (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, + (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, + ], + } + expr = getattr(nw.col("b"), attr)(ddof=ddof) + result = nw.from_native(pyspark_constructor(data)).group_by("a").agg(expr).sort("a") + assert_equal_data(result, expected)