Skip to content

Commit

Permalink
feat: add support for .shift(n).over('col') for pandas-like DataFra…
Browse files Browse the repository at this point in the history
…mes (#1627)

* First draft

* Fix failing test

* Improve test coverage

* Update narwhals/_pandas_like/expr.py

Accept comment

Co-authored-by: Francesco Bruzzesi <[email protected]>

---------

Co-authored-by: Francesco Bruzzesi <[email protected]>
  • Loading branch information
ClaudioSalvatoreArcidiacono and FBruzzesi authored Dec 26, 2024
1 parent ccf30e2 commit ecde246
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
15 changes: 11 additions & 4 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from narwhals.utils import Implementation
from narwhals.utils import Version

CUMULATIVE_FUNCTIONS_TO_PANDAS_EQUIVALENT = {
MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT = {
"col->cum_sum": "cumsum",
"col->cum_min": "cummin",
"col->cum_max": "cummax",
Expand All @@ -33,6 +33,7 @@
# Pandas cumcount counts nulls while Polars does not
# So, instead of using "cumcount" we use "cumsum" on notna() to get the same result
"col->cum_count": "cumsum",
"col->shift": "shift",
}


Expand Down Expand Up @@ -417,7 +418,7 @@ def alias(self, name: str) -> Self:
)

def over(self, keys: list[str]) -> Self:
if self._function_name in CUMULATIVE_FUNCTIONS_TO_PANDAS_EQUIVALENT:
if self._function_name in MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT:

def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
if (
Expand All @@ -443,12 +444,18 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
plx = self.__narwhals_namespace__()
df = df.with_columns(~plx.col(*self._root_names).is_null())

if self._function_name == "col->shift":
kwargs = {"periods": self._kwargs.get("n", 1)}
else:
# Cumulative operation
kwargs = {"skipna": True}

res_native = getattr(
df._native_frame.groupby(list(keys), as_index=False)[
self._root_names
],
CUMULATIVE_FUNCTIONS_TO_PANDAS_EQUIVALENT[self._function_name],
)(skipna=True)
MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT[self._function_name],
)(**kwargs)

result_frame = df._from_native_frame(
rename(
Expand Down
17 changes: 17 additions & 0 deletions tests/expr_and_series/over_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,23 @@ def test_over_anonymous() -> None:
nw.from_native(df).select(nw.all().cum_max().over("a"))


def test_over_shift(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyarrow_table_constructor" in str(
constructor
) or "dask_lazy_p2_constructor" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
expected = {
"a": ["a", "a", "b", "b", "b"],
"b": [1, 2, 3, 5, 3],
"c": [5, 4, 3, 2, 1],
"b_shift": [None, None, None, None, 3],
}
result = df.with_columns(b_shift=nw.col("b").shift(2).over("a"))
assert_equal_data(result, expected)


def test_over_cum_reverse() -> None:
df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})

Expand Down

0 comments on commit ecde246

Please sign in to comment.