Skip to content

Commit

Permalink
Make Timedelta/DatetimeColumn.__init__ strict (#16464)
Browse files Browse the repository at this point in the history
This PR makes Datetime/TimedeltaColumn.__init__ and its subclasses strict putting restrictions on data, dtype, size and children so these columns cannot be constructed into to an invalid state. It also aligns the signature with the base class.

xref #16469

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #16464
  • Loading branch information
mroeschke authored Aug 16, 2024
1 parent 5084135 commit 155edde
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 28 deletions.
12 changes: 3 additions & 9 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,32 +1592,26 @@ def build_column(
children=children,
)
elif dtype.type is np.datetime64:
if data is None:
raise TypeError("Must specify data buffer")
return cudf.core.column.DatetimeColumn(
data=data,
data=data, # type: ignore[arg-type]
dtype=dtype,
mask=mask,
size=size,
offset=offset,
null_count=null_count,
)
elif isinstance(dtype, pd.DatetimeTZDtype):
if data is None:
raise TypeError("Must specify data buffer")
return cudf.core.column.datetime.DatetimeTZColumn(
data=data,
data=data, # type: ignore[arg-type]
dtype=dtype,
mask=mask,
size=size,
offset=offset,
null_count=null_count,
)
elif dtype.type is np.timedelta64:
if data is None:
raise TypeError("Must specify data buffer")
return cudf.core.column.TimeDeltaColumn(
data=data,
data=data, # type: ignore[arg-type]
dtype=dtype,
mask=mask,
size=size,
Expand Down
43 changes: 30 additions & 13 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_compatible_timezone,
get_tz_data,
)
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, as_column, column, string
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.utils.dtypes import _get_base_dtype
Expand All @@ -34,10 +35,8 @@
ColumnBinaryOperand,
DatetimeLikeScalar,
Dtype,
DtypeObj,
ScalarLike,
)
from cudf.core.buffer import Buffer
from cudf.core.column.numerical import NumericalColumn

if PANDAS_GE_220:
Expand Down Expand Up @@ -207,30 +206,39 @@ class DatetimeColumn(column.ColumnBase):
def __init__(
self,
data: Buffer,
dtype: DtypeObj,
size: int | None,
dtype: np.dtype | pd.DatetimeTZDtype,
mask: Buffer | None = None,
size: int | None = None, # TODO: make non-optional
offset: int = 0,
null_count: int | None = None,
children: tuple = (),
):
dtype = cudf.dtype(dtype)
if dtype.kind != "M":
raise TypeError(f"{self.dtype} is not a supported datetime type")

if not isinstance(data, Buffer):
raise ValueError("data must be a Buffer.")
dtype = self._validate_dtype_instance(dtype)
if data.size % dtype.itemsize:
raise ValueError("Buffer size must be divisible by element size")
if size is None:
size = data.size // dtype.itemsize
size = size - offset
if len(children) != 0:
raise ValueError(f"{type(self).__name__} must have no children.")
super().__init__(
data,
data=data,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
)

@staticmethod
def _validate_dtype_instance(dtype: np.dtype) -> np.dtype:
if not (isinstance(dtype, np.dtype) and dtype.kind == "M"):
raise ValueError("dtype must be a datetime, numpy dtype")
return dtype

def __contains__(self, item: ScalarLike) -> bool:
try:
ts = pd.Timestamp(item).as_unit(self.time_unit)
Expand Down Expand Up @@ -858,21 +866,30 @@ class DatetimeTZColumn(DatetimeColumn):
def __init__(
self,
data: Buffer,
size: int | None,
dtype: pd.DatetimeTZDtype,
mask: Buffer | None = None,
size: int | None = None,
offset: int = 0,
null_count: int | None = None,
children: tuple = (),
):
super().__init__(
data=data,
dtype=_get_base_dtype(dtype),
mask=mask,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
)
self._dtype = get_compatible_timezone(dtype)

@staticmethod
def _validate_dtype_instance(
dtype: pd.DatetimeTZDtype,
) -> pd.DatetimeTZDtype:
if not isinstance(dtype, pd.DatetimeTZDtype):
raise ValueError("dtype must be a pandas.DatetimeTZDtype")
return get_compatible_timezone(dtype)

def to_pandas(
self,
Expand Down
17 changes: 11 additions & 6 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,33 @@ class TimeDeltaColumn(ColumnBase):
def __init__(
self,
data: Buffer,
dtype: Dtype,
size: int | None = None, # TODO: make non-optional
size: int | None,
dtype: np.dtype,
mask: Buffer | None = None,
offset: int = 0,
null_count: int | None = None,
children: tuple = (),
):
dtype = cudf.dtype(dtype)
if dtype.kind != "m":
raise TypeError(f"{self.dtype} is not a supported duration type")
if not isinstance(data, Buffer):
raise ValueError("data must be a Buffer.")
if not (isinstance(dtype, np.dtype) and dtype.kind == "m"):
raise ValueError("dtype must be a timedelta numpy dtype.")

if data.size % dtype.itemsize:
raise ValueError("Buffer size must be divisible by element size")
if size is None:
size = data.size // dtype.itemsize
size = size - offset
if len(children) != 0:
raise ValueError("TimedeltaColumn must have no children.")
super().__init__(
data,
data=data,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
)

def __contains__(self, item: DatetimeLikeScalar) -> bool:
Expand Down

0 comments on commit 155edde

Please sign in to comment.