diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index d2263c7c9..98a1d2d5b 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -503,6 +503,32 @@ def _n_unique(_input: Column) -> Column: return self._from_call(_n_unique, "n_unique", returns_scalar=True) + def over(self: Self, keys: list[str]) -> Self: + if self._output_names is None: + msg = ( + "Anonymous expressions are not supported in over.\n" + "Instead of `nw.all()`, try using a named expression, such as " + "`nw.col('a', 'b')`\n" + ) + raise ValueError(msg) + + def func(df: SparkLikeLazyFrame) -> list[Column]: + from pyspark.sql import Window + + return [expr.over(Window.partitionBy(*keys)) for expr in self._call(df)] + + return self.__class__( + func, + depth=self._depth + 1, + function_name=self._function_name + "->over", + root_names=self._root_names, + output_names=self._output_names, + backend_version=self._backend_version, + version=self._version, + returns_scalar=False, + kwargs={**self._kwargs, "keys": keys}, + ) + def is_null(self: Self) -> Self: from pyspark.sql import functions as F # noqa: N812 diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 57ab4118f..2ccbe070d 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -25,7 +25,7 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -43,25 +43,25 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) - def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) expected = { "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, 3], - "c": [5, 4, 3, 2, 1], - "c_min": [5, 4, 1, 2, 1], + "b": [1, 2, 3, 3, 5], + "c": [5, 4, 3, 1, 2], + "c_min": [5, 4, 1, 1, 2], } - result = df.with_columns(c_min=nw.col("c").min().over("a", "b")) + result = df.with_columns(c_min=nw.col("c").min().over("a", "b")).sort("a", "b") assert_equal_data(result, expected) def test_over_invalid(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "polars" in str(constructor): request.applymarker(pytest.mark.xfail) - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data))