Skip to content

Commit

Permalink
broken typing
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Feb 21, 2024
1 parent a3b27bf commit 6db674a
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 396 deletions.
8 changes: 4 additions & 4 deletions narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from typing import Iterable
from typing import Literal

from narwhals.pandas_like.utils import evaluate_into_exprs
from narwhals.pandas_like.utils import flatten_str
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import validate_dataframe_comparand
from narwhals.spec import DataFrame as DataFrameT
from narwhals.spec import IntoExpr
from narwhals.spec import LazyFrame as LazyFrameProtocol
from narwhals.spec import Namespace as NamespaceProtocol
from narwhals.utils import evaluate_into_exprs
from narwhals.utils import flatten_str
from narwhals.utils import horizontal_concat
from narwhals.utils import validate_dataframe_comparand

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down
88 changes: 46 additions & 42 deletions narwhals/pandas_like/expr.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

from narwhals.pandas_like.series import Series
from narwhals.pandas_like.utils import register_expression_call
from narwhals.spec import DataFrame as DataFrameT
from narwhals.spec import Expr as ExprT
from narwhals.spec import ExprStringNamespace as ExprStringNamespaceT
from narwhals.spec import LazyFrame as LazyFrameProtocol
from narwhals.spec import Namespace as NamespaceProtocol
from narwhals.spec import Series as SeriesProtocol
from narwhals.utils import register_expression_call

if TYPE_CHECKING:
from typing_extensions import Self


class Expr(ExprT):
Expand Down Expand Up @@ -44,8 +48,8 @@ def __repr__(self) -> str:

@classmethod
def from_column_names(
cls: type[Expr], *column_names: str, implementation: str
) -> ExprT:
cls: type[Self], *column_names: str, implementation: str
) -> Self:
return cls(
lambda df: [
Series(
Expand All @@ -70,124 +74,124 @@ def __expr_namespace__(self) -> NamespaceProtocol:
implementation=self._implementation, # type: ignore[attr-defined]
)

def __eq__(self, other: Expr | Any) -> ExprT: # type: ignore[override]
def __eq__(self, other: Expr | Any) -> Self: # type: ignore[override]
return register_expression_call(self, "__eq__", other)

def __ne__(self, other: Expr | Any) -> ExprT: # type: ignore[override]
def __ne__(self, other: Expr | Any) -> Self: # type: ignore[override]
return register_expression_call(self, "__ne__", other)

def __ge__(self, other: Expr | Any) -> ExprT:
def __ge__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__ge__", other)

def __gt__(self, other: Expr | Any) -> ExprT:
def __gt__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__gt__", other)

def __le__(self, other: Expr | Any) -> ExprT:
def __le__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__le__", other)

def __lt__(self, other: Expr | Any) -> ExprT:
def __lt__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__lt__", other)

def __and__(self, other: Expr | bool | Any) -> ExprT:
def __and__(self, other: Expr | bool | Any) -> Self:
return register_expression_call(self, "__and__", other)

def __rand__(self, other: Any) -> ExprT:
def __rand__(self, other: Any) -> Self:
return register_expression_call(self, "__rand__", other)

def __or__(self, other: Expr | bool | Any) -> ExprT:
def __or__(self, other: Expr | bool | Any) -> Self:
return register_expression_call(self, "__or__", other)

def __ror__(self, other: Any) -> ExprT:
def __ror__(self, other: Any) -> Self:
return register_expression_call(self, "__ror__", other)

def __add__(self, other: Expr | Any) -> ExprT: # type: ignore[override]
def __add__(self, other: Expr | Any) -> Self: # type: ignore[override]
return register_expression_call(self, "__add__", other)

def __radd__(self, other: Any) -> ExprT:
def __radd__(self, other: Any) -> Self:
return register_expression_call(self, "__radd__", other)

def __sub__(self, other: Expr | Any) -> ExprT:
def __sub__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__sub__", other)

def __rsub__(self, other: Any) -> ExprT:
def __rsub__(self, other: Any) -> Self:
return register_expression_call(self, "__rsub__", other)

def __mul__(self, other: Expr | Any) -> ExprT:
def __mul__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__mul__", other)

def __rmul__(self, other: Any) -> ExprT:
def __rmul__(self, other: Any) -> Self:
return self.__mul__(other)

def __truediv__(self, other: Expr | Any) -> ExprT:
def __truediv__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__truediv__", other)

def __rtruediv__(self, other: Any) -> ExprT:
def __rtruediv__(self, other: Any) -> Self:
raise NotImplementedError

def __floordiv__(self, other: Expr | Any) -> ExprT:
def __floordiv__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__floordiv__", other)

def __rfloordiv__(self, other: Any) -> ExprT:
def __rfloordiv__(self, other: Any) -> Self:
raise NotImplementedError

def __pow__(self, other: Expr | Any) -> ExprT:
def __pow__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__pow__", other)

def __rpow__(self, other: Any) -> ExprT: # pragma: no cover
def __rpow__(self, other: Any) -> Self: # pragma: no cover
raise NotImplementedError

def __mod__(self, other: Expr | Any) -> ExprT:
def __mod__(self, other: Expr | Any) -> Self:
return register_expression_call(self, "__mod__", other)

def __rmod__(self, other: Any) -> ExprT: # pragma: no cover
def __rmod__(self, other: Any) -> Self: # pragma: no cover
raise NotImplementedError

# Unary

def __invert__(self) -> ExprT:
def __invert__(self) -> Self:
return register_expression_call(self, "__invert__")

# Reductions

def sum(self) -> ExprT:
def sum(self) -> Self:
return register_expression_call(self, "sum")

def mean(self) -> ExprT:
def mean(self) -> Self:
return register_expression_call(self, "mean")

def max(self) -> ExprT:
def max(self) -> Self:
return register_expression_call(self, "max")

def min(self) -> ExprT:
def min(self) -> Self:
return register_expression_call(self, "min")

# Other
def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
) -> ExprT:
) -> Self:
return register_expression_call(
self, "is_between", lower_bound, upper_bound, closed
)

def is_null(self) -> ExprT:
def is_null(self) -> Self:
return register_expression_call(self, "is_null")

def is_in(self, other: Any) -> ExprT:
def is_in(self, other: Any) -> Self:
return register_expression_call(self, "is_in", other)

def drop_nulls(self) -> ExprT:
def drop_nulls(self) -> Self:
return register_expression_call(self, "drop_nulls")

def n_unique(self) -> ExprT:
def n_unique(self) -> Self:
return register_expression_call(self, "n_unique")

def unique(self) -> ExprT:
def unique(self) -> Self:
return register_expression_call(self, "unique")

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> ExprT:
def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self:
return register_expression_call(self, "sample", n, fraction, with_replacement)

def alias(self, name: str) -> ExprT:
def alias(self, name: str) -> Self:
# Define this one manually, so that we can
# override `output_names` and not increase depth
if self._depth is None:
Expand All @@ -211,7 +215,7 @@ class ExprStringNamespace(ExprStringNamespaceT):
def __init__(self, expr: ExprT) -> None:
self._expr = expr

def ends_with(self, suffix: str) -> ExprT:
def ends_with(self, suffix: str) -> Expr:
# TODO make a register_expression_call for namespaces
return Expr(
lambda df: [
Expand All @@ -229,7 +233,7 @@ def ends_with(self, suffix: str) -> ExprT:
implementation=self._expr._implementation, # type: ignore[attr-defined]
)

def strip_chars(self, characters: str = " ") -> ExprT:
def strip_chars(self, characters: str = " ") -> Expr:
return Expr(
lambda df: [
Series(
Expand Down
12 changes: 6 additions & 6 deletions narwhals/pandas_like/group_by_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from typing import Iterable

from narwhals.pandas_like.dataframe import LazyFrame
from narwhals.pandas_like.utils import dataframe_from_dict
from narwhals.pandas_like.utils import evaluate_simple_aggregation
from narwhals.pandas_like.utils import get_namespace
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import is_simple_aggregation
from narwhals.pandas_like.utils import parse_into_exprs
from narwhals.spec import DataFrame as DataFrameT
from narwhals.spec import GroupBy as GroupByProtocol
from narwhals.spec import IntoExpr
from narwhals.spec import LazyFrame as LazyFrameProtocol
from narwhals.spec import LazyGroupBy as LazyGroupByT
from narwhals.utils import dataframe_from_dict
from narwhals.utils import evaluate_simple_aggregation
from narwhals.utils import get_namespace
from narwhals.utils import horizontal_concat
from narwhals.utils import is_simple_aggregation
from narwhals.utils import parse_into_exprs


class GroupBy(GroupByProtocol):
Expand Down
33 changes: 16 additions & 17 deletions narwhals/pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
from narwhals.pandas_like.dataframe import LazyFrame
from narwhals.pandas_like.expr import Expr
from narwhals.pandas_like.series import Series
from narwhals.pandas_like.utils import flatten_str
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import parse_into_exprs
from narwhals.pandas_like.utils import series_from_iterable
from narwhals.spec import AnyDataFrame
from narwhals.spec import DataFrame as DataFrameT
from narwhals.spec import Expr as ExprT
from narwhals.spec import IntoExpr
from narwhals.spec import LazyFrame as LazyFrameProtocol
from narwhals.spec import Namespace as NamespaceProtocol
from narwhals.spec import Series as SeriesProtocol
from narwhals.utils import flatten_str
from narwhals.utils import horizontal_concat
from narwhals.utils import parse_into_exprs
from narwhals.utils import series_from_iterable


class Namespace(NamespaceProtocol):
Expand All @@ -29,13 +28,13 @@ def __init__(self, *, api_version: str, implementation: str) -> None:
self._implementation = implementation

# --- horizontal reductions
def sum_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> ExprT:
def sum_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return reduce(lambda x, y: x + y, parse_into_exprs(self, *exprs))

def all_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> ExprT:
def all_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return reduce(lambda x, y: x & y, parse_into_exprs(self, *exprs))

def any_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> ExprT:
def any_horizontal(self, *exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return reduce(lambda x, y: x | y, parse_into_exprs(self, *exprs))

def concat(self, items: Iterable[AnyDataFrame], *, how: str) -> AnyDataFrame:
Expand All @@ -62,32 +61,32 @@ def concat(self, items: Iterable[AnyDataFrame], *, how: str) -> AnyDataFrame:
implementation=self._implementation,
)

def col(self, *column_names: str | Iterable[str]) -> ExprT:
def col(self, *column_names: str | Iterable[str]) -> Expr:
return Expr.from_column_names(
*flatten_str(*column_names), implementation=self._implementation
)

def sum(self, *column_names: str) -> ExprT:
def sum(self, *column_names: str) -> Expr:
return Expr.from_column_names(
*column_names, implementation=self._implementation
).sum()

def mean(self, *column_names: str) -> ExprT:
def mean(self, *column_names: str) -> Expr:
return Expr.from_column_names(
*column_names, implementation=self._implementation
).mean()

def max(self, *column_names: str) -> ExprT:
def max(self, *column_names: str) -> Expr:
return Expr.from_column_names(
*column_names, implementation=self._implementation
).max()

def min(self, *column_names: str) -> ExprT:
def min(self, *column_names: str) -> Expr:
return Expr.from_column_names(
*column_names, implementation=self._implementation
).min()

def len(self) -> ExprT:
def len(self) -> Expr:
return Expr(
lambda df: [
Series(
Expand Down Expand Up @@ -116,7 +115,7 @@ def _create_expr_from_callable( # noqa: PLR0913
function_name: str | None,
root_names: list[str] | None,
output_names: list[str] | None,
) -> ExprT:
) -> Expr:
return Expr(
func,
depth=depth,
Expand All @@ -140,7 +139,7 @@ def _create_series_from_scalar(
implementation=self._implementation,
)

def _create_expr_from_series(self, series: SeriesProtocol) -> ExprT:
def _create_expr_from_series(self, series: SeriesProtocol) -> Expr:
return Expr(
lambda _df: [series],
depth=0,
Expand All @@ -150,7 +149,7 @@ def _create_expr_from_series(self, series: SeriesProtocol) -> ExprT:
implementation=self._implementation,
)

def all(self) -> ExprT:
def all(self) -> Expr:
return Expr(
lambda df: [
Series(
Expand Down
8 changes: 5 additions & 3 deletions narwhals/pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from pandas.api.types import is_extension_array_dtype

from narwhals.pandas_like.utils import item
from narwhals.pandas_like.utils import validate_column_comparand
from narwhals.spec import Series as SeriesProtocol
from narwhals.utils import item
from narwhals.utils import validate_column_comparand

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals.pandas_like.namespace import Namespace


Expand Down Expand Up @@ -296,7 +298,7 @@ def sort(
ser.sort_values(ascending=not descending).rename(self.name)
)

def alias(self, name: str) -> Series:
def alias(self, name: str) -> Self:
ser = self.series
return self._from_series(ser.rename(name, copy=False))

Expand Down
2 changes: 1 addition & 1 deletion narwhals/spec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class Expr(Protocol):
def alias(self, name: str) -> Expr:
def alias(self, name: str) -> Self:
...

def __and__(self, other: Any) -> Expr:
Expand Down
Loading

0 comments on commit 6db674a

Please sign in to comment.