Skip to content

Commit

Permalink
feat: Series.rolling_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Oct 30, 2024
1 parent d5feb6f commit 136889e
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 1 deletion.
31 changes: 30 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Sequence
from typing import overload

from narwhals._arrow.utils import _rolling
from narwhals._arrow.utils import cast_for_truediv
from narwhals._arrow.utils import floordiv_compat
from narwhals._arrow.utils import narwhals_to_native_dtype
Expand Down Expand Up @@ -714,13 +715,41 @@ def clip(
def to_arrow(self: Self) -> pa.Array:
return self._native_series.combine_chunks()

def mode(self: Self) -> ArrowSeries:
def mode(self: Self) -> Self:
plx = self.__narwhals_namespace__()
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
return self.value_counts(name=col_token, normalize=False).filter(
plx.col(col_token) == plx.col(col_token).max()
)[self.name]

def rolling_mean(
self: Self,
window_size: int,
weights: list[float] | None,
*,
min_periods: int | None,
center: bool,
) -> Self:
import pyarrow as pa
import pyarrow.compute as pc

native_series = self._native_series
result = pa.chunked_array(
[
[
pc.mean(v) if v is not None else None
for v in _rolling(
native_series,
window_size=window_size,
weights=weights,
min_periods=min_periods,
center=center,
)
]
]
)
return self._from_native_series(result)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
43 changes: 43 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Generator
from typing import Sequence

from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -420,3 +421,45 @@ def _parse_time_format(arr: pa.Array) -> str:

matches = pc.extract_regex(arr, pattern=TIME_RE)
return "%H:%M:%S" if pc.all(matches.is_valid()).as_py() else ""


def _rolling(
array: pa.chunked_array,
window_size: int,
weights: list[float] | None,
*,
min_periods: int | None,
center: bool,
) -> Generator[pa.array | None, None, None]:
import pyarrow as pa
import pyarrow.compute as pc

# Default min_periods to window_size if not provided
if min_periods is None:
min_periods = window_size

# Convert weights to a pyarrow array for elementwise operations if given
weights = pa.array(weights) if weights else pa.scalar(1)

# Flatten the chunked array to work with it as a contiguous array
flat_array = array.combine_chunks()
size = len(flat_array)
# Calculate rolling mean by slicing the flat array for each position
split_points = (
(max(0, i - window_size // 2), min(size, i + window_size // 2 + 1))
if center
else (max(0, i - window_size + 1), i + 1)
for i in range(size)
)

for start, end in split_points:
weighted_window = pc.drop_null(
pc.multiply(flat_array.slice(start, end - start), weights)
)

num_valid = len(weighted_window)

if num_valid >= min_periods:
yield weighted_window
else:
yield None
19 changes: 19 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,25 @@ def mode(self: Self) -> Self:
result.name = native_series.name
return self._from_native_series(result)

def rolling_mean(
self: Self,
window_size: int,
weights: list[float] | None,
*,
min_periods: int | None,
center: bool,
) -> Self:
if weights is not None:
msg = (
f"`weights` argument is not supported for {self._implementation} backend"
)
raise NotImplementedError(msg)

result = self._native_series.rolling(
window=window_size, min_periods=min_periods, center=center
).mean()
return self._from_native_series(result)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
80 changes: 80 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,86 @@ def mode(self: Self) -> Self:
"""
return self._from_compliant_series(self._compliant_series.mode())

def rolling_mean(
self: Self,
window_size: int,
weights: list[float] | None = None,
*,
min_periods: int | None = None,
center: bool = False,
) -> Self:
"""
Apply a rolling mean (moving mean) over the values of the series.
A window of length `window_size` will traverse the series. The values that fill
this window will (optionally) be multiplied with the weights given by the
`weight` vector. The resulting values will be aggregated to their mean.
The window at a given row will include the row itself and the `window_size - 1`
elements before it.
Arguments:
window_size: The length of the window in number of elements.
weights: An optional slice with the same length as the window that will be
multiplied elementwise with the values in the window.
min_periods: The number of values in the window that should be non-null before
computing a result. If set to `None` (default), it will be set equal to
`window_size`.
center: Set the labels at the center of the window.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = [100, 200, 300]
>>> s_pd = pd.Series(name="a", data=data)
>>> s_pl = pl.Series(name="a", values=data)
>>> s_pa = pa.chunked_array([data])
We define a library agnostic function:
>>> @nw.narwhalify
... def func(s):
... return s.rolling_mean(window_size=2)
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
>>> func(s_pd)
0 NaN
1 150.0
2 250.0
Name: a, dtype: float64
>>> func(s_pl) # doctest:+NORMALIZE_WHITESPACE
shape: (3,)
Series: 'a' [f64]
[
null
150.0
250.0
]
>>> func(s_pa) # doctest:+ELLIPSIS
<pyarrow.lib.ChunkedArray object at ...>
[
[
null,
150,
250
]
]
"""
return self._from_compliant_series(
self._compliant_series.rolling_mean(
window_size=window_size,
weights=weights,
min_periods=min_periods,
center=center,
)
)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._compliant_series.__iter__()

Expand Down

0 comments on commit 136889e

Please sign in to comment.