From 2ed4b7b2a040533b16240b5166f608e9c18471a6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 28 Dec 2024 15:52:47 +0000 Subject: [PATCH] chore: move some logic to get_spark_function --- narwhals/_spark_like/group_by.py | 38 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 985dc3168..f07101a7e 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -23,11 +23,7 @@ from narwhals.typing import CompliantExpr -POLARS_TO_PYSPARK_AGGREGATIONS = { - "len": "count", - "std": _std, - "var": _var, -} +POLARS_TO_PYSPARK_AGGREGATIONS = {"len": "count"} class SparkLikeLazyGroupBy: @@ -83,7 +79,18 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: ) -def get_spark_function(function_name: str) -> Column: +def get_spark_function( + function_name: str, backend_version: tuple[int, ...], **kwargs: Any +) -> Column: + if function_name in {"std", "var"}: + import numpy as np # ignore-banned-import + + return partial( + _std if function_name == "std" else _var, + ddof=kwargs.get("ddof", 1), + backend_version=backend_version, + np_version=parse_version(np.__version__), + ) from pyspark.sql import functions as F # noqa: N812 return getattr(F, function_name) @@ -120,7 +127,9 @@ def agg_pyspark( function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get( expr._function_name, expr._function_name ) - agg_func = get_spark_function(function_name) # type: ignore[arg-type] + agg_func = get_spark_function( + function_name, backend_version=expr._backend_version, **expr._kwargs + ) simple_aggregations.update( {output_name: agg_func(keys[0]) for output_name in expr._output_names} ) @@ -137,18 +146,9 @@ def agg_pyspark( pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get( function_name, function_name ) - - if function_name in {"std", "var"}: - import numpy as np # ignore-banned-import - - agg_func = partial( # type: ignore[misc,operator] - pyspark_function, # type: ignore[arg-type] - ddof=expr._kwargs.get("ddof", 1), - backend_version=expr._backend_version, - np_version=parse_version(np.__version__), - ) - else: - agg_func = get_spark_function(pyspark_function) # type: ignore[arg-type] + agg_func = get_spark_function( + pyspark_function, backend_version=expr._backend_version, **expr._kwargs + ) simple_aggregations.update( {