Skip to content

Commit

Permalink
fix: narwhals_to_native_dtype raise if polars dtype is passed (#1307)
Browse files Browse the repository at this point in the history
* fix: raise for non-narwhals dtypes

* check for polars specifically

* mypy ignore
  • Loading branch information
FBruzzesi authored Nov 2, 2024
1 parent abd9d4a commit 2235bae
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 12 deletions.
12 changes: 12 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Sequence

from narwhals.dependencies import get_polars
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,6 +77,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

import pyarrow as pa # ignore-banned-import

if isinstance_or_issubclass(dtype, dtypes.Float64):
Expand Down
12 changes: 12 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version
Expand Down Expand Up @@ -85,6 +86,17 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
Expand Down
5 changes: 4 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation
from narwhals.utils import isinstance_or_issubclass

Expand Down Expand Up @@ -376,7 +377,9 @@ def narwhals_to_native_dtype( # noqa: PLR0915
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Any:
if "polars" in str(type(dtype)):
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from narwhals.dtypes import DType
from narwhals.typing import DTypes

from narwhals.dependencies import get_polars
from narwhals.utils import parse_version


Expand Down Expand Up @@ -94,6 +95,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

import polars as pl # ignore-banned-import()

if dtype == dtypes.Float64:
Expand Down Expand Up @@ -132,7 +144,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime):
dt_time_unit = getattr(dtype, "time_unit", "us")
dt_time_zone = getattr(dtype, "time_zone", None)
return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type]
return pl.Datetime(dt_time_unit, dt_time_zone)
if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration):
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
return pl.Duration(time_unit=du_time_unit)
Expand Down
5 changes: 1 addition & 4 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Se
"""
return function(self, *args, **kwargs)

def cast(
self,
dtype: Any,
) -> Self:
def cast(self: Self, dtype: DType | type[DType]) -> Self:
"""
Redefine an object's data type.
Expand Down
5 changes: 1 addition & 4 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,7 @@ def name(self) -> str:
"""
return self._compliant_series.name # type: ignore[no-any-return]

def cast(
self,
dtype: Any,
) -> Self:
def cast(self: Self, dtype: DType | type[DType]) -> Self:
"""
Cast between data types.
Expand Down
11 changes: 10 additions & 1 deletion tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -191,7 +193,7 @@ class Banana:
pass

with pytest.raises(AssertionError, match=r"Unknown dtype"):
df.select(nw.col("a").cast(Banana))
df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type]


def test_cast_datetime_tz_aware(
Expand Down Expand Up @@ -222,3 +224,10 @@ def test_cast_datetime_tz_aware(
.str.slice(offset=0, length=19)
)
assert_equal_data(result, expected)


@pytest.mark.parametrize("dtype", [pl.String, pl.String()])
def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]}))
with pytest.raises(TypeError, match="Expected Narwhals object, got:"):
df.select(nw.col("a").cast(dtype))
2 changes: 1 addition & 1 deletion tests/frame/invalid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_invalid() -> None:
with pytest.raises(TypeError, match="Perhaps you:"):
df.select([pl.col("a")]) # type: ignore[list-item]
with pytest.raises(TypeError, match="Perhaps you:"):
df.select([nw.col("a").cast(pl.Int64)])
df.select([nw.col("a").cast(pl.Int64)]) # type: ignore[arg-type]


def test_native_vs_non_native() -> None:
Expand Down

0 comments on commit 2235bae

Please sign in to comment.