Skip to content

Commit

Permalink
Fix astype from tz-aware type to tz-aware type (#16980)
Browse files Browse the repository at this point in the history
closes #16973

Also matches astype from tz-naive to tz-aware type like pandas

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #16980
  • Loading branch information
mroeschke authored Oct 3, 2024
1 parent 7ae5360 commit 2ec6cb3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,11 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn:
if dtype == self.dtype:
return self
elif isinstance(dtype, pd.DatetimeTZDtype):
raise TypeError(
"Cannot use .astype to convert from timezone-naive dtype to timezone-aware dtype. "
"Use tz_localize instead."
)
return libcudf.unary.cast(self, dtype=dtype)

def as_timedelta_column(self, dtype: Dtype) -> None: # type: ignore[override]
Expand Down Expand Up @@ -940,6 +945,16 @@ def strftime(self, format: str) -> cudf.core.column.StringColumn:
def as_string_column(self) -> cudf.core.column.StringColumn:
return self._local_time.as_string_column()

def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn:
if isinstance(dtype, pd.DatetimeTZDtype) and dtype != self.dtype:
if dtype.unit != self.time_unit:
# TODO: Doesn't check that new unit is valid.
casted = self._with_type_metadata(dtype)
else:
casted = self
return casted.tz_convert(str(dtype.tz))
return super().as_datetime_column(dtype)

def get_dt_field(self, field: str) -> ColumnBase:
return libcudf.datetime.extract_datetime_component(
self._local_time, field
Expand Down
22 changes: 22 additions & 0 deletions python/cudf/cudf/tests/series/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,25 @@ def test_pandas_compatible_non_zoneinfo_raises(klass):
with cudf.option_context("mode.pandas_compatible", True):
with pytest.raises(NotImplementedError):
cudf.from_pandas(pandas_obj)


def test_astype_naive_to_aware_raises():
ser = cudf.Series([datetime.datetime(2020, 1, 1)])
with pytest.raises(TypeError):
ser.astype("datetime64[ns, UTC]")
with pytest.raises(TypeError):
ser.to_pandas().astype("datetime64[ns, UTC]")


@pytest.mark.parametrize("unit", ["ns", "us"])
def test_astype_aware_to_aware(unit):
ser = cudf.Series(
[datetime.datetime(2020, 1, 1, tzinfo=datetime.timezone.utc)]
)
result = ser.astype(f"datetime64[{unit}, US/Pacific]")
expected = ser.to_pandas().astype(f"datetime64[{unit}, US/Pacific]")
zoneinfo_type = pd.DatetimeTZDtype(
expected.dtype.unit, zoneinfo.ZoneInfo(str(expected.dtype.tz))
)
expected = ser.astype(zoneinfo_type)
assert_eq(result, expected)

0 comments on commit 2ec6cb3

Please sign in to comment.