From 155eddedc0e2b68d203cfbc318172396f4293d98 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:00:57 -1000 Subject: [PATCH] Make Timedelta/DatetimeColumn.__init__ strict (#16464) This PR makes Datetime/TimedeltaColumn.__init__ and its subclasses strict putting restrictions on data, dtype, size and children so these columns cannot be constructed into to an invalid state. It also aligns the signature with the base class. xref https://github.com/rapidsai/cudf/issues/16469 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/16464 --- python/cudf/cudf/core/column/column.py | 12 ++----- python/cudf/cudf/core/column/datetime.py | 43 ++++++++++++++++------- python/cudf/cudf/core/column/timedelta.py | 17 +++++---- 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 9785c3e5517..b0e33e8b9ce 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1592,10 +1592,8 @@ def build_column( children=children, ) elif dtype.type is np.datetime64: - if data is None: - raise TypeError("Must specify data buffer") return cudf.core.column.DatetimeColumn( - data=data, + data=data, # type: ignore[arg-type] dtype=dtype, mask=mask, size=size, @@ -1603,10 +1601,8 @@ def build_column( null_count=null_count, ) elif isinstance(dtype, pd.DatetimeTZDtype): - if data is None: - raise TypeError("Must specify data buffer") return cudf.core.column.datetime.DatetimeTZColumn( - data=data, + data=data, # type: ignore[arg-type] dtype=dtype, mask=mask, size=size, @@ -1614,10 +1610,8 @@ def build_column( null_count=null_count, ) elif dtype.type is np.timedelta64: - if data is None: - raise TypeError("Must specify data buffer") return cudf.core.column.TimeDeltaColumn( - data=data, + data=data, # type: ignore[arg-type] dtype=dtype, mask=mask, size=size, diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index 1dbc94384d3..d0ea4612a1b 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -24,6 +24,7 @@ get_compatible_timezone, get_tz_data, ) +from cudf.core.buffer import Buffer from cudf.core.column import ColumnBase, as_column, column, string from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion from cudf.utils.dtypes import _get_base_dtype @@ -34,10 +35,8 @@ ColumnBinaryOperand, DatetimeLikeScalar, Dtype, - DtypeObj, ScalarLike, ) - from cudf.core.buffer import Buffer from cudf.core.column.numerical import NumericalColumn if PANDAS_GE_220: @@ -207,30 +206,39 @@ class DatetimeColumn(column.ColumnBase): def __init__( self, data: Buffer, - dtype: DtypeObj, + size: int | None, + dtype: np.dtype | pd.DatetimeTZDtype, mask: Buffer | None = None, - size: int | None = None, # TODO: make non-optional offset: int = 0, null_count: int | None = None, + children: tuple = (), ): - dtype = cudf.dtype(dtype) - if dtype.kind != "M": - raise TypeError(f"{self.dtype} is not a supported datetime type") - + if not isinstance(data, Buffer): + raise ValueError("data must be a Buffer.") + dtype = self._validate_dtype_instance(dtype) if data.size % dtype.itemsize: raise ValueError("Buffer size must be divisible by element size") if size is None: size = data.size // dtype.itemsize size = size - offset + if len(children) != 0: + raise ValueError(f"{type(self).__name__} must have no children.") super().__init__( - data, + data=data, size=size, dtype=dtype, mask=mask, offset=offset, null_count=null_count, + children=children, ) + @staticmethod + def _validate_dtype_instance(dtype: np.dtype) -> np.dtype: + if not (isinstance(dtype, np.dtype) and dtype.kind == "M"): + raise ValueError("dtype must be a datetime, numpy dtype") + return dtype + def __contains__(self, item: ScalarLike) -> bool: try: ts = pd.Timestamp(item).as_unit(self.time_unit) @@ -858,21 +866,30 @@ class DatetimeTZColumn(DatetimeColumn): def __init__( self, data: Buffer, + size: int | None, dtype: pd.DatetimeTZDtype, mask: Buffer | None = None, - size: int | None = None, offset: int = 0, null_count: int | None = None, + children: tuple = (), ): super().__init__( data=data, - dtype=_get_base_dtype(dtype), - mask=mask, size=size, + dtype=dtype, + mask=mask, offset=offset, null_count=null_count, + children=children, ) - self._dtype = get_compatible_timezone(dtype) + + @staticmethod + def _validate_dtype_instance( + dtype: pd.DatetimeTZDtype, + ) -> pd.DatetimeTZDtype: + if not isinstance(dtype, pd.DatetimeTZDtype): + raise ValueError("dtype must be a pandas.DatetimeTZDtype") + return get_compatible_timezone(dtype) def to_pandas( self, diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index ba0dc4779bb..6b6f3e517a8 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -75,28 +75,33 @@ class TimeDeltaColumn(ColumnBase): def __init__( self, data: Buffer, - dtype: Dtype, - size: int | None = None, # TODO: make non-optional + size: int | None, + dtype: np.dtype, mask: Buffer | None = None, offset: int = 0, null_count: int | None = None, + children: tuple = (), ): - dtype = cudf.dtype(dtype) - if dtype.kind != "m": - raise TypeError(f"{self.dtype} is not a supported duration type") + if not isinstance(data, Buffer): + raise ValueError("data must be a Buffer.") + if not (isinstance(dtype, np.dtype) and dtype.kind == "m"): + raise ValueError("dtype must be a timedelta numpy dtype.") if data.size % dtype.itemsize: raise ValueError("Buffer size must be divisible by element size") if size is None: size = data.size // dtype.itemsize size = size - offset + if len(children) != 0: + raise ValueError("TimedeltaColumn must have no children.") super().__init__( - data, + data=data, size=size, dtype=dtype, mask=mask, offset=offset, null_count=null_count, + children=children, ) def __contains__(self, item: DatetimeLikeScalar) -> bool: