diff --git a/python/cudf/cudf/_lib/concat.pyx b/python/cudf/cudf/_lib/concat.pyx index e661059faa3..e6c2d136f0d 100644 --- a/python/cudf/cudf/_lib/concat.pyx +++ b/python/cudf/cudf/_lib/concat.pyx @@ -23,9 +23,9 @@ def concat_columns(object columns): def concat_tables(object tables, bool ignore_index=False): plc_tables = [] for table in tables: - cols = table._data.columns + cols = table._columns if not ignore_index: - cols = table._index._data.columns + cols + cols = table._index._columns + cols plc_tables.append(pylibcudf.Table([c.to_pylibcudf(mode="read") for c in cols])) return data_from_pylibcudf_table( diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index 16182e31c08..49714091f46 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -384,7 +384,7 @@ cdef class _CPackedColumns: p.column_names = input_table._column_names p.column_dtypes = {} - for name, col in input_table._data.items(): + for name, col in input_table._column_labels_and_values: if isinstance(col.dtype, cudf.core.dtypes._BaseDtype): p.column_dtypes[name] = col.dtype diff --git a/python/cudf/cudf/_lib/csv.pyx b/python/cudf/cudf/_lib/csv.pyx index 058e884e08b..9ad96f610b3 100644 --- a/python/cudf/cudf/_lib/csv.pyx +++ b/python/cudf/cudf/_lib/csv.pyx @@ -273,7 +273,7 @@ def read_csv( elif isinstance(dtype, abc.Collection): for index, col_dtype in enumerate(dtype): if isinstance(cudf.dtype(col_dtype), cudf.CategoricalDtype): - col_name = df._data.names[index] + col_name = df._column_names[index] df._data[col_name] = df._data[col_name].astype(col_dtype) if names is not None and len(names) and isinstance(names[0], int): diff --git a/python/cudf/cudf/_lib/io/utils.pyx b/python/cudf/cudf/_lib/io/utils.pyx index b1900138d94..564daefbae2 100644 --- a/python/cudf/cudf/_lib/io/utils.pyx +++ b/python/cudf/cudf/_lib/io/utils.pyx @@ -179,7 +179,7 @@ cdef update_struct_field_names( ): # Deprecated, remove in favor of add_col_struct_names # when a reader is ported to pylibcudf - for i, (name, col) in enumerate(table._data.items()): + for i, (name, col) in enumerate(table._column_labels_and_values): table._data[name] = update_column_struct_field_names( col, schema_info[i] ) diff --git a/python/cudf/cudf/_lib/parquet.pyx b/python/cudf/cudf/_lib/parquet.pyx index e6c9d60b05b..fa2690c7f21 100644 --- a/python/cudf/cudf/_lib/parquet.pyx +++ b/python/cudf/cudf/_lib/parquet.pyx @@ -235,16 +235,16 @@ cdef object _process_metadata(object df, df._index = idx elif set(index_col).issubset(names): index_data = df[index_col] - actual_index_names = list(index_col_names.values()) - if len(index_data._data) == 1: + actual_index_names = iter(index_col_names.values()) + if index_data._num_columns == 1: idx = cudf.Index._from_column( - index_data._data.columns[0], - name=actual_index_names[0] + index_data._columns[0], + name=next(actual_index_names) ) else: idx = cudf.MultiIndex.from_frame( index_data, - names=actual_index_names + names=list(actual_index_names) ) df.drop(columns=index_col, inplace=True) df._index = idx @@ -252,7 +252,7 @@ cdef object _process_metadata(object df, if use_pandas_metadata: df.index.names = index_col - if len(df._data.names) == 0 and column_index_type is not None: + if df._num_columns == 0 and column_index_type is not None: df._data.label_dtype = cudf.dtype(column_index_type) return df diff --git a/python/cudf/cudf/_lib/utils.pyx b/python/cudf/cudf/_lib/utils.pyx index cae28d02ef4..8660cca9322 100644 --- a/python/cudf/cudf/_lib/utils.pyx +++ b/python/cudf/cudf/_lib/utils.pyx @@ -49,9 +49,9 @@ cdef table_view table_view_from_table(tbl, ignore_index=False) except*: If True, don't include the index in the columns. """ return table_view_from_columns( - tbl._index._data.columns + tbl._data.columns + tbl._index._columns + tbl._columns if not ignore_index and tbl._index is not None - else tbl._data.columns + else tbl._columns ) @@ -62,7 +62,7 @@ cpdef generate_pandas_metadata(table, index): index_descriptors = [] columns_to_convert = list(table._columns) # Columns - for name, col in table._data.items(): + for name, col in table._column_labels_and_values: if cudf.get_option("mode.pandas_compatible"): # in pandas-compat mode, non-string column names are stringified. col_names.append(str(name)) diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index ff114474aa4..a6abd63d042 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -1951,7 +1951,7 @@ def drop_duplicates( return self._from_columns_like_self( drop_duplicates( list(self._columns), - keys=range(len(self._data)), + keys=range(len(self._columns)), keep=keep, nulls_are_equal=nulls_are_equal, ), diff --git a/python/cudf/cudf/core/column_accessor.py b/python/cudf/cudf/core/column_accessor.py index 09b0f453692..bc093fdaa9a 100644 --- a/python/cudf/cudf/core/column_accessor.py +++ b/python/cudf/cudf/core/column_accessor.py @@ -151,9 +151,9 @@ def __setitem__(self, key: abc.Hashable, value: ColumnBase) -> None: self.set_by_label(key, value) def __delitem__(self, key: abc.Hashable) -> None: - old_ncols = len(self._data) + old_ncols = len(self) del self._data[key] - new_ncols = len(self._data) + new_ncols = len(self) self._clear_cache(old_ncols, new_ncols) def __len__(self) -> int: @@ -213,7 +213,7 @@ def level_names(self) -> tuple[abc.Hashable, ...]: @property def nlevels(self) -> int: - if len(self._data) == 0: + if len(self) == 0: return 0 if not self.multiindex: return 1 @@ -226,7 +226,7 @@ def name(self) -> abc.Hashable: @cached_property def nrows(self) -> int: - if len(self._data) == 0: + if len(self) == 0: return 0 else: return len(next(iter(self.values()))) @@ -257,9 +257,9 @@ def _clear_cache(self, old_ncols: int, new_ncols: int) -> None: Parameters ---------- old_ncols: int - len(self._data) before self._data was modified + len(self) before self._data was modified new_ncols: int - len(self._data) after self._data was modified + len(self) after self._data was modified """ cached_properties = ("columns", "names", "_grouped_data") for attr in cached_properties: @@ -335,7 +335,7 @@ def insert( if name in self._data: raise ValueError(f"Cannot insert '{name}', already exists") - old_ncols = len(self._data) + old_ncols = len(self) if loc == -1: loc = old_ncols elif not (0 <= loc <= old_ncols): @@ -414,7 +414,7 @@ def get_labels_by_index(self, index: Any) -> tuple: tuple """ if isinstance(index, slice): - start, stop, step = index.indices(len(self._data)) + start, stop, step = index.indices(len(self)) return self.names[start:stop:step] elif pd.api.types.is_integer(index): return (self.names[index],) @@ -526,9 +526,9 @@ def set_by_label(self, key: abc.Hashable, value: ColumnBase) -> None: if len(self) > 0 and len(value) != self.nrows: raise ValueError("All columns must be of equal length") - old_ncols = len(self._data) + old_ncols = len(self) self._data[key] = value - new_ncols = len(self._data) + new_ncols = len(self) self._clear_cache(old_ncols, new_ncols) def _select_by_label_list_like(self, key: tuple) -> Self: @@ -718,12 +718,12 @@ def droplevel(self, level: int) -> None: if level < 0: level += self.nlevels - old_ncols = len(self._data) + old_ncols = len(self) self._data = { _remove_key_level(key, level): value # type: ignore[arg-type] for key, value in self._data.items() } - new_ncols = len(self._data) + new_ncols = len(self) self._level_names = ( self._level_names[:level] + self._level_names[level + 1 :] ) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index d73ad8225ca..16b0aa95c35 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -176,7 +176,7 @@ def _can_downcast_to_series(self, df, arg): return False @_performance_tracking - def _downcast_to_series(self, df, arg): + def _downcast_to_series(self, df: DataFrame, arg): """ "Downcast" from a DataFrame to a Series based on Pandas indexing rules @@ -203,16 +203,16 @@ def _downcast_to_series(self, df, arg): # take series along the axis: if axis == 1: - return df[df._data.names[0]] + return df[df._column_names[0]] else: if df._num_columns > 0: dtypes = df.dtypes.values.tolist() normalized_dtype = np.result_type(*dtypes) - for name, col in df._data.items(): + for name, col in df._column_labels_and_values: df[name] = col.astype(normalized_dtype) sr = df.T - return sr[sr._data.names[0]] + return sr[sr._column_names[0]] class _DataFrameLocIndexer(_DataFrameIndexer): @@ -258,7 +258,7 @@ def _getitem_tuple_arg(self, arg): and len(arg) > 1 and is_scalar(arg[1]) ): - return result._data.columns[0].element_indexing(0) + return result._columns[0].element_indexing(0) return result else: if isinstance(arg[0], slice): @@ -310,7 +310,7 @@ def _getitem_tuple_arg(self, arg): else: tmp_col_name = str(uuid4()) cantor_name = "_" + "_".join( - map(str, columns_df._data.names) + map(str, columns_df._column_names) ) if columns_df._data.multiindex: # column names must be appropriate length tuples @@ -1412,7 +1412,7 @@ def __setitem__(self, arg, value): else column.column_empty_like( col, masked=True, newsize=length ) - for key, col in self._data.items() + for key, col in self._column_labels_and_values ) self._data = self._data._from_columns_like_self( new_columns, verify=False @@ -1494,8 +1494,8 @@ def __delitem__(self, name): @_performance_tracking def memory_usage(self, index=True, deep=False) -> cudf.Series: - mem_usage = [col.memory_usage for col in self._data.columns] - names = [str(name) for name in self._data.names] + mem_usage = [col.memory_usage for col in self._columns] + names = [str(name) for name in self._column_names] if index: mem_usage.append(self.index.memory_usage()) names.append("Index") @@ -1725,7 +1725,7 @@ def _concat( [] if are_all_range_index or (ignore_index and not empty_has_index) - else list(f.index._data.columns) + else list(f.index._columns) ) + [f._data[name] if name in f._data else None for name in names] for f in objs @@ -1808,7 +1808,7 @@ def _concat( out.index.dtype, cudf.CategoricalDtype ): out = out.set_index(out.index) - for name, col in out._data.items(): + for name, col in out._column_labels_and_values: out._data[name] = col._with_type_metadata( tables[0]._data[name].dtype ) @@ -1831,13 +1831,13 @@ def astype( errors: Literal["raise", "ignore"] = "raise", ): if is_dict_like(dtype): - if len(set(dtype.keys()) - set(self._data.names)) > 0: + if len(set(dtype.keys()) - set(self._column_names)) > 0: raise KeyError( "Only a column name can be used for the " "key in a dtype mappings argument." ) else: - dtype = {cc: dtype for cc in self._data.names} + dtype = {cc: dtype for cc in self._column_names} return super().astype(dtype, copy, errors) def _clean_renderable_dataframe(self, output): @@ -2601,7 +2601,7 @@ def equals(self, other) -> bool: # If all other checks matched, validate names. if ret: for self_name, other_name in zip( - self._data.names, other._data.names + self._column_names, other._column_names ): if self_name != other_name: ret = False @@ -2676,7 +2676,7 @@ def columns(self, columns): ) self._data = ColumnAccessor( - data=dict(zip(pd_columns, self._data.columns)), + data=dict(zip(pd_columns, self._columns)), multiindex=multiindex, level_names=level_names, label_dtype=label_dtype, @@ -2698,7 +2698,7 @@ def _set_columns_like(self, other: ColumnAccessor) -> None: f"got {len(self)} elements" ) self._data = ColumnAccessor( - data=dict(zip(other.names, self._data.columns)), + data=dict(zip(other.names, self._columns)), multiindex=other.multiindex, rangeindex=other.rangeindex, level_names=other.level_names, @@ -2983,7 +2983,7 @@ def set_index( elif isinstance(col, (MultiIndex, pd.MultiIndex)): if isinstance(col, pd.MultiIndex): col = MultiIndex.from_pandas(col) - data_to_add.extend(col._data.columns) + data_to_add.extend(col._columns) names.extend(col.names) elif isinstance( col, (cudf.Series, cudf.Index, pd.Series, pd.Index) @@ -3110,7 +3110,9 @@ def where(self, cond, other=None, inplace=False, axis=None, level=None): ) out = [] - for (name, col), other_col in zip(self._data.items(), other_cols): + for (name, col), other_col in zip( + self._column_labels_and_values, other_cols + ): source_col, other_col = _check_and_cast_columns_with_other( source_col=col, other=other_col, @@ -3314,7 +3316,7 @@ def _insert(self, loc, name, value, nan_as_null=None, ignore_index=True): column.column_empty_like( col_data, masked=True, newsize=length ) - for col_data in self._data.values() + for col_data in self._columns ), verify=False, ) @@ -3664,7 +3666,7 @@ def rename( name: col.find_and_replace( to_replace, vals, is_all_na ) - for name, col in self.index._data.items() + for name, col in self.index._column_labels_and_values } ) except OverflowError: @@ -3686,9 +3688,7 @@ def add_prefix(self, prefix, axis=None): raise NotImplementedError("axis is currently not implemented.") # TODO: Change to deep=False when copy-on-write is default out = self.copy(deep=True) - out.columns = [ - prefix + col_name for col_name in list(self._data.keys()) - ] + out.columns = [prefix + col_name for col_name in self._column_names] return out @_performance_tracking @@ -3697,9 +3697,7 @@ def add_suffix(self, suffix, axis=None): raise NotImplementedError("axis is currently not implemented.") # TODO: Change to deep=False when copy-on-write is default out = self.copy(deep=True) - out.columns = [ - col_name + suffix for col_name in list(self._data.keys()) - ] + out.columns = [col_name + suffix for col_name in self._column_names] return out @_performance_tracking @@ -4805,7 +4803,7 @@ def _func(x): # pragma: no cover # TODO: naive implementation # this could be written as a single kernel result = {} - for name, col in self._data.items(): + for name, col in self._column_labels_and_values: apply_sr = Series._from_column(col) result[name] = apply_sr.apply(_func)._column @@ -5444,7 +5442,7 @@ def to_pandas( out_index = self.index.to_pandas() out_data = { i: col.to_pandas(nullable=nullable, arrow_type=arrow_type) - for i, col in enumerate(self._data.columns) + for i, col in enumerate(self._columns) } out_df = pd.DataFrame(out_data, index=out_index) @@ -5665,14 +5663,16 @@ def to_arrow(self, preserve_index=None) -> pa.Table: index = index._as_int_index() index.name = "__index_level_0__" if isinstance(index, MultiIndex): - index_descr = list(index._data.names) + index_descr = index._column_names index_levels = index.levels else: index_descr = ( index.names if index.name is not None else ("index",) ) data = data.copy(deep=False) - for gen_name, col_name in zip(index_descr, index._data.names): + for gen_name, col_name in zip( + index_descr, index._column_names + ): data._insert( data.shape[1], gen_name, @@ -5681,7 +5681,7 @@ def to_arrow(self, preserve_index=None) -> pa.Table: out = super(DataFrame, data).to_arrow() metadata = pa.pandas_compat.construct_metadata( - columns_to_convert=[self[col] for col in self._data.names], + columns_to_convert=[self[col] for col in self._column_names], df=self, column_names=out.schema.names, index_levels=index_levels, @@ -5724,12 +5724,12 @@ def to_records(self, index=True, column_dtypes=None, index_dtypes=None): "column_dtypes is currently not supported." ) members = [("index", self.index.dtype)] if index else [] - members += [(col, self[col].dtype) for col in self._data.names] + members += list(self._dtypes) dtype = np.dtype(members) ret = np.recarray(len(self), dtype=dtype) if index: ret["index"] = self.index.to_numpy() - for col in self._data.names: + for col in self._column_names: ret[col] = self[col].to_numpy() return ret @@ -6059,7 +6059,7 @@ def quantile( ) if columns is None: - columns = data_df._data.names + columns = set(data_df._column_names) if isinstance(q, numbers.Number): q_is_number = True @@ -6084,7 +6084,7 @@ def quantile( # Ensure that qs is non-scalar so that we always get a column back. interpolation = interpolation or "linear" result = {} - for k in data_df._data.names: + for k in data_df._column_names: if k in columns: ser = data_df[k] res = ser.quantile( @@ -6198,7 +6198,7 @@ def make_false_column_like_self(): if isinstance(values, DataFrame) else {name: values._column for name in self._data} ) - for col, self_col in self._data.items(): + for col, self_col in self._column_labels_and_values: if col in other_cols: other_col = other_cols[col] self_is_cat = isinstance(self_col, CategoricalColumn) @@ -6231,13 +6231,13 @@ def make_false_column_like_self(): else: result[col] = make_false_column_like_self() elif is_dict_like(values): - for name, col in self._data.items(): + for name, col in self._column_labels_and_values: if name in values: result[name] = col.isin(values[name]) else: result[name] = make_false_column_like_self() elif is_list_like(values): - for name, col in self._data.items(): + for name, col in self._column_labels_and_values: result[name] = col.isin(values) else: raise TypeError( @@ -6292,7 +6292,7 @@ def _prepare_for_rowwise_op(self, method, skipna, numeric_only): name: filtered._data[name]._get_mask_as_column() if filtered._data[name].nullable else as_column(True, length=len(filtered._data[name])) - for name in filtered._data.names + for name in filtered._column_names } ) mask = mask.all(axis=1) @@ -6342,7 +6342,7 @@ def count(self, axis=0, numeric_only=False): length = len(self) return Series._from_column( as_column([length - col.null_count for col in self._columns]), - index=cudf.Index(self._data.names), + index=cudf.Index(self._column_names), ) _SUPPORT_AXIS_LOOKUP = { @@ -6409,7 +6409,7 @@ def _reduce( return source._apply_cupy_method_axis_1(op, **kwargs) else: axis_0_results = [] - for col_label, col in source._data.items(): + for col_label, col in source._column_labels_and_values: try: axis_0_results.append(getattr(col, op)(**kwargs)) except AttributeError as err: @@ -6634,7 +6634,7 @@ def _apply_cupy_method_axis_1(self, method, *args, **kwargs): prepared, mask, common_dtype = self._prepare_for_rowwise_op( method, skipna, numeric_only ) - for col in prepared._data.names: + for col in prepared._column_names: if prepared._data[col].nullable: prepared._data[col] = ( prepared._data[col] @@ -6820,7 +6820,7 @@ def select_dtypes(self, include=None, exclude=None): # remove all exclude types inclusion = inclusion - exclude_subtypes - for k, col in self._data.items(): + for k, col in self._column_labels_and_values: infered_type = cudf_dtype_from_pydata_dtype(col.dtype) if infered_type in inclusion: df._insert(len(df._data), k, col) @@ -7192,7 +7192,7 @@ def stack(self, level=-1, dropna=no_default, future_stack=False): # Compute the column indices that serves as the input for # `interleave_columns` column_idx_df = pd.DataFrame( - data=range(len(self._data)), index=named_levels + data=range(self._num_columns), index=named_levels ) column_indices: list[list[int]] = [] @@ -7392,17 +7392,17 @@ def to_struct(self, name=None): ----- Note: a copy of the columns is made. """ - if not all(isinstance(name, str) for name in self._data.names): + if not all(isinstance(name, str) for name in self._column_names): warnings.warn( "DataFrame contains non-string column name(s). Struct column " "requires field name to be string. Non-string column names " "will be casted to string as the field name." ) - fields = {str(name): col.dtype for name, col in self._data.items()} + fields = {str(name): dtype for name, dtype in self._dtypes} col = StructColumn( data=None, dtype=cudf.StructDtype(fields=fields), - children=tuple(col.copy(deep=True) for col in self._data.columns), + children=tuple(col.copy(deep=True) for col in self._columns), size=len(self), offset=0, ) @@ -7984,7 +7984,7 @@ def value_counts( diff = set(subset) - set(self._data) if len(diff) != 0: raise KeyError(f"columns {diff} do not exist") - columns = list(self._data.names) if subset is None else subset + columns = list(self._column_names) if subset is None else subset result = ( self.groupby( by=columns, @@ -8105,7 +8105,7 @@ def func(left, right, output): right._column_names ) elif _is_scalar_or_zero_d_array(right): - for name, col in output._data.items(): + for name, col in output._column_labels_and_values: output._data[name] = col.fillna(value) return output else: @@ -8387,7 +8387,7 @@ def extract_col(df, col): and col not in df.index._data and not isinstance(df.index, MultiIndex) ): - return df.index._data.columns[0] + return df.index._column return df.index._data[col] diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 7b2bc85b13b..98af006f6e5 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -75,8 +75,15 @@ def _columns(self) -> tuple[ColumnBase, ...]: return self._data.columns @property - def _dtypes(self) -> abc.Iterable: - return zip(self._data.names, (col.dtype for col in self._data.columns)) + def _column_labels_and_values( + self, + ) -> abc.Iterable[tuple[abc.Hashable, ColumnBase]]: + return zip(self._column_names, self._columns) + + @property + def _dtypes(self) -> abc.Generator[tuple[abc.Hashable, Dtype], None, None]: + for label, col in self._column_labels_and_values: + yield label, col.dtype @property def ndim(self) -> int: @@ -87,7 +94,7 @@ def serialize(self): # TODO: See if self._data can be serialized outright header = { "type-serialized": pickle.dumps(type(self)), - "column_names": pickle.dumps(tuple(self._data.names)), + "column_names": pickle.dumps(self._column_names), "column_rangeindex": pickle.dumps(self._data.rangeindex), "column_multiindex": pickle.dumps(self._data.multiindex), "column_label_dtype": pickle.dumps(self._data.label_dtype), @@ -156,7 +163,7 @@ def _mimic_inplace( self, result: Self, inplace: bool = False ) -> Self | None: if inplace: - for col in self._data: + for col in self._column_names: if col in result._data: self._data[col]._mimic_inplace( result._data[col], inplace=True @@ -267,7 +274,7 @@ def __len__(self) -> int: def astype(self, dtype: dict[Any, Dtype], copy: bool = False) -> Self: casted = ( col.astype(dtype.get(col_name, col.dtype), copy=copy) - for col_name, col in self._data.items() + for col_name, col in self._column_labels_and_values ) ca = self._data._from_columns_like_self(casted, verify=False) return self._from_data_like_self(ca) @@ -338,9 +345,7 @@ def equals(self, other) -> bool: return all( self_col.equals(other_col, check_dtypes=True) - for self_col, other_col in zip( - self._data.values(), other._data.values() - ) + for self_col, other_col in zip(self._columns, other._columns) ) @_performance_tracking @@ -434,11 +439,9 @@ def to_array( if dtype is None: if ncol == 1: - dtype = next(iter(self._data.values())).dtype + dtype = next(self._dtypes)[1] else: - dtype = find_common_type( - [col.dtype for col in self._data.values()] - ) + dtype = find_common_type([dtype for _, dtype in self._dtypes]) if not isinstance(dtype, numpy.dtype): raise NotImplementedError( @@ -446,12 +449,12 @@ def to_array( ) if self.ndim == 1: - return to_array(self._data.columns[0], dtype) + return to_array(self._columns[0], dtype) else: matrix = module.empty( shape=(len(self), ncol), dtype=dtype, order="F" ) - for i, col in enumerate(self._data.values()): + for i, col in enumerate(self._columns): # TODO: col.values may fail if there is nullable data or an # unsupported dtype. We may want to catch and provide a more # suitable error. @@ -751,7 +754,7 @@ def fillna( filled_columns = [ col.fillna(value[name], method) if name in value else col.copy() - for name, col in self._data.items() + for name, col in self._column_labels_and_values ] return self._mimic_inplace( @@ -988,7 +991,10 @@ def to_arrow(self): index: [[1,2,3]] """ return pa.Table.from_pydict( - {str(name): col.to_arrow() for name, col in self._data.items()} + { + str(name): col.to_arrow() + for name, col in self._column_labels_and_values + } ) @_performance_tracking @@ -1012,7 +1018,9 @@ def _copy_type_metadata(self: Self, other: Self) -> Self: See `ColumnBase._with_type_metadata` for more information. """ - for (name, col), (_, dtype) in zip(self._data.items(), other._dtypes): + for (name, col), (_, dtype) in zip( + self._column_labels_and_values, other._dtypes + ): self._data.set_by_label(name, col._with_type_metadata(dtype)) return self @@ -1422,7 +1430,7 @@ def _split(self, splits): """ return [ self._from_columns_like_self( - libcudf.copying.columns_split([*self._data.columns], splits)[ + libcudf.copying.columns_split(list(self._columns), splits)[ split_idx ], self._column_names, @@ -1432,7 +1440,7 @@ def _split(self, splits): @_performance_tracking def _encode(self): - columns, indices = libcudf.transform.table_encode([*self._columns]) + columns, indices = libcudf.transform.table_encode(list(self._columns)) keys = self._from_columns_like_self(columns) return keys, indices @@ -1578,7 +1586,7 @@ def __neg__(self): col.unary_operator("not") if col.dtype.kind == "b" else -1 * col - for col in self._data.columns + for col in self._columns ) ) ) @@ -1840,9 +1848,7 @@ def __copy__(self): def __invert__(self): """Bitwise invert (~) for integral dtypes, logical NOT for bools.""" return self._from_data_like_self( - self._data._from_columns_like_self( - (~col for col in self._data.columns) - ) + self._data._from_columns_like_self((~col for col in self._columns)) ) @_performance_tracking diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 6424c8af877..cb8cd0cd28b 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -751,10 +751,8 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): ) and not libgroupby._is_all_scan_aggregate(normalized_aggs): # Even with `sort=False`, pandas guarantees that # groupby preserves the order of rows within each group. - left_cols = list( - self.grouping.keys.drop_duplicates()._data.columns - ) - right_cols = list(result_index._data.columns) + left_cols = list(self.grouping.keys.drop_duplicates()._columns) + right_cols = list(result_index._columns) join_keys = [ _match_join_keys(lcol, rcol, "left") for lcol, rcol in zip(left_cols, right_cols) @@ -1483,7 +1481,7 @@ def _post_process_chunk_results( # the column name should be, especially if we applied # a nameless UDF. result = result.to_frame( - name=grouped_values._data.names[0] + name=grouped_values._column_names[0] ) else: index_data = group_keys._data.copy(deep=True) @@ -1632,7 +1630,7 @@ def mult(df): if func in {"sum", "product"}: # For `sum` & `product`, boolean types # will need to result in `int64` type. - for name, col in res._data.items(): + for name, col in res._column_labels_and_values: if col.dtype.kind == "b": res._data[name] = col.astype("int") return res @@ -2715,11 +2713,8 @@ class DataFrameGroupBy(GroupBy, GetAttrGetItemMixin): def _reduce_numeric_only(self, op: str): columns = list( name - for name in self.obj._data.names - if ( - is_numeric_dtype(self.obj._data[name].dtype) - and name not in self.grouping.names - ) + for name, dtype in self.obj._dtypes + if (is_numeric_dtype(dtype) and name not in self.grouping.names) ) return self[columns].agg(op) @@ -3209,7 +3204,7 @@ def values(self) -> cudf.core.frame.Frame: """ # If the key columns are in `obj`, filter them out value_column_names = [ - x for x in self._obj._data.names if x not in self._named_columns + x for x in self._obj._column_names if x not in self._named_columns ] value_columns = self._obj._data.select_by_label(value_column_names) return self._obj.__class__._from_data(value_columns) @@ -3224,8 +3219,8 @@ def _handle_series(self, by): self.names.append(by.name) def _handle_index(self, by): - self._key_columns.extend(by._data.columns) - self.names.extend(by._data.names) + self._key_columns.extend(by._columns) + self.names.extend(by._column_names) def _handle_mapping(self, by): by = cudf.Series(by.values(), index=by.keys()) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index b2bd20c4982..cd07c58c5d9 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -122,13 +122,13 @@ def _lexsorted_equal_range( sort_inds = None sort_vals = idx lower_bound = search_sorted( - [*sort_vals._data.columns], + list(sort_vals._columns), keys, side="left", ascending=sort_vals.is_monotonic_increasing, ).element_indexing(0) upper_bound = search_sorted( - [*sort_vals._data.columns], + list(sort_vals._columns), keys, side="right", ascending=sort_vals.is_monotonic_increasing, @@ -286,6 +286,20 @@ def name(self): def name(self, value): self._name = value + @property + @_performance_tracking + def _column_names(self) -> tuple[Any]: + return (self.name,) + + @property + @_performance_tracking + def _columns(self) -> tuple[ColumnBase]: + return (self._values,) + + @property + def _column_labels_and_values(self) -> Iterable: + return zip(self._column_names, self._columns) + @property # type: ignore @_performance_tracking def start(self) -> int: @@ -1068,7 +1082,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): else: inputs = { name: (col, None, False, None) - for name, col in self._data.items() + for name, col in self._column_labels_and_values } data = self._apply_cupy_ufunc_to_operands( diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index fd6bf37f0e6..810d4ad74e7 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -294,7 +294,7 @@ def _num_rows(self) -> int: @property def _index_names(self) -> tuple[Any, ...]: # TODO: Tuple[str]? - return self.index._data.names + return self.index._column_names @classmethod def _from_data( @@ -307,6 +307,7 @@ def _from_data( raise ValueError( f"index must be None or a cudf.Index not {type(index).__name__}" ) + # out._num_rows requires .index to be defined out._index = RangeIndex(out._data.nrows) if index is None else index return out @@ -882,7 +883,7 @@ def replace( columns_dtype_map=dict(self._dtypes), ) copy_data = [] - for name, col in self._data.items(): + for name, col in self._column_labels_and_values: try: replaced = col.find_and_replace( to_replace_per_column[name], @@ -2703,11 +2704,11 @@ def sort_index( by.extend( filter( lambda n: n not in handled, - self.index._data.names, + self.index._column_names, ) ) else: - by = list(idx._data.names) + by = list(idx._column_names) inds = idx._get_sorted_inds( by=by, ascending=ascending, na_position=na_position @@ -3013,7 +3014,7 @@ def _slice(self, arg: slice, keep_index: bool = True) -> Self: columns_to_slice = [ *( - self.index._data.columns + self.index._columns if keep_index and not has_range_index else [] ), @@ -3210,7 +3211,7 @@ def _empty_like(self, keep_index=True) -> Self: result = self._from_columns_like_self( libcudf.copying.columns_empty_like( [ - *(self.index._data.columns if keep_index else ()), + *(self.index._columns if keep_index else ()), *self._columns, ] ), @@ -3227,7 +3228,7 @@ def _split(self, splits, keep_index=True): columns_split = libcudf.copying.columns_split( [ - *(self.index._data.columns if keep_index else []), + *(self.index._columns if keep_index else []), *self._columns, ], splits, @@ -3763,8 +3764,8 @@ def _reindex( idx_dtype_match = (df.index.nlevels == index.nlevels) and all( _is_same_dtype(left_dtype, right_dtype) for left_dtype, right_dtype in zip( - (col.dtype for col in df.index._data.columns), - (col.dtype for col in index._data.columns), + (dtype for _, dtype in df.index._dtypes), + (dtype for _, dtype in index._dtypes), ) ) @@ -3783,7 +3784,7 @@ def _reindex( (name or 0) if isinstance(self, cudf.Series) else name: col - for name, col in df._data.items() + for name, col in df._column_labels_and_values }, index=df.index, ) @@ -3794,7 +3795,7 @@ def _reindex( index = index if index is not None else df.index if column_names is None: - names = list(df._data.names) + names = list(df._column_names) level_names = self._data.level_names multiindex = self._data.multiindex rangeindex = self._data.rangeindex @@ -3948,7 +3949,7 @@ def round(self, decimals=0, how="half_even"): col.round(decimals[name], how=how) if name in decimals and col.dtype.kind in "fiu" else col.copy(deep=True) - for name, col in self._data.items() + for name, col in self._column_labels_and_values ) return self._from_data_like_self( self._data._from_columns_like_self(cols) @@ -4270,7 +4271,7 @@ def _drop_na_columns(self, how="any", subset=None, thresh=None): else: thresh = len(df) - for name, col in df._data.items(): + for name, col in df._column_labels_and_values: check_col = col.nans_to_nulls() no_threshold_valid_count = ( len(col) - check_col.null_count @@ -4305,7 +4306,7 @@ def _drop_na_rows(self, how="any", subset=None, thresh=None): return self._from_columns_like_self( libcudf.stream_compaction.drop_nulls( - [*self.index._data.columns, *data_columns], + [*self.index._columns, *data_columns], how=how, keys=self._positions_from_column_names( subset, offset_by_index_columns=True @@ -4853,7 +4854,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # This works for Index too inputs = { name: (col, None, False, None) - for name, col in self._data.items() + for name, col in self._column_labels_and_values } index = self.index @@ -4933,7 +4934,7 @@ def repeat(self, repeats, axis=None): """ res = self._from_columns_like_self( Frame._repeat( - [*self.index._data.columns, *self._columns], repeats, axis + [*self.index._columns, *self._columns], repeats, axis ), self._column_names, self._index_names, @@ -6224,7 +6225,7 @@ def _preprocess_subset(self, subset): not np.iterable(subset) or isinstance(subset, str) or isinstance(subset, tuple) - and subset in self._data.names + and subset in self._column_names ): subset = (subset,) diff = set(subset) - set(self._data) @@ -6306,8 +6307,8 @@ def rank( ) numeric_cols = ( name - for name in self._data.names - if _is_non_decimal_numeric_dtype(self._data[name]) + for name, dtype in self._dtypes + if _is_non_decimal_numeric_dtype(dtype) ) source = self._get_columns_by_label(numeric_cols) if source.empty: diff --git a/python/cudf/cudf/core/join/join.py b/python/cudf/cudf/core/join/join.py index b65bc7af832..cfeaca00888 100644 --- a/python/cudf/cudf/core/join/join.py +++ b/python/cudf/cudf/core/join/join.py @@ -140,11 +140,15 @@ def __init__( # right_on. self._using_left_index = bool(left_index) left_on = ( - lhs.index._data.names if left_index else left_on if left_on else on + lhs.index._column_names + if left_index + else left_on + if left_on + else on ) self._using_right_index = bool(right_index) right_on = ( - rhs.index._data.names + rhs.index._column_names if right_index else right_on if right_on @@ -334,18 +338,18 @@ def _merge_results( # All columns from the left table make it into the output. Non-key # columns that share a name with a column in the right table are # suffixed with the provided suffix. - common_names = set(left_result._data.names) & set( - right_result._data.names + common_names = set(left_result._column_names) & set( + right_result._column_names ) cols_to_suffix = common_names - self._key_columns_with_same_name data = { (f"{name}{self.lsuffix}" if name in cols_to_suffix else name): col - for name, col in left_result._data.items() + for name, col in left_result._column_labels_and_values } # The right table follows the same rule as the left table except that # key columns from the right table are removed. - for name, col in right_result._data.items(): + for name, col in right_result._column_labels_and_values: if name in common_names: if name not in self._key_columns_with_same_name: data[f"{name}{self.rsuffix}"] = col @@ -399,7 +403,7 @@ def _sort_result(self, result: cudf.DataFrame) -> cudf.DataFrame: # producing the input result. by: list[Any] = [] if self._using_left_index and self._using_right_index: - by.extend(result.index._data.columns) + by.extend(result.index._columns) if not self._using_left_index: by.extend([result._data[col.name] for col in self._left_keys]) if not self._using_right_index: diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index b86ad38c944..6de3981ba66 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -233,8 +233,8 @@ def names(self, value): # to unexpected behavior in some cases. This is # definitely buggy, but we can't disallow non-unique # names either... - self._data = self._data.__class__( - dict(zip(value, self._data.values())), + self._data = type(self._data)( + dict(zip(value, self._columns)), level_names=self._data.level_names, verify=False, ) @@ -693,19 +693,25 @@ def where(self, cond, other=None, inplace=False): @_performance_tracking def _compute_validity_mask(self, index, row_tuple, max_length): """Computes the valid set of indices of values in the lookup""" - lookup = cudf.DataFrame() + lookup_dict = {} for i, row in enumerate(row_tuple): if isinstance(row, slice) and row == slice(None): continue - lookup[i] = cudf.Series(row) - frame = cudf.DataFrame(dict(enumerate(index._data.columns))) + lookup_dict[i] = row + lookup = cudf.DataFrame(lookup_dict) + frame = cudf.DataFrame._from_data( + ColumnAccessor(dict(enumerate(index._columns)), verify=False) + ) with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) data_table = cudf.concat( [ frame, cudf.DataFrame._from_data( - {"idx": column.as_column(range(len(frame)))} + ColumnAccessor( + {"idx": column.as_column(range(len(frame)))}, + verify=False, + ) ), ], axis=1, @@ -716,7 +722,7 @@ def _compute_validity_mask(self, index, row_tuple, max_length): # TODO: Remove this after merge/join # obtain deterministic ordering. if cudf.get_option("mode.pandas_compatible"): - lookup_order = "_" + "_".join(map(str, lookup._data.names)) + lookup_order = "_" + "_".join(map(str, lookup._column_names)) lookup[lookup_order] = column.as_column(range(len(lookup))) postprocess = operator.methodcaller( "sort_values", by=[lookup_order, "idx"] @@ -784,7 +790,7 @@ def _index_and_downcast(self, result, index, index_key): out_index.insert( out_index._num_columns, k, - cudf.Series._from_column(index._data.columns[k]), + cudf.Series._from_column(index._columns[k]), ) # determine if we should downcast from a DataFrame to a Series @@ -800,19 +806,19 @@ def _index_and_downcast(self, result, index, index_key): ) if need_downcast: result = result.T - return result[result._data.names[0]] + return result[result._column_names[0]] if len(result) == 0 and not slice_access: # Pandas returns an empty Series with a tuple as name # the one expected result column result = cudf.Series._from_data( - {}, name=tuple(col[0] for col in index._data.columns) + {}, name=tuple(col[0] for col in index._columns) ) elif out_index._num_columns == 1: # If there's only one column remaining in the output index, convert # it into an Index and name the final index values according # to that column's name. - *_, last_column = index._data.columns + last_column = index._columns[-1] out_index = cudf.Index._from_column( last_column, name=index.names[-1] ) @@ -894,7 +900,7 @@ def __eq__(self, other): [ self_col.equals(other_col) for self_col, other_col in zip( - self._data.values(), other._data.values() + self._columns, other._columns ) ] ) @@ -1475,10 +1481,10 @@ def swaplevel(self, i=-2, j=-1) -> Self: ('aa', 'b')], ) """ - name_i = self._data.names[i] if isinstance(i, int) else i - name_j = self._data.names[j] if isinstance(j, int) else j + name_i = self._column_names[i] if isinstance(i, int) else i + name_j = self._column_names[j] if isinstance(j, int) else j new_data = {} - for k, v in self._data.items(): + for k, v in self._column_labels_and_values: if k not in (name_i, name_j): new_data[k] = v elif k == name_i: @@ -1916,7 +1922,7 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None): join_keys = [ _match_join_keys(lcol, rcol, "inner") - for lcol, rcol in zip(target._data.columns, self._data.columns) + for lcol, rcol in zip(target._columns, self._columns) ] join_keys = map(list, zip(*join_keys)) scatter_map, indices = libcudf.join.join( @@ -2113,7 +2119,7 @@ def _split_columns_by_levels( lv if isinstance(lv, int) else level_names.index(lv) for lv in levels } - for i, (name, col) in enumerate(zip(self.names, self._data.columns)): + for i, (name, col) in enumerate(zip(self.names, self._columns)): if in_levels and i in level_indices: name = f"level_{i}" if name is None else name yield name, col @@ -2154,9 +2160,7 @@ def _columns_for_reset_index( ) -> Generator[tuple[Any, column.ColumnBase], None, None]: """Return the columns and column names for .reset_index""" if levels is None: - for i, (col, name) in enumerate( - zip(self._data.columns, self.names) - ): + for i, (col, name) in enumerate(zip(self._columns, self.names)): yield f"level_{i}" if name is None else name, col else: yield from self._split_columns_by_levels(levels, in_levels=True) diff --git a/python/cudf/cudf/core/reshape.py b/python/cudf/cudf/core/reshape.py index c951db00c9a..401fef67ee6 100644 --- a/python/cudf/cudf/core/reshape.py +++ b/python/cudf/cudf/core/reshape.py @@ -410,7 +410,7 @@ def concat( result_columns = None if keys_objs is None: for o in objs: - for name, col in o._data.items(): + for name, col in o._column_labels_and_values: if name in result_data: raise NotImplementedError( f"A Column with duplicate name found: {name}, cuDF " @@ -438,7 +438,7 @@ def concat( else: # All levels in the multiindex label must have the same type has_multiple_level_types = ( - len({type(name) for o in objs for name in o._data.keys()}) > 1 + len({type(name) for o in objs for name in o._column_names}) > 1 ) if has_multiple_level_types: raise NotImplementedError( @@ -447,7 +447,7 @@ def concat( "the labels to the same type." ) for k, o in zip(keys_objs, objs): - for name, col in o._data.items(): + for name, col in o._column_labels_and_values: # if only series, then only keep keys_objs as column labels # if the existing column is multiindex, prepend it # to handle cases where dfs and srs are concatenated @@ -843,7 +843,7 @@ def get_dummies( else: result_data = { col_name: col - for col_name, col in data._data.items() + for col_name, col in data._column_labels_and_values if col_name not in columns } @@ -943,7 +943,7 @@ def _merge_sorted( columns = [ [ - *(obj.index._data.columns if not ignore_index else ()), + *(obj.index._columns if not ignore_index else ()), *obj._columns, ] for obj in objs @@ -985,7 +985,7 @@ def as_tuple(x): return x if isinstance(x, tuple) else (x,) nrows = len(index_labels) - for col_label, col in df._data.items(): + for col_label, col in df._column_labels_and_values: names = [ as_tuple(col_label) + as_tuple(name) for name in column_labels ] @@ -1009,7 +1009,7 @@ def as_tuple(x): ca = ColumnAccessor( result, multiindex=True, - level_names=(None,) + columns._data.names, + level_names=(None,) + columns._column_names, verify=False, ) return cudf.DataFrame._from_data( @@ -1087,11 +1087,7 @@ def pivot(data, columns=None, index=no_default, values=no_default): # Create a DataFrame composed of columns from both # columns and index ca = ColumnAccessor( - dict( - enumerate( - itertools.chain(index._data.columns, columns._data.columns) - ) - ), + dict(enumerate(itertools.chain(index._columns, columns._columns))), verify=False, ) columns_index = cudf.DataFrame._from_data(ca) @@ -1560,7 +1556,7 @@ def pivot_table( if values_passed and not values_multi and table._data.multiindex: column_names = table._data.level_names[1:] table_columns = tuple( - map(lambda column: column[1:], table._data.names) + map(lambda column: column[1:], table._column_names) ) table.columns = pd.MultiIndex.from_tuples( tuples=table_columns, names=column_names diff --git a/python/cudf/cudf/core/tools/datetimes.py b/python/cudf/cudf/core/tools/datetimes.py index 7197560b5a4..68f34fa28ff 100644 --- a/python/cudf/cudf/core/tools/datetimes.py +++ b/python/cudf/cudf/core/tools/datetimes.py @@ -186,7 +186,7 @@ def to_datetime( if isinstance(arg, cudf.DataFrame): # we require at least Ymd required = ["year", "month", "day"] - req = list(set(required) - set(arg._data.names)) + req = list(set(required) - set(arg._column_names)) if len(req): err_req = ",".join(req) raise ValueError( @@ -196,7 +196,7 @@ def to_datetime( ) # replace passed column name with values in _unit_map - got_units = {k: get_units(k) for k in arg._data.names} + got_units = {k: get_units(k) for k in arg._column_names} unit_rev = {v: k for k, v in got_units.items()} # keys we don't recognize diff --git a/python/cudf/cudf/core/udf/groupby_utils.py b/python/cudf/cudf/core/udf/groupby_utils.py index 265b87350ae..3af662b62ea 100644 --- a/python/cudf/cudf/core/udf/groupby_utils.py +++ b/python/cudf/cudf/core/udf/groupby_utils.py @@ -210,7 +210,7 @@ def _can_be_jitted(frame, func, args): # See https://github.com/numba/numba/issues/4587 return False - if any(col.has_nulls() for col in frame._data.values()): + if any(col.has_nulls() for col in frame._columns): return False np_field_types = np.dtype( list( diff --git a/python/cudf/cudf/core/udf/utils.py b/python/cudf/cudf/core/udf/utils.py index 6d7362952c9..bfe716f0afc 100644 --- a/python/cudf/cudf/core/udf/utils.py +++ b/python/cudf/cudf/core/udf/utils.py @@ -126,25 +126,23 @@ def _get_udf_return_type(argty, func: Callable, args=()): def _all_dtypes_from_frame(frame, supported_types=JIT_SUPPORTED_TYPES): return { - colname: col.dtype - if str(col.dtype) in supported_types - else np.dtype("O") - for colname, col in frame._data.items() + colname: dtype if str(dtype) in supported_types else np.dtype("O") + for colname, dtype in frame._dtypes } def _supported_dtypes_from_frame(frame, supported_types=JIT_SUPPORTED_TYPES): return { - colname: col.dtype - for colname, col in frame._data.items() - if str(col.dtype) in supported_types + colname: dtype + for colname, dtype in frame._dtypes + if str(dtype) in supported_types } def _supported_cols_from_frame(frame, supported_types=JIT_SUPPORTED_TYPES): return { colname: col - for colname, col in frame._data.items() + for colname, col in frame._column_labels_and_values if str(col.dtype) in supported_types } @@ -232,8 +230,8 @@ def _generate_cache_key(frame, func: Callable, args, suffix="__APPLY_UDF"): *cudautils.make_cache_key( func, tuple(_all_dtypes_from_frame(frame).values()) ), - *(col.mask is None for col in frame._data.values()), - *frame._data.keys(), + *(col.mask is None for col in frame._columns), + *frame._column_names, scalar_argtypes, suffix, ) diff --git a/python/cudf/cudf/io/csv.py b/python/cudf/cudf/io/csv.py index a9c20150930..3dc8915bfd1 100644 --- a/python/cudf/cudf/io/csv.py +++ b/python/cudf/cudf/io/csv.py @@ -186,13 +186,13 @@ def to_csv( "Dataframe doesn't have the labels provided in columns" ) - for col in df._data.columns: - if isinstance(col, cudf.core.column.ListColumn): + for _, dtype in df._dtypes: + if isinstance(dtype, cudf.ListDtype): raise NotImplementedError( "Writing to csv format is not yet supported with " "list columns." ) - elif isinstance(col, cudf.core.column.StructColumn): + elif isinstance(dtype, cudf.StructDtype): raise NotImplementedError( "Writing to csv format is not yet supported with " "Struct columns." @@ -203,12 +203,11 @@ def to_csv( # workaround once following issue is fixed: # https://github.com/rapidsai/cudf/issues/6661 if any( - isinstance(col, cudf.core.column.CategoricalColumn) - for col in df._data.columns + isinstance(dtype, cudf.CategoricalDtype) for _, dtype in df._dtypes ) or isinstance(df.index, cudf.CategoricalIndex): df = df.copy(deep=False) - for col_name, col in df._data.items(): - if isinstance(col, cudf.core.column.CategoricalColumn): + for col_name, col in df._column_labels_and_values: + if isinstance(col.dtype, cudf.CategoricalDtype): df._data[col_name] = col.astype(col.categories.dtype) if isinstance(df.index, cudf.CategoricalIndex): diff --git a/python/cudf/cudf/io/dlpack.py b/python/cudf/cudf/io/dlpack.py index 1347b2cc38f..fe8e446f9c0 100644 --- a/python/cudf/cudf/io/dlpack.py +++ b/python/cudf/cudf/io/dlpack.py @@ -79,13 +79,13 @@ def to_dlpack(cudf_obj): ) if any( - not cudf.api.types._is_non_decimal_numeric_dtype(col.dtype) - for col in gdf._data.columns + not cudf.api.types._is_non_decimal_numeric_dtype(dtype) + for _, dtype in gdf._dtypes ): raise TypeError("non-numeric data not yet supported") dtype = cudf.utils.dtypes.find_common_type( - [col.dtype for col in gdf._data.columns] + [dtype for _, dtype in gdf._dtypes] ) gdf = gdf.astype(dtype) diff --git a/python/cudf/cudf/io/orc.py b/python/cudf/cudf/io/orc.py index fd246c6215f..c54293badbe 100644 --- a/python/cudf/cudf/io/orc.py +++ b/python/cudf/cudf/io/orc.py @@ -396,8 +396,8 @@ def to_orc( ): """{docstring}""" - for col in df._data.columns: - if isinstance(col, cudf.core.column.CategoricalColumn): + for _, dtype in df._dtypes: + if isinstance(dtype, cudf.CategoricalDtype): raise NotImplementedError( "Writing to ORC format is not yet supported with " "Categorical columns." diff --git a/python/cudf/cudf/testing/testing.py b/python/cudf/cudf/testing/testing.py index 31ad24a4664..668e7a77454 100644 --- a/python/cudf/cudf/testing/testing.py +++ b/python/cudf/cudf/testing/testing.py @@ -676,7 +676,7 @@ def assert_frame_equal( if check_like: left, right = left.reindex(index=right.index), right - right = right[list(left._data.names)] + right = right[list(left._column_names)] # index comparison assert_index_equal( diff --git a/python/cudf/cudf/tests/test_multiindex.py b/python/cudf/cudf/tests/test_multiindex.py index b1e095e8853..c41be3e4428 100644 --- a/python/cudf/cudf/tests/test_multiindex.py +++ b/python/cudf/cudf/tests/test_multiindex.py @@ -813,8 +813,8 @@ def test_multiindex_copy_deep(data, copy_on_write, deep): mi1 = gdf.groupby(["Date", "Symbol"]).mean().index mi2 = mi1.copy(deep=deep) - lchildren = [col.children for _, col in mi1._data.items()] - rchildren = [col.children for _, col in mi2._data.items()] + lchildren = [col.children for col in mi1._columns] + rchildren = [col.children for col in mi2._columns] # Flatten lchildren = reduce(operator.add, lchildren) @@ -849,12 +849,8 @@ def test_multiindex_copy_deep(data, copy_on_write, deep): assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) # Assert ._data identity - lptrs = [ - d.base_data.get_ptr(mode="read") for _, d in mi1._data.items() - ] - rptrs = [ - d.base_data.get_ptr(mode="read") for _, d in mi2._data.items() - ] + lptrs = [d.base_data.get_ptr(mode="read") for d in mi1._columns] + rptrs = [d.base_data.get_ptr(mode="read") for d in mi2._columns] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) cudf.set_option("copy_on_write", original_cow_setting)