From 81e60787b6cb3304b36f4df8792aa500387deb48 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Sun, 1 Dec 2024 10:53:07 +0100 Subject: [PATCH] chore: polars type hinting (#1467) * chore: polars typing * less Any * rm unused typevar --- narwhals/_polars/dataframe.py | 123 +++++++++++++++++++------------ narwhals/_polars/expr.py | 94 +++++++++++------------ narwhals/_polars/group_by.py | 23 ++++-- narwhals/_polars/namespace.py | 44 +++++------ narwhals/_polars/series.py | 135 ++++++++++++++++++---------------- narwhals/_polars/utils.py | 38 +++++++++- 6 files changed, 267 insertions(+), 190 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 231fd9a4e..1d4cc3cd1 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -4,6 +4,7 @@ from typing import Any from typing import Literal from typing import Sequence +from typing import overload from narwhals._polars.namespace import PolarsNamespace from narwhals._polars.utils import convert_str_slice_to_int_slice @@ -17,31 +18,37 @@ if TYPE_CHECKING: from types import ModuleType + from typing import TypeVar import numpy as np import polars as pl from typing_extensions import Self + from narwhals._polars.group_by import PolarsGroupBy + from narwhals._polars.group_by import PolarsLazyGroupBy + from narwhals._polars.series import PolarsSeries from narwhals.dtypes import DType from narwhals.typing import DTypes + T = TypeVar("T") + class PolarsDataFrame: def __init__( - self, df: pl.DataFrame, *, backend_version: tuple[int, ...], dtypes: DTypes + self: Self, df: pl.DataFrame, *, backend_version: tuple[int, ...], dtypes: DTypes ) -> None: self._native_frame = df self._backend_version = backend_version self._implementation = Implementation.POLARS self._dtypes = dtypes - def __repr__(self) -> str: # pragma: no cover + def __repr__(self: Self) -> str: # pragma: no cover return "PolarsDataFrame" - def __narwhals_dataframe__(self) -> Self: + def __narwhals_dataframe__(self: Self) -> Self: return self - def __narwhals_namespace__(self) -> PolarsNamespace: + def __narwhals_namespace__(self: Self) -> PolarsNamespace: return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __native_namespace__(self: Self) -> ModuleType: @@ -51,17 +58,28 @@ def __native_namespace__(self: Self) -> ModuleType: msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) - def _change_dtypes(self, dtypes: DTypes) -> Self: + def _change_dtypes(self: Self, dtypes: DTypes) -> Self: return self.__class__( self._native_frame, backend_version=self._backend_version, dtypes=dtypes ) - def _from_native_frame(self, df: Any) -> Self: + def _from_native_frame(self: Self, df: pl.DataFrame) -> Self: return self.__class__( df, backend_version=self._backend_version, dtypes=self._dtypes ) - def _from_native_object(self, obj: Any) -> Any: + @overload + def _from_native_object(self: Self, obj: pl.Series) -> PolarsSeries: ... + + @overload + def _from_native_object(self: Self, obj: pl.DataFrame) -> Self: ... + + @overload + def _from_native_object(self: Self, obj: T) -> T: ... + + def _from_native_object( + self: Self, obj: pl.Series | pl.DataFrame | T + ) -> Self | PolarsSeries | T: import polars as pl # ignore-banned-import() if isinstance(obj, pl.Series): @@ -75,16 +93,7 @@ def _from_native_object(self, obj: Any) -> Any: # scalar return obj - def __getattr__(self, attr: str) -> Any: - if attr == "collect": # pragma: no cover - raise AttributeError - if attr == "schema": - schema = self._native_frame.schema - return { - name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version) - for name, dtype in schema.items() - } - + def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: import polars as pl # ignore-banned-import() @@ -107,7 +116,9 @@ def func(*args: Any, **kwargs: Any) -> Any: return func - def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.ndarray: + def __array__( + self: Self, dtype: Any | None = None, copy: bool | None = None + ) -> np.ndarray: if self._backend_version < (0, 20, 28) and copy is not None: msg = "`copy` in `__array__` is only supported for Polars>=0.20.28" raise NotImplementedError(msg) @@ -115,7 +126,7 @@ def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.nd return self._native_frame.__array__(dtype) return self._native_frame.__array__(dtype) - def collect_schema(self) -> dict[str, DType]: + def collect_schema(self: Self) -> dict[str, DType]: if self._backend_version < (1,): return { name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version) @@ -128,10 +139,10 @@ def collect_schema(self) -> dict[str, DType]: } @property - def shape(self) -> tuple[int, int]: + def shape(self: Self) -> tuple[int, int]: return self._native_frame.shape - def __getitem__(self, item: Any) -> Any: + def __getitem__(self: Self, item: Any) -> Any: if self._backend_version > (0, 20, 30): return self._from_native_object(self._native_frame.__getitem__(item)) else: # pragma: no cover @@ -191,7 +202,7 @@ def __getitem__(self, item: Any) -> Any: ) return self._from_native_object(result) - def get_column(self, name: str) -> Any: + def get_column(self: Self, name: str) -> PolarsSeries: from narwhals._polars.series import PolarsSeries return PolarsSeries( @@ -200,21 +211,37 @@ def get_column(self, name: str) -> Any: dtypes=self._dtypes, ) - def is_empty(self) -> bool: + def is_empty(self: Self) -> bool: return len(self._native_frame) == 0 @property - def columns(self) -> list[str]: + def columns(self: Self) -> list[str]: return self._native_frame.columns - def lazy(self) -> PolarsLazyFrame: + @property + def schema(self: Self) -> dict[str, DType]: + schema = self._native_frame.schema + return { + name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version) + for name, dtype in schema.items() + } + + def lazy(self: Self) -> PolarsLazyFrame: return PolarsLazyFrame( self._native_frame.lazy(), backend_version=self._backend_version, dtypes=self._dtypes, ) - def to_dict(self, *, as_series: bool) -> Any: + @overload + def to_dict(self: Self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ... + + @overload + def to_dict(self: Self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + + def to_dict( + self: Self, *, as_series: bool + ) -> dict[str, PolarsSeries] | dict[str, list[Any]]: df = self._native_frame if as_series: @@ -229,12 +256,12 @@ def to_dict(self, *, as_series: bool) -> Any: else: return df.to_dict(as_series=False) - def group_by(self, *by: str, drop_null_keys: bool) -> Any: + def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsGroupBy: from narwhals._polars.group_by import PolarsGroupBy return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys) - def with_row_index(self, name: str) -> Any: + def with_row_index(self: Self, name: str) -> Self: if self._backend_version < (0, 20, 4): return self._from_native_frame(self._native_frame.with_row_count(name)) return self._from_native_frame(self._native_frame.with_row_index(name)) @@ -271,15 +298,15 @@ def pivot( self: Self, on: str | list[str], *, - index: str | list[str] | None = None, - values: str | list[str] | None = None, + index: str | list[str] | None, + values: str | list[str] | None, aggregate_function: Literal[ "min", "max", "first", "last", "sum", "mean", "median", "len" ] - | None = None, - maintain_order: bool = True, - sort_columns: bool = False, - separator: str = "_", + | None, + maintain_order: bool, + sort_columns: bool, + separator: str, ) -> Self: if self._backend_version < (1, 0, 0): # pragma: no cover msg = "`pivot` is only supported for Polars>=1.0.0" @@ -293,25 +320,25 @@ def pivot( sort_columns=sort_columns, separator=separator, ) - return self._from_native_object(result) # type: ignore[no-any-return] + return self._from_native_object(result) class PolarsLazyFrame: def __init__( - self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], dtypes: DTypes + self: Self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], dtypes: DTypes ) -> None: self._native_frame = df self._backend_version = backend_version self._implementation = Implementation.POLARS self._dtypes = dtypes - def __repr__(self) -> str: # pragma: no cover + def __repr__(self: Self) -> str: # pragma: no cover return "PolarsLazyFrame" - def __narwhals_lazyframe__(self) -> Self: + def __narwhals_lazyframe__(self: Self) -> Self: return self - def __narwhals_namespace__(self) -> PolarsNamespace: + def __narwhals_namespace__(self: Self) -> PolarsNamespace: return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __native_namespace__(self: Self) -> ModuleType: @@ -321,17 +348,17 @@ def __native_namespace__(self: Self) -> ModuleType: msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) - def _from_native_frame(self, df: Any) -> Self: + def _from_native_frame(self: Self, df: pl.LazyFrame) -> Self: return self.__class__( df, backend_version=self._backend_version, dtypes=self._dtypes ) - def _change_dtypes(self, dtypes: DTypes) -> Self: + def _change_dtypes(self: Self, dtypes: DTypes) -> Self: return self.__class__( self._native_frame, backend_version=self._backend_version, dtypes=dtypes ) - def __getattr__(self, attr: str) -> Any: + def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: import polars as pl # ignore-banned-import @@ -354,18 +381,18 @@ def func(*args: Any, **kwargs: Any) -> Any: return func @property - def columns(self) -> list[str]: + def columns(self: Self) -> list[str]: return self._native_frame.columns @property - def schema(self) -> dict[str, Any]: + def schema(self: Self) -> dict[str, DType]: schema = self._native_frame.schema return { name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version) for name, dtype in schema.items() } - def collect_schema(self) -> dict[str, DType]: + def collect_schema(self: Self) -> dict[str, DType]: if self._backend_version < (1,): return { name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version) @@ -377,7 +404,7 @@ def collect_schema(self) -> dict[str, DType]: for name, dtype in self._native_frame.collect_schema().items() } - def collect(self) -> PolarsDataFrame: + def collect(self: Self) -> PolarsDataFrame: import polars as pl # ignore-banned-import try: @@ -391,12 +418,12 @@ def collect(self) -> PolarsDataFrame: dtypes=self._dtypes, ) - def group_by(self, *by: str, drop_null_keys: bool) -> Any: + def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsLazyGroupBy: from narwhals._polars.group_by import PolarsLazyGroupBy return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys) - def with_row_index(self, name: str) -> Any: + def with_row_index(self: Self, name: str) -> Self: if self._backend_version < (0, 20, 4): return self._from_native_frame(self._native_frame.with_row_count(name)) return self._from_native_frame(self._native_frame.with_row_index(name)) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 7c30d87b5..d992e33ba 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -20,22 +20,22 @@ class PolarsExpr: def __init__( - self, expr: pl.Expr, dtypes: DTypes, backend_version: tuple[int, ...] + self: Self, expr: pl.Expr, dtypes: DTypes, backend_version: tuple[int, ...] ) -> None: self._native_expr = expr self._implementation = Implementation.POLARS self._dtypes = dtypes self._backend_version = backend_version - def __repr__(self) -> str: # pragma: no cover + def __repr__(self: Self) -> str: # pragma: no cover return "PolarsExpr" - def _from_native_expr(self, expr: pl.Expr) -> Self: + def _from_native_expr(self: Self, expr: pl.Expr) -> Self: return self.__class__( expr, dtypes=self._dtypes, backend_version=self._backend_version ) - def __getattr__(self, attr: str) -> Any: + def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._from_native_expr( @@ -44,7 +44,7 @@ def func(*args: Any, **kwargs: Any) -> Any: return func - def cast(self, dtype: DType) -> Self: + def cast(self: Self, dtype: DType) -> Self: expr = self._native_expr dtype_pl = narwhals_to_native_dtype(dtype, self._dtypes) return self._from_native_expr(expr.cast(dtype_pl)) @@ -52,13 +52,13 @@ def cast(self, dtype: DType) -> Self: def ewm_mean( self: Self, *, - com: float | None = None, - span: float | None = None, - half_life: float | None = None, - alpha: float | None = None, - adjust: bool = True, - min_periods: int = 1, - ignore_nulls: bool = False, + com: float | None, + span: float | None, + half_life: float | None, + alpha: float | None, + adjust: bool, + min_periods: int, + ignore_nulls: bool, ) -> Self: if self._backend_version < (1,): # pragma: no cover msg = "`ewm_mean` not implemented for polars older than 1.0" @@ -78,8 +78,8 @@ def ewm_mean( def map_batches( self, - function: Callable[[Any], Self], - return_dtype: DType | None = None, + function: Callable[..., Self], + return_dtype: DType | None, ) -> Self: if return_dtype is not None: return_dtype_pl = narwhals_to_native_dtype(return_dtype, self._dtypes) @@ -90,7 +90,7 @@ def map_batches( return self._from_native_expr(self._native_expr.map_batches(function)) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None + self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: expr = self._native_expr return_dtype_pl = ( @@ -103,55 +103,55 @@ def replace_strict( expr.replace_strict(old, new, return_dtype=return_dtype_pl) ) - def __eq__(self, other: object) -> Self: # type: ignore[override] - return self._from_native_expr(self._native_expr.__eq__(extract_native(other))) + def __eq__(self: Self, other: object) -> Self: # type: ignore[override] + return self._from_native_expr(self._native_expr.__eq__(extract_native(other))) # type: ignore[operator] - def __ne__(self, other: object) -> Self: # type: ignore[override] - return self._from_native_expr(self._native_expr.__ne__(extract_native(other))) + def __ne__(self: Self, other: object) -> Self: # type: ignore[override] + return self._from_native_expr(self._native_expr.__ne__(extract_native(other))) # type: ignore[operator] - def __ge__(self, other: Any) -> Self: + def __ge__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__ge__(extract_native(other))) - def __gt__(self, other: Any) -> Self: + def __gt__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__gt__(extract_native(other))) - def __le__(self, other: Any) -> Self: + def __le__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__le__(extract_native(other))) - def __lt__(self, other: Any) -> Self: + def __lt__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__lt__(extract_native(other))) - def __and__(self, other: PolarsExpr | bool | Any) -> Self: - return self._from_native_expr(self._native_expr.__and__(extract_native(other))) + def __and__(self: Self, other: PolarsExpr | bool | Any) -> Self: + return self._from_native_expr(self._native_expr.__and__(extract_native(other))) # type: ignore[operator] - def __or__(self, other: PolarsExpr | bool | Any) -> Self: - return self._from_native_expr(self._native_expr.__or__(extract_native(other))) + def __or__(self: Self, other: PolarsExpr | bool | Any) -> Self: + return self._from_native_expr(self._native_expr.__or__(extract_native(other))) # type: ignore[operator] - def __add__(self, other: Any) -> Self: + def __add__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__add__(extract_native(other))) - def __radd__(self, other: Any) -> Self: + def __radd__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__radd__(extract_native(other))) - def __sub__(self, other: Any) -> Self: + def __sub__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__sub__(extract_native(other))) - def __rsub__(self, other: Any) -> Self: + def __rsub__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__rsub__(extract_native(other))) - def __mul__(self, other: Any) -> Self: + def __mul__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__mul__(extract_native(other))) - def __rmul__(self, other: Any) -> Self: + def __rmul__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__rmul__(extract_native(other))) - def __pow__(self, other: Any) -> Self: + def __pow__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__pow__(extract_native(other))) - def __rpow__(self, other: Any) -> Self: + def __rpow__(self: Self, other: Any) -> Self: return self._from_native_expr(self._native_expr.__rpow__(extract_native(other))) - def __invert__(self) -> Self: + def __invert__(self: Self) -> Self: return self._from_native_expr(self._native_expr.__invert__()) def cum_count(self: Self, *, reverse: bool) -> Self: @@ -164,15 +164,15 @@ def cum_count(self: Self, *, reverse: bool) -> Self: return self._from_native_expr(result) @property - def dt(self) -> PolarsExprDateTimeNamespace: + def dt(self: Self) -> PolarsExprDateTimeNamespace: return PolarsExprDateTimeNamespace(self) @property - def str(self) -> PolarsExprStringNamespace: + def str(self: Self) -> PolarsExprStringNamespace: return PolarsExprStringNamespace(self) @property - def cat(self) -> PolarsExprCatNamespace: + def cat(self: Self) -> PolarsExprCatNamespace: return PolarsExprCatNamespace(self) @property @@ -181,10 +181,10 @@ def name(self: Self) -> PolarsExprNameNamespace: class PolarsExprDateTimeNamespace: - def __init__(self, expr: PolarsExpr) -> None: + def __init__(self: Self, expr: PolarsExpr) -> None: self._expr = expr - def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._expr._from_native_expr( @@ -195,10 +195,10 @@ def func(*args: Any, **kwargs: Any) -> PolarsExpr: class PolarsExprStringNamespace: - def __init__(self, expr: PolarsExpr) -> None: + def __init__(self: Self, expr: PolarsExpr) -> None: self._expr = expr - def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._expr._from_native_expr( @@ -209,10 +209,10 @@ def func(*args: Any, **kwargs: Any) -> PolarsExpr: class PolarsExprCatNamespace: - def __init__(self, expr: PolarsExpr) -> None: + def __init__(self: Self, expr: PolarsExpr) -> None: self._expr = expr - def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._expr._from_native_expr( @@ -223,10 +223,10 @@ def func(*args: Any, **kwargs: Any) -> PolarsExpr: class PolarsExprNameNamespace: - def __init__(self, expr: PolarsExpr) -> None: + def __init__(self: Self, expr: PolarsExpr) -> None: self._expr = expr - def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._expr._from_native_expr( diff --git a/narwhals/_polars/group_by.py b/narwhals/_polars/group_by.py index aa69db37f..5bb1b58fc 100644 --- a/narwhals/_polars/group_by.py +++ b/narwhals/_polars/group_by.py @@ -1,17 +1,22 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import Any +from typing import Iterator from narwhals._polars.utils import extract_args_kwargs if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame + from narwhals._polars.expr import PolarsExpr class PolarsGroupBy: - def __init__(self, df: Any, keys: list[str], *, drop_null_keys: bool) -> None: + def __init__( + self: Self, df: PolarsDataFrame, keys: list[str], *, drop_null_keys: bool + ) -> None: self._compliant_frame = df self.keys = keys if drop_null_keys: @@ -19,19 +24,21 @@ def __init__(self, df: Any, keys: list[str], *, drop_null_keys: bool) -> None: else: self._grouped = df._native_frame.group_by(keys) - def agg(self, *aggs: Any, **named_aggs: Any) -> PolarsDataFrame: + def agg(self: Self, *aggs: PolarsExpr, **named_aggs: PolarsExpr) -> PolarsDataFrame: aggs, named_aggs = extract_args_kwargs(aggs, named_aggs) # type: ignore[assignment] - return self._compliant_frame._from_native_frame( # type: ignore[no-any-return] + return self._compliant_frame._from_native_frame( self._grouped.agg(*aggs, **named_aggs), ) - def __iter__(self) -> Any: + def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]: for key, df in self._grouped: yield tuple(key), self._compliant_frame._from_native_frame(df) class PolarsLazyGroupBy: - def __init__(self, df: Any, keys: list[str], *, drop_null_keys: bool) -> None: + def __init__( + self: Self, df: PolarsLazyFrame, keys: list[str], *, drop_null_keys: bool + ) -> None: self._compliant_frame = df self.keys = keys if drop_null_keys: @@ -39,8 +46,8 @@ def __init__(self, df: Any, keys: list[str], *, drop_null_keys: bool) -> None: else: self._grouped = df._native_frame.group_by(keys) - def agg(self, *aggs: Any, **named_aggs: Any) -> PolarsLazyFrame: + def agg(self: Self, *aggs: PolarsExpr, **named_aggs: PolarsExpr) -> PolarsLazyFrame: aggs, named_aggs = extract_args_kwargs(aggs, named_aggs) # type: ignore[assignment] - return self._compliant_frame._from_native_frame( # type: ignore[no-any-return] + return self._compliant_frame._from_native_frame( self._grouped.agg(*aggs, **named_aggs), ) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 19d2693d5..d15a4553e 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -13,6 +13,8 @@ from narwhals.utils import Implementation if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.expr import PolarsExpr @@ -22,12 +24,12 @@ class PolarsNamespace: - def __init__(self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None: + def __init__(self: Self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None: self._backend_version = backend_version self._implementation = Implementation.POLARS self._dtypes = dtypes - def __getattr__(self, attr: str) -> Any: + def __getattr__(self: Self, attr: str) -> Any: import polars as pl # ignore-banned-import from narwhals._polars.expr import PolarsExpr @@ -42,7 +44,7 @@ def func(*args: Any, **kwargs: Any) -> Any: return func - def nth(self, *indices: int) -> PolarsExpr: + def nth(self: Self, *indices: int) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -54,7 +56,7 @@ def nth(self, *indices: int) -> PolarsExpr: pl.nth(*indices), dtypes=self._dtypes, backend_version=self._backend_version ) - def len(self) -> PolarsExpr: + def len(self: Self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -71,7 +73,7 @@ def len(self) -> PolarsExpr: @overload def concat( - self, + self: Self, items: Sequence[PolarsDataFrame], *, how: Literal["vertical", "horizontal", "diagonal"], @@ -79,14 +81,14 @@ def concat( @overload def concat( - self, + self: Self, items: Sequence[PolarsLazyFrame], *, how: Literal["vertical", "horizontal", "diagonal"], ) -> PolarsLazyFrame: ... def concat( - self, + self: Self, items: Sequence[PolarsDataFrame] | Sequence[PolarsLazyFrame], *, how: Literal["vertical", "horizontal", "diagonal"], @@ -106,7 +108,7 @@ def concat( result, backend_version=items[0]._backend_version, dtypes=items[0]._dtypes ) - def lit(self, value: Any, dtype: DType | None = None) -> PolarsExpr: + def lit(self: Self, value: Any, dtype: DType | None = None) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -121,7 +123,7 @@ def lit(self, value: Any, dtype: DType | None = None) -> PolarsExpr: pl.lit(value), dtypes=self._dtypes, backend_version=self._backend_version ) - def mean(self, *column_names: str) -> PolarsExpr: + def mean(self: Self, *column_names: str) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -138,7 +140,7 @@ def mean(self, *column_names: str) -> PolarsExpr: backend_version=self._backend_version, ) - def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr: + def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -159,7 +161,7 @@ def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr: backend_version=self._backend_version, ) - def median(self, *column_names: str) -> PolarsExpr: + def median(self: Self, *column_names: str) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -174,8 +176,8 @@ def concat_str( self, exprs: Iterable[IntoPolarsExpr], *more_exprs: IntoPolarsExpr, - separator: str = "", - ignore_nulls: bool = False, + separator: str, + ignore_nulls: bool, ) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -230,16 +232,16 @@ def concat_str( ) @property - def selectors(self) -> PolarsSelectors: + def selectors(self: Self) -> PolarsSelectors: return PolarsSelectors(self._dtypes, backend_version=self._backend_version) class PolarsSelectors: - def __init__(self, dtypes: DTypes, backend_version: tuple[int, ...]) -> None: + def __init__(self: Self, dtypes: DTypes, backend_version: tuple[int, ...]) -> None: self._dtypes = dtypes self._backend_version = backend_version - def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr: + def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -252,7 +254,7 @@ def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr: backend_version=self._backend_version, ) - def numeric(self) -> PolarsExpr: + def numeric(self: Self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -263,7 +265,7 @@ def numeric(self) -> PolarsExpr: backend_version=self._backend_version, ) - def boolean(self) -> PolarsExpr: + def boolean(self: Self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -274,7 +276,7 @@ def boolean(self) -> PolarsExpr: backend_version=self._backend_version, ) - def string(self) -> PolarsExpr: + def string(self: Self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -285,7 +287,7 @@ def string(self) -> PolarsExpr: backend_version=self._backend_version, ) - def categorical(self) -> PolarsExpr: + def categorical(self: Self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr @@ -296,7 +298,7 @@ def categorical(self) -> PolarsExpr: backend_version=self._backend_version, ) - def all(self) -> PolarsExpr: + def all(self: Self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 492913227..a91d5e9cd 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -7,10 +7,13 @@ from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import extract_native +from narwhals._polars.utils import narwhals_to_native_dtype +from narwhals._polars.utils import native_to_narwhals_dtype from narwhals.utils import Implementation if TYPE_CHECKING: from types import ModuleType + from typing import TypeVar import numpy as np import polars as pl @@ -20,23 +23,22 @@ from narwhals.dtypes import DType from narwhals.typing import DTypes -from narwhals._polars.utils import narwhals_to_native_dtype -from narwhals._polars.utils import native_to_narwhals_dtype + T = TypeVar("T") class PolarsSeries: def __init__( - self, series: Any, *, backend_version: tuple[int, ...], dtypes: DTypes + self: Self, series: pl.Series, *, backend_version: tuple[int, ...], dtypes: DTypes ) -> None: self._native_series: pl.Series = series self._backend_version = backend_version self._implementation = Implementation.POLARS self._dtypes = dtypes - def __repr__(self) -> str: # pragma: no cover + def __repr__(self: Self) -> str: # pragma: no cover return "PolarsSeries" - def __narwhals_series__(self) -> Self: + def __narwhals_series__(self: Self) -> Self: return self def __native_namespace__(self: Self) -> ModuleType: @@ -46,17 +48,28 @@ def __native_namespace__(self: Self) -> ModuleType: msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) - def _change_dtypes(self, dtypes: DTypes) -> Self: + def _change_dtypes(self: Self, dtypes: DTypes) -> Self: return self.__class__( self._native_series, backend_version=self._backend_version, dtypes=dtypes ) - def _from_native_series(self, series: Any) -> Self: + def _from_native_series(self: Self, series: pl.Series) -> Self: return self.__class__( series, backend_version=self._backend_version, dtypes=self._dtypes ) - def _from_native_object(self, series: Any) -> Any: + @overload + def _from_native_object(self: Self, series: pl.Series) -> Self: ... + + @overload + def _from_native_object(self: Self, series: pl.DataFrame) -> PolarsDataFrame: ... + + @overload + def _from_native_object(self: Self, series: T) -> T: ... + + def _from_native_object( + self: Self, series: pl.Series | pl.DataFrame | T + ) -> Self | PolarsDataFrame | T: import polars as pl # ignore-banned-import() if isinstance(series, pl.Series): @@ -70,7 +83,7 @@ def _from_native_object(self, series: Any) -> Any: # scalar return series - def __getattr__(self, attr: str) -> Any: + def __getattr__(self: Self, attr: str) -> Any: if attr == "as_py": # pragma: no cover raise AttributeError @@ -82,15 +95,15 @@ def func(*args: Any, **kwargs: Any) -> Any: return func - def __len__(self) -> int: + def __len__(self: Self) -> int: return len(self._native_series) @property - def shape(self) -> tuple[int]: + def shape(self: Self) -> tuple[int]: return (len(self),) @property - def name(self) -> str: + def name(self: Self) -> str: return self._native_series.name @property @@ -100,21 +113,21 @@ def dtype(self: Self) -> DType: ) @overload - def __getitem__(self, item: int) -> Any: ... + def __getitem__(self: Self, item: int) -> Any: ... @overload - def __getitem__(self, item: slice | Sequence[int]) -> Self: ... + def __getitem__(self: Self, item: slice | Sequence[int]) -> Self: ... - def __getitem__(self, item: int | slice | Sequence[int]) -> Any | Self: + def __getitem__(self: Self, item: int | slice | Sequence[int]) -> Any | Self: return self._from_native_object(self._native_series.__getitem__(item)) - def cast(self, dtype: DType) -> Self: + def cast(self: Self, dtype: DType) -> Self: ser = self._native_series dtype_pl = narwhals_to_native_dtype(dtype, self._dtypes) return self._from_native_series(ser.cast(dtype_pl)) def replace_strict( - self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None + self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: ser = self._native_series dtype = ( @@ -125,83 +138,83 @@ def replace_strict( raise NotImplementedError(msg) return self._from_native_series(ser.replace_strict(old, new, return_dtype=dtype)) - def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: + def __array__(self: Self, dtype: Any, copy: bool | None) -> np.ndarray: if self._backend_version < (0, 20, 29): return self._native_series.__array__(dtype=dtype) return self._native_series.__array__(dtype=dtype, copy=copy) - def __eq__(self, other: object) -> Self: # type: ignore[override] + def __eq__(self: Self, other: object) -> Self: # type: ignore[override] return self._from_native_series(self._native_series.__eq__(extract_native(other))) - def __ne__(self, other: object) -> Self: # type: ignore[override] + def __ne__(self: Self, other: object) -> Self: # type: ignore[override] return self._from_native_series(self._native_series.__ne__(extract_native(other))) - def __ge__(self, other: Any) -> Self: + def __ge__(self: Self, other: Any) -> Self: return self._from_native_series(self._native_series.__ge__(extract_native(other))) - def __gt__(self, other: Any) -> Self: + def __gt__(self: Self, other: Any) -> Self: return self._from_native_series(self._native_series.__gt__(extract_native(other))) - def __le__(self, other: Any) -> Self: + def __le__(self: Self, other: Any) -> Self: return self._from_native_series(self._native_series.__le__(extract_native(other))) - def __lt__(self, other: Any) -> Self: + def __lt__(self: Self, other: Any) -> Self: return self._from_native_series(self._native_series.__lt__(extract_native(other))) - def __and__(self, other: PolarsSeries | bool | Any) -> Self: + def __and__(self: Self, other: PolarsSeries | bool | Any) -> Self: return self._from_native_series( self._native_series.__and__(extract_native(other)) ) - def __or__(self, other: PolarsSeries | bool | Any) -> Self: + def __or__(self: Self, other: PolarsSeries | bool | Any) -> Self: return self._from_native_series(self._native_series.__or__(extract_native(other))) - def __add__(self, other: PolarsSeries | Any) -> Self: + def __add__(self: Self, other: PolarsSeries | Any) -> Self: return self._from_native_series( self._native_series.__add__(extract_native(other)) ) - def __radd__(self, other: PolarsSeries | Any) -> Self: + def __radd__(self: Self, other: PolarsSeries | Any) -> Self: return self._from_native_series( self._native_series.__radd__(extract_native(other)) ) - def __sub__(self, other: PolarsSeries | Any) -> Self: + def __sub__(self: Self, other: PolarsSeries | Any) -> Self: return self._from_native_series( self._native_series.__sub__(extract_native(other)) ) - def __rsub__(self, other: PolarsSeries | Any) -> Self: + def __rsub__(self: Self, other: PolarsSeries | Any) -> Self: return self._from_native_series( self._native_series.__rsub__(extract_native(other)) ) - def __mul__(self, other: PolarsSeries | Any) -> Self: + def __mul__(self: Self, other: PolarsSeries | Any) -> Self: return self._from_native_series( self._native_series.__mul__(extract_native(other)) ) - def __rmul__(self, other: PolarsSeries | Any) -> Self: + def __rmul__(self: Self, other: PolarsSeries | Any) -> Self: return self._from_native_series( self._native_series.__rmul__(extract_native(other)) ) - def __pow__(self, other: PolarsSeries | Any) -> Self: + def __pow__(self: Self, other: PolarsSeries | Any) -> Self: return self._from_native_series( self._native_series.__pow__(extract_native(other)) ) - def __rpow__(self, other: PolarsSeries | Any) -> Self: + def __rpow__(self: Self, other: PolarsSeries | Any) -> Self: result = self._native_series.__rpow__(extract_native(other)) - if self._backend_version < (16, 1): + if self._backend_version < (1, 16, 1): # Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071 result = result.alias(self.name) return self._from_native_series(result) - def __invert__(self) -> Self: + def __invert__(self: Self) -> Self: return self._from_native_series(self._native_series.__invert__()) - def median(self) -> Any: + def median(self: Self) -> Any: from narwhals.exceptions import InvalidOperationError if not self.dtype.is_numeric(): @@ -210,9 +223,7 @@ def median(self) -> Any: return self._native_series.median() - def to_dummies( - self: Self, *, separator: str = "_", drop_first: bool = False - ) -> PolarsDataFrame: + def to_dummies(self: Self, *, separator: str, drop_first: bool) -> PolarsDataFrame: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame @@ -237,13 +248,13 @@ def to_dummies( def ewm_mean( self: Self, *, - com: float | None = None, - span: float | None = None, - half_life: float | None = None, - alpha: float | None = None, - adjust: bool = True, - min_periods: int = 1, - ignore_nulls: bool = False, + com: float | None, + span: float | None, + half_life: float | None, + alpha: float | None, + adjust: bool, + min_periods: int, + ignore_nulls: bool, ) -> Self: if self._backend_version < (1,): # pragma: no cover msg = "`ewm_mean` not implemented for polars older than 1.0" @@ -261,7 +272,7 @@ def ewm_mean( ) ) - def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: + def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: if self._backend_version < (0, 20, 6): result = self._native_series.sort(descending=descending) @@ -277,7 +288,7 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: return self._from_native_series(result) - def scatter(self, indices: int | Sequence[int], values: Any) -> Self: + def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: values = extract_native(values) s = self._native_series.clone() s.scatter(indices, values) @@ -286,10 +297,10 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self: def value_counts( self: Self, *, - sort: bool = False, - parallel: bool = False, - name: str | None = None, - normalize: bool = False, + sort: bool, + parallel: bool, + name: str | None, + normalize: bool, ) -> PolarsDataFrame: from narwhals._polars.dataframe import PolarsDataFrame @@ -327,23 +338,23 @@ def cum_count(self: Self, *, reverse: bool) -> Self: return self._from_native_series(result) @property - def dt(self) -> PolarsSeriesDateTimeNamespace: + def dt(self: Self) -> PolarsSeriesDateTimeNamespace: return PolarsSeriesDateTimeNamespace(self) @property - def str(self) -> PolarsSeriesStringNamespace: + def str(self: Self) -> PolarsSeriesStringNamespace: return PolarsSeriesStringNamespace(self) @property - def cat(self) -> PolarsSeriesCatNamespace: + def cat(self: Self) -> PolarsSeriesCatNamespace: return PolarsSeriesCatNamespace(self) class PolarsSeriesDateTimeNamespace: - def __init__(self, series: PolarsSeries) -> None: + def __init__(self: Self, series: PolarsSeries) -> None: self._series = series - def __getattr__(self, attr: str) -> Any: + def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._series._from_native_series( @@ -354,10 +365,10 @@ def func(*args: Any, **kwargs: Any) -> Any: class PolarsSeriesStringNamespace: - def __init__(self, series: PolarsSeries) -> None: + def __init__(self: Self, series: PolarsSeries) -> None: self._series = series - def __getattr__(self, attr: str) -> Any: + def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._series._from_native_series( @@ -368,10 +379,10 @@ def func(*args: Any, **kwargs: Any) -> Any: class PolarsSeriesCatNamespace: - def __init__(self, series: PolarsSeries) -> None: + def __init__(self: Self, series: PolarsSeries) -> None: self._series = series - def __getattr__(self, attr: str) -> Any: + def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._series._from_native_series( diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index f562c5841..704b3a4d4 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -3,15 +3,45 @@ from typing import TYPE_CHECKING from typing import Any from typing import Literal +from typing import TypeVar +from typing import overload if TYPE_CHECKING: import polars as pl + from narwhals._polars.dataframe import PolarsDataFrame + from narwhals._polars.dataframe import PolarsLazyFrame + from narwhals._polars.expr import PolarsExpr + from narwhals._polars.series import PolarsSeries from narwhals.dtypes import DType from narwhals.typing import DTypes + T = TypeVar("T") + + +@overload +def extract_native(obj: PolarsDataFrame) -> pl.DataFrame: ... + + +@overload +def extract_native(obj: PolarsLazyFrame) -> pl.LazyFrame: ... + + +@overload +def extract_native(obj: PolarsSeries) -> pl.Series: ... + + +@overload +def extract_native(obj: PolarsExpr) -> pl.Expr: ... + + +@overload +def extract_native(obj: T) -> T: ... + -def extract_native(obj: Any) -> Any: +def extract_native( + obj: PolarsDataFrame | PolarsLazyFrame | PolarsSeries | PolarsExpr | T, +) -> pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr | T: from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.expr import PolarsExpr @@ -27,9 +57,9 @@ def extract_native(obj: Any) -> Any: def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, Any]]: - args = [extract_native(arg) for arg in args] - kwargs = {k: extract_native(v) for k, v in kwargs.items()} - return args, kwargs + return [extract_native(arg) for arg in args], { + k: extract_native(v) for k, v in kwargs.items() + } def native_to_narwhals_dtype(