Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: narwhals_to_native_dtype raise if polars dtype is passed #1307

Merged
merged 3 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Sequence

from narwhals.dependencies import get_polars
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,6 +77,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pl.DataType.__class__ is a workaround since pl.DataTypeClass now raises a DeprecationWarning.

I think this is the best way to check for both pl.String() and pl.String: if "polars" in str(type(dtype)) was definitly not enough for both

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks!

):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

import pyarrow as pa # ignore-banned-import

if isinstance_or_issubclass(dtype, dtypes.Float64):
Expand Down
12 changes: 12 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version
Expand Down Expand Up @@ -85,6 +86,17 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
Expand Down
5 changes: 4 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation
from narwhals.utils import isinstance_or_issubclass

Expand Down Expand Up @@ -339,7 +340,9 @@ def narwhals_to_native_dtype( # noqa: PLR0915
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Any:
if "polars" in str(type(dtype)):
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from narwhals.dtypes import DType
from narwhals.typing import DTypes

from narwhals.dependencies import get_polars
from narwhals.utils import parse_version


Expand Down Expand Up @@ -94,6 +95,17 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.Int64` instead of `nw.Int64`?"
)
raise TypeError(msg)

import polars as pl # ignore-banned-import()

if dtype == dtypes.Float64:
Expand Down Expand Up @@ -132,7 +144,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime):
dt_time_unit = getattr(dtype, "time_unit", "us")
dt_time_zone = getattr(dtype, "time_zone", None)
return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type]
return pl.Datetime(dt_time_unit, dt_time_zone)
if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration):
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
return pl.Duration(time_unit=du_time_unit)
Expand Down
5 changes: 1 addition & 4 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Se
"""
return function(self, *args, **kwargs)

def cast(
self,
dtype: Any,
) -> Self:
def cast(self: Self, dtype: DType | type[DType]) -> Self:
"""
Redefine an object's data type.

Expand Down
5 changes: 1 addition & 4 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,7 @@ def name(self) -> str:
"""
return self._compliant_series.name # type: ignore[no-any-return]

def cast(
self,
dtype: Any,
) -> Self:
def cast(self: Self, dtype: DType | type[DType]) -> Self:
"""
Cast between data types.

Expand Down
11 changes: 10 additions & 1 deletion tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -191,7 +193,7 @@ class Banana:
pass

with pytest.raises(AssertionError, match=r"Unknown dtype"):
df.select(nw.col("a").cast(Banana))
df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type]


def test_cast_datetime_tz_aware(
Expand Down Expand Up @@ -222,3 +224,10 @@ def test_cast_datetime_tz_aware(
.str.slice(offset=0, length=19)
)
assert_equal_data(result, expected)


@pytest.mark.parametrize("dtype", [pl.String, pl.String()])
def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]}))
with pytest.raises(TypeError, match="Expected Narwhals object, got:"):
df.select(nw.col("a").cast(dtype))
2 changes: 1 addition & 1 deletion tests/frame/invalid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_invalid() -> None:
with pytest.raises(TypeError, match="Perhaps you:"):
df.select([pl.col("a")]) # type: ignore[list-item]
with pytest.raises(TypeError, match="Perhaps you:"):
df.select([nw.col("a").cast(pl.Int64)])
df.select([nw.col("a").cast(pl.Int64)]) # type: ignore[arg-type]


def test_native_vs_non_native() -> None:
Expand Down
Loading