From 0e9dced7cc3a38fd27c4d7412381e7c90c155671 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Wed, 30 Oct 2024 23:28:29 +0100 Subject: [PATCH] add test, fix arrow --- narwhals/_arrow/series.py | 18 ++-- narwhals/_arrow/utils.py | 23 +++--- narwhals/_pandas_like/series.py | 2 +- tests/expr_and_series/rolling_mean_test.py | 95 ++++++++++++++++++++++ 4 files changed, 120 insertions(+), 18 deletions(-) create mode 100644 tests/expr_and_series/rolling_mean_test.py diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index cd33ee0ab..63660a6f7 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -715,7 +715,7 @@ def clip( def to_arrow(self: Self) -> pa.Array: return self._native_series.combine_chunks() - def mode(self: Self) -> Self: + def mode(self: Self) -> ArrowSeries: plx = self.__narwhals_namespace__() col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) return self.value_counts(name=col_token, normalize=False).filter( @@ -730,22 +730,26 @@ def rolling_mean( min_periods: int | None, center: bool, ) -> Self: - import pyarrow as pa - import pyarrow.compute as pc + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import native_series = self._native_series + + def weighted_mean(arr: pa.array, weights: pa.array) -> pa.scalar: + return pc.divide(pc.sum(pc.multiply(arr, weights)), pc.sum(weights)) + result = pa.chunked_array( [ - [ - pc.mean(v) if v is not None else None - for v in _rolling( + list( + _rolling( native_series, window_size=window_size, weights=weights, min_periods=min_periods, center=center, + aggregate_function=weighted_mean, ) - ] + ) ] ) return self._from_native_series(result) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 5311eedd5..0012860d1 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Generator from typing import Sequence @@ -430,36 +431,38 @@ def _rolling( *, min_periods: int | None, center: bool, + aggregate_function: Callable[[pa.array, pa.array], pa.scalar], ) -> Generator[pa.array | None, None, None]: - import pyarrow as pa - import pyarrow.compute as pc + import numpy as np # ignore-banned-import + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import # Default min_periods to window_size if not provided if min_periods is None: min_periods = window_size # Convert weights to a pyarrow array for elementwise operations if given - weights = pa.array(weights) if weights else pa.scalar(1) + weights_: pa.array = ( + pa.array(weights) if weights else pa.array(np.full(window_size, 1.0)) + ) # Flatten the chunked array to work with it as a contiguous array flat_array = array.combine_chunks() size = len(flat_array) # Calculate rolling mean by slicing the flat array for each position split_points = ( - (max(0, i - window_size // 2), min(size, i + window_size // 2 + 1)) + (max(0, i - window_size // 2), min(size, i + window_size // 2)) if center else (max(0, i - window_size + 1), i + 1) for i in range(size) ) for start, end in split_points: - weighted_window = pc.drop_null( - pc.multiply(flat_array.slice(start, end - start), weights) - ) - - num_valid = len(weighted_window) + valid_window = pc.drop_null(flat_array.slice(start, end - start)) + num_valid = len(valid_window) if num_valid >= min_periods: - yield weighted_window + valid_weights = weights_.slice(0, num_valid) + yield aggregate_function(valid_window, valid_weights) else: yield None diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index c0b1e765d..82cf2c4f6 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -693,7 +693,7 @@ def rolling_mean( ) -> Self: if weights is not None: msg = ( - "`weights` argument is not supported in `rolling_meanr` for " + "`weights` argument is not supported in `rolling_mean` for " f"{self._implementation} backend." ) raise NotImplementedError(msg) diff --git a/tests/expr_and_series/rolling_mean_test.py b/tests/expr_and_series/rolling_mean_test.py new file mode 100644 index 000000000..d99086f25 --- /dev/null +++ b/tests/expr_and_series/rolling_mean_test.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import re +from contextlib import nullcontext as does_not_raise + +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import Constructor +from tests.utils import ConstructorEager +from tests.utils import assert_equal_data + +data = {"a": [None, 1, 2, None, 4, 6, 11]} +data_weighted = {"a": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]} + + +def test_rolling_mean_expr_no_weights( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "dask" in str(constructor): + # TODO(FBruzzesi): Dask is raising the following error: + # NotImplementedError: Partition size is less than overlapping window size. + # Try using ``df.repartition`` to increase the partition size. + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + x1=nw.col("a").rolling_mean(window_size=3), + x2=nw.col("a").rolling_mean(window_size=3, min_periods=1), + x3=nw.col("a").rolling_mean(window_size=2, min_periods=1), + x4=nw.col("a").rolling_mean(window_size=2, min_periods=1, center=True), + ) + expected = { + "x1": [float("nan")] * 6 + [7], + "x2": [float("nan"), 1.0, 1.5, 1.5, 3.0, 5.0, 7.0], + "x3": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5], + "x4": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5], + } + assert_equal_data(result, expected) + + +def test_rolling_mean_series_no_weights(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager(data), eager_only=True) + + result = df.select( + x1=df["a"].rolling_mean(window_size=3), + x2=df["a"].rolling_mean(window_size=3, min_periods=1), + x3=df["a"].rolling_mean(window_size=2, min_periods=1), + x4=df["a"].rolling_mean(window_size=2, min_periods=1, center=True), + ) + expected = { + "x1": [float("nan")] * 6 + [7], + "x2": [float("nan"), 1.0, 1.5, 1.5, 3.0, 5.0, 7.0], + "x3": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5], + "x4": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5], + } + assert_equal_data(result, expected) + + +def test_weighted_rolling_mean_expr(constructor: Constructor) -> None: + context = ( + pytest.raises( + NotImplementedError, + match=re.escape("`weights` argument is not supported in `rolling_mean`"), + ) + if "pandas" in str(constructor) or "dask" in str(constructor) + else does_not_raise() + ) + df = nw.from_native(constructor(data_weighted)) + + with context: + result = df.select( + x=nw.col("a").rolling_mean(window_size=2, weights=[0.25, 0.75]), + ) + expected = {"x": [float("nan"), 1.75, 2.75, 3.75, 4.75, 5.75]} + assert_equal_data(result, expected) + + +def test_weighted_rolling_mean_series(constructor_eager: ConstructorEager) -> None: + context = ( + pytest.raises( + NotImplementedError, + match=re.escape("`weights` argument is not supported in `rolling_mean`"), + ) + if "pandas" in str(constructor_eager) + else does_not_raise() + ) + df = nw.from_native(constructor_eager(data_weighted), eager_only=True) + + with context: + result = df.select( + x=df["a"].rolling_mean(window_size=2, weights=[0.25, 0.75]), + ) + expected = {"x": [float("nan"), 1.75, 2.75, 3.75, 4.75, 5.75]} + assert_equal_data(result, expected)