Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking β€œSign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SparkLikeNamespace methods #1779

Merged
merged 7 commits into from
Jan 10, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: spark like namespace functions
FBruzzesi committed Jan 9, 2025
commit 4eaa78df06d34d4437a93ab03a72b79a4b38de4c
9 changes: 9 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
@@ -462,3 +462,12 @@ def skew(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.skewness, "skew", returns_scalar=True)

def n_unique(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highly inspired by duckdb implementation πŸ˜‚πŸ˜‰


return self._from_call(_n_unique, "n_unique", returns_scalar=self._returns_scalar)
6 changes: 1 addition & 5 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
@@ -128,11 +128,7 @@ def agg_pyspark(
if expr._output_names is None: # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
agg_func = get_spark_function(function_name, **expr._kwargs)
agg_func = get_spark_function(expr._function_name, **expr._kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not manage to make nw.len() work in the group_by context

simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
156 changes: 135 additions & 21 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
@@ -3,18 +3,21 @@
import operator
from functools import reduce
from typing import TYPE_CHECKING
from typing import Iterable
from typing import Literal

from narwhals._expression_parsing import combine_root_names
from narwhals._expression_parsing import parse_into_exprs
from narwhals._expression_parsing import reduce_output_names
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.utils import get_column_name
from narwhals.typing import CompliantNamespace

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import DataFrame

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.dtypes import DType
from narwhals.utils import Version
@@ -43,26 +46,6 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def col(self, *column_names: str) -> SparkLikeExpr:
return SparkLikeExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
@@ -90,6 +73,64 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

def len(self) -> SparkLikeExpr:
def func(_: SparkLikeLazyFrame) -> list[Column]:
import pyspark.sql.functions as F # noqa: N812

return [F.count("*").alias("len")]

return SparkLikeExpr( # type: ignore[abstract]
func,
depth=0,
function_name="len",
root_names=None,
output_names=["len"],
returns_scalar=True,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def any_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.or_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="any_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def sum_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

@@ -116,3 +157,76 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
version=self._version,
kwargs={"exprs": exprs},
)

def mean_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import IntegerType

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [
(
reduce(operator.add, (F.coalesce(col, F.lit(0)) for col in cols))
/ reduce(
operator.add,
(col.isNotNull().cast(IntegerType()) for col in cols),
)
).alias(col_name)
]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="mean_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def concat(
self,
items: Iterable[SparkLikeLazyFrame],
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> SparkLikeLazyFrame:
dfs: list[DataFrame] = [item._native_frame for item in items]
if how == "horizontal":
msg = (
"Horizontal concatenation is not supported for LazyFrame backed by "
"a PySpark DataFrame"
)
raise NotImplementedError(msg)

if how == "vertical":
cols_0 = dfs[0].columns
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)

return SparkLikeLazyFrame(
native_dataframe=reduce(lambda x, y: x.union(y), dfs),
backend_version=self._backend_version,
version=self._version,
)

if how == "diagonal":
return SparkLikeLazyFrame(
native_dataframe=reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs
),
backend_version=self._backend_version,
version=self._version,
)
raise NotImplementedError
10 changes: 2 additions & 8 deletions tests/expr_and_series/any_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -11,11 +11,7 @@

@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_anyh(
request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
@@ -27,9 +23,7 @@ def test_anyh(
assert_equal_data(result, expected)


def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_anyh_all(constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
5 changes: 1 addition & 4 deletions tests/expr_and_series/len_test.py
Original file line number Diff line number Diff line change
@@ -34,10 +34,7 @@ def test_len_chaining(
assert_equal_data(df, expected)


def test_namespace_len(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_namespace_len(constructor: Constructor) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).select(
nw.len(), a=nw.len()
)
2 changes: 1 addition & 1 deletion tests/expr_and_series/mean_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ def test_meanh(


def test_meanh_all(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"a": [2, 4, 6], "b": [10, 20, 30]}
df = nw.from_native(constructor(data))
6 changes: 1 addition & 5 deletions tests/expr_and_series/n_unique_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
@@ -13,9 +11,7 @@
}


def test_n_unique(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_n_unique(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.all().n_unique())
expected = {"a": [3], "b": [4]}
5 changes: 3 additions & 2 deletions tests/expr_and_series/sort_test.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

@@ -21,9 +22,9 @@
],
)
def test_sort_expr(
constructor_eager: ConstructorEager, descending: Any, nulls_last: Any, expected: Any
constructor: Constructor, descending: Any, nulls_last: Any, expected: Any
) -> None:
df = nw.from_native(constructor_eager(data), eager_only=True)
df = nw.from_native(constructor(data))
result = df.select(
"a",
nw.col("b").sort(descending=descending, nulls_last=nulls_last),
9 changes: 2 additions & 7 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
@@ -32,12 +32,7 @@ def test_concat_horizontal(
nw.concat([])


def test_concat_vertical(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_concat_vertical(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_left = (
nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z")
@@ -68,7 +63,7 @@ def test_concat_vertical(
def test_concat_diagonal(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data_1 = {"a": [1, 3], "b": [4, 6]}
data_2 = {"a": [100, 200], "z": ["x", "y"]}
5 changes: 1 addition & 4 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
@@ -276,9 +276,6 @@ def test_key_with_nulls(
# TODO(unassigned): Modin flaky here?
request.applymarker(pytest.mark.skip)

if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

context = (
pytest.raises(NotImplementedError, match="null values")
if ("pandas_constructor" in str(constructor) and PANDAS_VERSION < (1, 1, 0))
@@ -300,7 +297,7 @@ def test_key_with_nulls(
def test_key_with_nulls_ignored(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"b": [4, 5, None], "a": [1, 2, 3]}
result = (