Skip to content

Commit

Permalink
Annotate ColumnAccessor._data labels as Hashable (#16623)
Browse files Browse the repository at this point in the history
The motivating change here is that since we store a dictionary of columns in `ColumnAccessor`, the labels should be `collections.abc.Hashable` and therefore we can type methods that select by key with this annotation.

This led to a mypy-typing-validation cascade that made me type the output of `def as_column(...) -> ColumnBase` which also lead to typing validation in several other files.

Namely there no logic changes here.

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #16623
  • Loading branch information
mroeschke authored Aug 22, 2024
1 parent 81d71fc commit e4e867a
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 69 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/column.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Column:
@property
def mask_ptr(self) -> int: ...
def set_base_mask(self, value: Buffer | None) -> None: ...
def set_mask(self, value: Buffer | None) -> Self: ...
def set_mask(self, value: ColumnBase | Buffer | None) -> Self: ...
@property
def null_count(self) -> int: ...
@property
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/_internals/timezones.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _read_tzfile_as_columns(

# this happens for UTC-like zones
min_date = np.int64(np.iinfo("int64").min + 1).astype("M8[s]")
return (as_column([min_date]), as_column([np.timedelta64(0, "s")]))
return (as_column([min_date]), as_column([np.timedelta64(0, "s")])) # type: ignore[return-value]
return tuple(transition_times_and_offsets) # type: ignore[return-value]


Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,9 +984,9 @@ def find_and_replace(
)
replacement_col = catmap._data["index"].astype(replaced.codes.dtype)

replaced = column.as_column(replaced.codes)
replaced_codes = column.as_column(replaced.codes)
output = libcudf.replace.replace(
replaced, to_replace_col, replacement_col
replaced_codes, to_replace_col, replacement_col
)

result = column.build_categorical_column(
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def _validate_fillna_value(
raise TypeError(
"Cannot set a categorical with non-categorical data"
)
fill_value = fill_value._set_categories(
fill_value = cast(CategoricalColumn, fill_value)._set_categories(
self.categories,
)
return fill_value.codes.astype(self.codes.dtype)
Expand Down
22 changes: 15 additions & 7 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def __setitem__(self, key: Any, value: Any):
"""

# Normalize value to scalar/column
value_normalized = (
value_normalized: cudf.Scalar | ColumnBase = (
cudf.Scalar(value, dtype=self.dtype)
if is_scalar(value)
else as_column(value, dtype=self.dtype)
Expand Down Expand Up @@ -609,9 +609,12 @@ def _scatter_by_slice(
)

# step != 1, create a scatter map with arange
scatter_map = as_column(
rng,
dtype=cudf.dtype(np.int32),
scatter_map = cast(
cudf.core.column.NumericalColumn,
as_column(
rng,
dtype=cudf.dtype(np.int32),
),
)

return self._scatter_by_column(scatter_map, value)
Expand Down Expand Up @@ -1111,11 +1114,16 @@ def argsort(
if (ascending and self.is_monotonic_increasing) or (
not ascending and self.is_monotonic_decreasing
):
return as_column(range(len(self)))
return cast(
cudf.core.column.NumericalColumn, as_column(range(len(self)))
)
elif (ascending and self.is_monotonic_decreasing) or (
not ascending and self.is_monotonic_increasing
):
return as_column(range(len(self) - 1, -1, -1))
return cast(
cudf.core.column.NumericalColumn,
as_column(range(len(self) - 1, -1, -1)),
)
else:
return libcudf.sort.order_by(
[self], [ascending], na_position, stable=True
Expand Down Expand Up @@ -1752,7 +1760,7 @@ def as_column(
nan_as_null: bool | None = None,
dtype: Dtype | None = None,
length: int | None = None,
):
) -> ColumnBase:
"""Create a Column from an arbitrary object
Parameters
Expand Down
7 changes: 5 additions & 2 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ def from_sequences(
offset += len(data)
offset_vals.append(offset)

offset_col = column.as_column(offset_vals, dtype=size_type_dtype)
offset_col = cast(
NumericalColumn,
column.as_column(offset_vals, dtype=size_type_dtype),
)

# Build ListColumn
res = cls(
Expand Down Expand Up @@ -338,7 +341,7 @@ def __init__(self, parent: ParentType):

def get(
self,
index: int,
index: int | ColumnLike,
default: ScalarLike | ColumnLike | None = None,
) -> ParentType:
"""
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __setitem__(self, key: Any, value: Any):
"""

# Normalize value to scalar/column
device_value = (
device_value: cudf.Scalar | ColumnBase = (
cudf.Scalar(
value,
dtype=self.dtype
Expand Down Expand Up @@ -552,7 +552,7 @@ def _validate_fillna_value(
) -> cudf.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if is_scalar(fill_value):
cudf_obj = cudf.Scalar(fill_value)
cudf_obj: cudf.Scalar | ColumnBase = cudf.Scalar(fill_value)
if not as_column(cudf_obj).can_cast_safely(self.dtype):
raise TypeError(
f"Cannot safely cast non-equivalent "
Expand Down
8 changes: 5 additions & 3 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,11 +776,13 @@ def contains(
# TODO: we silently ignore the `regex=` flag here
if case is False:
input_column = libstrings.to_lower(self._column)
pat = libstrings.to_lower(column.as_column(pat, dtype="str"))
col_pat = libstrings.to_lower(
column.as_column(pat, dtype="str")
)
else:
input_column = self._column
pat = column.as_column(pat, dtype="str")
result_col = libstrings.contains_multiple(input_column, pat)
col_pat = column.as_column(pat, dtype="str")
result_col = libstrings.contains_multiple(input_column, col_pat)
return self._return_or_inplace(result_col)

def like(self, pat: str, esc: str | None = None) -> SeriesOrIndex:
Expand Down
Loading

0 comments on commit e4e867a

Please sign in to comment.