Skip to content

Commit

Permalink
feat: add .over method for SparkLikeExpr (#1808)
Browse files Browse the repository at this point in the history
feat: add over for SparkLike
  • Loading branch information
FBruzzesi authored Jan 19, 2025
1 parent 0ec3c90 commit c39bd45
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
26 changes: 26 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,32 @@ def _n_unique(_input: Column) -> Column:

return self._from_call(_n_unique, "n_unique", returns_scalar=True)

def over(self: Self, keys: list[str]) -> Self:
if self._output_names is None:
msg = (
"Anonymous expressions are not supported in over.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)

def func(df: SparkLikeLazyFrame) -> list[Column]:
from pyspark.sql import Window

return [expr.over(Window.partitionBy(*keys)) for expr in self._call(df)]

return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
root_names=self._root_names,
output_names=self._output_names,
backend_version=self._backend_version,
version=self._version,
returns_scalar=False,
kwargs={**self._kwargs, "keys": keys},
)

def is_null(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812

Expand Down
14 changes: 7 additions & 7 deletions tests/expr_and_series/over_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "dask_lazy_p2" in str(constructor):
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -43,25 +43,25 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -
def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "dask_lazy_p2" in str(constructor):
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
expected = {
"a": ["a", "a", "b", "b", "b"],
"b": [1, 2, 3, 5, 3],
"c": [5, 4, 3, 2, 1],
"c_min": [5, 4, 1, 2, 1],
"b": [1, 2, 3, 3, 5],
"c": [5, 4, 3, 1, 2],
"c_min": [5, 4, 1, 1, 2],
}

result = df.with_columns(c_min=nw.col("c").min().over("a", "b"))
result = df.with_columns(c_min=nw.col("c").min().over("a", "b")).sort("a", "b")
assert_equal_data(result, expected)


def test_over_invalid(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "polars" in str(constructor):
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down

0 comments on commit c39bd45

Please sign in to comment.