Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
DeaMariaLeon committed Dec 23, 2024
1 parent e112a99 commit 6079bb7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
12 changes: 12 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def _validate_dtype(dtype: DType | type[DType]) -> None:
raise TypeError(msg)


def _validate_datetime(
dtype: DType | type[DType], series_dtype: DType | type[DType]
) -> None:
from narwhals.dtypes import Datetime

if isinstance_or_issubclass(series_dtype, Datetime) and not (
isinstance_or_issubclass(dtype, Datetime)
):
msg = f"Expected to cast to Narwhals Datetime, got: {dtype}.\n\n"
raise TypeError(msg)


class DType:
def __repr__(self) -> str: # pragma: no cover
return self.__class__.__qualname__
Expand Down
2 changes: 2 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import overload

from narwhals.dependencies import is_numpy_scalar
from narwhals.dtypes import _validate_datetime
from narwhals.dtypes import _validate_dtype
from narwhals.typing import IntoSeriesT
from narwhals.utils import _validate_rolling_arguments
Expand Down Expand Up @@ -630,6 +631,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
]
"""
_validate_dtype(dtype)
_validate_datetime(dtype, self.dtype)
return self._from_compliant_series(self._compliant_series.cast(dtype))

def to_frame(self) -> DataFrame[Any]:
Expand Down
16 changes: 16 additions & 0 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -226,3 +227,18 @@ def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None:
df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]}))
with pytest.raises(TypeError, match="Expected Narwhals dtype, got:"):
df.select(nw.col("a").cast(dtype))


def test_raise_datetime_to_numeric() -> None:
data = [datetime(year=2020, month=1, day=1, hour=0, second=0, minute=0)]
s_pd = nw.from_native(pd.Series(data), series_only=True)
with pytest.raises(TypeError, match="Expected to cast to Narwhals Datetime, got: "):
s_pd.cast(nw.Int64)

s_pa = nw.from_native(pa.chunked_array([data]), series_only=True)
with pytest.raises(TypeError, match="Expected to cast to Narwhals Datetime, got: "):
s_pa.cast(nw.Int64)

s_pl = nw.from_native(pl.Series(data), series_only=True)
with pytest.raises(TypeError, match="Expected to cast to Narwhals Datetime, got: "):
s_pl.cast(nw.Int64)

0 comments on commit 6079bb7

Please sign in to comment.