Skip to content

Commit

Permalink
Make StructColumn.__init__ strict (#16467)
Browse files Browse the repository at this point in the history
This PR makes `StructColumn.__init__` strict putting restrictions on data, dtype, size and children so these columns cannot be constructed into to an invalid state. It also aligns the signature with the base class.

xref #16469

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #16467
  • Loading branch information
mroeschke authored Aug 19, 2024
1 parent f2d13c9 commit 3f6dd14
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 40 deletions.
13 changes: 7 additions & 6 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,22 +1635,23 @@ def build_column(
)
elif isinstance(dtype, IntervalDtype):
return cudf.core.column.IntervalColumn(
data=None,
size=size, # type: ignore[arg-type]
dtype=dtype,
mask=mask,
size=size,
offset=offset,
children=children,
null_count=null_count,
children=children, # type: ignore[arg-type]
)
elif isinstance(dtype, StructDtype):
return cudf.core.column.StructColumn(
data=data,
dtype=dtype,
data=None,
size=size, # type: ignore[arg-type]
offset=offset,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
children=children, # type: ignore[arg-type]
)
elif isinstance(dtype, cudf.Decimal64Dtype):
return cudf.core.column.Decimal64Column(
Expand Down
71 changes: 47 additions & 24 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,46 @@
from cudf.core.dtypes import IntervalDtype

if TYPE_CHECKING:
from typing_extensions import Self

from cudf._typing import ScalarLike
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase


class IntervalColumn(StructColumn):
def __init__(
self,
dtype,
mask=None,
size=None,
offset=0,
null_count=None,
children=(),
data: None,
size: int,
dtype: IntervalDtype,
mask: Buffer | None = None,
offset: int = 0,
null_count: int | None = None,
children: tuple[ColumnBase, ColumnBase] = (), # type: ignore[assignment]
):
if len(children) != 2:
raise ValueError(
"children must be a tuple of two columns (left edges, right edges)."
)
super().__init__(
data=None,
data=data,
size=size,
dtype=dtype,
mask=mask,
size=size,
offset=offset,
null_count=null_count,
children=children,
)

@staticmethod
def _validate_dtype_instance(dtype: IntervalDtype) -> IntervalDtype:
if not isinstance(dtype, IntervalDtype):
raise ValueError("dtype must be a IntervalDtype.")
return dtype

@classmethod
def from_arrow(cls, data):
def from_arrow(cls, data: pa.Array) -> Self:
new_col = super().from_arrow(data.storage)
size = len(data)
dtype = IntervalDtype.from_arrow(data.type)
Expand All @@ -48,16 +62,17 @@ def from_arrow(cls, data):
null_count = data.null_count
children = new_col.children

return IntervalColumn(
return cls(
data=None,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
children=children, # type: ignore[arg-type]
)

def to_arrow(self):
def to_arrow(self) -> pa.Array:
typ = self.dtype.to_arrow()
struct_arrow = super().to_arrow()
if len(struct_arrow) == 0:
Expand All @@ -67,30 +82,36 @@ def to_arrow(self):
return pa.ExtensionArray.from_storage(typ, struct_arrow)

@classmethod
def from_struct_column(cls, struct_column: StructColumn, closed="right"):
def from_struct_column(
cls,
struct_column: StructColumn,
closed: Literal["left", "right", "both", "neither"] = "right",
) -> Self:
first_field_name = next(iter(struct_column.dtype.fields.keys()))
return IntervalColumn(
return cls(
data=None,
size=struct_column.size,
dtype=IntervalDtype(
struct_column.dtype.fields[first_field_name], closed
),
mask=struct_column.base_mask,
offset=struct_column.offset,
null_count=struct_column.null_count,
children=struct_column.base_children,
children=struct_column.base_children, # type: ignore[arg-type]
)

def copy(self, deep=True):
def copy(self, deep: bool = True) -> Self:
struct_copy = super().copy(deep=deep)
return IntervalColumn(
return IntervalColumn( # type: ignore[return-value]
data=None,
size=struct_copy.size,
dtype=IntervalDtype(
struct_copy.dtype.fields["left"], self.dtype.closed
),
mask=struct_copy.base_mask,
offset=struct_copy.offset,
null_count=struct_copy.null_count,
children=struct_copy.base_children,
children=struct_copy.base_children, # type: ignore[arg-type]
)

@property
Expand Down Expand Up @@ -138,25 +159,27 @@ def overlaps(other) -> ColumnBase:

def set_closed(
self, closed: Literal["left", "right", "both", "neither"]
) -> IntervalColumn:
return IntervalColumn(
) -> Self:
return IntervalColumn( # type: ignore[return-value]
data=None,
size=self.size,
dtype=IntervalDtype(self.dtype.fields["left"], closed),
mask=self.base_mask,
offset=self.offset,
null_count=self.null_count,
children=self.base_children,
children=self.base_children, # type: ignore[arg-type]
)

def as_interval_column(self, dtype):
def as_interval_column(self, dtype: IntervalDtype) -> Self: # type: ignore[override]
if isinstance(dtype, IntervalDtype):
return IntervalColumn(
return IntervalColumn( # type: ignore[return-value]
data=None,
size=self.size,
dtype=dtype,
mask=self.mask,
offset=self.offset,
null_count=self.null_count,
children=tuple(
children=tuple( # type: ignore[arg-type]
child.astype(dtype.subtype) for child in self.children
),
)
Expand Down
50 changes: 41 additions & 9 deletions python/cudf/cudf/core/column/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from cudf.core.missing import NA

if TYPE_CHECKING:
from typing_extensions import Self

from cudf._typing import Dtype
from cudf.core.buffer import Buffer


class StructColumn(ColumnBase):
Expand All @@ -23,10 +26,39 @@ class StructColumn(ColumnBase):
Every column has n children, where n is
the number of fields in the Struct Dtype.
"""

dtype: StructDtype
def __init__(
self,
data: None,
size: int,
dtype: StructDtype,
mask: Buffer | None = None,
offset: int = 0,
null_count: int | None = None,
children: tuple[ColumnBase, ...] = (),
):
if data is not None:
raise ValueError("data must be None.")
dtype = self._validate_dtype_instance(dtype)
super().__init__(
data=data,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
)

@staticmethod
def _validate_dtype_instance(dtype: StructDtype) -> StructDtype:
# IntervalDtype is a subclass of StructDtype, so compare types exactly
if type(dtype) is not StructDtype:
raise ValueError(
f"{type(dtype).__name__} must be a StructDtype exactly."
)
return dtype

@property
def base_size(self):
Expand All @@ -35,7 +67,7 @@ def base_size(self):
else:
return self.size + self.offset

def to_arrow(self):
def to_arrow(self) -> pa.Array:
children = [
pa.nulls(len(child))
if len(child) == child.null_count
Expand All @@ -50,7 +82,7 @@ def to_arrow(self):
}
)

if self.nullable:
if self.mask is not None:
buffers = (pa.py_buffer(self.mask.memoryview()),)
else:
buffers = (None,)
Expand All @@ -73,7 +105,7 @@ def to_pandas(
return pd.Index(self.to_arrow().tolist(), dtype="object")

@cached_property
def memory_usage(self):
def memory_usage(self) -> int:
n = 0
if self.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(self.size)
Expand All @@ -99,23 +131,23 @@ def __setitem__(self, key, value):
value = cudf.Scalar(value, self.dtype)
super().__setitem__(key, value)

def copy(self, deep=True):
def copy(self, deep: bool = True) -> Self:
# Since struct columns are immutable, both deep and
# shallow copies share the underlying device data and mask.
result = super().copy(deep=False)
if deep:
result = result._rename_fields(self.dtype.fields.keys())
return result

def _rename_fields(self, names):
def _rename_fields(self, names) -> Self:
"""
Return a StructColumn with the same field values as this StructColumn,
but with the field names equal to `names`.
"""
dtype = cudf.core.dtypes.StructDtype(
dtype = StructDtype(
{name: col.dtype for name, col in zip(names, self.children)}
)
return StructColumn(
return StructColumn( # type: ignore[return-value]
data=None,
size=self.size,
dtype=dtype,
Expand Down
6 changes: 5 additions & 1 deletion python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3354,6 +3354,7 @@ def interval_range(
return IntervalIndex(data, closed=closed, name=name)

interval_col = IntervalColumn(
data=None,
dtype=IntervalDtype(left_col.dtype, closed),
size=len(left_col),
children=(left_col, right_col),
Expand Down Expand Up @@ -3425,6 +3426,7 @@ def __init__(
elif isinstance(data.dtype, (pd.IntervalDtype, IntervalDtype)):
data = np.array([], dtype=data.dtype.subtype)
interval_col = IntervalColumn(
None,
dtype=IntervalDtype(data.dtype, closed),
size=len(data),
children=(as_column(data), as_column(data)),
Expand All @@ -3436,12 +3438,13 @@ def __init__(
if copy:
col = col.copy()
interval_col = IntervalColumn(
data=None,
dtype=IntervalDtype(col.dtype.subtype, closed),
mask=col.mask,
size=col.size,
offset=col.offset,
null_count=col.null_count,
children=col.children,
children=col.children, # type: ignore[arg-type]
)

if dtype:
Expand Down Expand Up @@ -3517,6 +3520,7 @@ def from_breaks(
)

interval_col = IntervalColumn(
data=None,
dtype=IntervalDtype(left_col.dtype, closed),
size=len(left_col),
children=(left_col, right_col),
Expand Down

0 comments on commit 3f6dd14

Please sign in to comment.