Skip to content

Commit

Permalink
feat: consistently return Python scalars from Series reductions for P…
Browse files Browse the repository at this point in the history
…yArrow (#1471)


---------

Co-authored-by: Francesco Bruzzesi <[email protected]>
  • Loading branch information
MarcoGorelli and FBruzzesi authored Nov 30, 2024
1 parent 635434e commit bfbc34d
Show file tree
Hide file tree
Showing 19 changed files with 137 additions and 98 deletions.
6 changes: 3 additions & 3 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,21 @@
"Field",
"Float32",
"Float64",
"Int8",
"Int16",
"Int32",
"Int64",
"Int8",
"LazyFrame",
"List",
"Object",
"Schema",
"Series",
"String",
"Struct",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"UInt8",
"Unknown",
"all",
"all_horizontal",
Expand All @@ -113,8 +113,8 @@
"exceptions",
"from_arrow",
"from_dict",
"from_numpy",
"from_native",
"from_numpy",
"generate_temporary_column_name",
"get_level",
"get_native_namespace",
Expand Down
10 changes: 8 additions & 2 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ def is_empty(self: Self) -> bool:
return self.shape[0] == 0

def item(self: Self, row: int | None, column: int | str | None) -> Any:
from narwhals._arrow.series import maybe_extract_py_scalar

if row is None and column is None:
if self.shape != (1, 1):
msg = (
Expand All @@ -549,14 +551,18 @@ def item(self: Self, row: int | None, column: int | str | None) -> Any:
f" frame has shape {self.shape!r}"
)
raise ValueError(msg)
return self._native_frame[0][0]
return maybe_extract_py_scalar(
self._native_frame[0][0], return_py_scalar=True
)

elif row is None or column is None:
msg = "cannot call `.item()` with only one of `row` or `column`"
raise ValueError(msg)

_col = self.columns.index(column) if isinstance(column, str) else column
return self._native_frame[_col][row]
return maybe_extract_py_scalar(
self._native_frame[_col][row], return_py_scalar=True
)

def rename(self: Self, mapping: dict[str, str]) -> Self:
df = self._native_frame
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


class ArrowExpr:
_implementation: Implementation = Implementation.PYARROW

def __init__(
self: Self,
call: Callable[[ArrowDataFrame], list[ArrowSeries]],
Expand Down
87 changes: 53 additions & 34 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import parse_datetime_format
from narwhals._arrow.utils import validate_column_comparand
from narwhals.translate import to_py_scalar
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name

Expand All @@ -32,6 +31,12 @@
from narwhals.typing import DTypes


def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: FBT001
if return_py_scalar:
return getattr(value, "as_py", lambda: value)()
return value


class ArrowSeries:
def __init__(
self: Self,
Expand Down Expand Up @@ -241,8 +246,8 @@ def __invert__(self: Self) -> Self:

return self._from_native_series(pc.invert(self._native_series))

def len(self: Self) -> int:
return len(self._native_series)
def len(self: Self, *, _return_py_scalar: bool = True) -> int:
return maybe_extract_py_scalar(len(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def filter(self: Self, other: Any) -> Self:
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
Expand All @@ -251,12 +256,12 @@ def filter(self: Self, other: Any) -> Self:
ser = self._native_series
return self._from_native_series(ser.filter(other))

def mean(self: Self) -> int:
def mean(self: Self, *, _return_py_scalar: bool = True) -> int:
import pyarrow.compute as pc # ignore-banned-import()

return pc.mean(self._native_series) # type: ignore[no-any-return]
return maybe_extract_py_scalar(pc.mean(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def median(self: Self) -> int:
def median(self: Self, *, _return_py_scalar: bool = True) -> int:
import pyarrow.compute as pc # ignore-banned-import()

from narwhals.exceptions import InvalidOperationError
Expand All @@ -265,22 +270,24 @@ def median(self: Self) -> int:
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)

return pc.approximate_median(self._native_series) # type: ignore[no-any-return]
return maybe_extract_py_scalar( # type: ignore[no-any-return]
pc.approximate_median(self._native_series), _return_py_scalar
)

def min(self: Self) -> int:
def min(self: Self, *, _return_py_scalar: bool = True) -> int:
import pyarrow.compute as pc # ignore-banned-import()

return pc.min(self._native_series) # type: ignore[no-any-return]
return maybe_extract_py_scalar(pc.min(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def max(self: Self) -> int:
def max(self: Self, *, _return_py_scalar: bool = True) -> int:
import pyarrow.compute as pc # ignore-banned-import()

return pc.max(self._native_series) # type: ignore[no-any-return]
return maybe_extract_py_scalar(pc.max(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def sum(self: Self) -> int:
def sum(self: Self, *, _return_py_scalar: bool = True) -> int:
import pyarrow.compute as pc # ignore-banned-import()

return pc.sum(self._native_series) # type: ignore[no-any-return]
return maybe_extract_py_scalar(pc.sum(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def drop_nulls(self: Self) -> ArrowSeries:
import pyarrow.compute as pc # ignore-banned-import()
Expand All @@ -300,12 +307,14 @@ def shift(self: Self, n: int) -> Self:
result = ca
return self._from_native_series(result)

def std(self: Self, ddof: int) -> float:
def std(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float:
import pyarrow.compute as pc # ignore-banned-import()

return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return]
return maybe_extract_py_scalar( # type: ignore[no-any-return]
pc.stddev(self._native_series, ddof=ddof), _return_py_scalar
)

def skew(self: Self) -> float | None:
def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
Expand All @@ -321,18 +330,22 @@ def skew(self: Self) -> float | None:
m2 = pc.mean(pc.power(m, 2))
m3 = pc.mean(pc.power(m, 3))
# Biased population skewness
return pc.divide(m3, pc.power(m2, 1.5)) # type: ignore[no-any-return]
return maybe_extract_py_scalar( # type: ignore[no-any-return]
pc.divide(m3, pc.power(m2, 1.5)), _return_py_scalar
)

def count(self: Self) -> int:
def count(self: Self, *, _return_py_scalar: bool = True) -> int:
import pyarrow.compute as pc # ignore-banned-import()

return pc.count(self._native_series) # type: ignore[no-any-return]
return maybe_extract_py_scalar(pc.count(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def n_unique(self: Self) -> int:
def n_unique(self: Self, *, _return_py_scalar: bool = True) -> int:
import pyarrow.compute as pc # ignore-banned-import()

unique_values = pc.unique(self._native_series)
return pc.count(unique_values, mode="all") # type: ignore[no-any-return]
return maybe_extract_py_scalar( # type: ignore[no-any-return]
pc.count(unique_values, mode="all"), _return_py_scalar
)

def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
Expand Down Expand Up @@ -430,15 +443,15 @@ def diff(self: Self) -> Self:
pc.pairwise_diff(self._native_series.combine_chunks())
)

def any(self: Self) -> bool:
def any(self: Self, *, _return_py_scalar: bool = True) -> bool:
import pyarrow.compute as pc # ignore-banned-import()

return to_py_scalar(pc.any(self._native_series)) # type: ignore[no-any-return]
return maybe_extract_py_scalar(pc.any(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def all(self: Self) -> bool:
def all(self: Self, *, _return_py_scalar: bool = True) -> bool:
import pyarrow.compute as pc # ignore-banned-import()

return to_py_scalar(pc.all(self._native_series)) # type: ignore[no-any-return]
return maybe_extract_py_scalar(pc.all(self._native_series), _return_py_scalar) # type: ignore[no-any-return]

def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
Expand Down Expand Up @@ -480,8 +493,8 @@ def cast(self: Self, dtype: DType) -> Self:
dtype = narwhals_to_native_dtype(dtype, self._dtypes)
return self._from_native_series(pc.cast(ser, dtype))

def null_count(self: Self) -> int:
return self._native_series.null_count # type: ignore[no-any-return]
def null_count(self: Self, *, _return_py_scalar: bool = True) -> int:
return maybe_extract_py_scalar(self._native_series.null_count, _return_py_scalar) # type: ignore[no-any-return]

def head(self: Self, n: int) -> Self:
ser = self._native_series
Expand Down Expand Up @@ -527,8 +540,8 @@ def item(self: Self, index: int | None = None) -> Any:
f" or an explicit index is provided (Series is of length {len(self)})"
)
raise ValueError(msg)
return self._native_series[0]
return self._native_series[index]
return maybe_extract_py_scalar(self._native_series[0], return_py_scalar=True)
return maybe_extract_py_scalar(self._native_series[index], return_py_scalar=True)

def value_counts(
self: Self,
Expand Down Expand Up @@ -718,7 +731,7 @@ def is_sorted(self: Self, *, descending: bool) -> bool:
result = pc.all(pc.greater_equal(ser[:-1], ser[1:]))
else:
result = pc.all(pc.less_equal(ser[:-1], ser[1:]))
return to_py_scalar(result) # type: ignore[no-any-return]
return maybe_extract_py_scalar(result, return_py_scalar=True) # type: ignore[no-any-return]

def unique(self: Self, *, maintain_order: bool) -> ArrowSeries:
# The param `maintain_order` is only here for compatibility with the Polars API
Expand Down Expand Up @@ -798,12 +811,15 @@ def quantile(
self: Self,
quantile: float,
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
*,
_return_py_scalar: bool = True,
) -> Any:
import pyarrow.compute as pc # ignore-banned-import()

return pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[
0
]
return maybe_extract_py_scalar(
pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[0],
_return_py_scalar,
)

def gather_every(self: Self, n: int, offset: int = 0) -> Self:
return self._from_native_series(self._native_series[offset::n])
Expand Down Expand Up @@ -994,7 +1010,10 @@ def rolling_mean(
return result

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()
yield from (
maybe_extract_py_scalar(x, return_py_scalar=True)
for x in self._native_series.__iter__()
)

@property
def shape(self: Self) -> tuple[int]:
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@


class DaskExpr:
_implementation: Implementation = Implementation.DASK

def __init__(
self,
call: Callable[[DaskLazyFrame], list[dask_expr.Series]],
Expand Down
11 changes: 10 additions & 1 deletion narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidIntoExprError
from narwhals.utils import Implementation

if TYPE_CHECKING:
from narwhals._arrow.dataframe import ArrowDataFrame
Expand Down Expand Up @@ -223,9 +224,17 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]:
for arg_name, arg_value in kwargs.items()
}

# For PyArrow.Series, we return Python Scalars (like Polars does) instead of PyArrow Scalars.
# However, when working with expressions, we keep everything PyArrow-native.
extra_kwargs = (
{"_return_py_scalar": False}
if returns_scalar and expr._implementation is Implementation.PYARROW
else {}
)

out: list[CompliantSeries] = [
plx._create_series_from_scalar(
getattr(series, attr)(*_args, **_kwargs),
getattr(series, attr)(*_args, **extra_kwargs, **_kwargs),
reference_series=series, # type: ignore[arg-type]
)
if returns_scalar
Expand Down
34 changes: 17 additions & 17 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,30 +302,30 @@ def is_into_dataframe(native_dataframe: Any) -> bool:


__all__ = [
"get_polars",
"get_pandas",
"get_modin",
"get_cudf",
"get_pyarrow",
"get_numpy",
"get_ibis",
"get_modin",
"get_numpy",
"get_pandas",
"get_polars",
"get_pyarrow",
"is_cudf_dataframe",
"is_cudf_series",
"is_dask_dataframe",
"is_ibis_table",
"is_into_dataframe",
"is_into_series",
"is_modin_dataframe",
"is_modin_series",
"is_numpy_array",
"is_pandas_dataframe",
"is_pandas_series",
"is_pandas_index",
"is_pandas_like_dataframe",
"is_pandas_like_series",
"is_pandas_series",
"is_polars_dataframe",
"is_polars_lazyframe",
"is_polars_series",
"is_modin_dataframe",
"is_modin_series",
"is_cudf_dataframe",
"is_cudf_series",
"is_pyarrow_table",
"is_pyarrow_chunked_array",
"is_numpy_array",
"is_dask_dataframe",
"is_pandas_like_dataframe",
"is_pandas_like_series",
"is_into_dataframe",
"is_into_series",
"is_pyarrow_table",
]
6 changes: 3 additions & 3 deletions narwhals/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ def all() -> Expr:


__all__ = [
"all",
"boolean",
"by_dtype",
"categorical",
"numeric",
"boolean",
"string",
"categorical",
"all",
]
2 changes: 1 addition & 1 deletion narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def median(self) -> Any:
>>> my_library_agnostic_function(s_pl)
5.0
>>> my_library_agnostic_function(s_pa)
<pyarrow.DoubleScalar: 5.0>
5.0
"""
return self._compliant_series.median()

Expand Down
Loading

0 comments on commit bfbc34d

Please sign in to comment.