Skip to content

Commit

Permalink
WIPWIP
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Nov 8, 2024
1 parent ebf4321 commit e60214d
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 75 deletions.
25 changes: 15 additions & 10 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,21 +727,26 @@ def rank(
*,
descending: bool,
) -> Self:
if method == "average":
msg = (
"`rank` with `method='average' is not supported for pyarrow backend. "
"The available methods are {'min', 'max', 'dense', 'ordinal'}."
)
raise ValueError(msg)

import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

if method != "average":
sort_keys = "descending" if descending else "ascending"
tiebreaker = "first" if method == "ordinal" else method
native_series = self._native_series
null_mask = pc.is_null(native_series)
sort_keys = "descending" if descending else "ascending"
tiebreaker = "first" if method == "ordinal" else method

rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker)
native_series = self._native_series
null_mask = pc.is_null(native_series)

result = pc.if_else(null_mask, pa.scalar(None), rank)
return self._from_native_series(result)
else:
pass
rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker)

result = pc.if_else(null_mask, pa.scalar(None), rank)
return self._from_native_series(result)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()
Expand Down
98 changes: 47 additions & 51 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,9 @@ def rank(
"""
Assign ranks to data, dealing with ties appropriately.
Notes:
The resulting dtype may differ between backends.
Arguments:
method: The method used to assign ranks to tied elements.
The following methods are available (default is 'average'):
Expand All @@ -2338,61 +2341,54 @@ def rank(
descending: Rank in descending order.
Examples
--------
The 'average' method:
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {"a": [3, 6, 1, 1, 6]}
>>> df = pl.DataFrame({"a": [3, 6, 1, 1, 6]})
>>> df.select(pl.col("a").rank())
shape: (5, 1)
┌─────┐
│ a │
│ --- │
│ f64 │
╞═════╡
│ 3.0 │
│ 4.5 │
│ 1.5 │
│ 1.5 │
│ 4.5 │
└─────┘
We define a dataframe-agnostic function that computes the dense rank for
the data:
The 'ordinal' method:
>>> @nw.narwhalify
... def func(df):
... return df.with_columns(rnk=nw.col("a").rank(method="dense"))
>>> df = pl.DataFrame({"a": [3, 6, 1, 1, 6]})
>>> df.select(pl.col("a").rank("ordinal"))
shape: (5, 1)
┌─────┐
│ a │
│ --- │
│ u32 │
╞═════╡
│ 3 │
│ 4 │
│ 1 │
│ 2 │
│ 5 │
└─────┘
We can then pass any supported library such as pandas, Polars, or PyArrow:
>>> func(pl.DataFrame(data))
shape: (5, 2)
┌─────┬─────┐
│ a ┆ rnk │
│ --- ┆ --- │
│ i64 ┆ u32 │
╞═════╪═════╡
│ 3 ┆ 2 │
│ 6 ┆ 3 │
│ 1 ┆ 1 │
│ 1 ┆ 1 │
│ 6 ┆ 3 │
└─────┴─────┘
>>> func(pd.DataFrame(data))
a rnk
0 3 2.0
1 6 3.0
2 1 1.0
3 1 1.0
4 6 3.0
>>> func(pa.table(data))
pyarrow.Table
a: int64
rnk: uint64
----
a: [[3,6,1,1,6]]
rnk: [[2,3,1,1,3]]
"""

Use 'rank' with 'over' to rank within groups:
>>> df = pl.DataFrame({"a": [1, 1, 2, 2, 2], "b": [6, 7, 5, 14, 11]})
>>> df.with_columns(pl.col("b").rank().over("a").alias("rank"))
shape: (5, 3)
┌─────┬─────┬──────┐
│ a ┆ b ┆ rank │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f64 │
╞═════╪═════╪══════╡
│ 1 ┆ 6 ┆ 1.0 │
│ 1 ┆ 7 ┆ 2.0 │
│ 2 ┆ 5 ┆ 1.0 │
│ 2 ┆ 14 ┆ 3.0 │
│ 2 ┆ 11 ┆ 2.0 │
└─────┴─────┴──────┘
"""

supported_rank_methods = {"average", "min", "max", "dense"}
supported_rank_methods = {"average", "min", "max", "dense", "ordinal"}
if method not in supported_rank_methods:
msg = f"Ranking method must be one of {supported_rank_methods}. Found '{method}'"
raise ValueError(msg)
Expand Down
64 changes: 50 additions & 14 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,6 +2534,9 @@ def rank(
"""
Assign ranks to data, dealing with ties appropriately.
Notes:
The resulting dtype may differ between backends.
Arguments:
method: The method used to assign ranks to tied elements.
The following methods are available (default is 'average'):
Expand All @@ -2554,20 +2557,53 @@ def rank(
descending: Rank in descending order.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = [3, 6, 1, 1, 6]
We define a dataframe-agnostic function that computes the dense rank for
the data:
>>> @nw.narwhalify
... def func(s):
... return s.rank(method="dense")
>>> s = pl.Series("a", [3, 6, 1, 1, 6])
>>> s.rank()
shape: (5,)
Series: 'a' [f64]
[
3.0
4.5
1.5
1.5
4.5
]
"""
supported_rank_methods = {"average", "min", "max", "dense"}
We can then pass any supported library such as pandas, Polars, or PyArrow:
>>> func(pl.Series(data)) # doctest:+NORMALIZE_WHITESPACE
shape: (5,)
Series: '' [u32]
[
2
3
1
1
3
]
>>> func(pd.Series(data))
0 2.0
1 3.0
2 1.0
3 1.0
4 3.0
dtype: float64
>>> func(pa.chunked_array([data])) # doctest:+ELLIPSIS
<pyarrow.lib.ChunkedArray object at ...>
[
[
2,
3,
1,
1,
3
]
]
"""
supported_rank_methods = {"average", "min", "max", "dense", "ordinal"}
if method not in supported_rank_methods:
msg = f"Ranking method must be one of {supported_rank_methods}. Found '{method}'"
raise ValueError(msg)
Expand Down Expand Up @@ -3220,7 +3256,7 @@ def to_datetime(self: Self, format: str | None = None) -> T: # noqa: A002
... def func(s):
... return s.str.to_datetime(format="%Y-%m-%d")
We can then pass any supported library such as pandas, Polars, or PyArrow::
We can then pass any supported library such as pandas, Polars, or PyArrow:
>>> func(s_pd)
0 2020-01-01
Expand Down
98 changes: 98 additions & 0 deletions tests/expr_and_series/rank_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import Literal

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

rank_methods = ["average", "min", "max", "dense", "ordinal"]

data = {"a": [3, 6, 1, 1, None, 6], "b": [1, 1, 2, 1, 2, 2]}

expected = {
"average": [3.0, 4.5, 1.5, 1.5, float("nan"), 4.5],
"min": [3, 4, 1, 1, float("nan"), 4],
"max": [3, 5, 2, 2, float("nan"), 5],
"dense": [2, 3, 1, 1, float("nan"), 3],
"ordinal": [3, 4, 1, 2, float("nan"), 5],
}

expected_over = {
"average": [2.0, 3.0, 1.0, 1.0, float("nan"), 2.0],
"min": [2, 3, 1, 1, float("nan"), 2],
"max": [2, 3, 1, 1, float("nan"), 2],
"dense": [2, 3, 1, 1, float("nan"), 2],
"ordinal": [2, 3, 1, 1, float("nan"), 2],
}


@pytest.mark.parametrize("method", rank_methods)
def test_rank_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
method: Literal["average", "min", "max", "dense", "ordinal"],
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

context = (
pytest.raises(
ValueError,
match=r"`rank` with `method='average' is not supported for pyarrow backend.",
)
if "pyarrow_table" in str(constructor) and method == "average"
else does_not_raise()
)

with context:
df = nw.from_native(constructor(data))

result = df.select(nw.col("a").rank(method=method))
expected_data = {"a": expected[method]}
assert_equal_data(result, expected_data)


@pytest.mark.parametrize("method", rank_methods)
def test_rank_series(
constructor_eager: ConstructorEager,
method: Literal["average", "min", "max", "dense", "ordinal"],
) -> None:
context = (
pytest.raises(
ValueError,
match=r"`rank` with `method='average' is not supported for pyarrow backend.",
)
if "pyarrow_table" in str(constructor_eager) and method == "average"
else does_not_raise()
)

with context:
df = nw.from_native(constructor_eager(data), eager_only=True)

result = {"a": df["a"].rank(method=method)}
expected_data = {"a": expected[method]}
assert_equal_data(result, expected_data)


@pytest.mark.parametrize("method", rank_methods)
def test_rank_expr_in_over_context(
request: pytest.FixtureRequest,
constructor: Constructor,
method: Literal["average", "min", "max", "dense", "ordinal"],
) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
# Pyarrow raises:
# > pyarrow.lib.ArrowKeyError: No function registered with name: hash_rank
# We can handle that to provide a better error message.
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))

result = df.select(nw.col("a").rank(method=method).over("b"))
expected_data = {"a": expected_over[method]}
assert_equal_data(result, expected_data)

0 comments on commit e60214d

Please sign in to comment.