Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expressify lower_bound and upper_bound in is_between #1672

Merged
merged 6 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ In order to tell whether an aggregation is simple, Narwhals uses the private `_d
```python exec="1" result="python" session="pandas_impl" source="above"
print(pn.col("a").mean())
print((pn.col("a") + 1).mean())
print(pn.mean("a"))
```

For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out
Expand Down
25 changes: 0 additions & 25 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,31 +348,6 @@ def concat(
result_table, backend_version=self._backend_version, version=self._version
)

def sum(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).sum()

def mean(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).mean()

def median(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).median()

def max(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).max()

def min(self: Self, *column_names: str) -> ArrowExpr:
return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).min()

Comment on lines -351 to -375
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by - it's redundat to define all of these in the CompliantExprs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one! I love to see net negative in PRs πŸ‘Œ

@property
def selectors(self: Self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace(
Expand Down
6 changes: 6 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ def is_between(
import pyarrow.compute as pc

ser = self._native_series
_, lower_bound = broadcast_and_extract_native(
self, lower_bound, self._backend_version
)
_, upper_bound = broadcast_and_extract_native(
self, upper_bound, self._backend_version
)
if closed == "left":
ge = pc.greater_equal(ser, lower_bound)
lt = pc.less(ser, upper_bound)
Expand Down
25 changes: 0 additions & 25 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,31 +91,6 @@ def convert_if_dtype(
kwargs={},
)

def min(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).min()

def max(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).max()

def mean(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).mean()

def median(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).median()

def sum(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
).sum()

def len(self) -> DaskExpr:
import dask.dataframe as dd
import pandas as pd
Expand Down
41 changes: 0 additions & 41 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,47 +169,6 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries:
kwargs={},
)

# --- reduction ---
def sum(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).sum()

def mean(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).mean()

def median(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).median()

def max(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).max()

def min(self, *column_names: str) -> PandasLikeExpr:
return PandasLikeExpr.from_column_names(
*column_names,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).min()

def len(self) -> PandasLikeExpr:
return PandasLikeExpr(
lambda df: [
Expand Down
11 changes: 10 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
) -> PandasLikeSeries:
ser = self._native_series
_, lower_bound = broadcast_align_and_extract_native(self, lower_bound)
_, upper_bound = broadcast_align_and_extract_native(self, upper_bound)
if closed == "left":
res = ser.ge(lower_bound) & ser.lt(upper_bound)
elif closed == "right":
Expand All @@ -273,7 +275,14 @@ def is_between(
res = ser.ge(lower_bound) & ser.le(upper_bound)
else: # pragma: no cover
raise AssertionError
return self._from_native_series(res)
return self._from_native_series(
rename(
res,
ser.name,
implementation=self._implementation,
backend_version=self._backend_version,
)
)

def is_in(self, other: Any) -> PandasLikeSeries:
ser = self._native_series
Expand Down
22 changes: 0 additions & 22 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,6 @@ def lit(self: Self, value: Any, dtype: DType | None = None) -> PolarsExpr:
pl.lit(value), version=self._version, backend_version=self._backend_version
)

def mean(self: Self, *column_names: str) -> PolarsExpr:
import polars as pl

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(
pl.mean([*column_names]), # type: ignore[arg-type]
version=self._version,
backend_version=self._backend_version,
)

def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr:
import polars as pl

Expand All @@ -160,17 +149,6 @@ def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr:
backend_version=self._backend_version,
)

def median(self: Self, *column_names: str) -> PolarsExpr:
import polars as pl

from narwhals._polars.expr import PolarsExpr

return PolarsExpr(
pl.median([*column_names]), # type: ignore[arg-type]
version=self._version,
backend_version=self._backend_version,
)

def concat_str(
self,
exprs: Iterable[IntoPolarsExpr],
Expand Down
21 changes: 13 additions & 8 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,10 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:

# --- transform ---
def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
self,
lower_bound: Any | IntoExpr,
upper_bound: Any | IntoExpr,
closed: str = "both",
) -> Self:
"""Check if this expression is between the given lower and upper bounds.
Expand Down Expand Up @@ -1724,7 +1727,9 @@ def is_between(
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).is_between(
lower_bound, upper_bound, closed
extract_compliant(plx, lower_bound),
extract_compliant(plx, upper_bound),
closed,
)
)

Expand Down Expand Up @@ -6049,7 +6054,7 @@ def col(*names: str | Iterable[str]) -> Expr:
"""Creates an expression that references one or more columns by their name(s).
Arguments:
names: Name(s) of the columns to use in the aggregation function.
names: Name(s) of the columns to use.
Returns:
A new expression.
Expand Down Expand Up @@ -6308,7 +6313,7 @@ def sum(*columns: str) -> Expr:
----
a: [[3]]
"""
return Expr(lambda plx: plx.sum(*columns))
return Expr(lambda plx: plx.col(*columns).sum())


def mean(*columns: str) -> Expr:
Expand Down Expand Up @@ -6359,7 +6364,7 @@ def mean(*columns: str) -> Expr:
----
a: [[4]]
"""
return Expr(lambda plx: plx.mean(*columns))
return Expr(lambda plx: plx.col(*columns).mean())


def median(*columns: str) -> Expr:
Expand Down Expand Up @@ -6411,7 +6416,7 @@ def median(*columns: str) -> Expr:
----
a: [[4]]
"""
return Expr(lambda plx: plx.median(*columns))
return Expr(lambda plx: plx.col(*columns).median())


def min(*columns: str) -> Expr:
Expand Down Expand Up @@ -6462,7 +6467,7 @@ def min(*columns: str) -> Expr:
----
b: [[5]]
"""
return Expr(lambda plx: plx.min(*columns))
return Expr(lambda plx: plx.col(*columns).min())


def max(*columns: str) -> Expr:
Expand Down Expand Up @@ -6513,7 +6518,7 @@ def max(*columns: str) -> Expr:
----
a: [[2]]
"""
return Expr(lambda plx: plx.max(*columns))
return Expr(lambda plx: plx.col(*columns).max())


def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
Expand Down
8 changes: 6 additions & 2 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,7 +2069,7 @@ def fill_null(
)

def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
self, lower_bound: Any | Self, upper_bound: Any | Self, closed: str = "both"
) -> Self:
"""Get a boolean mask of the values that are between the given lower/upper bounds.

Expand Down Expand Up @@ -2119,7 +2119,11 @@ def is_between(
]
"""
return self._from_compliant_series(
self._compliant_series.is_between(lower_bound, upper_bound, closed=closed)
self._compliant_series.is_between(
self._extract_native(lower_bound),
self._extract_native(upper_bound),
closed=closed,
)
)

def n_unique(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,7 @@ def col(*names: str | Iterable[str]) -> Expr:
"""Creates an expression that references one or more columns by their name(s).

Arguments:
names: Name(s) of the columns to use in the aggregation function.
names: Name(s) of the columns to use.

Returns:
A new expression.
Expand Down
22 changes: 18 additions & 4 deletions tests/expr_and_series/is_between_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {
"a": [1, 4, 2, 5],
}


@pytest.mark.parametrize(
("closed", "expected"),
Expand All @@ -22,12 +18,21 @@
],
)
def test_is_between(constructor: Constructor, closed: str, expected: list[bool]) -> None:
data = {"a": [1, 4, 2, 5]}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").is_between(1, 5, closed=closed))
expected_dict = {"a": expected}
assert_equal_data(result, expected_dict)


def test_is_between_expressified(constructor: Constructor) -> None:
data = {"a": [1, 4, 2, 5], "b": [0, 5, 2, 4], "c": [9, 9, 9, 9]}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").is_between(nw.col("b") * 0.9, nw.col("c") - 1))
expected_dict = {"a": [True, False, True, True]}
assert_equal_data(result, expected_dict)


@pytest.mark.parametrize(
("closed", "expected"),
[
Expand All @@ -40,7 +45,16 @@ def test_is_between(constructor: Constructor, closed: str, expected: list[bool])
def test_is_between_series(
constructor_eager: ConstructorEager, closed: str, expected: list[bool]
) -> None:
data = {"a": [1, 4, 2, 5]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df.with_columns(a=df["a"].is_between(1, 5, closed=closed))
expected_dict = {"a": expected}
assert_equal_data(result, expected_dict)


def test_is_between_expressified_series(constructor_eager: ConstructorEager) -> None:
data = {"a": [1, 4, 2, 5], "b": [0, 5, 2, 4], "c": [9, 9, 9, 9]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df["a"].is_between(df["b"], df["c"]).to_frame()
expected_dict = {"a": [True, False, True, True]}
assert_equal_data(result, expected_dict)
Loading