Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SparkLikeStrNamespace methods #1781

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,128 @@ def skew(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.skewness, "skew", returns_scalar=True)

@property
def str(self: Self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)


class SparkLikeExprStringNamespace:
def __init__(self: Self, expr: SparkLikeExpr) -> None:
self._compliant_expr = expr

def len_chars(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.char_length,
"len",
returns_scalar=self._compliant_expr._returns_scalar,
)

def replace_all(
self: Self, pattern: str, value: str, *, literal: bool = False
) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, pattern: str, value: str, *, literal: bool) -> Column:
replace_all_func = F.replace if literal else F.regexp_replace
return replace_all_func(_input, F.lit(pattern), F.lit(value))

return self._compliant_expr._from_call(
func,
"replace",
pattern=pattern,
value=value,
literal=literal,
returns_scalar=self._compliant_expr._returns_scalar,
)

def strip_chars(self: Self, characters: str | None) -> SparkLikeExpr:
import string

from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, characters: str | None) -> Column:
to_remove = characters if characters is not None else string.whitespace
return F.btrim(_input, F.lit(to_remove))

return self._compliant_expr._from_call(
func,
"strip",
characters=characters,
returns_scalar=self._compliant_expr._returns_scalar,
)

def starts_with(self: Self, prefix: str) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
lambda _input, prefix: F.startswith(_input, F.lit(prefix)),
"starts_with",
prefix=prefix,
returns_scalar=self._compliant_expr._returns_scalar,
)

def ends_with(self: Self, suffix: str) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
lambda _input, suffix: F.endswith(_input, F.lit(suffix)),
"ends_with",
suffix=suffix,
returns_scalar=self._compliant_expr._returns_scalar,
)

def contains(self: Self, pattern: str, *, literal: bool) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, pattern: str, *, literal: bool) -> Column:
contains_func = F.contains if literal else F.regexp
return contains_func(_input, F.lit(pattern))

return self._compliant_expr._from_call(
func,
"contains",
pattern=pattern,
literal=literal,
returns_scalar=self._compliant_expr._returns_scalar,
)

def slice(self: Self, offset: int, length: int | None = None) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

# From the docs: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.substring.html
# The position is not zero based, but 1 based index.
def func(_input: Column, offset: int, length: int | None) -> Column:
col_length = F.char_length(_input)

_offset = col_length + F.lit(offset + 1) if offset < 0 else F.lit(offset + 1)
_length = F.lit(length) if length is not None else col_length
return _input.substr(_offset, _length)

return self._compliant_expr._from_call(
func,
"slice",
offset=offset,
length=length,
returns_scalar=self._compliant_expr._returns_scalar,
)

def to_uppercase(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.upper,
"to_uppercase",
returns_scalar=self._compliant_expr._returns_scalar,
)

def to_lowercase(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.lower,
"to_lowercase",
returns_scalar=self._compliant_expr._returns_scalar,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ filterwarnings = [
'ignore: unclosed <socket.socket',
'ignore:.*The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:pyspark',
'ignore:.*distutils Version classes are deprecated. Use packaging.version instead.*:DeprecationWarning:pyspark',
'ignore:.*is_datetime64tz_dtype is deprecated and will be removed in a future version.*:DeprecationWarning:pyspark',

]
xfail_strict = true
Expand Down
21 changes: 7 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,13 @@ def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cove
register(session.stop)

def _constructor(obj: Any) -> IntoFrame:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
r".*is_datetime64tz_dtype is deprecated and will be removed in a future version.*",
module="pyspark",
category=DeprecationWarning,
)
pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index()
return ( # type: ignore[no-any-return]
session.createDataFrame(pd_df)
.repartition(2)
.orderBy("index")
.drop("index")
)
pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index()
return ( # type: ignore[no-any-return]
session.createDataFrame(pd_df)
.repartition(2)
.orderBy("index")
.drop("index")
)

return _constructor

Expand Down
16 changes: 3 additions & 13 deletions tests/expr_and_series/str/contains_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def test_contains_case_insensitive(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "cudf" in str(constructor) or "pyspark" in str(constructor):
if "cudf" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -40,12 +40,7 @@ def test_contains_series_case_insensitive(
assert_equal_data(result, expected)


def test_contains_case_sensitive(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_contains_case_sensitive(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("pets").str.contains("parrot|Dove").alias("default_match"))
expected = {
Expand All @@ -63,12 +58,7 @@ def test_contains_series_case_sensitive(constructor_eager: ConstructorEager) ->
assert_equal_data(result, expected)


def test_contains_literal(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_contains_literal(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.col("pets").str.contains("Parrot|dove").alias("default_match"),
Expand Down
7 changes: 1 addition & 6 deletions tests/expr_and_series/str/head_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -10,10 +8,7 @@
data = {"a": ["foo", "bars"]}


def test_str_head(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_str_head(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").str.head(3))
expected = {
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/str/len_chars_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_str_len_chars(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").str.len_chars())
Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/str/replace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def test_str_replace_all_expr(
literal: bool, # noqa: FBT001
expected: dict[str, list[str]],
) -> None:
if ("pyspark" in str(constructor)) or (
"duckdb" in str(constructor) and literal is False
):
if "duckdb" in str(constructor) and literal is False:
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(
Expand Down
4 changes: 0 additions & 4 deletions tests/expr_and_series/str/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@
[(1, 2, {"a": ["da", "df"]}), (-2, None, {"a": ["as", "as"]})],
)
def test_str_slice(
request: pytest.FixtureRequest,
constructor: Constructor,
offset: int,
length: int | None,
expected: Any,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result_frame = df.select(nw.col("a").str.slice(offset, length))
assert_equal_data(result_frame, expected)
Expand Down
12 changes: 2 additions & 10 deletions tests/expr_and_series/str/starts_with_ends_with_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -13,10 +11,7 @@
data = {"a": ["fdas", "edfas"]}


def test_ends_with(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_ends_with(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").str.ends_with("das"))
expected = {
Expand All @@ -34,10 +29,7 @@ def test_ends_with_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_starts_with(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_starts_with(constructor: Constructor) -> None:
df = nw.from_native(constructor(data)).lazy()
result = df.select(nw.col("a").str.starts_with("fda"))
expected = {
Expand Down
3 changes: 0 additions & 3 deletions tests/expr_and_series/str/strip_chars_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
],
)
def test_str_strip_chars(
request: pytest.FixtureRequest,
constructor: Constructor,
characters: str | None,
expected: Any,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result_frame = df.select(nw.col("a").str.strip_chars(characters))
assert_equal_data(result_frame, expected)
Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/str/tail_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -10,9 +8,7 @@
data = {"a": ["foo", "bars"]}


def test_str_tail(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_str_tail(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
expected = {"a": ["foo", "ars"]}

Expand Down
6 changes: 0 additions & 6 deletions tests/expr_and_series/str/to_uppercase_to_lowercase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def test_str_to_uppercase(
expected: dict[str, list[str]],
request: pytest.FixtureRequest,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

if any("ß" in s for value in data.values() for s in value) & (
constructor.__name__
in (
Expand Down Expand Up @@ -113,13 +110,10 @@ def test_str_to_uppercase_series(
],
)
def test_str_to_lowercase(
request: pytest.FixtureRequest,
constructor: Constructor,
data: dict[str, list[str]],
expected: dict[str, list[str]],
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result_frame = df.select(nw.col("a").str.to_lowercase())
assert_equal_data(result_frame, expected)
Expand Down
4 changes: 3 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _sort_dict_by_key(
data_dict: dict[str, list[Any]], key: str
) -> dict[str, list[Any]]: # pragma: no cover
sort_list = data_dict[key]
sorted_indices = sorted(range(len(sort_list)), key=lambda i: sort_list[i])
sorted_indices = sorted(
range(len(sort_list)), key=lambda i: (sort_list[i] is None, sort_list[i])
)
return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()}


Expand Down
Loading