Skip to content

Commit

Permalink
chore: register kwargs for CompliantExpr (#1614)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Dec 21, 2024
1 parent 9c69d7d commit 5def702
Show file tree
Hide file tree
Showing 15 changed files with 444 additions and 259 deletions.
81 changes: 51 additions & 30 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
output_names: list[str] | None,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
) -> None:
self._call = call
self._depth = depth
Expand All @@ -49,6 +50,7 @@ def __init__(
self._implementation = Implementation.PYARROW
self._backend_version = backend_version
self._version = version
self._kwargs = kwargs

def __repr__(self: Self) -> str: # pragma: no cover
return (
Expand Down Expand Up @@ -97,6 +99,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=list(column_names),
backend_version=backend_version,
version=version,
kwargs={},
)

@classmethod
Expand Down Expand Up @@ -127,6 +130,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=None,
backend_version=backend_version,
version=version,
kwargs={},
)

def __narwhals_namespace__(self: Self) -> ArrowNamespace:
Expand Down Expand Up @@ -171,49 +175,49 @@ def __ror__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return other.__or__(self) # type: ignore[return-value]

def __add__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__add__", other)
return reuse_series_implementation(self, "__add__", other=other)

def __radd__(self: Self, other: ArrowExpr | Any) -> Self:
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__add__(self) # type: ignore[return-value]

def __sub__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__sub__", other)
return reuse_series_implementation(self, "__sub__", other=other)

def __rsub__(self: Self, other: ArrowExpr | Any) -> Self:
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__sub__(self) # type: ignore[return-value]

def __mul__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__mul__", other)
return reuse_series_implementation(self, "__mul__", other=other)

def __rmul__(self: Self, other: ArrowExpr | Any) -> Self:
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__mul__(self) # type: ignore[return-value]

def __pow__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__pow__", other)
return reuse_series_implementation(self, "__pow__", other=other)

def __rpow__(self: Self, other: ArrowExpr | Any) -> Self:
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__pow__(self) # type: ignore[return-value]

def __floordiv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__floordiv__", other)
return reuse_series_implementation(self, "__floordiv__", other=other)

def __rfloordiv__(self: Self, other: ArrowExpr | Any) -> Self:
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__floordiv__(self) # type: ignore[return-value]

def __truediv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__truediv__", other)
return reuse_series_implementation(self, "__truediv__", other=other)

def __rtruediv__(self: Self, other: ArrowExpr | Any) -> Self:
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__truediv__(self) # type: ignore[return-value]

def __mod__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__mod__", other)
return reuse_series_implementation(self, "__mod__", other=other)

def __rmod__(self: Self, other: ArrowExpr | Any) -> Self:
other = self.__narwhals_namespace__().lit(other, dtype=None)
Expand Down Expand Up @@ -252,7 +256,7 @@ def skew(self: Self) -> Self:
return reuse_series_implementation(self, "skew", returns_scalar=True)

def cast(self: Self, dtype: DType) -> Self:
return reuse_series_implementation(self, "cast", dtype)
return reuse_series_implementation(self, "cast", dtype=dtype)

def abs(self: Self) -> Self:
return reuse_series_implementation(self, "abs")
Expand All @@ -264,7 +268,7 @@ def cum_sum(self: Self, *, reverse: bool) -> Self:
return reuse_series_implementation(self, "cum_sum", reverse=reverse)

def round(self: Self, decimals: int) -> Self:
return reuse_series_implementation(self, "round", decimals)
return reuse_series_implementation(self, "round", decimals=decimals)

def any(self: Self) -> Self:
return reuse_series_implementation(self, "any", returns_scalar=True)
Expand All @@ -291,7 +295,7 @@ def drop_nulls(self: Self) -> Self:
return reuse_series_implementation(self, "drop_nulls")

def shift(self: Self, n: int) -> Self:
return reuse_series_implementation(self, "shift", n)
return reuse_series_implementation(self, "shift", n=n)

def alias(self: Self, name: str) -> Self:
# Define this one manually, so that we can
Expand All @@ -304,6 +308,7 @@ def alias(self: Self, name: str) -> Self:
output_names=[name],
backend_version=self._backend_version,
version=self._version,
kwargs={"name": name},
)

def null_count(self: Self) -> Self:
Expand All @@ -314,17 +319,21 @@ def is_null(self: Self) -> Self:

def is_between(self: Self, lower_bound: Any, upper_bound: Any, closed: str) -> Self:
return reuse_series_implementation(
self, "is_between", lower_bound, upper_bound, closed
self,
"is_between",
lower_bound=lower_bound,
upper_bound=upper_bound,
closed=closed,
)

def head(self: Self, n: int) -> Self:
return reuse_series_implementation(self, "head", n)
return reuse_series_implementation(self, "head", n=n)

def tail(self: Self, n: int) -> Self:
return reuse_series_implementation(self, "tail", n)
return reuse_series_implementation(self, "tail", n=n)

def is_in(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "is_in", other)
return reuse_series_implementation(self, "is_in", other=other)

def arg_true(self: Self) -> Self:
return reuse_series_implementation(self, "arg_true")
Expand Down Expand Up @@ -375,7 +384,7 @@ def replace_strict(
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
return reuse_series_implementation(
self, "replace_strict", old, new, return_dtype=return_dtype
self, "replace_strict", old=old, new=new, return_dtype=return_dtype
)

def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self:
Expand All @@ -389,7 +398,11 @@ def quantile(
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
) -> Self:
return reuse_series_implementation(
self, "quantile", quantile, interpolation, returns_scalar=True
self,
"quantile",
returns_scalar=True,
quantile=quantile,
interpolation=interpolation,
)

def gather_every(self: Self, n: int, offset: int = 0) -> Self:
Expand Down Expand Up @@ -423,6 +436,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=self._output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={"keys": keys},
)

def mode(self: Self) -> Self:
Expand Down Expand Up @@ -464,6 +478,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=self._output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={"function": function, "return_dtype": return_dtype},
)

def is_finite(self: Self) -> Self:
Expand Down Expand Up @@ -584,22 +599,22 @@ def __init__(self: Self, expr: ArrowExpr) -> None:

def to_string(self: Self, format: str) -> ArrowExpr: # noqa: A002
return reuse_series_namespace_implementation(
self._compliant_expr, "dt", "to_string", format
self._compliant_expr, "dt", "to_string", format=format
)

def replace_time_zone(self: Self, time_zone: str | None) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr, "dt", "replace_time_zone", time_zone
self._compliant_expr, "dt", "replace_time_zone", time_zone=time_zone
)

def convert_time_zone(self: Self, time_zone: str) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr, "dt", "convert_time_zone", time_zone
self._compliant_expr, "dt", "convert_time_zone", time_zone=time_zone
)

def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr, "dt", "timestamp", time_unit
self._compliant_expr, "dt", "timestamp", time_unit=time_unit
)

def date(self: Self) -> ArrowExpr:
Expand Down Expand Up @@ -690,8 +705,8 @@ def replace(
self._compliant_expr,
"str",
"replace",
pattern,
value,
pattern=pattern,
value=value,
literal=literal,
n=n,
)
Expand All @@ -707,8 +722,8 @@ def replace_all(
self._compliant_expr,
"str",
"replace_all",
pattern,
value,
pattern=pattern,
value=value,
literal=literal,
)

Expand All @@ -717,41 +732,41 @@ def strip_chars(self: Self, characters: str | None) -> ArrowExpr:
self._compliant_expr,
"str",
"strip_chars",
characters,
characters=characters,
)

def starts_with(self: Self, prefix: str) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"starts_with",
prefix,
prefix=prefix,
)

def ends_with(self: Self, suffix: str) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"ends_with",
suffix,
suffix=suffix,
)

def contains(self, pattern: str, *, literal: bool) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr, "str", "contains", pattern, literal=literal
self._compliant_expr, "str", "contains", pattern=pattern, literal=literal
)

def slice(self: Self, offset: int, length: int | None) -> ArrowExpr:
return reuse_series_namespace_implementation(
self._compliant_expr, "str", "slice", offset, length
self._compliant_expr, "str", "slice", offset=offset, length=length
)

def to_datetime(self: Self, format: str | None) -> ArrowExpr: # noqa: A002
return reuse_series_namespace_implementation(
self._compliant_expr,
"str",
"to_datetime",
format,
format=format,
)

def to_uppercase(self: Self) -> ArrowExpr:
Expand Down Expand Up @@ -795,6 +810,7 @@ def keep(self: Self) -> ArrowExpr:
output_names=root_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
)

def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
Expand All @@ -821,6 +837,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"function": function},
)

def prefix(self: Self, prefix: str) -> ArrowExpr:
Expand All @@ -845,6 +862,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"prefix": prefix},
)

def suffix(self: Self, suffix: str) -> ArrowExpr:
Expand All @@ -870,6 +888,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"suffix": suffix},
)

def to_lowercase(self: Self) -> ArrowExpr:
Expand All @@ -895,6 +914,7 @@ def to_lowercase(self: Self) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
)

def to_uppercase(self: Self) -> ArrowExpr:
Expand All @@ -920,6 +940,7 @@ def to_uppercase(self: Self) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
)


Expand Down
Loading

0 comments on commit 5def702

Please sign in to comment.