diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 6f74294d5..56c8b1e50 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -4,6 +4,7 @@ from typing import Any from typing import Sequence +from narwhals.dependencies import get_polars from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: @@ -76,6 +77,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): + msg = ( + f"Expected Narwhals object, got: {type(dtype)}.\n\n" + "Perhaps you:\n" + "- Forgot a `nw.from_native` somewhere?\n" + "- Used `pl.Int64` instead of `nw.Int64`?" + ) + raise TypeError(msg) + import pyarrow as pa # ignore-banned-import if isinstance_or_issubclass(dtype, dtypes.Float64): diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 2ba7cdcbd..cf8f9a3fc 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -4,6 +4,7 @@ from typing import Any from narwhals.dependencies import get_pandas +from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow from narwhals.utils import isinstance_or_issubclass from narwhals.utils import parse_version @@ -85,6 +86,17 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None: def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): + msg = ( + f"Expected Narwhals object, got: {type(dtype)}.\n\n" + "Perhaps you:\n" + "- Forgot a `nw.from_native` somewhere?\n" + "- Used `pl.Int64` instead of `nw.Int64`?" + ) + raise TypeError(msg) + if isinstance_or_issubclass(dtype, dtypes.Float64): return "float64" if isinstance_or_issubclass(dtype, dtypes.Float32): diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 99181bc1e..58123c565 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -10,6 +10,7 @@ from narwhals._arrow.utils import ( native_to_narwhals_dtype as arrow_native_to_narwhals_dtype, ) +from narwhals.dependencies import get_polars from narwhals.utils import Implementation from narwhals.utils import isinstance_or_issubclass @@ -339,7 +340,9 @@ def narwhals_to_native_dtype( # noqa: PLR0915 backend_version: tuple[int, ...], dtypes: DTypes, ) -> Any: - if "polars" in str(type(dtype)): + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): msg = ( f"Expected Narwhals object, got: {type(dtype)}.\n\n" "Perhaps you:\n" diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index ac6ffb2bd..295c03bc3 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -8,6 +8,7 @@ from narwhals.dtypes import DType from narwhals.typing import DTypes +from narwhals.dependencies import get_polars from narwhals.utils import parse_version @@ -94,6 +95,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): + msg = ( + f"Expected Narwhals object, got: {type(dtype)}.\n\n" + "Perhaps you:\n" + "- Forgot a `nw.from_native` somewhere?\n" + "- Used `pl.Int64` instead of `nw.Int64`?" + ) + raise TypeError(msg) + import polars as pl # ignore-banned-import() if dtype == dtypes.Float64: @@ -132,7 +144,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime): dt_time_unit = getattr(dtype, "time_unit", "us") dt_time_zone = getattr(dtype, "time_zone", None) - return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type] + return pl.Datetime(dt_time_unit, dt_time_zone) if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration): du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") return pl.Duration(time_unit=du_time_unit) diff --git a/narwhals/expr.py b/narwhals/expr.py index 6c2d28962..46d44bee3 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -135,10 +135,7 @@ def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Se """ return function(self, *args, **kwargs) - def cast( - self, - dtype: Any, - ) -> Self: + def cast(self: Self, dtype: DType | type[DType]) -> Self: """ Redefine an object's data type. diff --git a/narwhals/series.py b/narwhals/series.py index 6f5223202..add55897e 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -383,10 +383,7 @@ def name(self) -> str: """ return self._compliant_series.name # type: ignore[no-any-return] - def cast( - self, - dtype: Any, - ) -> Self: + def cast(self: Self, dtype: DType | type[DType]) -> Self: """ Cast between data types. diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 11c20d0a7..14e77d68d 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -3,8 +3,10 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +from typing import Any import pandas as pd +import polars as pl import pytest import narwhals.stable.v1 as nw @@ -191,7 +193,7 @@ class Banana: pass with pytest.raises(AssertionError, match=r"Unknown dtype"): - df.select(nw.col("a").cast(Banana)) + df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type] def test_cast_datetime_tz_aware( @@ -222,3 +224,10 @@ def test_cast_datetime_tz_aware( .str.slice(offset=0, length=19) ) assert_equal_data(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.String, pl.String()]) +def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None: + df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) + with pytest.raises(TypeError, match="Expected Narwhals object, got:"): + df.select(nw.col("a").cast(dtype)) diff --git a/tests/frame/invalid_test.py b/tests/frame/invalid_test.py index 7fdf3e5fe..811d04304 100644 --- a/tests/frame/invalid_test.py +++ b/tests/frame/invalid_test.py @@ -20,7 +20,7 @@ def test_invalid() -> None: with pytest.raises(TypeError, match="Perhaps you:"): df.select([pl.col("a")]) # type: ignore[list-item] with pytest.raises(TypeError, match="Perhaps you:"): - df.select([nw.col("a").cast(pl.Int64)]) + df.select([nw.col("a").cast(pl.Int64)]) # type: ignore[arg-type] def test_native_vs_non_native() -> None: