Skip to content

Commit

Permalink
chore: move some logic to get_spark_function
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Dec 28, 2024
1 parent 9b7954e commit 2ed4b7b
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@
from narwhals.typing import CompliantExpr


POLARS_TO_PYSPARK_AGGREGATIONS = {
"len": "count",
"std": _std,
"var": _var,
}
POLARS_TO_PYSPARK_AGGREGATIONS = {"len": "count"}


class SparkLikeLazyGroupBy:
Expand Down Expand Up @@ -83,7 +79,18 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
)


def get_spark_function(function_name: str) -> Column:
def get_spark_function(
function_name: str, backend_version: tuple[int, ...], **kwargs: Any
) -> Column:
if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

return partial(
_std if function_name == "std" else _var,
ddof=kwargs.get("ddof", 1),
backend_version=backend_version,
np_version=parse_version(np.__version__),
)
from pyspark.sql import functions as F # noqa: N812

return getattr(F, function_name)
Expand Down Expand Up @@ -120,7 +127,9 @@ def agg_pyspark(
function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
agg_func = get_spark_function(function_name) # type: ignore[arg-type]
agg_func = get_spark_function(
function_name, backend_version=expr._backend_version, **expr._kwargs
)
simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
Expand All @@ -137,18 +146,9 @@ def agg_pyspark(
pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get(
function_name, function_name
)

if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

agg_func = partial( # type: ignore[misc,operator]
pyspark_function, # type: ignore[arg-type]
ddof=expr._kwargs.get("ddof", 1),
backend_version=expr._backend_version,
np_version=parse_version(np.__version__),
)
else:
agg_func = get_spark_function(pyspark_function) # type: ignore[arg-type]
agg_func = get_spark_function(
pyspark_function, backend_version=expr._backend_version, **expr._kwargs
)

simple_aggregations.update(
{
Expand Down

0 comments on commit 2ed4b7b

Please sign in to comment.