From db9a04843f6e53779ef41514b1cd0e335b45d11a Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Thu, 21 Nov 2024 09:23:20 +0100 Subject: [PATCH] feat: add `Expr|Series.rolling_mean` method (#1290) * feat: Series.rolling_mean * feat: Expr.rolling_mean * doc api reference * add test, fix arrow * old arrow, xfail for modin * arrow wip * perf arrow --- docs/api-reference/expr.md | 1 + docs/api-reference/series.md | 1 + narwhals/_arrow/expr.py | 15 ++ narwhals/_arrow/series.py | 51 ++++++ narwhals/_dask/expr.py | 26 +++ narwhals/_pandas_like/expr.py | 15 ++ narwhals/_pandas_like/series.py | 12 ++ narwhals/expr.py | 130 +++++++++---- narwhals/series.py | 131 ++++++++++---- narwhals/stable/v1/__init__.py | 201 ++++++++++++++++++++- narwhals/utils.py | 37 ++++ tests/expr_and_series/rolling_mean_test.py | 103 +++++++++++ 12 files changed, 651 insertions(+), 72 deletions(-) create mode 100644 tests/expr_and_series/rolling_mean_test.py diff --git a/docs/api-reference/expr.md b/docs/api-reference/expr.md index 694ae504b..3ca53934b 100644 --- a/docs/api-reference/expr.md +++ b/docs/api-reference/expr.md @@ -45,6 +45,7 @@ - pipe - quantile - replace_strict + - rolling_mean - rolling_sum - round - sample diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index d0cf7875f..7e7c85230 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -52,6 +52,7 @@ - quantile - rename - replace_strict + - rolling_mean - rolling_sum - round - sample diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 7e43620bd..d9ee5a361 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -468,6 +468,21 @@ def rolling_sum( center=center, ) + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + return reuse_series_implementation( + self, + "rolling_mean", + window_size=window_size, + min_periods=min_periods, + center=center, + ) + @property def dt(self: Self) -> ArrowExprDateTimeNamespace: return ArrowExprDateTimeNamespace(self) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 3c41c42a6..4ced6da54 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -922,6 +922,57 @@ def rolling_sum( result = result[offset_left + offset_right :] return result + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import + + min_periods = min_periods if min_periods is not None else window_size + if center: + offset_left = window_size // 2 + offset_right = offset_left - ( + window_size % 2 == 0 + ) # subtract one if window_size is even + + native_series = self._native_series + + pad_left = pa.array([None] * offset_left, type=native_series.type) + pad_right = pa.array([None] * offset_right, type=native_series.type) + padded_arr = self._from_native_series( + pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right]) + ) + else: + padded_arr = self + + cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward") + rolling_sum = ( + cum_sum - cum_sum.shift(window_size).fill_null(0) + if window_size != 0 + else cum_sum + ) + + valid_count = padded_arr.cum_count(reverse=False) + count_in_window = valid_count - valid_count.shift(window_size).fill_null(0) + + result = ( + self._from_native_series( + pc.if_else( + (count_in_window >= min_periods)._native_series, + rolling_sum._native_series, + None, + ) + ) + / count_in_window + ) + if center: + result = result[offset_left + offset_right :] + return result + def __iter__(self: Self) -> Iterator[Any]: yield from self._native_series.__iter__() diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index b70635dc6..92d670908 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -880,6 +880,32 @@ def func( returns_scalar=False, ) + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + def func( + _input: dask_expr.Series, + _window: int, + _min_periods: int | None, + _center: bool, # noqa: FBT001 + ) -> dask_expr.Series: + return _input.rolling( + window=_window, min_periods=_min_periods, center=_center + ).mean() + + return self._from_call( + func, + "rolling_mean", + window_size, + min_periods, + center, + returns_scalar=False, + ) + class DaskExprStringNamespace: def __init__(self, expr: DaskExpr) -> None: diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index fa769e790..ebbc05fe5 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -502,6 +502,21 @@ def rolling_sum( center=center, ) + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + return reuse_series_implementation( + self, + "rolling_mean", + window_size=window_size, + min_periods=min_periods, + center=center, + ) + @property def str(self: Self) -> PandasLikeExprStringNamespace: return PandasLikeExprStringNamespace(self) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 9cbdfd6af..0fca8ca4f 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -829,6 +829,18 @@ def rolling_sum( ).sum() return self._from_native_series(result) + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None, + center: bool, + ) -> Self: + result = self._native_series.rolling( + window=window_size, min_periods=min_periods, center=center + ).mean() + return self._from_native_series(result) + def __iter__(self: Self) -> Iterator[Any]: yield from self._native_series.__iter__() diff --git a/narwhals/expr.py b/narwhals/expr.py index 88c551315..6802b4152 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -11,7 +11,7 @@ from typing import TypeVar from narwhals.dependencies import is_numpy_array -from narwhals.exceptions import InvalidOperationError +from narwhals.utils import _validate_rolling_arguments from narwhals.utils import flatten if TYPE_CHECKING: @@ -3190,7 +3190,7 @@ def rolling_sum( We define a library agnostic function: - >>> def my_library_agnostic_function(df_native: IntoFrameT) -> IntoFrameT: + >>> def agnostic_rolling_sum(df_native: IntoFrameT) -> IntoFrameT: ... df = nw.from_native(df_native) ... return df.with_columns( ... b=nw.col("a").rolling_sum(window_size=3, min_periods=1) @@ -3198,14 +3198,14 @@ def rolling_sum( We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: - >>> my_library_agnostic_function(df_pd) + >>> agnostic_rolling_sum(df_pd) a b 0 1.0 1.0 1 2.0 3.0 2 NaN 3.0 3 4.0 6.0 - >>> my_library_agnostic_function(df_pl) + >>> agnostic_rolling_sum(df_pl) shape: (4, 2) ┌──────┬─────┐ │ a ┆ b │ @@ -3218,7 +3218,7 @@ def rolling_sum( │ 4.0 ┆ 6.0 │ └──────┴─────┘ - >>> my_library_agnostic_function(df_pa) # doctest:+ELLIPSIS + >>> agnostic_rolling_sum(df_pa) # doctest:+ELLIPSIS pyarrow.Table a: double b: double @@ -3226,38 +3226,104 @@ def rolling_sum( a: [[1,2,null,4]] b: [[1,3,3,6]] """ - if window_size < 1: - msg = "window_size must be greater or equal than 1" - raise ValueError(msg) + window_size, min_periods = _validate_rolling_arguments( + window_size=window_size, min_periods=min_periods + ) - if not isinstance(window_size, int): - _type = window_size.__class__.__name__ - msg = ( - f"argument 'window_size': '{_type}' object cannot be " - "interpreted as an integer" + return self.__class__( + lambda plx: self._call(plx).rolling_sum( + window_size=window_size, + min_periods=min_periods, + center=center, ) - raise TypeError(msg) + ) - if min_periods is not None: - if min_periods < 1: - msg = "min_periods must be greater or equal than 1" - raise ValueError(msg) - - if not isinstance(min_periods, int): - _type = min_periods.__class__.__name__ - msg = ( - f"argument 'min_periods': '{_type}' object cannot be " - "interpreted as an integer" - ) - raise TypeError(msg) - if min_periods > window_size: - msg = "`min_periods` must be less or equal than `window_size`" - raise InvalidOperationError(msg) - else: - min_periods = window_size + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling mean (moving mean) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their mean. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. It must be a + strictly positive integer. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> from narwhals.typing import IntoFrameT + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = {"a": [1.0, 2.0, None, 4.0]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) + + We define a library agnostic function: + + >>> def agnostic_rolling_mean(df_native: IntoFrameT) -> IntoFrameT: + ... df = nw.from_native(df_native) + ... return df.with_columns( + ... b=nw.col("a").rolling_mean(window_size=3, min_periods=1) + ... ).to_native() + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> agnostic_rolling_mean(df_pd) + a b + 0 1.0 1.0 + 1 2.0 1.5 + 2 NaN 1.5 + 3 4.0 3.0 + + >>> agnostic_rolling_mean(df_pl) + shape: (4, 2) + ┌──────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ f64 ┆ f64 │ + ╞══════╪═════╡ + │ 1.0 ┆ 1.0 │ + │ 2.0 ┆ 1.5 │ + │ null ┆ 1.5 │ + │ 4.0 ┆ 3.0 │ + └──────┴─────┘ + + >>> agnostic_rolling_mean(df_pa) # doctest:+ELLIPSIS + pyarrow.Table + a: double + b: double + ---- + a: [[1,2,null,4]] + b: [[1,1.5,1.5,3]] + """ + window_size, min_periods = _validate_rolling_arguments( + window_size=window_size, min_periods=min_periods + ) return self.__class__( - lambda plx: self._call(plx).rolling_sum( + lambda plx: self._call(plx).rolling_mean( window_size=window_size, min_periods=min_periods, center=center, diff --git a/narwhals/series.py b/narwhals/series.py index c153fd1c8..ac827303e 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -11,7 +11,7 @@ from typing import TypeVar from typing import overload -from narwhals.exceptions import InvalidOperationError +from narwhals.utils import _validate_rolling_arguments from narwhals.utils import parse_version if TYPE_CHECKING: @@ -3128,20 +3128,20 @@ def rolling_sum( We define a library agnostic function: - >>> def my_library_agnostic_function(s_native: IntoSeriesT) -> IntoSeriesT: + >>> def agnostic_rolling_sum(s_native: IntoSeriesT) -> IntoSeriesT: ... s = nw.from_native(s_native, series_only=True) ... return s.rolling_sum(window_size=2).to_native() We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: - >>> my_library_agnostic_function(s_pd) + >>> agnostic_rolling_sum(s_pd) 0 NaN 1 3.0 2 5.0 3 7.0 dtype: float64 - >>> my_library_agnostic_function(s_pl) # doctest:+NORMALIZE_WHITESPACE + >>> agnostic_rolling_sum(s_pl) # doctest:+NORMALIZE_WHITESPACE shape: (4,) Series: '' [f64] [ @@ -3151,7 +3151,7 @@ def rolling_sum( 7.0 ] - >>> my_library_agnostic_function(s_pa) # doctest:+ELLIPSIS + >>> agnostic_rolling_sum(s_pa) # doctest:+ELLIPSIS [ [ @@ -3162,41 +3162,108 @@ def rolling_sum( ] ] """ - if window_size < 1: - msg = "window_size must be greater or equal than 1" - raise ValueError(msg) + window_size, min_periods = _validate_rolling_arguments( + window_size=window_size, min_periods=min_periods + ) - if not isinstance(window_size, int): - _type = window_size.__class__.__name__ - msg = ( - f"argument 'window_size': '{_type}' object cannot be " - "interpreted as an integer" + if len(self) == 0: # pragma: no cover + return self + + return self._from_compliant_series( + self._compliant_series.rolling_sum( + window_size=window_size, + min_periods=min_periods, + center=center, ) - raise TypeError(msg) + ) - if min_periods is not None: - if min_periods < 1: - msg = "min_periods must be greater or equal than 1" - raise ValueError(msg) - - if not isinstance(min_periods, int): - _type = min_periods.__class__.__name__ - msg = ( - f"argument 'min_periods': '{_type}' object cannot be " - "interpreted as an integer" - ) - raise TypeError(msg) - if min_periods > window_size: - msg = "`min_periods` must be less or equal than `window_size`" - raise InvalidOperationError(msg) - else: - min_periods = window_size + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling mean (moving mean) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their mean. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. It must be a + strictly positive integer. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> from narwhals.typing import IntoSeriesT + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = [1.0, 2.0, 3.0, 4.0] + >>> s_pd = pd.Series(data) + >>> s_pl = pl.Series(data) + >>> s_pa = pa.chunked_array([data]) + + We define a library agnostic function: + + >>> def agnostic_rolling_mean(s_native: IntoSeriesT) -> IntoSeriesT: + ... s = nw.from_native(s_native, series_only=True) + ... return s.rolling_mean(window_size=2).to_native() + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> agnostic_rolling_mean(s_pd) + 0 NaN + 1 1.5 + 2 2.5 + 3 3.5 + dtype: float64 + + >>> agnostic_rolling_mean(s_pl) # doctest:+NORMALIZE_WHITESPACE + shape: (4,) + Series: '' [f64] + [ + null + 1.5 + 2.5 + 3.5 + ] + + >>> agnostic_rolling_mean(s_pa) # doctest:+ELLIPSIS + + [ + [ + null, + 1.5, + 2.5, + 3.5 + ] + ] + """ + window_size, min_periods = _validate_rolling_arguments( + window_size=window_size, min_periods=min_periods + ) if len(self) == 0: # pragma: no cover return self return self._from_compliant_series( - self._compliant_series.rolling_sum( + self._compliant_series.rolling_mean( window_size=window_size, min_periods=min_periods, center=center, diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 8769d9102..75b968c2d 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -673,20 +673,20 @@ def rolling_sum( We define a library agnostic function: - >>> def my_library_agnostic_function(s_native: IntoSeriesT) -> IntoSeriesT: + >>> def agnostic_rolling_sum(s_native: IntoSeriesT) -> IntoSeriesT: ... s = nw.from_native(s_native, series_only=True) ... return s.rolling_sum(window_size=2).to_native() We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: - >>> my_library_agnostic_function(s_pd) + >>> agnostic_rolling_sum(s_pd) 0 NaN 1 3.0 2 5.0 3 7.0 dtype: float64 - >>> my_library_agnostic_function(s_pl) # doctest:+NORMALIZE_WHITESPACE + >>> agnostic_rolling_sum(s_pl) # doctest:+NORMALIZE_WHITESPACE shape: (4,) Series: '' [f64] [ @@ -696,7 +696,7 @@ def rolling_sum( 7.0 ] - >>> my_library_agnostic_function(s_pa) # doctest:+ELLIPSIS + >>> agnostic_rolling_sum(s_pa) # doctest:+ELLIPSIS [ [ @@ -721,6 +721,98 @@ def rolling_sum( center=center, ) + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling mean (moving mean) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their mean. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. It must be a + strictly positive integer. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> from narwhals.typing import IntoSeriesT + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = [1.0, 2.0, 3.0, 4.0] + >>> s_pd = pd.Series(data) + >>> s_pl = pl.Series(data) + >>> s_pa = pa.chunked_array([data]) + + We define a library agnostic function: + + >>> def agnostic_rolling_mean(s_native: IntoSeriesT) -> IntoSeriesT: + ... s = nw.from_native(s_native, series_only=True) + ... return s.rolling_mean(window_size=2).to_native() + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> agnostic_rolling_mean(s_pd) + 0 NaN + 1 1.5 + 2 2.5 + 3 3.5 + dtype: float64 + + >>> agnostic_rolling_mean(s_pl) # doctest:+NORMALIZE_WHITESPACE + shape: (4,) + Series: '' [f64] + [ + null + 1.5 + 2.5 + 3.5 + ] + + >>> agnostic_rolling_mean(s_pa) # doctest:+ELLIPSIS + + [ + [ + null, + 1.5, + 2.5, + 3.5 + ] + ] + """ + from narwhals.exceptions import NarwhalsUnstableWarning + from narwhals.utils import find_stacklevel + + msg = ( + "`Series.rolling_mean` is being called from the stable API although considered " + "an unstable feature." + ) + warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel()) + return super().rolling_mean( + window_size=window_size, + min_periods=min_periods, + center=center, + ) + class Expr(NwExpr): def _l1_norm(self) -> Self: @@ -877,21 +969,21 @@ def rolling_sum( We define a library agnostic function: >>> @nw.narwhalify - ... def func(df): + ... def agnostic_rolling_sum(df): ... return df.with_columns( ... b=nw.col("a").rolling_sum(window_size=3, min_periods=1) ... ) We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: - >>> func(df_pd) + >>> agnostic_rolling_sum(df_pd) a b 0 1.0 1.0 1 2.0 3.0 2 NaN 3.0 3 4.0 6.0 - >>> func(df_pl) + >>> agnostic_rolling_sum(df_pl) shape: (4, 2) ┌──────┬─────┐ │ a ┆ b │ @@ -904,7 +996,7 @@ def rolling_sum( │ 4.0 ┆ 6.0 │ └──────┴─────┘ - >>> func(df_pa) # doctest:+ELLIPSIS + >>> agnostic_rolling_sum(df_pa) # doctest:+ELLIPSIS pyarrow.Table a: double b: double @@ -926,6 +1018,99 @@ def rolling_sum( center=center, ) + def rolling_mean( + self: Self, + window_size: int, + *, + min_periods: int | None = None, + center: bool = False, + ) -> Self: + """Apply a rolling mean (moving mean) over the values. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + A window of length `window_size` will traverse the values. The resulting values + will be aggregated to their mean. + + The window at a given row will include the row itself and the `window_size - 1` + elements before it. + + Arguments: + window_size: The length of the window in number of elements. It must be a + strictly positive integer. + min_periods: The number of values in the window that should be non-null before + computing a result. If set to `None` (default), it will be set equal to + `window_size`. If provided, it must be a strictly positive integer, and + less than or equal to `window_size` + center: Set the labels at the center of the window. + + Returns: + A new expression. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = {"a": [1.0, 2.0, None, 4.0]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + >>> df_pa = pa.table(data) + + We define a library agnostic function: + + >>> @nw.narwhalify + ... def agnostic_rolling_mean(df): + ... return df.with_columns( + ... b=nw.col("a").rolling_mean(window_size=3, min_periods=1) + ... ) + + We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`: + + >>> agnostic_rolling_mean(df_pd) + a b + 0 1.0 1.0 + 1 2.0 1.5 + 2 NaN 1.5 + 3 4.0 3.0 + + >>> agnostic_rolling_mean(df_pl) + shape: (4, 2) + ┌──────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ f64 ┆ f64 │ + ╞══════╪═════╡ + │ 1.0 ┆ 1.0 │ + │ 2.0 ┆ 1.5 │ + │ null ┆ 1.5 │ + │ 4.0 ┆ 3.0 │ + └──────┴─────┘ + + >>> agnostic_rolling_mean(df_pa) # doctest:+ELLIPSIS + pyarrow.Table + a: double + b: double + ---- + a: [[1,2,null,4]] + b: [[1,1.5,1.5,3]] + """ + from narwhals.exceptions import NarwhalsUnstableWarning + from narwhals.utils import find_stacklevel + + msg = ( + "`Expr.rolling_mean` is being called from the stable API although considered " + "an unstable feature." + ) + warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel()) + return super().rolling_mean( + window_size=window_size, + min_periods=min_periods, + center=center, + ) + class Schema(NwSchema): """Ordered mapping of column names to their data type. diff --git a/narwhals/utils.py b/narwhals/utils.py index 7cc41f78a..d3cb0ddaf 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -28,6 +28,7 @@ from narwhals.dependencies import is_polars_series from narwhals.dependencies import is_pyarrow_chunked_array from narwhals.exceptions import ColumnNotFoundError +from narwhals.exceptions import InvalidOperationError from narwhals.translate import to_native if TYPE_CHECKING: @@ -732,3 +733,39 @@ def validate_strict_and_pass_though( msg = "Cannot pass both `strict` and `pass_through`" raise ValueError(msg) return pass_through + + +def _validate_rolling_arguments( + window_size: int, min_periods: int | None +) -> tuple[int, int]: + if window_size < 1: + msg = "window_size must be greater or equal than 1" + raise ValueError(msg) + + if not isinstance(window_size, int): + _type = window_size.__class__.__name__ + msg = ( + f"argument 'window_size': '{_type}' object cannot be " + "interpreted as an integer" + ) + raise TypeError(msg) + + if min_periods is not None: + if min_periods < 1: + msg = "min_periods must be greater or equal than 1" + raise ValueError(msg) + + if not isinstance(min_periods, int): + _type = min_periods.__class__.__name__ + msg = ( + f"argument 'min_periods': '{_type}' object cannot be " + "interpreted as an integer" + ) + raise TypeError(msg) + if min_periods > window_size: + msg = "`min_periods` must be less or equal than `window_size`" + raise InvalidOperationError(msg) + else: + min_periods = window_size + + return window_size, min_periods diff --git a/tests/expr_and_series/rolling_mean_test.py b/tests/expr_and_series/rolling_mean_test.py new file mode 100644 index 000000000..3ffd8749a --- /dev/null +++ b/tests/expr_and_series/rolling_mean_test.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import random + +import hypothesis.strategies as st +import pandas as pd +import pyarrow as pa +import pytest +from hypothesis import given + +import narwhals.stable.v1 as nw +from tests.utils import PANDAS_VERSION +from tests.utils import Constructor +from tests.utils import ConstructorEager +from tests.utils import assert_equal_data + +data = {"a": [None, 1, 2, None, 4, 6, 11]} + +kwargs_and_expected = { + "x1": {"kwargs": {"window_size": 3}, "expected": [float("nan")] * 6 + [7.0]}, + "x2": { + "kwargs": {"window_size": 3, "min_periods": 1}, + "expected": [float("nan"), 1.0, 1.5, 1.5, 3.0, 5.0, 7.0], + }, + "x3": { + "kwargs": {"window_size": 2, "min_periods": 1}, + "expected": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5], + }, + "x4": { + "kwargs": {"window_size": 5, "min_periods": 1, "center": True}, + "expected": [1.5, 1.5, 7 / 3, 3.25, 5.75, 7.0, 7.0], + }, + "x5": { + "kwargs": {"window_size": 4, "min_periods": 1, "center": True}, + "expected": [1.0, 1.5, 1.5, 7 / 3, 4.0, 7.0, 7.0], + }, +} + + +@pytest.mark.filterwarnings( + "ignore:`Expr.rolling_mean` is being called from the stable API although considered an unstable feature." +) +def test_rolling_mean_expr( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "dask" in str(constructor): + # TODO(FBruzzesi): Dask is raising the following error: + # NotImplementedError: Partition size is less than overlapping window size. + # Try using ``df.repartition`` to increase the partition size. + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + **{ + name: nw.col("a").rolling_mean(**values["kwargs"]) # type: ignore[arg-type] + for name, values in kwargs_and_expected.items() + } + ) + expected = {name: values["expected"] for name, values in kwargs_and_expected.items()} + + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Series.rolling_mean` is being called from the stable API although considered an unstable feature." +) +def test_rolling_mean_series(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + + result = df.select( + **{ + name: df["a"].rolling_mean(**values["kwargs"]) # type: ignore[arg-type] + for name, values in kwargs_and_expected.items() + } + ) + expected = {name: values["expected"] for name, values in kwargs_and_expected.items()} + assert_equal_data(result, expected) + + +@given( # type: ignore[misc] + center=st.booleans(), + values=st.lists(st.floats(-10, 10), min_size=3, max_size=10), +) +@pytest.mark.skipif(PANDAS_VERSION < (1,), reason="too old for pyarrow") +@pytest.mark.filterwarnings("ignore:.*:narwhals.exceptions.NarwhalsUnstableWarning") +def test_rolling_mean_hypothesis(center: bool, values: list[float]) -> None: # noqa: FBT001 + s = pd.Series(values) + n_missing = random.randint(0, len(s) - 1) # noqa: S311 + window_size = random.randint(1, len(s)) # noqa: S311 + min_periods = random.randint(1, window_size) # noqa: S311 + mask = random.sample(range(len(s)), n_missing) + s[mask] = None + df = pd.DataFrame({"a": s}) + expected = ( + s.rolling(window=window_size, center=center, min_periods=min_periods) + .mean() + .to_frame("a") + ) + result = nw.from_native(pa.Table.from_pandas(df)).select( + nw.col("a").rolling_mean(window_size, center=center, min_periods=min_periods) + ) + expected_dict = nw.from_native(expected, eager_only=True).to_dict(as_series=False) + assert_equal_data(result, expected_dict)