Skip to content

Commit

Permalink
feat: make Series generic (#1412)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
EdAbati and MarcoGorelli authored Nov 29, 2024
1 parent 5b24161 commit 1a386a3
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 41 deletions.
14 changes: 14 additions & 0 deletions docs/backcompat.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ before making any change.

### After `stable.v1`

- Since Narwhals 1.15, `Series` is generic in the native Series, meaning that you can
write:
```python
import narwhals as nw
import polars as pl

s_pl = pl.Series([1, 2, 3])
s = nw.from_native(s, series_only=True)
# mypy infers `s.to_native()` to be `polars.Series`
reveal_type(s.to_native())
```
Previously, `Series` was not generic, so in the above example
`s.to_native()` would have been inferred as `Any`.

- Since Narwhals 1.13.0, the `strict` parameter in `from_native`, `to_native`, and `narwhalify`
has been deprecated in favour of `pass_through`. This is because several users expressed
confusion/surprise over what `strict=False` did.
Expand Down
28 changes: 15 additions & 13 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class DataFrame(BaseFrame[DataFrameT]):
"""

@property
def _series(self) -> type[Series]:
def _series(self) -> type[Series[Any]]:
from narwhals.series import Series

return Series
Expand Down Expand Up @@ -667,7 +667,7 @@ def shape(self) -> tuple[int, int]:
"""
return self._compliant_frame.shape # type: ignore[no-any-return]

def get_column(self, name: str) -> Series:
def get_column(self, name: str) -> Series[Any]:
"""Get a single column by name.
Notes:
Expand Down Expand Up @@ -721,23 +721,23 @@ def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[slice, str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[slice, int]) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: str) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[str]) -> Self: ...
Expand All @@ -760,7 +760,7 @@ def __getitem__(
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice]
| tuple[slice, slice]
),
) -> Series | Self:
) -> Series[Any] | Self:
"""Extract column or slice of DataFrame.
Arguments:
Expand Down Expand Up @@ -881,14 +881,16 @@ def __contains__(self, key: str) -> bool:
return key in self.columns

@overload
def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series]: ...
def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series[Any]]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
@overload
def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, Series[Any]] | dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool = True
) -> dict[str, Series] | dict[str, list[Any]]:
) -> dict[str, Series[Any]] | dict[str, list[Any]]:
"""Convert DataFrame to a dictionary mapping column name to values.
Arguments:
Expand Down Expand Up @@ -2317,7 +2319,7 @@ def join_asof(
)

# --- descriptive ---
def is_duplicated(self: Self) -> Series:
def is_duplicated(self: Self) -> Series[Any]:
r"""Get a mask of all duplicated rows in this DataFrame.
Returns:
Expand Down Expand Up @@ -2399,7 +2401,7 @@ def is_empty(self: Self) -> bool:
"""
return self._compliant_frame.is_empty() # type: ignore[no-any-return]

def is_unique(self: Self) -> Series:
def is_unique(self: Self) -> Series[Any]:
r"""Get a mask of all unique rows in this DataFrame.
Returns:
Expand Down
7 changes: 4 additions & 3 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from narwhals.schema import Schema
from narwhals.series import Series
from narwhals.typing import DTypes
from narwhals.typing import IntoSeriesT

class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(
Expand Down Expand Up @@ -200,7 +201,7 @@ def new_series(
dtype: DType | type[DType] | None = None,
*,
native_namespace: ModuleType,
) -> Series:
) -> Series[Any]:
"""Instantiate Narwhals Series from iterable (e.g. list or array).
Arguments:
Expand Down Expand Up @@ -266,7 +267,7 @@ def _new_series_impl(
*,
native_namespace: ModuleType,
dtypes: DTypes,
) -> Series:
) -> Series[Any]:
implementation = Implementation.from_native_namespace(native_namespace)

if implementation is Implementation.POLARS:
Expand Down Expand Up @@ -904,7 +905,7 @@ def show_versions() -> None:


def get_level(
obj: DataFrame[Any] | LazyFrame[Any] | Series,
obj: DataFrame[Any] | LazyFrame[Any] | Series[IntoSeriesT],
) -> Literal["full", "interchange"]:
"""Level of support Narwhals has for current object.
Expand Down
9 changes: 5 additions & 4 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import overload

from narwhals.dtypes import _validate_dtype
from narwhals.typing import IntoSeriesT
from narwhals.utils import _validate_rolling_arguments
from narwhals.utils import parse_version

Expand All @@ -27,7 +28,7 @@
from narwhals.dtypes import DType


class Series:
class Series(Generic[IntoSeriesT]):
"""Narwhals Series, backed by a native series.
The native series might be pandas.Series, polars.Series, ...
Expand Down Expand Up @@ -98,7 +99,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
ca = pa.chunked_array([self.to_arrow()])
return ca.__arrow_c_stream__(requested_schema=requested_schema)

def to_native(self) -> Any:
def to_native(self) -> IntoSeriesT:
"""Convert Narwhals series to native series.
Returns:
Expand Down Expand Up @@ -135,7 +136,7 @@ def to_native(self) -> Any:
3
]
"""
return self._compliant_series._native_series
return self._compliant_series._native_series # type: ignore[no-any-return]

def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
"""Set value(s) at given position(s).
Expand Down Expand Up @@ -3326,7 +3327,7 @@ def cat(self: Self) -> SeriesCatNamespace[Self]:
return SeriesCatNamespace(self)


SeriesT = TypeVar("SeriesT", bound=Series)
SeriesT = TypeVar("SeriesT", bound=Series[Any])


class SeriesCatNamespace(Generic[SeriesT]):
Expand Down
6 changes: 3 additions & 3 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _l1_norm(self: Self) -> Self:
return self.select(all()._l1_norm())


class Series(NwSeries):
class Series(NwSeries[Any]):
"""Narwhals Series, backed by a native series.
The native series might be pandas.Series, polars.Series, ...
Expand Down Expand Up @@ -1153,15 +1153,15 @@ def _stableify(obj: NwDataFrame[IntoFrameT]) -> DataFrame[IntoFrameT]: ...
@overload
def _stableify(obj: NwLazyFrame[IntoFrameT]) -> LazyFrame[IntoFrameT]: ...
@overload
def _stableify(obj: NwSeries) -> Series: ...
def _stableify(obj: NwSeries[Any]) -> Series: ...
@overload
def _stableify(obj: NwExpr) -> Expr: ...
@overload
def _stableify(obj: Any) -> Any: ...


def _stableify(
obj: NwDataFrame[IntoFrameT] | NwLazyFrame[IntoFrameT] | NwSeries | NwExpr | Any,
obj: NwDataFrame[IntoFrameT] | NwLazyFrame[IntoFrameT] | NwSeries[Any] | NwExpr | Any,
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series | Expr | Any:
from narwhals.stable.v1 import dtypes

Expand Down
20 changes: 11 additions & 9 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ def to_native(
narwhals_object: LazyFrame[IntoFrameT], *, pass_through: Literal[False] = ...
) -> IntoFrameT: ...
@overload
def to_native(narwhals_object: Series, *, pass_through: Literal[False] = ...) -> Any: ...
def to_native(
narwhals_object: Series[IntoSeriesT], *, pass_through: Literal[False] = ...
) -> IntoSeriesT: ...
@overload
def to_native(narwhals_object: Any, *, pass_through: bool) -> Any: ...


def to_native(
narwhals_object: DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series,
narwhals_object: DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series[IntoSeriesT],
*,
strict: bool | None = None,
pass_through: bool | None = None,
Expand Down Expand Up @@ -137,7 +139,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[IntoDataFrameT] | Series: ...
) -> DataFrame[IntoDataFrameT] | Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -197,7 +199,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series: ...
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series[IntoSeriesT]: ...


@overload
Expand All @@ -209,7 +211,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[True],
allow_series: None = ...,
) -> Series: ...
) -> Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -269,7 +271,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[Any] | LazyFrame[Any] | Series: ...
) -> DataFrame[Any] | LazyFrame[Any] | Series[Any]: ...


@overload
Expand All @@ -281,7 +283,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[True],
allow_series: None = ...,
) -> Series: ...
) -> Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -318,7 +320,7 @@ def from_native(
eager_or_interchange_only: bool = False,
series_only: bool = False,
allow_series: bool | None = None,
) -> LazyFrame[IntoFrameT] | DataFrame[IntoFrameT] | Series | T:
) -> LazyFrame[IntoFrameT] | DataFrame[IntoFrameT] | Series[IntoSeriesT] | T:
"""Convert `native_object` to Narwhals Dataframe, Lazyframe, or Series.
Arguments:
Expand Down Expand Up @@ -728,7 +730,7 @@ def _from_native_impl( # noqa: PLR0915
return native_object


def get_native_namespace(obj: DataFrame[Any] | LazyFrame[Any] | Series) -> Any:
def get_native_namespace(obj: DataFrame[Any] | LazyFrame[Any] | Series[Any]) -> Any:
"""Get native namespace from object.
Arguments:
Expand Down
4 changes: 2 additions & 2 deletions narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DataFrameLike(Protocol):
def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: ...


IntoExpr: TypeAlias = Union["Expr", str, "Series"]
IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"]
"""Anything which can be converted to an expression.
Use this to mean "either a Narwhals expression, or something which can be converted
Expand Down Expand Up @@ -88,7 +88,7 @@ def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: ...
... return df.columns
"""

IntoSeries: TypeAlias = Union["Series", "NativeSeries"]
IntoSeries: TypeAlias = Union["Series[Any]", "NativeSeries"]
"""Anything which can be converted to a Narwhals Series.
Use this if your function can accept an object which can be converted to `nw.Series`
Expand Down
11 changes: 6 additions & 5 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.series import Series
from narwhals.typing import IntoSeriesT

FrameOrSeriesT = TypeVar(
"FrameOrSeriesT", bound=Union[LazyFrame[Any], DataFrame[Any], Series]
"FrameOrSeriesT", bound=Union[LazyFrame[Any], DataFrame[Any], Series[Any]]
)


Expand Down Expand Up @@ -178,7 +179,7 @@ def validate_laziness(items: Iterable[Any]) -> None:


def maybe_align_index(
lhs: FrameOrSeriesT, rhs: Series | DataFrame[Any] | LazyFrame[Any]
lhs: FrameOrSeriesT, rhs: Series[Any] | DataFrame[Any] | LazyFrame[Any]
) -> FrameOrSeriesT:
"""Align `lhs` to the Index of `rhs`, if they're both pandas-like.
Expand Down Expand Up @@ -274,7 +275,7 @@ def _validate_index(index: Any) -> None:
return lhs


def maybe_get_index(obj: DataFrame[Any] | LazyFrame[Any] | Series) -> Any | None:
def maybe_get_index(obj: DataFrame[Any] | LazyFrame[Any] | Series[Any]) -> Any | None:
"""Get the index of a DataFrame or a Series, if it's pandas-like.
Arguments:
Expand Down Expand Up @@ -314,7 +315,7 @@ def maybe_set_index(
obj: FrameOrSeriesT,
column_names: str | list[str] | None = None,
*,
index: Series | list[Series] | None = None,
index: Series[IntoSeriesT] | list[Series[IntoSeriesT]] | None = None,
) -> FrameOrSeriesT:
"""Set the index of a DataFrame or a Series, if it's pandas-like.
Expand Down Expand Up @@ -516,7 +517,7 @@ def maybe_convert_dtypes(
return obj_any # type: ignore[no-any-return]


def is_ordered_categorical(series: Series) -> bool:
def is_ordered_categorical(series: Series[Any]) -> bool:
"""Return whether indices of categories are semantically meaningful.
This is a convenience function to accessing what would otherwise be
Expand Down
5 changes: 3 additions & 2 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

if TYPE_CHECKING:
from narwhals.series import Series
from narwhals.typing import IntoSeriesT


def test_maybe_align_index_pandas() -> None:
Expand Down Expand Up @@ -111,7 +112,7 @@ def test_maybe_set_index_polars_column_names(
],
)
def test_maybe_set_index_pandas_direct_index(
narwhals_index: Series | list[Series] | None,
narwhals_index: Series[IntoSeriesT] | list[Series[IntoSeriesT]] | None,
pandas_index: pd.Series | list[pd.Series] | None,
native_df_or_series: pd.DataFrame | pd.Series,
) -> None:
Expand All @@ -136,7 +137,7 @@ def test_maybe_set_index_pandas_direct_index(
],
)
def test_maybe_set_index_polars_direct_index(
index: Series | list[Series] | None,
index: Series[IntoSeriesT] | list[Series[IntoSeriesT]] | None,
) -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, index=index)
Expand Down

0 comments on commit 1a386a3

Please sign in to comment.