Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 15, 2024
1 parent 68a1084 commit 5f916d5
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 90 deletions.
7 changes: 5 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal
from typing import Sequence

from narwhals.dtypes import to_narwhals_dtype
from narwhals.pandas_like.dataframe import PandasDataFrame
from narwhals.translate import get_pandas
from narwhals.translate import get_polars
Expand Down Expand Up @@ -85,7 +86,6 @@ def _extract_native(self, arg: Any) -> Any:
import polars as pl

return arg._call(pl)
# todo: if it's dtype, translate
return arg

def __repr__(self) -> str: # pragma: no cover
Expand All @@ -104,7 +104,10 @@ def __repr__(self) -> str: # pragma: no cover

@property
def schema(self) -> dict[str, DType]:
return self._dataframe.schema # type: ignore[no-any-return]
return {
k: to_narwhals_dtype(v, self._implementation)
for k, v in self._dataframe.schema.items()
}

@property
def columns(self) -> list[str]:
Expand Down
133 changes: 69 additions & 64 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

from narwhals.pandas_like.utils import isinstance_or_issubclass

Expand Down Expand Up @@ -88,67 +89,71 @@ class Date(TemporalType):
...


# def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
# if isinstance_or_issubclass(dtype, Float64):
# return pl.Float64
# if isinstance_or_issubclass(dtype, Float32):
# return pl.Float32
# if isinstance_or_issubclass(dtype, Int64):
# return pl.Int64
# if isinstance_or_issubclass(dtype, Int32):
# return pl.Int32
# if isinstance_or_issubclass(dtype, Int16):
# return pl.Int16
# if isinstance_or_issubclass(dtype, UInt8):
# return pl.UInt8
# if isinstance_or_issubclass(dtype, UInt64):
# return pl.UInt64
# if isinstance_or_issubclass(dtype, UInt32):
# return pl.UInt32
# if isinstance_or_issubclass(dtype, UInt16):
# return pl.UInt16
# if isinstance_or_issubclass(dtype, UInt8):
# return pl.UInt8
# if isinstance_or_issubclass(dtype, String):
# return pl.String
# if isinstance_or_issubclass(dtype, Boolean):
# return pl.Boolean
# if isinstance_or_issubclass(dtype, Datetime):
# return pl.Datetime
# if isinstance_or_issubclass(dtype, Date):
# return pl.Date
# msg = f"Unknown dtype: {dtype}"
# raise TypeError(msg)


# def translate_dtype(dtype: PolarsDataType) -> Any:
# if dtype == pl.Float64:
# return Float64
# if dtype == pl.Float32:
# return Float32
# if dtype == pl.Int64:
# return Int64
# if dtype == pl.Int32:
# return Int32
# if dtype == pl.Int16:
# return Int16
# if dtype == pl.UInt8:
# return UInt8
# if dtype == pl.UInt64:
# return UInt64
# if dtype == pl.UInt32:
# return UInt32
# if dtype == pl.UInt16:
# return UInt16
# if dtype == pl.UInt8:
# return UInt8
# if dtype == pl.String:
# return String
# if dtype == pl.Boolean:
# return Boolean
# if dtype == pl.Datetime:
# return Datetime
# if dtype == pl.Date:
# return Date
# msg = f"Unknown dtype: {dtype}"
# raise TypeError(msg)
def translate_dtype(plx: Any, dtype: DType) -> Any:
if dtype == Float64:
return plx.Float64
if dtype == Float32:
return plx.Float32
if dtype == Int64:
return plx.Int64
if dtype == Int32:
return plx.Int32
if dtype == Int16:
return plx.Int16
if dtype == UInt8:
return plx.UInt8
if dtype == UInt64:
return plx.UInt64
if dtype == UInt32:
return plx.UInt32
if dtype == UInt16:
return plx.UInt16
if dtype == UInt8:
return plx.UInt8
if dtype == String:
return plx.String
if dtype == Boolean:
return plx.Boolean
if dtype == Datetime:
return plx.Datetime
if dtype == Date:
return plx.Date
msg = f"Unknown dtype: {dtype}"
raise TypeError(msg)


def to_narwhals_dtype(dtype: Any, implementation: str) -> DType:
if implementation != "polars":
return dtype # type: ignore[no-any-return]
import polars as pl

if dtype == pl.Float64:
return Float64()
if dtype == pl.Float32:
return Float32()
if dtype == pl.Int64:
return Int64()
if dtype == pl.Int32:
return Int32()
if dtype == pl.Int16:
return Int16()
if dtype == pl.UInt8:
return UInt8()
if dtype == pl.UInt64:
return UInt64()
if dtype == pl.UInt32:
return UInt32()
if dtype == pl.UInt16:
return UInt16()
if dtype == pl.UInt8:
return UInt8()
if dtype == pl.String:
return String()
if dtype == pl.Boolean:
return Boolean()
if dtype == pl.Datetime:
return Datetime()
if dtype == pl.Date:
return Date()
msg = f"Unexpected dtype, got: {type(dtype)}"
raise TypeError(msg)
14 changes: 3 additions & 11 deletions narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing import Callable
from typing import Iterable

from narwhals.dtypes import translate_dtype

if TYPE_CHECKING:
from narwhals.dtypes import DType
from narwhals.typing import IntoExpr


Expand All @@ -16,15 +17,6 @@ def extract_native(expr: Expr, other: Any) -> Any:
return other


def translate_dtype(plx: Any, dtype: DType) -> Any:
from narwhals.dtypes import Int64

if dtype == Int64:
return plx.Int64
msg = f"unrecognised type: {dtype}"
raise TypeError(msg)


class Expr:
def __init__(self, call: Callable[[Any], Any]) -> None:
# callable from namespace to expr
Expand All @@ -39,7 +31,7 @@ def cast(
dtype: Any,
) -> Expr:
return self.__class__(
lambda plx: self._call(plx).cast(translate_dtype(plx, dtype))
lambda plx: self._call(plx).cast(translate_dtype(plx, dtype)),
)

# --- binary ---
Expand Down
11 changes: 9 additions & 2 deletions narwhals/pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,15 @@ def sample(self, n: int, fraction: float, *, with_replacement: bool) -> PandasSe
)

def unique(self) -> PandasSeries:
# in pandas it returns a list...
raise NotImplementedError
if self._implementation != "pandas":
raise NotImplementedError
import pandas as pd

return self._from_series(
pd.Series(
self._series.unique(), dtype=self._series.dtype, name=self._series.name
)
)

def is_nan(self) -> PandasSeries:
ser = self._series
Expand Down
42 changes: 31 additions & 11 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,24 @@ def __init__(
msg = f"Expected pandas or Polars Series, got: {type(series)}"
raise TypeError(msg)

def _extract_native(self, arg: Any) -> Any:
from narwhals.expression import Expr

if self._implementation != "polars":
return arg
if isinstance(arg, Series):
return arg._series
if isinstance(arg, Expr):
import polars as pl

return arg._call(pl)
return arg

def _from_series(self, series: Any) -> Self:
return self.__class__(series, implementation=self._implementation)

def __repr__(self) -> str: # pragma: no cover
header = " Narwhals Series "
header = " Narwhals Series "
length = len(header)
return (
"┌"
Expand All @@ -49,7 +65,7 @@ def __repr__(self) -> str: # pragma: no cover
)

def alias(self, name: str) -> Self:
return self.__class__(self._series.alias(name))
return self._from_series(self._series.alias(name))

@property
def name(self) -> str:
Expand All @@ -64,42 +80,46 @@ def shape(self) -> tuple[int]:
return self._series.shape # type: ignore[no-any-return]

def rename(self, name: str) -> Self:
return self.__class__(self._series.rename(name))
return self._from_series(self._series.rename(name))

def cast(
self,
dtype: Any,
) -> Self:
return self.__class__(self._series.cast(dtype))
return self._from_series(self._series.cast(dtype))

def item(self) -> Any:
return self._series.item()

def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
) -> Series:
return self.__class__(self._series.is_between(lower_bound, upper_bound, closed))
return self._from_series(
self._series.is_between(lower_bound, upper_bound, closed)
)

def is_in(self, other: Any) -> Series:
return self.__class__(self._series.is_in(other))
return self._from_series(self._series.is_in(self._extract_native(other)))

def is_null(self) -> Series:
return self.__class__(self._series.is_null())
return self._from_series(self._series.is_null())

def drop_nulls(self) -> Series:
return self.__class__(self._series.drop_nulls())
return self._from_series(self._series.drop_nulls())

def n_unique(self) -> int:
return self._series.n_unique() # type: ignore[no-any-return]

def unique(self) -> Series:
return self.__class__(self._series.unique())
return self._from_series(self._series.unique())

def zip_with(self, mask: Self, other: Self) -> Self:
raise NotImplementedError
return self._from_series(
self._series.zip_with(self._extract_native(mask), self._extract_native(other))
)

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Series:
return self.__class__(
return self._from_series(
self._series.sample(n, fraction=fraction, with_replacement=with_replacement)
)

Expand Down
3 changes: 3 additions & 0 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@

def to_native(obj: Any) -> Any:
from narwhals.dataframe import DataFrame
from narwhals.series import Series

if isinstance(obj, DataFrame):
return (
obj._dataframe
if obj._implementation == "polars"
else obj._dataframe._dataframe
)
if isinstance(obj, Series):
return obj._series if obj._implementation == "polars" else obj._series._series

msg = f"Expected Narwhals object, got {type(obj)}."
raise TypeError(msg)
Expand Down

0 comments on commit 5f916d5

Please sign in to comment.