From 5def702f7732d4de8627a57b67774d16a6520327 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Sat, 21 Dec 2024 10:29:23 +0100 Subject: [PATCH] chore: register kwargs for `CompliantExpr` (#1614) --- narwhals/_arrow/expr.py | 81 ++++++---- narwhals/_arrow/group_by.py | 50 +++--- narwhals/_arrow/namespace.py | 16 ++ narwhals/_arrow/selectors.py | 6 + narwhals/_dask/expr.py | 245 ++++++++++++++++------------- narwhals/_dask/namespace.py | 13 ++ narwhals/_dask/selectors.py | 6 + narwhals/_expression_parsing.py | 22 +-- narwhals/_pandas_like/expr.py | 92 ++++++++--- narwhals/_pandas_like/namespace.py | 16 ++ narwhals/_pandas_like/selectors.py | 6 + narwhals/_spark_like/expr.py | 58 +++++-- narwhals/_spark_like/namespace.py | 2 + narwhals/typing.py | 1 + tests/expr_and_series/over_test.py | 89 +++++------ 15 files changed, 444 insertions(+), 259 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index b960ffa4e..6f1d627f5 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -39,6 +39,7 @@ def __init__( output_names: list[str] | None, backend_version: tuple[int, ...], version: Version, + kwargs: dict[str, Any], ) -> None: self._call = call self._depth = depth @@ -49,6 +50,7 @@ def __init__( self._implementation = Implementation.PYARROW self._backend_version = backend_version self._version = version + self._kwargs = kwargs def __repr__(self: Self) -> str: # pragma: no cover return ( @@ -97,6 +99,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=list(column_names), backend_version=backend_version, version=version, + kwargs={}, ) @classmethod @@ -127,6 +130,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=backend_version, version=version, + kwargs={}, ) def __narwhals_namespace__(self: Self) -> ArrowNamespace: @@ -171,49 +175,49 @@ def __ror__(self: Self, other: ArrowExpr | bool | Any) -> Self: return other.__or__(self) # type: ignore[return-value] def __add__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__add__", other) + return reuse_series_implementation(self, "__add__", other=other) def __radd__(self: Self, other: ArrowExpr | Any) -> Self: other = self.__narwhals_namespace__().lit(other, dtype=None) return other.__add__(self) # type: ignore[return-value] def __sub__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__sub__", other) + return reuse_series_implementation(self, "__sub__", other=other) def __rsub__(self: Self, other: ArrowExpr | Any) -> Self: other = self.__narwhals_namespace__().lit(other, dtype=None) return other.__sub__(self) # type: ignore[return-value] def __mul__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__mul__", other) + return reuse_series_implementation(self, "__mul__", other=other) def __rmul__(self: Self, other: ArrowExpr | Any) -> Self: other = self.__narwhals_namespace__().lit(other, dtype=None) return other.__mul__(self) # type: ignore[return-value] def __pow__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__pow__", other) + return reuse_series_implementation(self, "__pow__", other=other) def __rpow__(self: Self, other: ArrowExpr | Any) -> Self: other = self.__narwhals_namespace__().lit(other, dtype=None) return other.__pow__(self) # type: ignore[return-value] def __floordiv__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__floordiv__", other) + return reuse_series_implementation(self, "__floordiv__", other=other) def __rfloordiv__(self: Self, other: ArrowExpr | Any) -> Self: other = self.__narwhals_namespace__().lit(other, dtype=None) return other.__floordiv__(self) # type: ignore[return-value] def __truediv__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__truediv__", other) + return reuse_series_implementation(self, "__truediv__", other=other) def __rtruediv__(self: Self, other: ArrowExpr | Any) -> Self: other = self.__narwhals_namespace__().lit(other, dtype=None) return other.__truediv__(self) # type: ignore[return-value] def __mod__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__mod__", other) + return reuse_series_implementation(self, "__mod__", other=other) def __rmod__(self: Self, other: ArrowExpr | Any) -> Self: other = self.__narwhals_namespace__().lit(other, dtype=None) @@ -252,7 +256,7 @@ def skew(self: Self) -> Self: return reuse_series_implementation(self, "skew", returns_scalar=True) def cast(self: Self, dtype: DType) -> Self: - return reuse_series_implementation(self, "cast", dtype) + return reuse_series_implementation(self, "cast", dtype=dtype) def abs(self: Self) -> Self: return reuse_series_implementation(self, "abs") @@ -264,7 +268,7 @@ def cum_sum(self: Self, *, reverse: bool) -> Self: return reuse_series_implementation(self, "cum_sum", reverse=reverse) def round(self: Self, decimals: int) -> Self: - return reuse_series_implementation(self, "round", decimals) + return reuse_series_implementation(self, "round", decimals=decimals) def any(self: Self) -> Self: return reuse_series_implementation(self, "any", returns_scalar=True) @@ -291,7 +295,7 @@ def drop_nulls(self: Self) -> Self: return reuse_series_implementation(self, "drop_nulls") def shift(self: Self, n: int) -> Self: - return reuse_series_implementation(self, "shift", n) + return reuse_series_implementation(self, "shift", n=n) def alias(self: Self, name: str) -> Self: # Define this one manually, so that we can @@ -304,6 +308,7 @@ def alias(self: Self, name: str) -> Self: output_names=[name], backend_version=self._backend_version, version=self._version, + kwargs={"name": name}, ) def null_count(self: Self) -> Self: @@ -314,17 +319,21 @@ def is_null(self: Self) -> Self: def is_between(self: Self, lower_bound: Any, upper_bound: Any, closed: str) -> Self: return reuse_series_implementation( - self, "is_between", lower_bound, upper_bound, closed + self, + "is_between", + lower_bound=lower_bound, + upper_bound=upper_bound, + closed=closed, ) def head(self: Self, n: int) -> Self: - return reuse_series_implementation(self, "head", n) + return reuse_series_implementation(self, "head", n=n) def tail(self: Self, n: int) -> Self: - return reuse_series_implementation(self, "tail", n) + return reuse_series_implementation(self, "tail", n=n) def is_in(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "is_in", other) + return reuse_series_implementation(self, "is_in", other=other) def arg_true(self: Self) -> Self: return reuse_series_implementation(self, "arg_true") @@ -375,7 +384,7 @@ def replace_strict( self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: return reuse_series_implementation( - self, "replace_strict", old, new, return_dtype=return_dtype + self, "replace_strict", old=old, new=new, return_dtype=return_dtype ) def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: @@ -389,7 +398,11 @@ def quantile( interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Self: return reuse_series_implementation( - self, "quantile", quantile, interpolation, returns_scalar=True + self, + "quantile", + returns_scalar=True, + quantile=quantile, + interpolation=interpolation, ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -423,6 +436,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=self._output_names, backend_version=self._backend_version, version=self._version, + kwargs={"keys": keys}, ) def mode(self: Self) -> Self: @@ -464,6 +478,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=self._output_names, backend_version=self._backend_version, version=self._version, + kwargs={"function": function, "return_dtype": return_dtype}, ) def is_finite(self: Self) -> Self: @@ -584,22 +599,22 @@ def __init__(self: Self, expr: ArrowExpr) -> None: def to_string(self: Self, format: str) -> ArrowExpr: # noqa: A002 return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "to_string", format + self._compliant_expr, "dt", "to_string", format=format ) def replace_time_zone(self: Self, time_zone: str | None) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "replace_time_zone", time_zone + self._compliant_expr, "dt", "replace_time_zone", time_zone=time_zone ) def convert_time_zone(self: Self, time_zone: str) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "convert_time_zone", time_zone + self._compliant_expr, "dt", "convert_time_zone", time_zone=time_zone ) def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "timestamp", time_unit + self._compliant_expr, "dt", "timestamp", time_unit=time_unit ) def date(self: Self) -> ArrowExpr: @@ -690,8 +705,8 @@ def replace( self._compliant_expr, "str", "replace", - pattern, - value, + pattern=pattern, + value=value, literal=literal, n=n, ) @@ -707,8 +722,8 @@ def replace_all( self._compliant_expr, "str", "replace_all", - pattern, - value, + pattern=pattern, + value=value, literal=literal, ) @@ -717,7 +732,7 @@ def strip_chars(self: Self, characters: str | None) -> ArrowExpr: self._compliant_expr, "str", "strip_chars", - characters, + characters=characters, ) def starts_with(self: Self, prefix: str) -> ArrowExpr: @@ -725,7 +740,7 @@ def starts_with(self: Self, prefix: str) -> ArrowExpr: self._compliant_expr, "str", "starts_with", - prefix, + prefix=prefix, ) def ends_with(self: Self, suffix: str) -> ArrowExpr: @@ -733,17 +748,17 @@ def ends_with(self: Self, suffix: str) -> ArrowExpr: self._compliant_expr, "str", "ends_with", - suffix, + suffix=suffix, ) def contains(self, pattern: str, *, literal: bool) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "contains", pattern, literal=literal + self._compliant_expr, "str", "contains", pattern=pattern, literal=literal ) def slice(self: Self, offset: int, length: int | None) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "slice", offset, length + self._compliant_expr, "str", "slice", offset=offset, length=length ) def to_datetime(self: Self, format: str | None) -> ArrowExpr: # noqa: A002 @@ -751,7 +766,7 @@ def to_datetime(self: Self, format: str | None) -> ArrowExpr: # noqa: A002 self._compliant_expr, "str", "to_datetime", - format, + format=format, ) def to_uppercase(self: Self) -> ArrowExpr: @@ -795,6 +810,7 @@ def keep(self: Self) -> ArrowExpr: output_names=root_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: @@ -821,6 +837,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"function": function}, ) def prefix(self: Self, prefix: str) -> ArrowExpr: @@ -845,6 +862,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"prefix": prefix}, ) def suffix(self: Self, suffix: str) -> ArrowExpr: @@ -870,6 +888,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"suffix": suffix}, ) def to_lowercase(self: Self) -> ArrowExpr: @@ -895,6 +914,7 @@ def to_lowercase(self: Self) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) def to_uppercase(self: Self) -> ArrowExpr: @@ -920,6 +940,7 @@ def to_uppercase(self: Self) -> ArrowExpr: output_names=output_names, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 93835463a..66008eea9 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -22,28 +22,18 @@ from narwhals._arrow.typing import IntoArrowExpr from narwhals.typing import CompliantExpr - -def polars_to_arrow_aggregations() -> ( - dict[str, tuple[str, pc.VarianceOptions | pc.CountOptions | None]] -): - """Map polars compute functions to their pyarrow counterparts and options that help match polars behaviour.""" - import pyarrow.compute as pc - - return { - "sum": ("sum", None), - "mean": ("mean", None), - "median": ("approximate_median", None), - "max": ("max", None), - "min": ("min", None), - "std": ("stddev", pc.VarianceOptions(ddof=1)), - "var": ( - "variance", - pc.VarianceOptions(ddof=1), - ), # currently unused, we don't have `var` yet - "len": ("count", pc.CountOptions(mode="all")), - "n_unique": ("count_distinct", pc.CountOptions(mode="all")), - "count": ("count", pc.CountOptions(mode="only_valid")), - } +POLARS_TO_ARROW_AGGREGATIONS = { + "sum": "sum", + "mean": "mean", + "median": "approximate_median", + "max": "max", + "min": "min", + "std": "stddev", + "var": "variance", + "len": "count", + "n_unique": "count_distinct", + "count": "count", +} class ArrowGroupBy: @@ -132,7 +122,7 @@ def agg_arrow( if not ( is_simple_aggregation(expr) and remove_prefix(expr._function_name, "col->") - in polars_to_arrow_aggregations() + in POLARS_TO_ARROW_AGGREGATIONS ): all_simple_aggs = False break @@ -177,9 +167,17 @@ def agg_arrow( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") - function_name, option = polars_to_arrow_aggregations().get( - function_name, (function_name, None) - ) + + if function_name in {"std", "var"}: + option = pc.VarianceOptions(ddof=expr._kwargs.get("ddof", 1)) + elif function_name in {"len", "n_unique"}: + option = pc.CountOptions(mode="all") + elif function_name == "count": + option = pc.CountOptions(mode="only_valid") + else: + option = None + + function_name = POLARS_TO_ARROW_AGGREGATIONS[function_name] new_column_names.extend(expr._output_names) expected_pyarrow_column_names.extend( diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index c4da2e824..7dc4db577 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -41,6 +41,7 @@ def _create_expr_from_callable( function_name: str, root_names: list[str] | None, output_names: list[str] | None, + kwargs: dict[str, Any], ) -> ArrowExpr: from narwhals._arrow.expr import ArrowExpr @@ -52,6 +53,7 @@ def _create_expr_from_callable( output_names=output_names, backend_version=self._backend_version, version=self._version, + kwargs=kwargs, ) def _create_expr_from_series(self: Self, series: ArrowSeries) -> ArrowExpr: @@ -65,6 +67,7 @@ def _create_expr_from_series(self: Self, series: ArrowSeries) -> ArrowExpr: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def _create_series_from_scalar( @@ -133,6 +136,7 @@ def len(self: Self) -> ArrowExpr: output_names=["len"], backend_version=self._backend_version, version=self._version, + kwargs={}, ) def all(self: Self) -> ArrowExpr: @@ -155,6 +159,7 @@ def all(self: Self) -> ArrowExpr: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr: @@ -177,6 +182,7 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: output_names=["literal"], backend_version=self._backend_version, version=self._version, + kwargs={}, ) def all_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: @@ -192,6 +198,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="all_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def any_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: @@ -207,6 +214,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="any_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def sum_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: @@ -226,6 +234,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="sum_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def mean_horizontal(self: Self, *exprs: IntoArrowExpr) -> IntoArrowExpr: @@ -253,6 +262,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="mean_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def min_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: @@ -281,6 +291,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="min_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def max_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: @@ -309,6 +320,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="max_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def concat( @@ -420,6 +432,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="concat_str", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"separator": separator, "ignore_nulls": ignore_nulls}, ) @@ -506,6 +519,7 @@ def then(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={"value": value}, ) @@ -520,6 +534,7 @@ def __init__( output_names: list[str] | None, backend_version: tuple[int, ...], version: Version, + kwargs: dict[str, Any], ) -> None: self._backend_version = backend_version self._version = version @@ -528,6 +543,7 @@ def __init__( self._function_name = function_name self._root_names = root_names self._output_names = output_names + self._kwargs = kwargs def otherwise(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 1d0180c4f..7750bdd03 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -38,6 +38,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={"dtypes": dtypes}, ) def numeric(self: Self) -> ArrowSelector: @@ -81,6 +82,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={}, ) @@ -103,6 +105,7 @@ def _to_expr(self: Self) -> ArrowExpr: output_names=self._output_names, backend_version=self._backend_version, version=self._version, + kwargs=self._kwargs, ) def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any: @@ -121,6 +124,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={"other": other}, ) else: return self._to_expr() - other @@ -141,6 +145,7 @@ def call(df: ArrowDataFrame) -> Sequence[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={"other": other}, ) else: return self._to_expr() | other @@ -161,6 +166,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: output_names=None, backend_version=self._backend_version, version=self._version, + kwargs={"other": other}, ) else: return self._to_expr() & other diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 0e8985791..50133ff93 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -46,6 +46,7 @@ def __init__( returns_scalar: bool, backend_version: tuple[int, ...], version: Version, + kwargs: dict[str, Any], ) -> None: self._call = call self._depth = depth @@ -55,6 +56,7 @@ def __init__( self._returns_scalar = returns_scalar self._backend_version = backend_version self._version = version + self._kwargs = kwargs def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]: return self._call(df) @@ -93,6 +95,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=backend_version, version=version, + kwargs={}, ) @classmethod @@ -116,6 +119,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=backend_version, version=version, + kwargs={}, ) def _from_call( @@ -123,17 +127,16 @@ def _from_call( # First argument to `call` should be `dask_expr.Series` call: Callable[..., dask_expr.Series], expr_name: str, - *args: Any, + *, returns_scalar: bool, **kwargs: Any, ) -> Self: def func(df: DaskLazyFrame) -> list[dask_expr.Series]: results = [] inputs = self._call(df) - _args = [maybe_evaluate(df, x) for x in args] _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} for _input in inputs: - result = call(_input, *_args, **_kwargs) + result = call(_input, **_kwargs) if returns_scalar: result = result.to_series() result = result.rename(_input.name) @@ -146,7 +149,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: # and just set it to None. root_names = copy(self._root_names) output_names = self._output_names - for arg in list(args) + list(kwargs.values()): + for arg in list(kwargs.values()): if root_names is not None and isinstance(arg, self.__class__): if arg._root_names is not None: root_names.extend(arg._root_names) @@ -174,6 +177,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=self._returns_scalar or returns_scalar, backend_version=self._backend_version, version=self._version, + kwargs=kwargs, ) def alias(self, name: str) -> Self: @@ -190,13 +194,14 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, + kwargs={"name": name}, ) def __add__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__add__(other), "__add__", - other, + other=other, returns_scalar=False, ) @@ -204,7 +209,7 @@ def __radd__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__radd__(other), "__radd__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -212,7 +217,7 @@ def __sub__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__sub__(other), "__sub__", - other, + other=other, returns_scalar=False, ) @@ -220,7 +225,7 @@ def __rsub__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__rsub__(other), "__rsub__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -228,7 +233,7 @@ def __mul__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__mul__(other), "__mul__", - other, + other=other, returns_scalar=False, ) @@ -236,7 +241,7 @@ def __rmul__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__rmul__(other), "__rmul__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -244,7 +249,7 @@ def __truediv__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__truediv__(other), "__truediv__", - other, + other=other, returns_scalar=False, ) @@ -252,7 +257,7 @@ def __rtruediv__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__rtruediv__(other), "__rtruediv__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -260,7 +265,7 @@ def __floordiv__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__floordiv__(other), "__floordiv__", - other, + other=other, returns_scalar=False, ) @@ -268,7 +273,7 @@ def __rfloordiv__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__rfloordiv__(other), "__rfloordiv__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -276,7 +281,7 @@ def __pow__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__pow__(other), "__pow__", - other, + other=other, returns_scalar=False, ) @@ -284,7 +289,7 @@ def __rpow__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__rpow__(other), "__rpow__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -292,7 +297,7 @@ def __mod__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__mod__(other), "__mod__", - other, + other=other, returns_scalar=False, ) @@ -300,7 +305,7 @@ def __rmod__(self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.__rmod__(other), "__rmod__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -308,7 +313,7 @@ def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override] return self._from_call( lambda _input, other: _input.__eq__(other), "__eq__", - other, + other=other, returns_scalar=False, ) @@ -316,7 +321,7 @@ def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override] return self._from_call( lambda _input, other: _input.__ne__(other), "__ne__", - other, + other=other, returns_scalar=False, ) @@ -324,7 +329,7 @@ def __ge__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__ge__(other), "__ge__", - other, + other=other, returns_scalar=False, ) @@ -332,7 +337,7 @@ def __gt__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__gt__(other), "__gt__", - other, + other=other, returns_scalar=False, ) @@ -340,7 +345,7 @@ def __le__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__le__(other), "__le__", - other, + other=other, returns_scalar=False, ) @@ -348,7 +353,7 @@ def __lt__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__lt__(other), "__lt__", - other, + other=other, returns_scalar=False, ) @@ -356,7 +361,7 @@ def __and__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__and__(other), "__and__", - other, + other=other, returns_scalar=False, ) @@ -364,7 +369,7 @@ def __rand__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__rand__(other), "__rand__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -372,7 +377,7 @@ def __or__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__or__(other), "__or__", - other, + other=other, returns_scalar=False, ) @@ -380,7 +385,7 @@ def __ror__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__ror__(other), "__ror__", - other, + other=other, returns_scalar=False, ).alias("literal") @@ -433,20 +438,26 @@ def max(self) -> Self: ) def std(self, ddof: int) -> Self: - return self._from_call( + expr = self._from_call( lambda _input, ddof: _input.std(ddof=ddof), "std", - ddof, + ddof=ddof, returns_scalar=True, ) + if ddof != 1: + expr._depth += 1 + return expr def var(self, ddof: int) -> Self: - return self._from_call( + expr = self._from_call( lambda _input, ddof: _input.var(ddof=ddof), "var", - ddof, + ddof=ddof, returns_scalar=True, ) + if ddof != 1: + expr._depth += 1 + return expr def skew(self: Self) -> Self: return self._from_call( @@ -459,7 +470,7 @@ def shift(self, n: int) -> Self: return self._from_call( lambda _input, n: _input.shift(n), "shift", - n, + n=n, returns_scalar=False, ) @@ -533,9 +544,9 @@ def is_between( closed, ), "is_between", - lower_bound, - upper_bound, - closed, + lower_bound=lower_bound, + upper_bound=upper_bound, + closed=closed, returns_scalar=False, ) @@ -557,7 +568,7 @@ def round(self, decimals: int) -> Self: return self._from_call( lambda _input, decimals: _input.round(decimals), "round", - decimals, + decimals=decimals, returns_scalar=False, ) @@ -649,9 +660,9 @@ def func( return self._from_call( func, "fillna", - value, - strategy, - limit, + value=value, + strategy=strategy, + limit=limit, returns_scalar=False, ) @@ -661,10 +672,12 @@ def clip( upper_bound: Any | None = None, ) -> Self: return self._from_call( - lambda _input, _lower, _upper: _input.clip(lower=_lower, upper=_upper), + lambda _input, lower_bound, upper_bound: _input.clip( + lower=lower_bound, upper=upper_bound + ), "clip", - lower_bound, - upper_bound, + lower_bound=lower_bound, + upper_bound=upper_bound, returns_scalar=False, ) @@ -703,16 +716,16 @@ def quantile( ) -> Self: if interpolation == "linear": - def func(_input: dask_expr.Series, _quantile: float) -> dask_expr.Series: + def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series: if _input.npartitions > 1: msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." raise NotImplementedError(msg) - return _input.quantile(q=_quantile, method="dask") # pragma: no cover + return _input.quantile(q=quantile, method="dask") # pragma: no cover return self._from_call( func, "quantile", - quantile, + quantile=quantile, returns_scalar=True, ) else: @@ -781,7 +794,7 @@ def is_in(self: Self, other: Any) -> Self: return self._from_call( lambda _input, other: _input.isin(other), "is_in", - other, + other=other, returns_scalar=False, ) @@ -834,6 +847,7 @@ def func(df: DaskLazyFrame) -> list[Any]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={"keys": keys}, ) def mode(self: Self) -> Self: @@ -863,7 +877,7 @@ def func(_input: Any, dtype: DType | type[DType]) -> Any: return self._from_call( func, "cast", - dtype, + dtype=dtype, returns_scalar=False, ) @@ -885,20 +899,20 @@ def rolling_sum( ) -> Self: def func( _input: dask_expr.Series, - _window: int, - _min_periods: int | None, - _center: bool, # noqa: FBT001 + window_size: int, + min_periods: int | None, + center: bool, # noqa: FBT001 ) -> dask_expr.Series: return _input.rolling( - window=_window, min_periods=_min_periods, center=_center + window=window_size, min_periods=min_periods, center=center ).sum() return self._from_call( func, "rolling_sum", - window_size, - min_periods, - center, + window_size=window_size, + min_periods=min_periods, + center=center, returns_scalar=False, ) @@ -911,20 +925,20 @@ def rolling_mean( ) -> Self: def func( _input: dask_expr.Series, - _window: int, - _min_periods: int | None, - _center: bool, # noqa: FBT001 + window_size: int, + min_periods: int | None, + center: bool, # noqa: FBT001 ) -> dask_expr.Series: return _input.rolling( - window=_window, min_periods=_min_periods, center=_center + window=window_size, min_periods=min_periods, center=center ).mean() return self._from_call( func, "rolling_mean", - window_size, - min_periods, - center, + window_size=window_size, + min_periods=min_periods, + center=center, returns_scalar=False, ) @@ -938,22 +952,22 @@ def rolling_var( ) -> Self: def func( _input: dask_expr.Series, - _window: int, - _min_periods: int | None, - _center: bool, # noqa: FBT001 - _ddof: int, + window_size: int, + min_periods: int | None, + center: bool, # noqa: FBT001 + ddof: int, ) -> dask_expr.Series: return _input.rolling( - window=_window, min_periods=_min_periods, center=_center + window=window_size, min_periods=min_periods, center=center ).var(ddof=ddof) return self._from_call( func, "rolling_var", - window_size, - min_periods, - center, - ddof, + window_size=window_size, + min_periods=min_periods, + center=center, + ddof=ddof, returns_scalar=False, ) @@ -967,22 +981,22 @@ def rolling_std( ) -> Self: def func( _input: dask_expr.Series, - _window: int, - _min_periods: int | None, - _center: bool, # noqa: FBT001 - _ddof: int, + window_size: int, + min_periods: int | None, + center: bool, # noqa: FBT001 + ddof: int, ) -> dask_expr.Series: return _input.rolling( - window=_window, min_periods=_min_periods, center=_center + window=window_size, min_periods=min_periods, center=center ).std(ddof=ddof) return self._from_call( func, "rolling_std", - window_size, - min_periods, - center, - ddof, + window_size=window_size, + min_periods=min_periods, + center=center, + ddof=ddof, returns_scalar=False, ) @@ -1005,14 +1019,14 @@ def replace( n: int = 1, ) -> DaskExpr: return self._compliant_expr._from_call( - lambda _input, _pattern, _value, _literal, _n: _input.str.replace( - _pattern, _value, regex=not _literal, n=_n + lambda _input, pattern, value, literal, n: _input.str.replace( + pattern, value, regex=not literal, n=n ), "replace", - pattern, - value, - literal, - n, + pattern=pattern, + value=value, + literal=literal, + n=n, returns_scalar=False, ) @@ -1024,13 +1038,13 @@ def replace_all( literal: bool = False, ) -> DaskExpr: return self._compliant_expr._from_call( - lambda _input, _pattern, _value, _literal: _input.str.replace( - _pattern, _value, n=-1, regex=not _literal + lambda _input, pattern, value, literal: _input.str.replace( + pattern, value, n=-1, regex=not literal ), "replace", - pattern, - value, - literal, + pattern=pattern, + value=value, + literal=literal, returns_scalar=False, ) @@ -1038,7 +1052,7 @@ def strip_chars(self, characters: str | None = None) -> DaskExpr: return self._compliant_expr._from_call( lambda _input, characters: _input.str.strip(characters), "strip", - characters, + characters=characters, returns_scalar=False, ) @@ -1046,7 +1060,7 @@ def starts_with(self, prefix: str) -> DaskExpr: return self._compliant_expr._from_call( lambda _input, prefix: _input.str.startswith(prefix), "starts_with", - prefix, + prefix=prefix, returns_scalar=False, ) @@ -1054,26 +1068,29 @@ def ends_with(self, suffix: str) -> DaskExpr: return self._compliant_expr._from_call( lambda _input, suffix: _input.str.endswith(suffix), "ends_with", - suffix, + suffix=suffix, returns_scalar=False, ) def contains(self, pattern: str, *, literal: bool = False) -> DaskExpr: return self._compliant_expr._from_call( - lambda _input, pat, regex: _input.str.contains(pat=pat, regex=regex), + lambda _input, pattern, literal: _input.str.contains( + pat=pattern, regex=not literal + ), "contains", - pattern, - not literal, + pattern=pattern, + literal=literal, returns_scalar=False, ) def slice(self, offset: int, length: int | None = None) -> DaskExpr: - stop = offset + length if length else None return self._compliant_expr._from_call( - lambda _input, start, stop: _input.str.slice(start=start, stop=stop), + lambda _input, offset, length: _input.str.slice( + start=offset, stop=offset + length if length else None + ), "slice", - offset, - stop, + offset=offset, + length=length, returns_scalar=False, ) @@ -1081,9 +1098,9 @@ def to_datetime(self: Self, format: str | None) -> DaskExpr: # noqa: A002 import dask.dataframe as dd return self._compliant_expr._from_call( - lambda _input, fmt: dd.to_datetime(_input, format=fmt), + lambda _input, format: dd.to_datetime(_input, format=format), "to_datetime", - format, + format=format, returns_scalar=False, ) @@ -1185,21 +1202,21 @@ def ordinal_day(self) -> DaskExpr: def to_string(self, format: str) -> DaskExpr: # noqa: A002 return self._compliant_expr._from_call( - lambda _input, _format: _input.dt.strftime(_format), + lambda _input, format: _input.dt.strftime(format.replace("%.f", ".%f")), "strftime", - format.replace("%.f", ".%f"), + format=format, returns_scalar=False, ) def replace_time_zone(self, time_zone: str | None) -> DaskExpr: return self._compliant_expr._from_call( - lambda _input, _time_zone: _input.dt.tz_localize(None).dt.tz_localize( - _time_zone + lambda _input, time_zone: _input.dt.tz_localize(None).dt.tz_localize( + time_zone ) - if _time_zone is not None + if time_zone is not None else _input.dt.tz_localize(None), "tz_localize", - time_zone, + time_zone=time_zone, returns_scalar=False, ) @@ -1216,7 +1233,7 @@ def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series: return self._compliant_expr._from_call( func, "tz_convert", - time_zone, + time_zone=time_zone, returns_scalar=False, ) @@ -1250,7 +1267,7 @@ def func( return self._compliant_expr._from_call( func, "datetime", - time_unit, + time_unit=time_unit, returns_scalar=False, ) @@ -1317,6 +1334,7 @@ def keep(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) def map(self: Self, function: Callable[[str], str]) -> DaskExpr: @@ -1344,6 +1362,7 @@ def map(self: Self, function: Callable[[str], str]) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"function": function}, ) def prefix(self: Self, prefix: str) -> DaskExpr: @@ -1369,6 +1388,7 @@ def prefix(self: Self, prefix: str) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"prefix": prefix}, ) def suffix(self: Self, suffix: str) -> DaskExpr: @@ -1395,6 +1415,7 @@ def suffix(self: Self, suffix: str) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"suffix": suffix}, ) def to_lowercase(self: Self) -> DaskExpr: @@ -1421,6 +1442,7 @@ def to_lowercase(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) def to_uppercase(self: Self) -> DaskExpr: @@ -1447,4 +1469,5 @@ def to_uppercase(self: Self) -> DaskExpr: returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index a64734bae..e0870d242 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -52,6 +52,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def col(self, *column_names: str) -> DaskExpr: @@ -87,6 +88,7 @@ def convert_if_dtype( returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def min(self, *column_names: str) -> DaskExpr: @@ -138,6 +140,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=True, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def all_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -156,6 +159,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -174,6 +178,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -192,6 +197,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def concat( @@ -273,6 +279,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -294,6 +301,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -315,6 +323,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def when( @@ -379,6 +388,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) @@ -441,6 +451,7 @@ def then(self, value: DaskExpr | Any) -> DaskThen: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, + kwargs={}, ) @@ -456,6 +467,7 @@ def __init__( returns_scalar: bool, backend_version: tuple[int, ...], version: Version, + kwargs: dict[str, Any], ) -> None: self._backend_version = backend_version self._version = version @@ -465,6 +477,7 @@ def __init__( self._root_names = root_names self._output_names = output_names self._returns_scalar = returns_scalar + self._kwargs = kwargs def otherwise(self, value: DaskExpr | Any) -> DaskExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index d4064353d..2891d84ff 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -38,6 +38,7 @@ def func(df: DaskLazyFrame) -> list[Any]: backend_version=self._backend_version, returns_scalar=False, version=self._version, + kwargs={}, ) def numeric(self: Self) -> DaskSelector: @@ -82,6 +83,7 @@ def func(df: DaskLazyFrame) -> list[Any]: backend_version=self._backend_version, returns_scalar=False, version=self._version, + kwargs={}, ) @@ -105,6 +107,7 @@ def _to_expr(self: Self) -> DaskExpr: backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, + kwargs={}, ) def __sub__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: @@ -124,6 +127,7 @@ def call(df: DaskLazyFrame) -> list[Any]: backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, + kwargs={}, ) else: return self._to_expr() - other @@ -145,6 +149,7 @@ def call(df: DaskLazyFrame) -> list[dask_expr.Series]: backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, + kwargs={}, ) else: return self._to_expr() | other @@ -166,6 +171,7 @@ def call(df: DaskLazyFrame) -> list[Any]: backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, + kwargs={}, ) else: return self._to_expr() & other diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index e15746f0b..3a9744c9c 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -129,7 +129,7 @@ def parse_into_expr( def reuse_series_implementation( expr: PandasLikeExprT, attr: str, - *args: Any, + *, returns_scalar: bool = False, **kwargs: Any, ) -> PandasLikeExprT: ... @@ -139,7 +139,7 @@ def reuse_series_implementation( def reuse_series_implementation( expr: ArrowExprT, attr: str, - *args: Any, + *, returns_scalar: bool = False, **kwargs: Any, ) -> ArrowExprT: ... @@ -148,7 +148,7 @@ def reuse_series_implementation( def reuse_series_implementation( expr: ArrowExprT | PandasLikeExprT, attr: str, - *args: Any, + *, returns_scalar: bool = False, **kwargs: Any, ) -> ArrowExprT | PandasLikeExprT: @@ -168,7 +168,6 @@ def reuse_series_implementation( plx = expr.__narwhals_namespace__() def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: - _args = [maybe_evaluate_expr(df, arg) for arg in args] # type: ignore[var-annotated] _kwargs = { # type: ignore[var-annotated] arg_name: maybe_evaluate_expr(df, arg_value) for arg_name, arg_value in kwargs.items() @@ -184,11 +183,11 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: out: list[CompliantSeries] = [ plx._create_series_from_scalar( - getattr(series, attr)(*_args, **extra_kwargs, **_kwargs), + getattr(series, attr)(**extra_kwargs, **_kwargs), reference_series=series, # type: ignore[arg-type] ) if returns_scalar - else getattr(series, attr)(*_args, **_kwargs) + else getattr(series, attr)(**_kwargs) for series in expr(df) # type: ignore[arg-type] ] if expr._output_names is not None and ( @@ -208,7 +207,7 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: # and just set it to None. root_names = copy(expr._root_names) output_names = expr._output_names - for arg in list(args) + list(kwargs.values()): + for arg in list(kwargs.values()): if root_names is not None and isinstance(arg, expr.__class__): if arg._root_names is not None: root_names.extend(arg._root_names) @@ -233,22 +232,22 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: function_name=f"{expr._function_name}->{attr}", root_names=root_names, output_names=output_names, + kwargs=kwargs, ) @overload def reuse_series_namespace_implementation( - expr: ArrowExprT, series_namespace: str, attr: str, *args: Any, **kwargs: Any + expr: ArrowExprT, series_namespace: str, attr: str, **kwargs: Any ) -> ArrowExprT: ... @overload def reuse_series_namespace_implementation( - expr: PandasLikeExprT, series_namespace: str, attr: str, *args: Any, **kwargs: Any + expr: PandasLikeExprT, series_namespace: str, attr: str, **kwargs: Any ) -> PandasLikeExprT: ... def reuse_series_namespace_implementation( expr: ArrowExprT | PandasLikeExprT, series_namespace: str, attr: str, - *args: Any, **kwargs: Any, ) -> ArrowExprT | PandasLikeExprT: """Reuse Series implementation for expression. @@ -266,13 +265,14 @@ def reuse_series_namespace_implementation( plx = expr.__narwhals_namespace__() return plx._create_expr_from_callable( # type: ignore[return-value] lambda df: [ - getattr(getattr(series, series_namespace), attr)(*args, **kwargs) + getattr(getattr(series, series_namespace), attr)(**kwargs) for series in expr(df) # type: ignore[arg-type] ], depth=expr._depth + 1, function_name=f"{expr._function_name}->{series_namespace}.{attr}", root_names=expr._root_names, output_names=expr._output_names, + kwargs=kwargs, ) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 8d507fb46..0cf2a3f73 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -48,6 +48,7 @@ def __init__( implementation: Implementation, backend_version: tuple[int, ...], version: Version, + kwargs: dict[str, Any], ) -> None: self._call = call self._depth = depth @@ -57,6 +58,7 @@ def __init__( self._implementation = implementation self._backend_version = backend_version self._version = version + self._kwargs = kwargs def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: return self._call(df) @@ -115,6 +117,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=implementation, backend_version=backend_version, version=version, + kwargs={}, ) @classmethod @@ -145,6 +148,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=implementation, backend_version=backend_version, version=version, + kwargs={}, ) def cast( @@ -259,10 +263,16 @@ def median(self) -> Self: return reuse_series_implementation(self, "median", returns_scalar=True) def std(self, *, ddof: int) -> Self: - return reuse_series_implementation(self, "std", ddof=ddof, returns_scalar=True) + expr = reuse_series_implementation(self, "std", ddof=ddof, returns_scalar=True) + if ddof != 1: + expr._depth += 1 + return expr def var(self, *, ddof: int) -> Self: - return reuse_series_implementation(self, "var", ddof=ddof, returns_scalar=True) + expr = reuse_series_implementation(self, "var", ddof=ddof, returns_scalar=True) + if ddof != 1: + expr._depth += 1 + return expr def skew(self: Self) -> Self: return reuse_series_implementation(self, "skew", returns_scalar=True) @@ -357,7 +367,7 @@ def replace_strict( self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: return reuse_series_implementation( - self, "replace_strict", old, new, return_dtype=return_dtype + self, "replace_strict", old=old, new=new, return_dtype=return_dtype ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -409,6 +419,7 @@ def alias(self, name: str) -> Self: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"name": name}, ) def over(self, keys: list[str]) -> Self: @@ -426,15 +437,25 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) raise ValueError(msg) + reverse = self._kwargs.get("reverse", False) + if reverse: + msg = ( + "Cumulative operation with `reverse=True` is not supported in " + "over context for pandas-like backend." + ) + raise NotImplementedError(msg) + if self._function_name == "col->cum_count": plx = self.__narwhals_namespace__() df = df.with_columns(~plx.col(*self._root_names).is_null()) - res_native = df._native_frame.groupby(list(keys), as_index=False)[ - self._root_names - ].transform( - CUMULATIVE_FUNCTIONS_TO_PANDAS_EQUIVALENT[self._function_name] - ) + res_native = getattr( + df._native_frame.groupby(list(keys), as_index=False)[ + self._root_names + ], + CUMULATIVE_FUNCTIONS_TO_PANDAS_EQUIVALENT[self._function_name], + )(skipna=True) + result_frame = df._from_native_frame( rename( res_native, @@ -470,6 +491,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"keys": keys}, ) def is_duplicated(self) -> Self: @@ -490,17 +512,21 @@ def quantile( interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Self: return reuse_series_implementation( - self, "quantile", quantile, interpolation, returns_scalar=True + self, + "quantile", + quantile=quantile, + interpolation=interpolation, + returns_scalar=True, ) def head(self, n: int) -> Self: - return reuse_series_implementation(self, "head", n) + return reuse_series_implementation(self, "head", n=n) def tail(self, n: int) -> Self: - return reuse_series_implementation(self, "tail", n) + return reuse_series_implementation(self, "tail", n=n) def round(self: Self, decimals: int) -> Self: - return reuse_series_implementation(self, "round", decimals) + return reuse_series_implementation(self, "round", decimals=decimals) def len(self: Self) -> Self: return reuse_series_implementation(self, "len", returns_scalar=True) @@ -542,6 +568,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"function": function, "return_dtype": return_dtype}, ) def is_finite(self: Self) -> Self: @@ -676,7 +703,13 @@ def replace( n: int = 1, ) -> PandasLikeExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "replace", pattern, value, literal=literal, n=n + self._compliant_expr, + "str", + "replace", + pattern=pattern, + value=value, + literal=literal, + n=n, ) def replace_all( @@ -687,7 +720,12 @@ def replace_all( literal: bool = False, ) -> PandasLikeExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "replace_all", pattern, value, literal=literal + self._compliant_expr, + "str", + "replace_all", + pattern=pattern, + value=value, + literal=literal, ) def strip_chars(self, characters: str | None = None) -> PandasLikeExpr: @@ -695,7 +733,7 @@ def strip_chars(self, characters: str | None = None) -> PandasLikeExpr: self._compliant_expr, "str", "strip_chars", - characters, + characters=characters, ) def starts_with(self, prefix: str) -> PandasLikeExpr: @@ -703,7 +741,7 @@ def starts_with(self, prefix: str) -> PandasLikeExpr: self._compliant_expr, "str", "starts_with", - prefix, + prefix=prefix, ) def ends_with(self, suffix: str) -> PandasLikeExpr: @@ -711,7 +749,7 @@ def ends_with(self, suffix: str) -> PandasLikeExpr: self._compliant_expr, "str", "ends_with", - suffix, + suffix=suffix, ) def contains(self, pattern: str, *, literal: bool) -> PandasLikeExpr: @@ -719,13 +757,13 @@ def contains(self, pattern: str, *, literal: bool) -> PandasLikeExpr: self._compliant_expr, "str", "contains", - pattern, + pattern=pattern, literal=literal, ) def slice(self, offset: int, length: int | None = None) -> PandasLikeExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "slice", offset, length + self._compliant_expr, "str", "slice", offset=offset, length=length ) def to_datetime(self: Self, format: str | None) -> PandasLikeExpr: # noqa: A002 @@ -733,7 +771,7 @@ def to_datetime(self: Self, format: str | None) -> PandasLikeExpr: # noqa: A002 self._compliant_expr, "str", "to_datetime", - format, + format=format, ) def to_uppercase(self) -> PandasLikeExpr: @@ -823,22 +861,22 @@ def total_nanoseconds(self) -> PandasLikeExpr: def to_string(self, format: str) -> PandasLikeExpr: # noqa: A002 return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "to_string", format + self._compliant_expr, "dt", "to_string", format=format ) def replace_time_zone(self, time_zone: str | None) -> PandasLikeExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "replace_time_zone", time_zone + self._compliant_expr, "dt", "replace_time_zone", time_zone=time_zone ) def convert_time_zone(self, time_zone: str) -> PandasLikeExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "convert_time_zone", time_zone + self._compliant_expr, "dt", "convert_time_zone", time_zone=time_zone ) def timestamp(self, time_unit: Literal["ns", "us", "ms"] = "us") -> PandasLikeExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "timestamp", time_unit + self._compliant_expr, "dt", "timestamp", time_unit=time_unit ) @@ -869,6 +907,7 @@ def keep(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: @@ -896,6 +935,7 @@ def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"function": function}, ) def prefix(self: Self, prefix: str) -> PandasLikeExpr: @@ -921,6 +961,7 @@ def prefix(self: Self, prefix: str) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"prefix": prefix}, ) def suffix(self: Self, suffix: str) -> PandasLikeExpr: @@ -947,6 +988,7 @@ def suffix(self: Self, suffix: str) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={"suffix": suffix}, ) def to_lowercase(self: Self) -> PandasLikeExpr: @@ -973,6 +1015,7 @@ def to_lowercase(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) def to_uppercase(self: Self) -> PandasLikeExpr: @@ -999,6 +1042,7 @@ def to_uppercase(self: Self) -> PandasLikeExpr: implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, + kwargs={}, ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index f6918d01b..495173025 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -58,6 +58,7 @@ def _create_expr_from_callable( function_name: str, root_names: list[str] | None, output_names: list[str] | None, + kwargs: dict[str, Any], ) -> PandasLikeExpr: return PandasLikeExpr( func, @@ -68,6 +69,7 @@ def _create_expr_from_callable( implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs=kwargs, ) def _create_series_from_scalar( @@ -92,6 +94,7 @@ def _create_expr_from_series(self, series: PandasLikeSeries) -> PandasLikeExpr: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def _create_compliant_series(self, value: Any) -> PandasLikeSeries: @@ -137,6 +140,7 @@ def all(self) -> PandasLikeExpr: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def lit(self, value: Any, dtype: DType | None) -> PandasLikeExpr: @@ -162,6 +166,7 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={}, ) # --- reduction --- @@ -224,6 +229,7 @@ def len(self) -> PandasLikeExpr: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={}, ) # --- horizontal --- @@ -240,6 +246,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="sum_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def all_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: @@ -255,6 +262,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="all_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def any_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: @@ -270,6 +278,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="any_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def mean_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: @@ -288,6 +297,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="mean_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def min_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: @@ -318,6 +328,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="min_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def max_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: @@ -348,6 +359,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="max_horizontal", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"exprs": exprs}, ) def concat( @@ -462,6 +474,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: function_name="concat_str", root_names=combine_root_names(parsed_exprs), output_names=reduce_output_names(parsed_exprs), + kwargs={"separator": separator, "ignore_nulls": ignore_nulls}, ) @@ -542,6 +555,7 @@ def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"value": value}, ) @@ -557,6 +571,7 @@ def __init__( implementation: Implementation, backend_version: tuple[int, ...], version: Version, + kwargs: dict[str, Any], ) -> None: self._implementation = implementation self._backend_version = backend_version @@ -566,6 +581,7 @@ def __init__( self._function_name = function_name self._root_names = root_names self._output_names = output_names + self._kwargs = kwargs def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 2f775aa6c..7ef666b96 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -40,6 +40,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"dtypes": dtypes}, ) def numeric(self) -> PandasSelector: @@ -84,6 +85,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={}, ) @@ -107,6 +109,7 @@ def _to_expr(self) -> PandasLikeExpr: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs=self._kwargs, ) def __sub__(self, other: PandasSelector | Any) -> PandasSelector | Any: @@ -126,6 +129,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"other": other}, ) else: return self._to_expr() - other @@ -147,6 +151,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"other": other}, ) else: return self._to_expr() | other @@ -168,6 +173,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: implementation=self._implementation, backend_version=self._backend_version, version=self._version, + kwargs={"other": other}, ) else: return self._to_expr() & other diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 6074dff2c..cbd645298 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -1,8 +1,8 @@ from __future__ import annotations -import operator from copy import copy from typing import TYPE_CHECKING +from typing import Any from typing import Callable from typing import Sequence @@ -37,6 +37,7 @@ def __init__( returns_scalar: bool, backend_version: tuple[int, ...], version: Version, + kwargs: dict[str, Any], ) -> None: self._call = call self._depth = depth @@ -46,6 +47,7 @@ def __init__( self._returns_scalar = returns_scalar self._backend_version = backend_version self._version = version + self._kwargs = kwargs def __call__(self, df: SparkLikeLazyFrame) -> Sequence[Column]: return self._call(df) @@ -81,24 +83,24 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: returns_scalar=False, backend_version=backend_version, version=version, + kwargs={}, ) def _from_call( self, call: Callable[..., Column], expr_name: str, - *args: SparkLikeExpr, + *, returns_scalar: bool, - **kwargs: SparkLikeExpr, + **kwargs: Any, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: results = [] inputs = self._call(df) - _args = [maybe_evaluate(df, arg) for arg in args] _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} for _input in inputs: input_col_name = get_column_name(df, _input) - column_result = call(_input, *_args, **_kwargs) + column_result = call(_input, **_kwargs) if not returns_scalar: column_result = column_result.alias(input_col_name) results.append(column_result) @@ -110,7 +112,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: # and just set it to None. root_names = copy(self._root_names) output_names = self._output_names - for arg in list(args) + list(kwargs.values()): + for arg in list(kwargs.values()): if root_names is not None and isinstance(arg, self.__class__): if arg._root_names is not None: root_names.extend(arg._root_names) @@ -138,22 +140,48 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=self._returns_scalar or returns_scalar, backend_version=self._backend_version, version=self._version, + kwargs=kwargs, ) def __add__(self, other: SparkLikeExpr) -> Self: - return self._from_call(operator.add, "__add__", other, returns_scalar=False) + return self._from_call( + lambda _input, other: _input + other, + "__add__", + other=other, + returns_scalar=False, + ) def __sub__(self, other: SparkLikeExpr) -> Self: - return self._from_call(operator.sub, "__sub__", other, returns_scalar=False) + return self._from_call( + lambda _input, other: _input - other, + "__sub__", + other=other, + returns_scalar=False, + ) def __mul__(self, other: SparkLikeExpr) -> Self: - return self._from_call(operator.mul, "__mul__", other, returns_scalar=False) + return self._from_call( + lambda _input, other: _input * other, + "__mul__", + other=other, + returns_scalar=False, + ) def __lt__(self, other: SparkLikeExpr) -> Self: - return self._from_call(operator.lt, "__lt__", other, returns_scalar=False) + return self._from_call( + lambda _input, other: _input < other, + "__lt__", + other=other, + returns_scalar=False, + ) def __gt__(self, other: SparkLikeExpr) -> Self: - return self._from_call(operator.gt, "__gt__", other, returns_scalar=False) + return self._from_call( + lambda _input, other: _input > other, + "__gt__", + other=other, + returns_scalar=False, + ) def alias(self, name: str) -> Self: def _alias(df: SparkLikeLazyFrame) -> list[Column]: @@ -170,6 +198,7 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, + kwargs={"name": name}, ) def count(self) -> Self: @@ -207,7 +236,7 @@ def _min(_input: Column) -> Column: def std(self, ddof: int) -> Self: import numpy as np # ignore-banned-import - def _std(_input: Column) -> Column: # pragma: no cover + def _std(_input: Column, ddof: int) -> Column: # pragma: no cover if self._backend_version < (3, 5) or parse_version(np.__version__) > (2, 0): from pyspark.sql import functions as F # noqa: N812 @@ -221,4 +250,7 @@ def _std(_input: Column) -> Column: # pragma: no cover return stddev(_input, ddof=ddof) - return self._from_call(_std, "std", returns_scalar=True) + expr = self._from_call(_std, "std", returns_scalar=True, ddof=ddof) + if ddof != 1: + expr._depth += 1 + return expr diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 44053e7c6..639523ed0 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -39,6 +39,7 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: @@ -58,6 +59,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: returns_scalar=False, backend_version=self._backend_version, version=self._version, + kwargs={}, ) def col(self, *column_names: str) -> SparkLikeExpr: diff --git a/narwhals/typing.py b/narwhals/typing.py index e6503e1ae..ec440751a 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -72,6 +72,7 @@ class CompliantExpr(Protocol, Generic[CompliantSeriesT_co]): _root_names: list[str] | None _depth: int _function_name: str + _kwargs: dict[str, Any] def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ... def __narwhals_expr__(self) -> None: ... diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index 54593421e..6f78add4a 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -14,6 +14,12 @@ "c": [5, 4, 3, 2, 1], } +data_cum = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, None, 5, 3], + "c": [5, 4, 3, 2, 1], +} + def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask_lazy_p2" in str(constructor): @@ -62,15 +68,16 @@ def test_over_cumsum(request: pytest.FixtureRequest, constructor: Constructor) - if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) + df = nw.from_native(constructor(data_cum)) expected = { "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, 3], + "b": [1, 2, None, 5, 3], "c": [5, 4, 3, 2, 1], - "b_cumsum": [1, 3, 3, 8, 11], + "b_cumsum": [1, 3, None, 5, 8], + "c_cumsum": [5, 9, 3, 5, 6], } - result = df.with_columns(b_cumsum=nw.col("b").cum_sum().over("a")) + result = df.with_columns(nw.col("b", "c").cum_sum().over("a").name.suffix("_cumsum")) assert_equal_data(result, expected) @@ -78,39 +85,18 @@ def test_over_cumcount(request: pytest.FixtureRequest, constructor: Constructor) if "pyarrow_table" in str(constructor) or "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) + df = nw.from_native(constructor(data_cum)) expected = { "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, 3], + "b": [1, 2, None, 5, 3], "c": [5, 4, 3, 2, 1], - "b_cumcount": [1, 2, 1, 2, 3], + "b_cumcount": [1, 2, 0, 1, 2], + "c_cumcount": [1, 2, 1, 2, 3], } - result = df.with_columns(b_cumcount=nw.col("b").cum_count().over("a")) - assert_equal_data(result, expected) - - -def test_over_cumcount_missing_values( - request: pytest.FixtureRequest, constructor: Constructor -) -> None: - if "pyarrow_table" in str(constructor) or "dask_lazy_p2" in str(constructor): - request.applymarker(pytest.mark.xfail) - - data_with_missing_value = { - "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, None], - "c": [5, 4, 3, 2, 1], - } - - df = nw.from_native(constructor(data_with_missing_value)) - expected = { - "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, None], - "c": [5, 4, 3, 2, 1], - "b_cumcount": [1, 2, 1, 2, 2], - } - - result = df.with_columns(b_cumcount=nw.col("b").cum_count().over("a")) + result = df.with_columns( + nw.col("b", "c").cum_count().over("a").name.suffix("_cumcount") + ) assert_equal_data(result, expected) @@ -119,14 +105,15 @@ def test_over_cummax(request: pytest.FixtureRequest, constructor: Constructor) - request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) + df = nw.from_native(constructor(data_cum)) expected = { "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, 3], + "b": [1, 2, None, 5, 3], "c": [5, 4, 3, 2, 1], - "b_cummax": [1, 2, 3, 5, 5], + "b_cummax": [1, 2, None, 5, 5], + "c_cummax": [5, 5, 3, 3, 3], } - result = df.with_columns(b_cummax=nw.col("b").cum_max().over("a")) + result = df.with_columns(nw.col("b", "c").cum_max().over("a").name.suffix("_cummax")) assert_equal_data(result, expected) @@ -137,15 +124,16 @@ def test_over_cummin(request: pytest.FixtureRequest, constructor: Constructor) - if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) + df = nw.from_native(constructor(data_cum)) expected = { "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, 3], + "b": [1, 2, None, 5, 3], "c": [5, 4, 3, 2, 1], - "b_cummin": [1, 1, 3, 3, 3], + "b_cummin": [1, 1, None, 5, 3], + "c_cummin": [5, 4, 3, 2, 1], } - result = df.with_columns(b_cummin=nw.col("b").cum_min().over("a")) + result = df.with_columns(nw.col("b", "c").cum_min().over("a").name.suffix("_cummin")) assert_equal_data(result, expected) @@ -156,15 +144,18 @@ def test_over_cumprod(request: pytest.FixtureRequest, constructor: Constructor) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) + df = nw.from_native(constructor(data_cum)) expected = { "a": ["a", "a", "b", "b", "b"], - "b": [1, 2, 3, 5, 3], + "b": [1, 2, None, 5, 3], "c": [5, 4, 3, 2, 1], - "b_cumprod": [1, 2, 3, 15, 45], + "b_cumprod": [1, 2, None, 5, 15], + "c_cumprod": [5, 20, 3, 6, 6], } - result = df.with_columns(b_cumprod=nw.col("b").cum_prod().over("a")) + result = df.with_columns( + nw.col("b", "c").cum_prod().over("a").name.suffix("_cumprod") + ) assert_equal_data(result, expected) @@ -172,3 +163,13 @@ def test_over_anonymous() -> None: df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) with pytest.raises(ValueError, match="Anonymous expressions"): nw.from_native(df).select(nw.all().cum_max().over("a")) + + +def test_over_cum_reverse() -> None: + df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + + with pytest.raises( + NotImplementedError, + match=r"Cumulative operation with `reverse=True` is not supported", + ): + nw.from_native(df).select(nw.col("b").cum_max(reverse=True).over("a"))