Skip to content

Commit

Permalink
chore: polars type hinting (#1467)
Browse files Browse the repository at this point in the history
* chore: polars typing

* less Any

* rm unused typevar
  • Loading branch information
FBruzzesi authored Dec 1, 2024
1 parent 12be0ca commit 81e6078
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 190 deletions.
123 changes: 75 additions & 48 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Literal
from typing import Sequence
from typing import overload

from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import convert_str_slice_to_int_slice
Expand All @@ -17,31 +18,37 @@

if TYPE_CHECKING:
from types import ModuleType
from typing import TypeVar

import numpy as np
import polars as pl
from typing_extensions import Self

from narwhals._polars.group_by import PolarsGroupBy
from narwhals._polars.group_by import PolarsLazyGroupBy
from narwhals._polars.series import PolarsSeries
from narwhals.dtypes import DType
from narwhals.typing import DTypes

T = TypeVar("T")


class PolarsDataFrame:
def __init__(
self, df: pl.DataFrame, *, backend_version: tuple[int, ...], dtypes: DTypes
self: Self, df: pl.DataFrame, *, backend_version: tuple[int, ...], dtypes: DTypes
) -> None:
self._native_frame = df
self._backend_version = backend_version
self._implementation = Implementation.POLARS
self._dtypes = dtypes

def __repr__(self) -> str: # pragma: no cover
def __repr__(self: Self) -> str: # pragma: no cover
return "PolarsDataFrame"

def __narwhals_dataframe__(self) -> Self:
def __narwhals_dataframe__(self: Self) -> Self:
return self

def __narwhals_namespace__(self) -> PolarsNamespace:
def __narwhals_namespace__(self: Self) -> PolarsNamespace:
return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes)

def __native_namespace__(self: Self) -> ModuleType:
Expand All @@ -51,17 +58,28 @@ def __native_namespace__(self: Self) -> ModuleType:
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def _change_dtypes(self, dtypes: DTypes) -> Self:
def _change_dtypes(self: Self, dtypes: DTypes) -> Self:
return self.__class__(
self._native_frame, backend_version=self._backend_version, dtypes=dtypes
)

def _from_native_frame(self, df: Any) -> Self:
def _from_native_frame(self: Self, df: pl.DataFrame) -> Self:
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
)

def _from_native_object(self, obj: Any) -> Any:
@overload
def _from_native_object(self: Self, obj: pl.Series) -> PolarsSeries: ...

@overload
def _from_native_object(self: Self, obj: pl.DataFrame) -> Self: ...

@overload
def _from_native_object(self: Self, obj: T) -> T: ...

def _from_native_object(
self: Self, obj: pl.Series | pl.DataFrame | T
) -> Self | PolarsSeries | T:
import polars as pl # ignore-banned-import()

if isinstance(obj, pl.Series):
Expand All @@ -75,16 +93,7 @@ def _from_native_object(self, obj: Any) -> Any:
# scalar
return obj

def __getattr__(self, attr: str) -> Any:
if attr == "collect": # pragma: no cover
raise AttributeError
if attr == "schema":
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in schema.items()
}

def __getattr__(self: Self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
import polars as pl # ignore-banned-import()

Expand All @@ -107,15 +116,17 @@ def func(*args: Any, **kwargs: Any) -> Any:

return func

def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.ndarray:
def __array__(
self: Self, dtype: Any | None = None, copy: bool | None = None
) -> np.ndarray:
if self._backend_version < (0, 20, 28) and copy is not None:
msg = "`copy` in `__array__` is only supported for Polars>=0.20.28"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28):
return self._native_frame.__array__(dtype)
return self._native_frame.__array__(dtype)

def collect_schema(self) -> dict[str, DType]:
def collect_schema(self: Self) -> dict[str, DType]:
if self._backend_version < (1,):
return {
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
Expand All @@ -128,10 +139,10 @@ def collect_schema(self) -> dict[str, DType]:
}

@property
def shape(self) -> tuple[int, int]:
def shape(self: Self) -> tuple[int, int]:
return self._native_frame.shape

def __getitem__(self, item: Any) -> Any:
def __getitem__(self: Self, item: Any) -> Any:
if self._backend_version > (0, 20, 30):
return self._from_native_object(self._native_frame.__getitem__(item))
else: # pragma: no cover
Expand Down Expand Up @@ -191,7 +202,7 @@ def __getitem__(self, item: Any) -> Any:
)
return self._from_native_object(result)

def get_column(self, name: str) -> Any:
def get_column(self: Self, name: str) -> PolarsSeries:
from narwhals._polars.series import PolarsSeries

return PolarsSeries(
Expand All @@ -200,21 +211,37 @@ def get_column(self, name: str) -> Any:
dtypes=self._dtypes,
)

def is_empty(self) -> bool:
def is_empty(self: Self) -> bool:
return len(self._native_frame) == 0

@property
def columns(self) -> list[str]:
def columns(self: Self) -> list[str]:
return self._native_frame.columns

def lazy(self) -> PolarsLazyFrame:
@property
def schema(self: Self) -> dict[str, DType]:
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in schema.items()
}

def lazy(self: Self) -> PolarsLazyFrame:
return PolarsLazyFrame(
self._native_frame.lazy(),
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def to_dict(self, *, as_series: bool) -> Any:
@overload
def to_dict(self: Self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...

@overload
def to_dict(self: Self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...

def to_dict(
self: Self, *, as_series: bool
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
df = self._native_frame

if as_series:
Expand All @@ -229,12 +256,12 @@ def to_dict(self, *, as_series: bool) -> Any:
else:
return df.to_dict(as_series=False)

def group_by(self, *by: str, drop_null_keys: bool) -> Any:
def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsGroupBy:
from narwhals._polars.group_by import PolarsGroupBy

return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys)

def with_row_index(self, name: str) -> Any:
def with_row_index(self: Self, name: str) -> Self:
if self._backend_version < (0, 20, 4):
return self._from_native_frame(self._native_frame.with_row_count(name))
return self._from_native_frame(self._native_frame.with_row_index(name))
Expand Down Expand Up @@ -271,15 +298,15 @@ def pivot(
self: Self,
on: str | list[str],
*,
index: str | list[str] | None = None,
values: str | list[str] | None = None,
index: str | list[str] | None,
values: str | list[str] | None,
aggregate_function: Literal[
"min", "max", "first", "last", "sum", "mean", "median", "len"
]
| None = None,
maintain_order: bool = True,
sort_columns: bool = False,
separator: str = "_",
| None,
maintain_order: bool,
sort_columns: bool,
separator: str,
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
msg = "`pivot` is only supported for Polars>=1.0.0"
Expand All @@ -293,25 +320,25 @@ def pivot(
sort_columns=sort_columns,
separator=separator,
)
return self._from_native_object(result) # type: ignore[no-any-return]
return self._from_native_object(result)


class PolarsLazyFrame:
def __init__(
self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], dtypes: DTypes
self: Self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], dtypes: DTypes
) -> None:
self._native_frame = df
self._backend_version = backend_version
self._implementation = Implementation.POLARS
self._dtypes = dtypes

def __repr__(self) -> str: # pragma: no cover
def __repr__(self: Self) -> str: # pragma: no cover
return "PolarsLazyFrame"

def __narwhals_lazyframe__(self) -> Self:
def __narwhals_lazyframe__(self: Self) -> Self:
return self

def __narwhals_namespace__(self) -> PolarsNamespace:
def __narwhals_namespace__(self: Self) -> PolarsNamespace:
return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes)

def __native_namespace__(self: Self) -> ModuleType:
Expand All @@ -321,17 +348,17 @@ def __native_namespace__(self: Self) -> ModuleType:
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def _from_native_frame(self, df: Any) -> Self:
def _from_native_frame(self: Self, df: pl.LazyFrame) -> Self:
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
)

def _change_dtypes(self, dtypes: DTypes) -> Self:
def _change_dtypes(self: Self, dtypes: DTypes) -> Self:
return self.__class__(
self._native_frame, backend_version=self._backend_version, dtypes=dtypes
)

def __getattr__(self, attr: str) -> Any:
def __getattr__(self: Self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
import polars as pl # ignore-banned-import

Expand All @@ -354,18 +381,18 @@ def func(*args: Any, **kwargs: Any) -> Any:
return func

@property
def columns(self) -> list[str]:
def columns(self: Self) -> list[str]:
return self._native_frame.columns

@property
def schema(self) -> dict[str, Any]:
def schema(self: Self) -> dict[str, DType]:
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
for name, dtype in schema.items()
}

def collect_schema(self) -> dict[str, DType]:
def collect_schema(self: Self) -> dict[str, DType]:
if self._backend_version < (1,):
return {
name: native_to_narwhals_dtype(dtype, self._dtypes, self._backend_version)
Expand All @@ -377,7 +404,7 @@ def collect_schema(self) -> dict[str, DType]:
for name, dtype in self._native_frame.collect_schema().items()
}

def collect(self) -> PolarsDataFrame:
def collect(self: Self) -> PolarsDataFrame:
import polars as pl # ignore-banned-import

try:
Expand All @@ -391,12 +418,12 @@ def collect(self) -> PolarsDataFrame:
dtypes=self._dtypes,
)

def group_by(self, *by: str, drop_null_keys: bool) -> Any:
def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsLazyGroupBy:
from narwhals._polars.group_by import PolarsLazyGroupBy

return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys)

def with_row_index(self, name: str) -> Any:
def with_row_index(self: Self, name: str) -> Self:
if self._backend_version < (0, 20, 4):
return self._from_native_frame(self._native_frame.with_row_count(name))
return self._from_native_frame(self._native_frame.with_row_index(name))
Expand Down
Loading

0 comments on commit 81e6078

Please sign in to comment.