Skip to content

Commit

Permalink
fix: pyspark group by with kwargs (#1665)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Dec 28, 2024
1 parent 30d0249 commit 74e8f95
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 26 deletions.
38 changes: 24 additions & 14 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,24 +233,34 @@ def _min(_input: Column) -> Column:

return self._from_call(_min, "min", returns_scalar=True)

def std(self, ddof: int) -> Self:
def std(self: Self, ddof: int) -> Self:
from functools import partial

import numpy as np # ignore-banned-import

def _std(_input: Column, ddof: int) -> Column: # pragma: no cover
if self._backend_version < (3, 5) or parse_version(np.__version__) > (2, 0):
from pyspark.sql import functions as F # noqa: N812
from narwhals._spark_like.utils import _std

func = partial(
_std,
ddof=ddof,
backend_version=self._backend_version,
np_version=parse_version(np.__version__),
)

return self._from_call(func, "std", returns_scalar=True, ddof=ddof)

if ddof == 1:
return F.stddev_samp(_input)
def var(self: Self, ddof: int) -> Self:
from functools import partial

n_rows = F.count(_input)
return F.stddev_samp(_input) * F.sqrt((n_rows - 1) / (n_rows - ddof))
import numpy as np # ignore-banned-import

from pyspark.pandas.spark.functions import stddev
from narwhals._spark_like.utils import _var

return stddev(_input, ddof=ddof)
func = partial(
_var,
ddof=ddof,
backend_version=self._backend_version,
np_version=parse_version(np.__version__),
)

expr = self._from_call(_std, "std", returns_scalar=True, ddof=ddof)
if ddof != 1:
expr._depth += 1
return expr
return self._from_call(func, "var", returns_scalar=True, ddof=ddof)
49 changes: 37 additions & 12 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

from copy import copy
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Sequence

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals._spark_like.utils import _std
from narwhals._spark_like.utils import _var
from narwhals.utils import parse_version
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
Expand All @@ -18,10 +22,8 @@
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.typing import CompliantExpr

POLARS_TO_PYSPARK_AGGREGATIONS = {
"len": "count",
"std": "stddev",
}

POLARS_TO_PYSPARK_AGGREGATIONS = {"len": "count"}


class SparkLikeLazyGroupBy:
Expand Down Expand Up @@ -77,7 +79,18 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
)


def get_spark_function(function_name: str) -> Column:
def get_spark_function(
function_name: str, backend_version: tuple[int, ...], **kwargs: Any
) -> Column:
if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

return partial(
_std if function_name == "std" else _var,
ddof=kwargs.get("ddof", 1),
backend_version=backend_version,
np_version=parse_version(np.__version__),
)
from pyspark.sql import functions as F # noqa: N812

return getattr(F, function_name)
Expand Down Expand Up @@ -114,9 +127,12 @@ def agg_pyspark(
function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
for output_name in expr._output_names:
agg_func = get_spark_function(function_name)
simple_aggregations[output_name] = agg_func(keys[0])
agg_func = get_spark_function(
function_name, backend_version=expr._backend_version, **expr._kwargs
)
simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
continue

# e.g. agg(nw.mean('a')) # noqa: ERA001
Expand All @@ -127,11 +143,20 @@ def agg_pyspark(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(function_name, function_name)
pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get(
function_name, function_name
)
agg_func = get_spark_function(
pyspark_function, backend_version=expr._backend_version, **expr._kwargs
)

simple_aggregations.update(
{
output_name: agg_func(root_name)
for root_name, output_name in zip(expr._root_names, expr._output_names)
}
)

for root_name, output_name in zip(expr._root_names, expr._output_names):
agg_func = get_spark_function(function_name)
simple_aggregations[output_name] = agg_func(root_name)
agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()]
try:
result_simple = grouped.agg(*agg_columns)
Expand Down
40 changes: 40 additions & 0 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,43 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
return column_result.over(Window.partitionBy(F.lit(1)))
return column_result
return obj


def _std(
_input: Column,
ddof: int,
backend_version: tuple[int, ...],
np_version: tuple[int, ...],
) -> Column:
if backend_version < (3, 5) or np_version > (2, 0):
from pyspark.sql import functions as F # noqa: N812

if ddof == 1:
return F.stddev_samp(_input)

n_rows = F.count(_input)
return F.stddev_samp(_input) * F.sqrt((n_rows - 1) / (n_rows - ddof))

from pyspark.pandas.spark.functions import stddev

return stddev(_input, ddof=ddof)


def _var(
_input: Column,
ddof: int,
backend_version: tuple[int, ...],
np_version: tuple[int, ...],
) -> Column:
if backend_version < (3, 5) or np_version > (2, 0):
from pyspark.sql import functions as F # noqa: N812

if ddof == 1:
return F.var_samp(_input)

n_rows = F.count(_input)
return F.var_samp(_input) * (n_rows - 1) / (n_rows - ddof)

from pyspark.pandas.spark.functions import var

return var(_input, ddof=ddof)
1 change: 1 addition & 0 deletions narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __narwhals_namespace__(self) -> Any: ...

class CompliantExpr(Protocol, Generic[CompliantSeriesT_co]):
_implementation: Implementation
_backend_version: tuple[int, ...]
_output_names: list[str] | None
_root_names: list[str] | None
_depth: int
Expand Down
50 changes: 50 additions & 0 deletions tests/spark_like_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,27 @@ def test_std(pyspark_constructor: Constructor) -> None:
assert_equal_data(result, expected)


# copied from tests/expr_and_series/var_test.py
def test_var(pyspark_constructor: Constructor) -> None:
data = {"a": [1, 3, 2, None], "b": [4, 4, 6, None], "z": [7.0, 8, 9, None]}

expected_results = {
"a_ddof_1": [1.0],
"a_ddof_0": [0.6666666666666666],
"b_ddof_2": [2.666666666666667],
"z_ddof_0": [0.6666666666666666],
}

df = nw.from_native(pyspark_constructor(data))
result = df.select(
nw.col("a").var(ddof=1).alias("a_ddof_1"),
nw.col("a").var(ddof=0).alias("a_ddof_0"),
nw.col("b").var(ddof=2).alias("b_ddof_2"),
nw.col("z").var(ddof=0).alias("z_ddof_0"),
)
assert_equal_data(result, expected_results)


# copied from tests/group_by_test.py
def test_group_by_std(pyspark_constructor: Constructor) -> None:
data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]}
Expand Down Expand Up @@ -441,3 +462,32 @@ def test_group_by_multiple_keys(pyspark_constructor: Constructor) -> None:
"c_max": [7, 1],
}
assert_equal_data(result, expected)


# copied from tests/group_by_test.py
@pytest.mark.parametrize(
("attr", "ddof"),
[
("std", 0),
("var", 0),
("std", 2),
("var", 2),
],
)
def test_group_by_depth_1_std_var(
pyspark_constructor: Constructor,
attr: str,
ddof: int,
) -> None:
data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}
_pow = 0.5 if attr == "std" else 1
expected = {
"a": [1, 2],
"b": [
(sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow,
(sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow,
],
}
expr = getattr(nw.col("b"), attr)(ddof=ddof)
result = nw.from_native(pyspark_constructor(data)).group_by("a").agg(expr).sort("a")
assert_equal_data(result, expected)

0 comments on commit 74e8f95

Please sign in to comment.