Skip to content

Commit

Permalink
check for polars specifically
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Nov 2, 2024
1 parent 487bfa8 commit 6040d32
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 32 deletions.
5 changes: 4 additions & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Sequence

from narwhals.dependencies import get_polars
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,7 +77,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if "polars" in str(type(dtype)):
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
Expand Down
5 changes: 4 additions & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version
Expand Down Expand Up @@ -85,7 +86,9 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if "polars" in str(type(dtype)):
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
Expand Down
5 changes: 4 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation
from narwhals.utils import isinstance_or_issubclass

Expand Down Expand Up @@ -339,7 +340,9 @@ def narwhals_to_native_dtype( # noqa: PLR0915
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Any:
if "polars" in str(type(dtype)):
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
Expand Down
7 changes: 5 additions & 2 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from narwhals.dtypes import DType
from narwhals.typing import DTypes

from narwhals.dependencies import get_polars
from narwhals.utils import parse_version


Expand Down Expand Up @@ -94,7 +95,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:


def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if "polars" in str(type(dtype)):
if (pl := get_polars()) is not None and isinstance(
dtype, (pl.DataType, pl.DataType.__class__)
):
msg = (
f"Expected Narwhals object, got: {type(dtype)}.\n\n"
"Perhaps you:\n"
Expand Down Expand Up @@ -141,7 +144,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime):
dt_time_unit = getattr(dtype, "time_unit", "us")
dt_time_zone = getattr(dtype, "time_zone", None)
return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type]
return pl.Datetime(dt_time_unit, dt_time_zone)
if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration):
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
return pl.Duration(time_unit=du_time_unit)
Expand Down
13 changes: 0 additions & 13 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,6 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
foo: [[1,2,3]]
bar: [[6,7,8]]
"""
# from narwhals.dtypes import DType

# if not (
# isinstance(dtype, DType)
# or (isinstance(dtype, type) and issubclass(dtype, DType))
# ):
# msg = (
# f"Expected Narwhals DType, got: {type(dtype)}.\n\n"
# "Hint: Perhaps you used Polars DataType instance `pl.dtype` instead of "
# "Narwhals DType `nw.dtype`?"
# )
# raise TypeError(msg)

return self.__class__(
lambda plx: self._call(plx).cast(dtype),
)
Expand Down
10 changes: 0 additions & 10 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,16 +420,6 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
1
]
"""
# from narwhals.dtypes import DType

# if not (isinstance(dtype, DType) or dtype == DType()):
# msg = (
# f"Expected Narwhals DType, got: {type(dtype)}.\n\n"
# "Hint: Perhaps you used Polars DataType instance `pl.dtype` instead of "
# "Narwhals DType `nw.dtype`?"
# )
# raise TypeError(msg)

return self._from_compliant_series(self._compliant_series.cast(dtype))

def to_frame(self) -> DataFrame[Any]:
Expand Down
7 changes: 3 additions & 4 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -194,7 +193,7 @@ class Banana:
pass

with pytest.raises(AssertionError, match=r"Unknown dtype"):
df.select(nw.col("a").cast(Banana))
df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type]


def test_cast_datetime_tz_aware(
Expand Down Expand Up @@ -227,8 +226,8 @@ def test_cast_datetime_tz_aware(
assert_equal_data(result, expected)


@pytest.mark.parametrize("dtype", [pl.String, pl.String(), pa.float64(), str])
def test_raise_if_not_narwhals_dtype(constructor: Constructor, dtype: Any) -> None:
@pytest.mark.parametrize("dtype", [pl.String, pl.String()])
def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]}))
with pytest.raises(TypeError, match="Expected Narwhals object, got:"):
df.select(nw.col("a").cast(dtype))

0 comments on commit 6040d32

Please sign in to comment.