Skip to content

Commit

Permalink
add test, fix arrow
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Oct 30, 2024
1 parent 59edc61 commit 0e9dced
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 18 deletions.
18 changes: 11 additions & 7 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def clip(
def to_arrow(self: Self) -> pa.Array:
return self._native_series.combine_chunks()

def mode(self: Self) -> Self:
def mode(self: Self) -> ArrowSeries:
plx = self.__narwhals_namespace__()
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
return self.value_counts(name=col_token, normalize=False).filter(
Expand All @@ -730,22 +730,26 @@ def rolling_mean(
min_periods: int | None,
center: bool,
) -> Self:
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

native_series = self._native_series

def weighted_mean(arr: pa.array, weights: pa.array) -> pa.scalar:
return pc.divide(pc.sum(pc.multiply(arr, weights)), pc.sum(weights))

result = pa.chunked_array(
[
[
pc.mean(v) if v is not None else None
for v in _rolling(
list(
_rolling(
native_series,
window_size=window_size,
weights=weights,
min_periods=min_periods,
center=center,
aggregate_function=weighted_mean,
)
]
)
]
)
return self._from_native_series(result)
Expand Down
23 changes: 13 additions & 10 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Generator
from typing import Sequence

Expand Down Expand Up @@ -430,36 +431,38 @@ def _rolling(
*,
min_periods: int | None,
center: bool,
aggregate_function: Callable[[pa.array, pa.array], pa.scalar],
) -> Generator[pa.array | None, None, None]:
import pyarrow as pa
import pyarrow.compute as pc
import numpy as np # ignore-banned-import
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

# Default min_periods to window_size if not provided
if min_periods is None:
min_periods = window_size

# Convert weights to a pyarrow array for elementwise operations if given
weights = pa.array(weights) if weights else pa.scalar(1)
weights_: pa.array = (
pa.array(weights) if weights else pa.array(np.full(window_size, 1.0))
)

# Flatten the chunked array to work with it as a contiguous array
flat_array = array.combine_chunks()
size = len(flat_array)
# Calculate rolling mean by slicing the flat array for each position
split_points = (
(max(0, i - window_size // 2), min(size, i + window_size // 2 + 1))
(max(0, i - window_size // 2), min(size, i + window_size // 2))
if center
else (max(0, i - window_size + 1), i + 1)
for i in range(size)
)

for start, end in split_points:
weighted_window = pc.drop_null(
pc.multiply(flat_array.slice(start, end - start), weights)
)

num_valid = len(weighted_window)
valid_window = pc.drop_null(flat_array.slice(start, end - start))
num_valid = len(valid_window)

if num_valid >= min_periods:
yield weighted_window
valid_weights = weights_.slice(0, num_valid)
yield aggregate_function(valid_window, valid_weights)
else:
yield None
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def rolling_mean(
) -> Self:
if weights is not None:
msg = (
"`weights` argument is not supported in `rolling_meanr` for "
"`weights` argument is not supported in `rolling_mean` for "
f"{self._implementation} backend."
)
raise NotImplementedError(msg)
Expand Down
95 changes: 95 additions & 0 deletions tests/expr_and_series/rolling_mean_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import re
from contextlib import nullcontext as does_not_raise

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {"a": [None, 1, 2, None, 4, 6, 11]}
data_weighted = {"a": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]}


def test_rolling_mean_expr_no_weights(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "dask" in str(constructor):
# TODO(FBruzzesi): Dask is raising the following error:
# NotImplementedError: Partition size is less than overlapping window size.
# Try using ``df.repartition`` to increase the partition size.
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.select(
x1=nw.col("a").rolling_mean(window_size=3),
x2=nw.col("a").rolling_mean(window_size=3, min_periods=1),
x3=nw.col("a").rolling_mean(window_size=2, min_periods=1),
x4=nw.col("a").rolling_mean(window_size=2, min_periods=1, center=True),
)
expected = {
"x1": [float("nan")] * 6 + [7],
"x2": [float("nan"), 1.0, 1.5, 1.5, 3.0, 5.0, 7.0],
"x3": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5],
"x4": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5],
}
assert_equal_data(result, expected)


def test_rolling_mean_series_no_weights(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager(data), eager_only=True)

result = df.select(
x1=df["a"].rolling_mean(window_size=3),
x2=df["a"].rolling_mean(window_size=3, min_periods=1),
x3=df["a"].rolling_mean(window_size=2, min_periods=1),
x4=df["a"].rolling_mean(window_size=2, min_periods=1, center=True),
)
expected = {
"x1": [float("nan")] * 6 + [7],
"x2": [float("nan"), 1.0, 1.5, 1.5, 3.0, 5.0, 7.0],
"x3": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5],
"x4": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5],
}
assert_equal_data(result, expected)


def test_weighted_rolling_mean_expr(constructor: Constructor) -> None:
context = (
pytest.raises(
NotImplementedError,
match=re.escape("`weights` argument is not supported in `rolling_mean`"),
)
if "pandas" in str(constructor) or "dask" in str(constructor)
else does_not_raise()
)
df = nw.from_native(constructor(data_weighted))

with context:
result = df.select(
x=nw.col("a").rolling_mean(window_size=2, weights=[0.25, 0.75]),
)
expected = {"x": [float("nan"), 1.75, 2.75, 3.75, 4.75, 5.75]}
assert_equal_data(result, expected)


def test_weighted_rolling_mean_series(constructor_eager: ConstructorEager) -> None:
context = (
pytest.raises(
NotImplementedError,
match=re.escape("`weights` argument is not supported in `rolling_mean`"),
)
if "pandas" in str(constructor_eager)
else does_not_raise()
)
df = nw.from_native(constructor_eager(data_weighted), eager_only=True)

with context:
result = df.select(
x=df["a"].rolling_mean(window_size=2, weights=[0.25, 0.75]),
)
expected = {"x": [float("nan"), 1.75, 2.75, 3.75, 4.75, 5.75]}
assert_equal_data(result, expected)

0 comments on commit 0e9dced

Please sign in to comment.