Skip to content

Commit

Permalink
fix: address lit broadcasting and output name of right arithmetic o…
Browse files Browse the repository at this point in the history
…ps (#1424)



---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
AlessandroMiola and MarcoGorelli authored Nov 30, 2024
1 parent ea1a64f commit 943349d
Show file tree
Hide file tree
Showing 17 changed files with 340 additions and 163 deletions.
27 changes: 18 additions & 9 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,55 +151,64 @@ def __and__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__and__", other=other)

def __rand__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__rand__", other=other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__and__(self) # type: ignore[return-value]

def __or__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__or__", other=other)

def __ror__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__ror__", other=other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__or__(self) # type: ignore[return-value]

def __add__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__add__", other)

def __radd__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__radd__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__add__(self) # type: ignore[return-value]

def __sub__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__sub__", other)

def __rsub__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rsub__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__sub__(self) # type: ignore[return-value]

def __mul__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__mul__", other)

def __rmul__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rmul__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__mul__(self) # type: ignore[return-value]

def __pow__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__pow__", other)

def __rpow__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rpow__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__pow__(self) # type: ignore[return-value]

def __floordiv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__floordiv__", other)

def __rfloordiv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rfloordiv__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__floordiv__(self) # type: ignore[return-value]

def __truediv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__truediv__", other)

def __rtruediv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rtruediv__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__truediv__(self) # type: ignore[return-value]

def __mod__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__mod__", other)

def __rmod__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rmod__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__mod__(self) # type: ignore[return-value]

def __invert__(self: Self) -> Self:
return reuse_series_implementation(self, "__invert__")
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
depth=0,
function_name="lit",
root_names=None,
output_names=[_lit_arrow_series.__name__],
output_names=["literal"],
backend_version=self._backend_version,
dtypes=self._dtypes,
)
Expand Down
81 changes: 33 additions & 48 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,130 +96,115 @@ def __len__(self: Self) -> int:
def __eq__(self: Self, other: object) -> Self: # type: ignore[override]
import pyarrow.compute as pc

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.equal(ser, other))

def __ne__(self: Self, other: object) -> Self: # type: ignore[override]
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.not_equal(ser, other))

def __ge__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.greater_equal(ser, other))

def __gt__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.greater(ser, other))

def __le__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.less_equal(ser, other))

def __lt__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.less(ser, other))

def __and__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.and_kleene(ser, other))

def __rand__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.and_kleene(other, ser))

def __or__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.or_kleene(ser, other))

def __ror__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.or_kleene(other, ser))

def __add__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

other = validate_column_comparand(other)
return self._from_native_series(pc.add(self._native_series, other))
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.add(ser, other))

def __radd__(self: Self, other: Any) -> Self:
return self + other # type: ignore[no-any-return]

def __sub__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

other = validate_column_comparand(other)
return self._from_native_series(pc.subtract(self._native_series, other))
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.subtract(ser, other))

def __rsub__(self: Self, other: Any) -> Self:
return (self - other) * (-1) # type: ignore[no-any-return]

def __mul__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

other = validate_column_comparand(other)
return self._from_native_series(pc.multiply(self._native_series, other))
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.multiply(ser, other))

def __rmul__(self: Self, other: Any) -> Self:
return self * other # type: ignore[no-any-return]

def __pow__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.power(ser, other))

def __rpow__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.power(other, ser))

def __floordiv__(self: Self, other: Any) -> Self:
ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(floordiv_compat(ser, other))

def __rfloordiv__(self: Self, other: Any) -> Self:
ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(floordiv_compat(other, ser))

def __truediv__(self: Self, other: Any) -> Self:
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
# scalar
other = pa.scalar(other)
Expand All @@ -229,8 +214,7 @@ def __rtruediv__(self: Self, other: Any) -> Self:
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
# scalar
other = pa.scalar(other)
Expand All @@ -239,18 +223,16 @@ def __rtruediv__(self: Self, other: Any) -> Self:
def __mod__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
floor_div = (self // other)._native_series
ser, other = validate_column_comparand(self, other, self._backend_version)
res = pc.subtract(ser, pc.multiply(floor_div, other))
return self._from_native_series(res)

def __rmod__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
floor_div = (other // self)._native_series
ser, other = validate_column_comparand(self, other, self._backend_version)
res = pc.subtract(other, pc.multiply(floor_div, ser))
return self._from_native_series(res)

Expand All @@ -264,8 +246,10 @@ def len(self: Self) -> int:

def filter(self: Self, other: Any) -> Self:
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
other = validate_column_comparand(other)
return self._from_native_series(self._native_series.filter(other))
ser, other = validate_column_comparand(self, other, self._backend_version)
else:
ser = self._native_series
return self._from_native_series(ser.filter(other))

def mean(self: Self) -> int:
import pyarrow.compute as pc # ignore-banned-import()
Expand Down Expand Up @@ -382,16 +366,17 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

ca = self._native_series
mask = np.zeros(len(ca), dtype=bool)
mask = np.zeros(self.len(), dtype=bool)
mask[indices] = True
if isinstance(values, self.__class__):
values = validate_column_comparand(values)
ser, values = validate_column_comparand(self, values, self._backend_version)
else:
ser = self._native_series
if isinstance(values, pa.ChunkedArray):
values = values.combine_chunks()
if not isinstance(values, pa.Array):
values = pa.array(values)
result = pc.replace_with_mask(ca, mask, values.take(indices))
result = pc.replace_with_mask(ser, mask, values.take(indices))
return self._from_native_series(result)

def to_list(self: Self) -> list[Any]:
Expand Down
56 changes: 39 additions & 17 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> pa.D
raise AssertionError(msg)


def validate_column_comparand(other: Any) -> Any:
def validate_column_comparand(
lhs: ArrowSeries, rhs: Any, backend_version: tuple[int, ...]
) -> tuple[pa.ChunkedArray, Any]:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -140,27 +142,47 @@ def validate_column_comparand(other: Any) -> Any:
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.series import ArrowSeries

if isinstance(other, list):
if len(other) > 1:
if hasattr(other[0], "__narwhals_expr__") or hasattr(
other[0], "__narwhals_series__"
# If `rhs` is the output of an expression evaluation, then it is
# a list of Series. So, we verify that that list is of length-1,
# and take the first (and only) element.
if isinstance(rhs, list):
if len(rhs) > 1:
if hasattr(rhs[0], "__narwhals_expr__") or hasattr(
rhs[0], "__narwhals_series__"
):
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) are not supported in this context"
raise ValueError(msg)
msg = (
f"Expected scalar value, Series, or Expr, got list of : {type(other[0])}"
)
msg = f"Expected scalar value, Series, or Expr, got list of : {type(rhs[0])}"
raise ValueError(msg)
other = other[0]
if isinstance(other, ArrowDataFrame):
rhs = rhs[0]

if isinstance(rhs, ArrowDataFrame):
return NotImplemented
if isinstance(other, ArrowSeries):
if len(other) == 1:

if isinstance(rhs, ArrowSeries):
if len(rhs) == 1:
# broadcast
return other[0]
return other._native_series
return other
return lhs._native_series, rhs[0]
if len(lhs) == 1:
# broadcast
import numpy as np # ignore-banned-import
import pyarrow as pa # ignore-banned-import

fill_value = lhs[0]
if backend_version < (13,) and hasattr(fill_value, "as_py"):
fill_value = fill_value.as_py()
left_result = pa.chunked_array(
[
pa.array(
np.full(shape=rhs.len(), fill_value=fill_value),
type=lhs._native_series.type,
)
]
)
return left_result, rhs._native_series
return lhs._native_series, rhs._native_series
return lhs._native_series, rhs


def validate_dataframe_comparand(
Expand All @@ -179,7 +201,7 @@ def validate_dataframe_comparand(
import pyarrow as pa # ignore-banned-import

value = other._native_series[0]
if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
if backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return pa.array(np.full(shape=length, fill_value=value))
return other._native_series
Expand Down Expand Up @@ -321,7 +343,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]:
s_native = s._native_series
if is_max_length_gt_1 and length == 1:
value = s_native[0]
if s._backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
if s._backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
reshaped.append(pa.array([value] * max_length, type=s_native.type))
else:
Expand Down
Loading

0 comments on commit 943349d

Please sign in to comment.