From 6079bb7fb2f5e3c581756d3e9fbd0d00d7362a1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Mon, 23 Dec 2024 18:45:20 +0100 Subject: [PATCH] wip --- narwhals/dtypes.py | 12 ++++++++++++ narwhals/series.py | 2 ++ tests/expr_and_series/cast_test.py | 16 ++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 57ee762eb..233590df8 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -24,6 +24,18 @@ def _validate_dtype(dtype: DType | type[DType]) -> None: raise TypeError(msg) +def _validate_datetime( + dtype: DType | type[DType], series_dtype: DType | type[DType] +) -> None: + from narwhals.dtypes import Datetime + + if isinstance_or_issubclass(series_dtype, Datetime) and not ( + isinstance_or_issubclass(dtype, Datetime) + ): + msg = f"Expected to cast to Narwhals Datetime, got: {dtype}.\n\n" + raise TypeError(msg) + + class DType: def __repr__(self) -> str: # pragma: no cover return self.__class__.__qualname__ diff --git a/narwhals/series.py b/narwhals/series.py index f37943969..dcd7a5819 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -12,6 +12,7 @@ from typing import overload from narwhals.dependencies import is_numpy_scalar +from narwhals.dtypes import _validate_datetime from narwhals.dtypes import _validate_dtype from narwhals.typing import IntoSeriesT from narwhals.utils import _validate_rolling_arguments @@ -630,6 +631,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: ] """ _validate_dtype(dtype) + _validate_datetime(dtype, self.dtype) return self._from_compliant_series(self._compliant_series.cast(dtype)) def to_frame(self) -> DataFrame[Any]: diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 992ea5f54..5c5d2d862 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -7,6 +7,7 @@ import pandas as pd import polars as pl +import pyarrow as pa import pytest import narwhals.stable.v1 as nw @@ -226,3 +227,18 @@ def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) with pytest.raises(TypeError, match="Expected Narwhals dtype, got:"): df.select(nw.col("a").cast(dtype)) + + +def test_raise_datetime_to_numeric() -> None: + data = [datetime(year=2020, month=1, day=1, hour=0, second=0, minute=0)] + s_pd = nw.from_native(pd.Series(data), series_only=True) + with pytest.raises(TypeError, match="Expected to cast to Narwhals Datetime, got: "): + s_pd.cast(nw.Int64) + + s_pa = nw.from_native(pa.chunked_array([data]), series_only=True) + with pytest.raises(TypeError, match="Expected to cast to Narwhals Datetime, got: "): + s_pa.cast(nw.Int64) + + s_pl = nw.from_native(pl.Series(data), series_only=True) + with pytest.raises(TypeError, match="Expected to cast to Narwhals Datetime, got: "): + s_pl.cast(nw.Int64)