Skip to content

Commit

Permalink
chore: rename internal dtype functions (#780)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 13, 2024
1 parent 0cff9d4 commit a75b726
Show file tree
Hide file tree
Showing 21 changed files with 54 additions and 135 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
if: startsWith(github.ref, 'refs/tags/v') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
Expand Down
6 changes: 4 additions & 2 deletions docs/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ Make sure that, in addition to the public Narwhals API, you also define:
from `Narwhals.DataFrame`
- `DataFrame.__narwhals_namespace__`: return an object which implements public top-level
functions from `narwhals` (e.g. `narwhals.col`, `narwhals.concat`, ...)
- `DataFrame.__native_namespace__`: return a native namespace object which must have a
`from_dict` method
- `LazyFrame.__narwhals_lazyframe__`: return an object which implements public methods
from `Narwhals.LazyFrame`
- `LazyFrame.__narwhals_namespace__`: return an object which implements public top-level
functions from `narwhals` (e.g. `narwhals.col`, `narwhals.concat`, ...)
- `LazyFrame.__native_namespace__`: return a native namespace object which must have a
`from_dict` method
- `Series.__narwhals_series__`: return an object which implements public methods
from `Narwhals.Series`
- `Series.__narwhals_namespace__`: return an object which implements public top-level
functions from `narwhals` (e.g. `narwhals.col`, `narwhals.concat`, ...)

If your library doesn't distinguish between lazy and eager, then it's OK for your dataframe
object to implement both `__narwhals_dataframe__` and `__narwhals_lazyframe__`. In fact,
Expand Down
4 changes: 1 addition & 3 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,7 @@ def filter(
self,
*predicates: IntoArrowExpr,
) -> Self:
from narwhals._arrow.namespace import ArrowNamespace

plx = ArrowNamespace(backend_version=self._backend_version)
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
Expand Down
4 changes: 1 addition & 3 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,7 @@ def len(self) -> Self:
return reuse_series_implementation(self, "len", returns_scalar=True)

def filter(self, *predicates: Any) -> Self:
from narwhals._arrow.namespace import ArrowNamespace

plx = ArrowNamespace(backend_version=self._backend_version)
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
return reuse_series_implementation(self, "filter", other=expr)

Expand Down
10 changes: 2 additions & 8 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from narwhals._arrow.utils import cast_for_truediv
from narwhals._arrow.utils import floordiv_compat
from narwhals._arrow.utils import reverse_translate_dtype
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_column_comparand
from narwhals.dependencies import get_numpy
Expand All @@ -23,7 +23,6 @@
from typing_extensions import Self

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
from narwhals.dtypes import DType


Expand Down Expand Up @@ -265,11 +264,6 @@ def n_unique(self) -> int:
def __native_namespace__(self) -> Any: # pragma: no cover
return get_pyarrow()

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -369,7 +363,7 @@ def is_null(self) -> Self:
def cast(self, dtype: DType) -> Self:
pc = get_pyarrow_compute()
ser = self._native_series
dtype = reverse_translate_dtype(dtype)
dtype = narwhals_to_native_dtype(dtype)
return self._from_native_series(pc.cast(ser, dtype))

def null_count(self: Self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
raise AssertionError


def reverse_translate_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
from narwhals import dtypes

pa = get_pyarrow()
Expand Down
1 change: 1 addition & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
def __narwhals_expr__(self) -> None: ...

def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
# Unused, just for compatibility with PandasLikeExpr
from narwhals._dask.namespace import DaskNamespace

return DaskNamespace(backend_version=self._backend_version)
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def col(self, *column_names: str) -> DaskExpr:
)

def lit(self, value: Any, dtype: dtypes.DType | None) -> DaskExpr:
# TODO @FBruzzesi: cast to dtype once `reverse_translate_dtype` is implemented.
# It should be enough to add `.astype(reverse_translate_dtype(dtype))`
# TODO @FBruzzesi: cast to dtype once `narwhals_to_native_dtype` is implemented.
# It should be enough to add `.astype(narwhals_to_native_dtype(dtype))`
return DaskExpr(
lambda df: [df._native_dataframe.assign(lit=value).loc[:, "lit"]],
depth=0,
Expand Down
4 changes: 1 addition & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,7 @@ def filter(
self,
*predicates: IntoPandasLikeExpr,
) -> Self:
from narwhals._pandas_like.namespace import PandasLikeNamespace

plx = PandasLikeNamespace(self._implementation, self._backend_version)
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
Expand Down
4 changes: 1 addition & 3 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def arg_true(self) -> Self:
return reuse_series_implementation(self, "arg_true")

def filter(self, *predicates: Any) -> Self:
from narwhals._pandas_like.namespace import PandasLikeNamespace

plx = PandasLikeNamespace(self._implementation, self._backend_version)
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
return reuse_series_implementation(self, "filter", other=expr)

Expand Down
10 changes: 2 additions & 8 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from typing import overload

from narwhals._pandas_like.utils import int_dtype_mapper
from narwhals._pandas_like.utils import narwhals_to_native_dtype
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import reverse_translate_dtype
from narwhals._pandas_like.utils import to_datetime
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_column_comparand
Expand All @@ -25,7 +25,6 @@
from typing_extensions import Self

from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals.dtypes import DType

PANDAS_TO_NUMPY_DTYPE_NO_MISSING = {
Expand Down Expand Up @@ -99,11 +98,6 @@ def __init__(
else:
self._use_copy_false = False

def __narwhals_namespace__(self) -> PandasLikeNamespace:
from narwhals._pandas_like.namespace import PandasLikeNamespace

return PandasLikeNamespace(self._implementation, self._backend_version)

def __native_namespace__(self) -> Any:
if self._implementation is Implementation.PANDAS:
return get_pandas()
Expand Down Expand Up @@ -181,7 +175,7 @@ def cast(
dtype: Any,
) -> Self:
ser = self._native_series
dtype = reverse_translate_dtype(dtype, ser.dtype, self._implementation)
dtype = narwhals_to_native_dtype(dtype, ser.dtype, self._implementation)
return self._from_native_series(ser.astype(dtype))

def item(self: Self, index: int | None = None) -> Any:
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,20 @@ def get_dtype_backend(dtype: Any, implementation: Implementation) -> str:
return "numpy"


def reverse_translate_dtype( # noqa: PLR0915
def narwhals_to_native_dtype( # noqa: PLR0915
dtype: DType | type[DType], starting_dtype: Any, implementation: Implementation
) -> Any:
from narwhals import dtypes

if "polars" in str(type(dtype)):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

dtype_backend = get_dtype_backend(starting_dtype, implementation)
if isinstance_or_issubclass(dtype, dtypes.Float64):
if dtype_backend == "pyarrow-nullable":
Expand Down Expand Up @@ -413,6 +422,9 @@ def reverse_translate_dtype( # noqa: PLR0915
return "date32[pyarrow]"
msg = "Date dtype only supported for pyarrow-backed data types in pandas"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Enum):
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
11 changes: 2 additions & 9 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import TYPE_CHECKING
from typing import Any

from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import reverse_translate_dtype
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.utils import Implementation

if TYPE_CHECKING:
Expand All @@ -23,12 +22,6 @@ def __init__(self, expr: Any) -> None:
def __repr__(self) -> str: # pragma: no cover
return "PolarsExpr"

def __narwhals_expr__(self) -> Self: # pragma: no cover
return self

def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
return PolarsNamespace(backend_version=self._backend_version)

def _from_native_expr(self, expr: Any) -> Self:
return self.__class__(expr)

Expand All @@ -43,7 +36,7 @@ def func(*args: Any, **kwargs: Any) -> Any:

def cast(self, dtype: DType) -> Self:
expr = self._native_expr
dtype = reverse_translate_dtype(dtype)
dtype = narwhals_to_native_dtype(dtype)
return self._from_native_expr(expr.cast(dtype))

def __eq__(self, other: object) -> Self: # type: ignore[override]
Expand Down
6 changes: 3 additions & 3 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from narwhals import dtypes
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import reverse_translate_dtype
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation

Expand Down Expand Up @@ -82,7 +82,7 @@ def lit(self, value: Any, dtype: dtypes.DType | None = None) -> PolarsExpr:

pl = get_polars()
if dtype is not None:
return PolarsExpr(pl.lit(value, dtype=reverse_translate_dtype(dtype)))
return PolarsExpr(pl.lit(value, dtype=narwhals_to_native_dtype(dtype)))
return PolarsExpr(pl.lit(value))

def mean(self, *column_names: str) -> Any:
Expand All @@ -102,7 +102,7 @@ def by_dtype(self, dtypes: Iterable[dtypes.DType]) -> PolarsExpr:

pl = get_polars()
return PolarsExpr(
pl.selectors.by_dtype([reverse_translate_dtype(dtype) for dtype in dtypes])
pl.selectors.by_dtype([narwhals_to_native_dtype(dtype) for dtype in dtypes])
)

def numeric(self) -> PolarsExpr:
Expand Down
8 changes: 2 additions & 6 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals.dtypes import DType

from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import reverse_translate_dtype
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals._polars.utils import translate_dtype

PL = get_polars()
Expand All @@ -39,9 +38,6 @@ def __narwhals_series__(self) -> Self:
def __native_namespace__(self) -> Any:
return get_polars()

def __narwhals_namespace__(self) -> PolarsNamespace:
return PolarsNamespace(backend_version=self._backend_version)

def _from_native_series(self, series: Any) -> Self:
return self.__class__(series, backend_version=self._backend_version)

Expand Down Expand Up @@ -94,7 +90,7 @@ def __getitem__(self, item: int | slice | Sequence[int]) -> Any | Self:

def cast(self, dtype: DType) -> Self:
ser = self._native_series
dtype = reverse_translate_dtype(dtype)
dtype = narwhals_to_native_dtype(dtype)
return self._from_native_series(ser.cast(dtype))

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
Expand Down
7 changes: 4 additions & 3 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
return dtypes.Unknown()


def reverse_translate_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
pl = get_polars()
from narwhals import dtypes

Expand Down Expand Up @@ -100,8 +100,9 @@ def reverse_translate_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
return pl.Object()
if dtype == dtypes.Categorical:
return pl.Categorical()
if dtype == dtypes.Enum: # pragma: no cover
return pl.Enum()
if dtype == dtypes.Enum:
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if dtype == dtypes.Datetime:
return pl.Datetime()
if dtype == dtypes.Duration:
Expand Down
53 changes: 2 additions & 51 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from typing_extensions import Self
Expand All @@ -18,6 +15,8 @@ def is_numeric(cls: type[Self]) -> bool:
return issubclass(cls, NumericType)

def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
from narwhals.utils import isinstance_or_issubclass

return isinstance_or_issubclass(other, type(self))

def __hash__(self) -> int:
Expand Down Expand Up @@ -85,51 +84,3 @@ class Enum(DType): ...


class Date(TemporalType): ...


def translate_dtype(plx: Any, dtype: DType) -> Any:
if "polars" in str(type(dtype)):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)
if dtype == Float64:
return plx.Float64
if dtype == Float32:
return plx.Float32
if dtype == Int64:
return plx.Int64
if dtype == Int32:
return plx.Int32
if dtype == Int16:
return plx.Int16
if dtype == Int8:
return plx.Int8
if dtype == UInt64:
return plx.UInt64
if dtype == UInt32:
return plx.UInt32
if dtype == UInt16:
return plx.UInt16
if dtype == UInt8:
return plx.UInt8
if dtype == String:
return plx.String
if dtype == Boolean:
return plx.Boolean
if dtype == Categorical:
return plx.Categorical
if dtype == Enum:
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if dtype == Datetime:
return plx.Datetime
if dtype == Duration:
return plx.Duration
if dtype == Date:
return plx.Date
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
Loading

0 comments on commit a75b726

Please sign in to comment.