diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 48b90c118..3278920a5 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -727,21 +727,26 @@ def rank( *, descending: bool, ) -> Self: + if method == "average": + msg = ( + "`rank` with `method='average' is not supported for pyarrow backend. " + "The available methods are {'min', 'max', 'dense', 'ordinal'}." + ) + raise ValueError(msg) + import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import - if method != "average": - sort_keys = "descending" if descending else "ascending" - tiebreaker = "first" if method == "ordinal" else method - native_series = self._native_series - null_mask = pc.is_null(native_series) + sort_keys = "descending" if descending else "ascending" + tiebreaker = "first" if method == "ordinal" else method - rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker) + native_series = self._native_series + null_mask = pc.is_null(native_series) - result = pc.if_else(null_mask, pa.scalar(None), rank) - return self._from_native_series(result) - else: - pass + rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker) + + result = pc.if_else(null_mask, pa.scalar(None), rank) + return self._from_native_series(result) def __iter__(self: Self) -> Iterator[Any]: yield from self._native_series.__iter__() diff --git a/narwhals/expr.py b/narwhals/expr.py index 07e2fc92a..c0c83714f 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -2319,6 +2319,9 @@ def rank( """ Assign ranks to data, dealing with ties appropriately. + Notes: + The resulting dtype may differ between backends. + Arguments: method: The method used to assign ranks to tied elements. The following methods are available (default is 'average'): @@ -2338,61 +2341,54 @@ def rank( descending: Rank in descending order. - Examples - -------- - The 'average' method: + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = {"a": [3, 6, 1, 1, 6]} - >>> df = pl.DataFrame({"a": [3, 6, 1, 1, 6]}) - >>> df.select(pl.col("a").rank()) - shape: (5, 1) - ┌─────┐ - │ a │ - │ --- │ - │ f64 │ - ╞═════╡ - │ 3.0 │ - │ 4.5 │ - │ 1.5 │ - │ 1.5 │ - │ 4.5 │ - └─────┘ + We define a dataframe-agnostic function that computes the dense rank for + the data: - The 'ordinal' method: + >>> @nw.narwhalify + ... def func(df): + ... return df.with_columns(rnk=nw.col("a").rank(method="dense")) - >>> df = pl.DataFrame({"a": [3, 6, 1, 1, 6]}) - >>> df.select(pl.col("a").rank("ordinal")) - shape: (5, 1) - ┌─────┐ - │ a │ - │ --- │ - │ u32 │ - ╞═════╡ - │ 3 │ - │ 4 │ - │ 1 │ - │ 2 │ - │ 5 │ - └─────┘ + We can then pass any supported library such as pandas, Polars, or PyArrow: + + >>> func(pl.DataFrame(data)) + shape: (5, 2) + ┌─────┬─────┐ + │ a ┆ rnk │ + │ --- ┆ --- │ + │ i64 ┆ u32 │ + ╞═════╪═════╡ + │ 3 ┆ 2 │ + │ 6 ┆ 3 │ + │ 1 ┆ 1 │ + │ 1 ┆ 1 │ + │ 6 ┆ 3 │ + └─────┴─────┘ + + >>> func(pd.DataFrame(data)) + a rnk + 0 3 2.0 + 1 6 3.0 + 2 1 1.0 + 3 1 1.0 + 4 6 3.0 + + >>> func(pa.table(data)) + pyarrow.Table + a: int64 + rnk: uint64 + ---- + a: [[3,6,1,1,6]] + rnk: [[2,3,1,1,3]] + """ - Use 'rank' with 'over' to rank within groups: - - >>> df = pl.DataFrame({"a": [1, 1, 2, 2, 2], "b": [6, 7, 5, 14, 11]}) - >>> df.with_columns(pl.col("b").rank().over("a").alias("rank")) - shape: (5, 3) - ┌─────┬─────┬──────┐ - │ a ┆ b ┆ rank │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ f64 │ - ╞═════╪═════╪══════╡ - │ 1 ┆ 6 ┆ 1.0 │ - │ 1 ┆ 7 ┆ 2.0 │ - │ 2 ┆ 5 ┆ 1.0 │ - │ 2 ┆ 14 ┆ 3.0 │ - │ 2 ┆ 11 ┆ 2.0 │ - └─────┴─────┴──────┘ - """ - - supported_rank_methods = {"average", "min", "max", "dense"} + supported_rank_methods = {"average", "min", "max", "dense", "ordinal"} if method not in supported_rank_methods: msg = f"Ranking method must be one of {supported_rank_methods}. Found '{method}'" raise ValueError(msg) diff --git a/narwhals/series.py b/narwhals/series.py index 868a9eb5c..7aa921eec 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2534,6 +2534,9 @@ def rank( """ Assign ranks to data, dealing with ties appropriately. + Notes: + The resulting dtype may differ between backends. + Arguments: method: The method used to assign ranks to tied elements. The following methods are available (default is 'average'): @@ -2554,20 +2557,53 @@ def rank( descending: Rank in descending order. Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> data = [3, 6, 1, 1, 6] + + We define a dataframe-agnostic function that computes the dense rank for + the data: + + >>> @nw.narwhalify + ... def func(s): + ... return s.rank(method="dense") - >>> s = pl.Series("a", [3, 6, 1, 1, 6]) - >>> s.rank() - shape: (5,) - Series: 'a' [f64] - [ - 3.0 - 4.5 - 1.5 - 1.5 - 4.5 - ] - """ - supported_rank_methods = {"average", "min", "max", "dense"} + We can then pass any supported library such as pandas, Polars, or PyArrow: + + >>> func(pl.Series(data)) # doctest:+NORMALIZE_WHITESPACE + shape: (5,) + Series: '' [u32] + [ + 2 + 3 + 1 + 1 + 3 + ] + + >>> func(pd.Series(data)) + 0 2.0 + 1 3.0 + 2 1.0 + 3 1.0 + 4 3.0 + dtype: float64 + + >>> func(pa.chunked_array([data])) # doctest:+ELLIPSIS + + [ + [ + 2, + 3, + 1, + 1, + 3 + ] + ] + """ + supported_rank_methods = {"average", "min", "max", "dense", "ordinal"} if method not in supported_rank_methods: msg = f"Ranking method must be one of {supported_rank_methods}. Found '{method}'" raise ValueError(msg) @@ -3220,7 +3256,7 @@ def to_datetime(self: Self, format: str | None = None) -> T: # noqa: A002 ... def func(s): ... return s.str.to_datetime(format="%Y-%m-%d") - We can then pass any supported library such as pandas, Polars, or PyArrow:: + We can then pass any supported library such as pandas, Polars, or PyArrow: >>> func(s_pd) 0 2020-01-01 diff --git a/tests/expr_and_series/rank_test.py b/tests/expr_and_series/rank_test.py new file mode 100644 index 000000000..090605bf4 --- /dev/null +++ b/tests/expr_and_series/rank_test.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from contextlib import nullcontext as does_not_raise +from typing import Literal + +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import Constructor +from tests.utils import ConstructorEager +from tests.utils import assert_equal_data + +rank_methods = ["average", "min", "max", "dense", "ordinal"] + +data = {"a": [3, 6, 1, 1, None, 6], "b": [1, 1, 2, 1, 2, 2]} + +expected = { + "average": [3.0, 4.5, 1.5, 1.5, float("nan"), 4.5], + "min": [3, 4, 1, 1, float("nan"), 4], + "max": [3, 5, 2, 2, float("nan"), 5], + "dense": [2, 3, 1, 1, float("nan"), 3], + "ordinal": [3, 4, 1, 2, float("nan"), 5], +} + +expected_over = { + "average": [2.0, 3.0, 1.0, 1.0, float("nan"), 2.0], + "min": [2, 3, 1, 1, float("nan"), 2], + "max": [2, 3, 1, 1, float("nan"), 2], + "dense": [2, 3, 1, 1, float("nan"), 2], + "ordinal": [2, 3, 1, 1, float("nan"), 2], +} + + +@pytest.mark.parametrize("method", rank_methods) +def test_rank_expr( + request: pytest.FixtureRequest, + constructor: Constructor, + method: Literal["average", "min", "max", "dense", "ordinal"], +) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + + context = ( + pytest.raises( + ValueError, + match=r"`rank` with `method='average' is not supported for pyarrow backend.", + ) + if "pyarrow_table" in str(constructor) and method == "average" + else does_not_raise() + ) + + with context: + df = nw.from_native(constructor(data)) + + result = df.select(nw.col("a").rank(method=method)) + expected_data = {"a": expected[method]} + assert_equal_data(result, expected_data) + + +@pytest.mark.parametrize("method", rank_methods) +def test_rank_series( + constructor_eager: ConstructorEager, + method: Literal["average", "min", "max", "dense", "ordinal"], +) -> None: + context = ( + pytest.raises( + ValueError, + match=r"`rank` with `method='average' is not supported for pyarrow backend.", + ) + if "pyarrow_table" in str(constructor_eager) and method == "average" + else does_not_raise() + ) + + with context: + df = nw.from_native(constructor_eager(data), eager_only=True) + + result = {"a": df["a"].rank(method=method)} + expected_data = {"a": expected[method]} + assert_equal_data(result, expected_data) + + +@pytest.mark.parametrize("method", rank_methods) +def test_rank_expr_in_over_context( + request: pytest.FixtureRequest, + constructor: Constructor, + method: Literal["average", "min", "max", "dense", "ordinal"], +) -> None: + if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + # Pyarrow raises: + # > pyarrow.lib.ArrowKeyError: No function registered with name: hash_rank + # We can handle that to provide a better error message. + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + + result = df.select(nw.col("a").rank(method=method).over("b")) + expected_data = {"a": expected_over[method]} + assert_equal_data(result, expected_data)