Skip to content

Commit

Permalink
feat: add Expr|Series.rolling_mean method (#1290)
Browse files Browse the repository at this point in the history
* feat: Series.rolling_mean

* feat: Expr.rolling_mean

* doc api reference

* add test, fix arrow

* old arrow, xfail for modin

* arrow wip

* perf arrow
  • Loading branch information
FBruzzesi authored Nov 21, 2024
1 parent f77d137 commit db9a048
Show file tree
Hide file tree
Showing 12 changed files with 651 additions and 72 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
- pipe
- quantile
- replace_strict
- rolling_mean
- rolling_sum
- round
- sample
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
- quantile
- rename
- replace_strict
- rolling_mean
- rolling_sum
- round
- sample
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,21 @@ def rolling_sum(
center=center,
)

def rolling_mean(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_mean",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def dt(self: Self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
51 changes: 51 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,57 @@ def rolling_sum(
result = result[offset_left + offset_right :]
return result

def rolling_mean(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

min_periods = min_periods if min_periods is not None else window_size
if center:
offset_left = window_size // 2
offset_right = offset_left - (
window_size % 2 == 0
) # subtract one if window_size is even

native_series = self._native_series

pad_left = pa.array([None] * offset_left, type=native_series.type)
pad_right = pa.array([None] * offset_right, type=native_series.type)
padded_arr = self._from_native_series(
pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right])
)
else:
padded_arr = self

cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward")
rolling_sum = (
cum_sum - cum_sum.shift(window_size).fill_null(0)
if window_size != 0
else cum_sum
)

valid_count = padded_arr.cum_count(reverse=False)
count_in_window = valid_count - valid_count.shift(window_size).fill_null(0)

result = (
self._from_native_series(
pc.if_else(
(count_in_window >= min_periods)._native_series,
rolling_sum._native_series,
None,
)
)
/ count_in_window
)
if center:
result = result[offset_left + offset_right :]
return result

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
26 changes: 26 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,32 @@ def func(
returns_scalar=False,
)

def rolling_mean(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
def func(
_input: dask_expr.Series,
_window: int,
_min_periods: int | None,
_center: bool, # noqa: FBT001
) -> dask_expr.Series:
return _input.rolling(
window=_window, min_periods=_min_periods, center=_center
).mean()

return self._from_call(
func,
"rolling_mean",
window_size,
min_periods,
center,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,21 @@ def rolling_sum(
center=center,
)

def rolling_mean(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_mean",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def str(self: Self) -> PandasLikeExprStringNamespace:
return PandasLikeExprStringNamespace(self)
Expand Down
12 changes: 12 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,18 @@ def rolling_sum(
).sum()
return self._from_native_series(result)

def rolling_mean(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
result = self._native_series.rolling(
window=window_size, min_periods=min_periods, center=center
).mean()
return self._from_native_series(result)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
130 changes: 98 additions & 32 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import TypeVar

from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import _validate_rolling_arguments
from narwhals.utils import flatten

if TYPE_CHECKING:
Expand Down Expand Up @@ -3190,22 +3190,22 @@ def rolling_sum(
We define a library agnostic function:
>>> def my_library_agnostic_function(df_native: IntoFrameT) -> IntoFrameT:
>>> def agnostic_rolling_sum(df_native: IntoFrameT) -> IntoFrameT:
... df = nw.from_native(df_native)
... return df.with_columns(
... b=nw.col("a").rolling_sum(window_size=3, min_periods=1)
... ).to_native()
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
>>> my_library_agnostic_function(df_pd)
>>> agnostic_rolling_sum(df_pd)
a b
0 1.0 1.0
1 2.0 3.0
2 NaN 3.0
3 4.0 6.0
>>> my_library_agnostic_function(df_pl)
>>> agnostic_rolling_sum(df_pl)
shape: (4, 2)
┌──────┬─────┐
│ a ┆ b │
Expand All @@ -3218,46 +3218,112 @@ def rolling_sum(
│ 4.0 ┆ 6.0 │
└──────┴─────┘
>>> my_library_agnostic_function(df_pa) # doctest:+ELLIPSIS
>>> agnostic_rolling_sum(df_pa) # doctest:+ELLIPSIS
pyarrow.Table
a: double
b: double
----
a: [[1,2,null,4]]
b: [[1,3,3,6]]
"""
if window_size < 1:
msg = "window_size must be greater or equal than 1"
raise ValueError(msg)
window_size, min_periods = _validate_rolling_arguments(
window_size=window_size, min_periods=min_periods
)

if not isinstance(window_size, int):
_type = window_size.__class__.__name__
msg = (
f"argument 'window_size': '{_type}' object cannot be "
"interpreted as an integer"
return self.__class__(
lambda plx: self._call(plx).rolling_sum(
window_size=window_size,
min_periods=min_periods,
center=center,
)
raise TypeError(msg)
)

if min_periods is not None:
if min_periods < 1:
msg = "min_periods must be greater or equal than 1"
raise ValueError(msg)

if not isinstance(min_periods, int):
_type = min_periods.__class__.__name__
msg = (
f"argument 'min_periods': '{_type}' object cannot be "
"interpreted as an integer"
)
raise TypeError(msg)
if min_periods > window_size:
msg = "`min_periods` must be less or equal than `window_size`"
raise InvalidOperationError(msg)
else:
min_periods = window_size
def rolling_mean(
self: Self,
window_size: int,
*,
min_periods: int | None = None,
center: bool = False,
) -> Self:
"""Apply a rolling mean (moving mean) over the values.
!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.
A window of length `window_size` will traverse the values. The resulting values
will be aggregated to their mean.
The window at a given row will include the row itself and the `window_size - 1`
elements before it.
Arguments:
window_size: The length of the window in number of elements. It must be a
strictly positive integer.
min_periods: The number of values in the window that should be non-null before
computing a result. If set to `None` (default), it will be set equal to
`window_size`. If provided, it must be a strictly positive integer, and
less than or equal to `window_size`
center: Set the labels at the center of the window.
Returns:
A new expression.
Examples:
>>> import narwhals as nw
>>> from narwhals.typing import IntoFrameT
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {"a": [1.0, 2.0, None, 4.0]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
>>> df_pa = pa.table(data)
We define a library agnostic function:
>>> def agnostic_rolling_mean(df_native: IntoFrameT) -> IntoFrameT:
... df = nw.from_native(df_native)
... return df.with_columns(
... b=nw.col("a").rolling_mean(window_size=3, min_periods=1)
... ).to_native()
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
>>> agnostic_rolling_mean(df_pd)
a b
0 1.0 1.0
1 2.0 1.5
2 NaN 1.5
3 4.0 3.0
>>> agnostic_rolling_mean(df_pl)
shape: (4, 2)
┌──────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞══════╪═════╡
│ 1.0 ┆ 1.0 │
│ 2.0 ┆ 1.5 │
│ null ┆ 1.5 │
│ 4.0 ┆ 3.0 │
└──────┴─────┘
>>> agnostic_rolling_mean(df_pa) # doctest:+ELLIPSIS
pyarrow.Table
a: double
b: double
----
a: [[1,2,null,4]]
b: [[1,1.5,1.5,3]]
"""
window_size, min_periods = _validate_rolling_arguments(
window_size=window_size, min_periods=min_periods
)

return self.__class__(
lambda plx: self._call(plx).rolling_sum(
lambda plx: self._call(plx).rolling_mean(
window_size=window_size,
min_periods=min_periods,
center=center,
Expand Down
Loading

0 comments on commit db9a048

Please sign in to comment.