From 2dd44802887b253d20332b1a3cece34658ad6342 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sun, 29 Dec 2024 14:14:32 +0000 Subject: [PATCH] feat: expressify `lower_bound` and `upper_bound` in `is_between` (#1672) --- docs/how_it_works.md | 1 - narwhals/_arrow/namespace.py | 25 --------------- narwhals/_arrow/series.py | 6 ++++ narwhals/_dask/namespace.py | 25 --------------- narwhals/_pandas_like/namespace.py | 41 ------------------------ narwhals/_pandas_like/series.py | 11 ++++++- narwhals/_polars/namespace.py | 22 ------------- narwhals/expr.py | 21 +++++++----- narwhals/series.py | 8 +++-- narwhals/stable/v1/__init__.py | 2 +- tests/expr_and_series/is_between_test.py | 22 ++++++++++--- 11 files changed, 54 insertions(+), 130 deletions(-) diff --git a/docs/how_it_works.md b/docs/how_it_works.md index 70bc54bfe..6a6703581 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -266,7 +266,6 @@ In order to tell whether an aggregation is simple, Narwhals uses the private `_d ```python exec="1" result="python" session="pandas_impl" source="above" print(pn.col("a").mean()) print((pn.col("a") + 1).mean()) -print(pn.mean("a")) ``` For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 884bc8f08..ea37ae762 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -348,31 +348,6 @@ def concat( result_table, backend_version=self._backend_version, version=self._version ) - def sum(self: Self, *column_names: str) -> ArrowExpr: - return ArrowExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).sum() - - def mean(self: Self, *column_names: str) -> ArrowExpr: - return ArrowExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).mean() - - def median(self: Self, *column_names: str) -> ArrowExpr: - return ArrowExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).median() - - def max(self: Self, *column_names: str) -> ArrowExpr: - return ArrowExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).max() - - def min(self: Self, *column_names: str) -> ArrowExpr: - return ArrowExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).min() - @property def selectors(self: Self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace( diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 2f29ab9db..451247448 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -487,6 +487,12 @@ def is_between( import pyarrow.compute as pc ser = self._native_series + _, lower_bound = broadcast_and_extract_native( + self, lower_bound, self._backend_version + ) + _, upper_bound = broadcast_and_extract_native( + self, upper_bound, self._backend_version + ) if closed == "left": ge = pc.greater_equal(ser, lower_bound) lt = pc.less(ser, upper_bound) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 38cd16a87..126afaae6 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -91,31 +91,6 @@ def convert_if_dtype( kwargs={}, ) - def min(self, *column_names: str) -> DaskExpr: - return DaskExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).min() - - def max(self, *column_names: str) -> DaskExpr: - return DaskExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).max() - - def mean(self, *column_names: str) -> DaskExpr: - return DaskExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).mean() - - def median(self, *column_names: str) -> DaskExpr: - return DaskExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).median() - - def sum(self, *column_names: str) -> DaskExpr: - return DaskExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version - ).sum() - def len(self) -> DaskExpr: import dask.dataframe as dd import pandas as pd diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 0b060708b..7885d7de0 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -169,47 +169,6 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: kwargs={}, ) - # --- reduction --- - def sum(self, *column_names: str) -> PandasLikeExpr: - return PandasLikeExpr.from_column_names( - *column_names, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).sum() - - def mean(self, *column_names: str) -> PandasLikeExpr: - return PandasLikeExpr.from_column_names( - *column_names, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).mean() - - def median(self, *column_names: str) -> PandasLikeExpr: - return PandasLikeExpr.from_column_names( - *column_names, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).median() - - def max(self, *column_names: str) -> PandasLikeExpr: - return PandasLikeExpr.from_column_names( - *column_names, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).max() - - def min(self, *column_names: str) -> PandasLikeExpr: - return PandasLikeExpr.from_column_names( - *column_names, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).min() - def len(self) -> PandasLikeExpr: return PandasLikeExpr( lambda df: [ diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 1d895e147..c9f3d006d 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -263,6 +263,8 @@ def is_between( self, lower_bound: Any, upper_bound: Any, closed: str = "both" ) -> PandasLikeSeries: ser = self._native_series + _, lower_bound = broadcast_align_and_extract_native(self, lower_bound) + _, upper_bound = broadcast_align_and_extract_native(self, upper_bound) if closed == "left": res = ser.ge(lower_bound) & ser.lt(upper_bound) elif closed == "right": @@ -273,7 +275,14 @@ def is_between( res = ser.ge(lower_bound) & ser.le(upper_bound) else: # pragma: no cover raise AssertionError - return self._from_native_series(res) + return self._from_native_series( + rename( + res, + ser.name, + implementation=self._implementation, + backend_version=self._backend_version, + ) + ) def is_in(self, other: Any) -> PandasLikeSeries: ser = self._native_series diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 3e1ea1761..00e005c33 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -128,17 +128,6 @@ def lit(self: Self, value: Any, dtype: DType | None = None) -> PolarsExpr: pl.lit(value), version=self._version, backend_version=self._backend_version ) - def mean(self: Self, *column_names: str) -> PolarsExpr: - import polars as pl - - from narwhals._polars.expr import PolarsExpr - - return PolarsExpr( - pl.mean([*column_names]), # type: ignore[arg-type] - version=self._version, - backend_version=self._backend_version, - ) - def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr: import polars as pl @@ -160,17 +149,6 @@ def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr: backend_version=self._backend_version, ) - def median(self: Self, *column_names: str) -> PolarsExpr: - import polars as pl - - from narwhals._polars.expr import PolarsExpr - - return PolarsExpr( - pl.median([*column_names]), # type: ignore[arg-type] - version=self._version, - backend_version=self._backend_version, - ) - def concat_str( self, exprs: Iterable[IntoPolarsExpr], diff --git a/narwhals/expr.py b/narwhals/expr.py index 777cae615..ba0896db2 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1666,7 +1666,10 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: # --- transform --- def is_between( - self, lower_bound: Any, upper_bound: Any, closed: str = "both" + self, + lower_bound: Any | IntoExpr, + upper_bound: Any | IntoExpr, + closed: str = "both", ) -> Self: """Check if this expression is between the given lower and upper bounds. @@ -1724,7 +1727,9 @@ def is_between( """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_between( - lower_bound, upper_bound, closed + extract_compliant(plx, lower_bound), + extract_compliant(plx, upper_bound), + closed, ) ) @@ -6049,7 +6054,7 @@ def col(*names: str | Iterable[str]) -> Expr: """Creates an expression that references one or more columns by their name(s). Arguments: - names: Name(s) of the columns to use in the aggregation function. + names: Name(s) of the columns to use. Returns: A new expression. @@ -6308,7 +6313,7 @@ def sum(*columns: str) -> Expr: ---- a: [[3]] """ - return Expr(lambda plx: plx.sum(*columns)) + return Expr(lambda plx: plx.col(*columns).sum()) def mean(*columns: str) -> Expr: @@ -6359,7 +6364,7 @@ def mean(*columns: str) -> Expr: ---- a: [[4]] """ - return Expr(lambda plx: plx.mean(*columns)) + return Expr(lambda plx: plx.col(*columns).mean()) def median(*columns: str) -> Expr: @@ -6411,7 +6416,7 @@ def median(*columns: str) -> Expr: ---- a: [[4]] """ - return Expr(lambda plx: plx.median(*columns)) + return Expr(lambda plx: plx.col(*columns).median()) def min(*columns: str) -> Expr: @@ -6462,7 +6467,7 @@ def min(*columns: str) -> Expr: ---- b: [[5]] """ - return Expr(lambda plx: plx.min(*columns)) + return Expr(lambda plx: plx.col(*columns).min()) def max(*columns: str) -> Expr: @@ -6513,7 +6518,7 @@ def max(*columns: str) -> Expr: ---- a: [[2]] """ - return Expr(lambda plx: plx.max(*columns)) + return Expr(lambda plx: plx.col(*columns).max()) def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: diff --git a/narwhals/series.py b/narwhals/series.py index 1868eb396..9f5728390 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2138,7 +2138,7 @@ def fill_null( ) def is_between( - self, lower_bound: Any, upper_bound: Any, closed: str = "both" + self, lower_bound: Any | Self, upper_bound: Any | Self, closed: str = "both" ) -> Self: """Get a boolean mask of the values that are between the given lower/upper bounds. @@ -2189,7 +2189,11 @@ def is_between( ] """ return self._from_compliant_series( - self._compliant_series.is_between(lower_bound, upper_bound, closed=closed) + self._compliant_series.is_between( + self._extract_native(lower_bound), + self._extract_native(upper_bound), + closed=closed, + ) ) def n_unique(self) -> int: diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index b234ad9b2..2b5be2eee 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -2321,7 +2321,7 @@ def col(*names: str | Iterable[str]) -> Expr: """Creates an expression that references one or more columns by their name(s). Arguments: - names: Name(s) of the columns to use in the aggregation function. + names: Name(s) of the columns to use. Returns: A new expression. diff --git a/tests/expr_and_series/is_between_test.py b/tests/expr_and_series/is_between_test.py index 8d08c6fac..57ad545c0 100644 --- a/tests/expr_and_series/is_between_test.py +++ b/tests/expr_and_series/is_between_test.py @@ -7,10 +7,6 @@ from tests.utils import ConstructorEager from tests.utils import assert_equal_data -data = { - "a": [1, 4, 2, 5], -} - @pytest.mark.parametrize( ("closed", "expected"), @@ -22,12 +18,21 @@ ], ) def test_is_between(constructor: Constructor, closed: str, expected: list[bool]) -> None: + data = {"a": [1, 4, 2, 5]} df = nw.from_native(constructor(data)) result = df.select(nw.col("a").is_between(1, 5, closed=closed)) expected_dict = {"a": expected} assert_equal_data(result, expected_dict) +def test_is_between_expressified(constructor: Constructor) -> None: + data = {"a": [1, 4, 2, 5], "b": [0, 5, 2, 4], "c": [9, 9, 9, 9]} + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a").is_between(nw.col("b") * 0.9, nw.col("c") - 1)) + expected_dict = {"a": [True, False, True, True]} + assert_equal_data(result, expected_dict) + + @pytest.mark.parametrize( ("closed", "expected"), [ @@ -40,7 +45,16 @@ def test_is_between(constructor: Constructor, closed: str, expected: list[bool]) def test_is_between_series( constructor_eager: ConstructorEager, closed: str, expected: list[bool] ) -> None: + data = {"a": [1, 4, 2, 5]} df = nw.from_native(constructor_eager(data), eager_only=True) result = df.with_columns(a=df["a"].is_between(1, 5, closed=closed)) expected_dict = {"a": expected} assert_equal_data(result, expected_dict) + + +def test_is_between_expressified_series(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 4, 2, 5], "b": [0, 5, 2, 4], "c": [9, 9, 9, 9]} + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df["a"].is_between(df["b"], df["c"]).to_frame() + expected_dict = {"a": [True, False, True, True]} + assert_equal_data(result, expected_dict)