Skip to content

Commit

Permalink
Merge branch 'main' into anti-duck
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jan 10, 2025
2 parents 4a3664a + 8229282 commit f769605
Show file tree
Hide file tree
Showing 39 changed files with 5,479 additions and 5,093 deletions.
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def concat_str(
self,
exprs: Iterable[IntoDaskExpr],
*more_exprs: IntoDaskExpr,
separator: str = "",
ignore_nulls: bool = False,
separator: str,
ignore_nulls: bool,
) -> DaskExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def concat_str(
self,
exprs: Iterable[IntoPandasLikeExpr],
*more_exprs: IntoPandasLikeExpr,
separator: str = "",
ignore_nulls: bool = False,
separator: str,
ignore_nulls: bool,
) -> PandasLikeExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
Expand Down
139 changes: 139 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,142 @@ def skew(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.skewness, "skew", returns_scalar=True)

def n_unique(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))

return self._from_call(_n_unique, "n_unique", returns_scalar=True)

def is_null(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812

return self._from_call(F.isnull, "is_null", returns_scalar=self._returns_scalar)

@property
def str(self: Self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)


class SparkLikeExprStringNamespace:
def __init__(self: Self, expr: SparkLikeExpr) -> None:
self._compliant_expr = expr

def len_chars(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.char_length,
"len",
returns_scalar=self._compliant_expr._returns_scalar,
)

def replace_all(
self: Self, pattern: str, value: str, *, literal: bool = False
) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, pattern: str, value: str, *, literal: bool) -> Column:
replace_all_func = F.replace if literal else F.regexp_replace
return replace_all_func(_input, F.lit(pattern), F.lit(value))

return self._compliant_expr._from_call(
func,
"replace",
pattern=pattern,
value=value,
literal=literal,
returns_scalar=self._compliant_expr._returns_scalar,
)

def strip_chars(self: Self, characters: str | None) -> SparkLikeExpr:
import string

from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, characters: str | None) -> Column:
to_remove = characters if characters is not None else string.whitespace
return F.btrim(_input, F.lit(to_remove))

return self._compliant_expr._from_call(
func,
"strip",
characters=characters,
returns_scalar=self._compliant_expr._returns_scalar,
)

def starts_with(self: Self, prefix: str) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
lambda _input, prefix: F.startswith(_input, F.lit(prefix)),
"starts_with",
prefix=prefix,
returns_scalar=self._compliant_expr._returns_scalar,
)

def ends_with(self: Self, suffix: str) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
lambda _input, suffix: F.endswith(_input, F.lit(suffix)),
"ends_with",
suffix=suffix,
returns_scalar=self._compliant_expr._returns_scalar,
)

def contains(self: Self, pattern: str, *, literal: bool) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

def func(_input: Column, pattern: str, *, literal: bool) -> Column:
contains_func = F.contains if literal else F.regexp
return contains_func(_input, F.lit(pattern))

return self._compliant_expr._from_call(
func,
"contains",
pattern=pattern,
literal=literal,
returns_scalar=self._compliant_expr._returns_scalar,
)

def slice(self: Self, offset: int, length: int | None = None) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

# From the docs: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.substring.html
# The position is not zero based, but 1 based index.
def func(_input: Column, offset: int, length: int | None) -> Column:
col_length = F.char_length(_input)

_offset = col_length + F.lit(offset + 1) if offset < 0 else F.lit(offset + 1)
_length = F.lit(length) if length is not None else col_length
return _input.substr(_offset, _length)

return self._compliant_expr._from_call(
func,
"slice",
offset=offset,
length=length,
returns_scalar=self._compliant_expr._returns_scalar,
)

def to_uppercase(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.upper,
"to_uppercase",
returns_scalar=self._compliant_expr._returns_scalar,
)

def to_lowercase(self: Self) -> SparkLikeExpr:
from pyspark.sql import functions as F # noqa: N812

return self._compliant_expr._from_call(
F.lower,
"to_lowercase",
returns_scalar=self._compliant_expr._returns_scalar,
)
6 changes: 1 addition & 5 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,7 @@ def agg_pyspark(
if expr._output_names is None: # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
agg_func = get_spark_function(function_name, **expr._kwargs)
agg_func = get_spark_function(expr._function_name, **expr._kwargs)
simple_aggregations.update(
{output_name: agg_func(keys[0]) for output_name in expr._output_names}
)
Expand Down
Loading

0 comments on commit f769605

Please sign in to comment.