From 219ec0e7fbff37d7387d25e93510b55a8782e2bf Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 8 Oct 2024 13:40:40 +0100 Subject: [PATCH] Expunge NamedColumn (#16962) Everything in the expression evaluation now operates on columns without names. DataFrame construction takes either a mapping from string-valued names to columns, or a sequence of pairs of names and columns. This removes some duplicate code in the NamedColumn class (by removing it) where we had to fight the inheritance hierarchy. - Closes #16272 Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Matthew Murray (https://github.com/Matt711) URL: https://github.com/rapidsai/cudf/pull/16962 --- .../cudf_polars/containers/__init__.py | 4 +- .../cudf_polars/containers/column.py | 110 ++++++---------- .../cudf_polars/containers/dataframe.py | 111 ++++++++--------- python/cudf_polars/cudf_polars/dsl/expr.py | 19 ++- python/cudf_polars/cudf_polars/dsl/ir.py | 117 +++++++++--------- python/cudf_polars/docs/overview.md | 18 +-- .../tests/containers/test_column.py | 9 +- .../tests/containers/test_dataframe.py | 39 +++--- .../tests/expressions/test_sort.py | 2 +- .../cudf_polars/tests/utils/test_broadcast.py | 20 +-- 10 files changed, 209 insertions(+), 240 deletions(-) diff --git a/python/cudf_polars/cudf_polars/containers/__init__.py b/python/cudf_polars/cudf_polars/containers/__init__.py index 06bb08953f1..3b1eff4a0d0 100644 --- a/python/cudf_polars/cudf_polars/containers/__init__.py +++ b/python/cudf_polars/cudf_polars/containers/__init__.py @@ -5,7 +5,7 @@ from __future__ import annotations -__all__: list[str] = ["DataFrame", "Column", "NamedColumn"] +__all__: list[str] = ["DataFrame", "Column"] -from cudf_polars.containers.column import Column, NamedColumn +from cudf_polars.containers.column import Column from cudf_polars.containers.dataframe import DataFrame diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index 3fe3e5557cb..00186098e54 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -15,7 +15,7 @@ import polars as pl -__all__: list[str] = ["Column", "NamedColumn"] +__all__: list[str] = ["Column"] class Column: @@ -26,6 +26,9 @@ class Column: order: plc.types.Order null_order: plc.types.NullOrder is_scalar: bool + # Optional name, only ever set by evaluation of NamedExpr nodes + # The internal evaluation should not care about the name. + name: str | None def __init__( self, @@ -34,14 +37,12 @@ def __init__( is_sorted: plc.types.Sorted = plc.types.Sorted.NO, order: plc.types.Order = plc.types.Order.ASCENDING, null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE, + name: str | None = None, ): self.obj = column self.is_scalar = self.obj.size() == 1 - if self.obj.size() <= 1: - is_sorted = plc.types.Sorted.YES - self.is_sorted = is_sorted - self.order = order - self.null_order = null_order + self.name = name + self.set_sorted(is_sorted=is_sorted, order=order, null_order=null_order) @functools.cached_property def obj_scalar(self) -> plc.Scalar: @@ -63,9 +64,26 @@ def obj_scalar(self) -> plc.Scalar: ) return plc.copying.get_element(self.obj, 0) + def rename(self, name: str | None, /) -> Self: + """ + Return a shallow copy with a new name. + + Parameters + ---------- + name + New name + + Returns + ------- + Shallow copy of self with new name set. + """ + new = self.copy() + new.name = name + return new + def sorted_like(self, like: Column, /) -> Self: """ - Copy sortedness properties from a column onto self. + Return a shallow copy with sortedness from like. Parameters ---------- @@ -74,20 +92,23 @@ def sorted_like(self, like: Column, /) -> Self: Returns ------- - Self with metadata set. + Shallow copy of self with metadata set. See Also -------- set_sorted, copy_metadata """ - return self.set_sorted( - is_sorted=like.is_sorted, order=like.order, null_order=like.null_order + return type(self)( + self.obj, + name=self.name, + is_sorted=like.is_sorted, + order=like.order, + null_order=like.null_order, ) - # TODO: Return Column once #16272 is fixed. - def astype(self, dtype: plc.DataType) -> plc.Column: + def astype(self, dtype: plc.DataType) -> Column: """ - Return the backing column as the requested dtype. + Cast the column to as the requested dtype. Parameters ---------- @@ -109,8 +130,10 @@ def astype(self, dtype: plc.DataType) -> plc.Column: the current one. """ if self.obj.type() != dtype: - return plc.unary.cast(self.obj, dtype) - return self.obj + return Column(plc.unary.cast(self.obj, dtype), name=self.name).sorted_like( + self + ) + return self def copy_metadata(self, from_: pl.Series, /) -> Self: """ @@ -129,6 +152,7 @@ def copy_metadata(self, from_: pl.Series, /) -> Self: -------- set_sorted, sorted_like """ + self.name = from_.name if len(from_) <= 1: return self ascending = from_.flags["SORTED_ASC"] @@ -192,6 +216,7 @@ def copy(self) -> Self: is_sorted=self.is_sorted, order=self.order, null_order=self.null_order, + name=self.name, ) def mask_nans(self) -> Self: @@ -217,58 +242,3 @@ def nan_count(self) -> int: ) ).as_py() return 0 - - -class NamedColumn(Column): - """A column with a name.""" - - name: str - - def __init__( - self, - column: plc.Column, - name: str, - *, - is_sorted: plc.types.Sorted = plc.types.Sorted.NO, - order: plc.types.Order = plc.types.Order.ASCENDING, - null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE, - ) -> None: - super().__init__( - column, is_sorted=is_sorted, order=order, null_order=null_order - ) - self.name = name - - def copy(self, *, new_name: str | None = None) -> Self: - """ - A shallow copy of the column. - - Parameters - ---------- - new_name - Optional new name for the copied column. - - Returns - ------- - New column sharing data with self. - """ - return type(self)( - self.obj, - self.name if new_name is None else new_name, - is_sorted=self.is_sorted, - order=self.order, - null_order=self.null_order, - ) - - def mask_nans(self) -> Self: - """Return a shallow copy of self with nans masked out.""" - # Annoying, the inheritance is not right (can't call the - # super-type mask_nans), but will sort that by refactoring - # later. - if plc.traits.is_floating_point(self.obj.type()): - old_count = self.obj.null_count() - mask, new_count = plc.transform.nans_to_nulls(self.obj) - result = type(self)(self.obj.with_mask(mask, new_count), self.name) - if old_count == new_count: - return result.sorted_like(self) - return result - return self.copy() diff --git a/python/cudf_polars/cudf_polars/containers/dataframe.py b/python/cudf_polars/cudf_polars/containers/dataframe.py index f3e3862d0cc..2c195f6637c 100644 --- a/python/cudf_polars/cudf_polars/containers/dataframe.py +++ b/python/cudf_polars/cudf_polars/containers/dataframe.py @@ -5,43 +5,50 @@ from __future__ import annotations -import itertools from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import pyarrow as pa import pylibcudf as plc import polars as pl -from cudf_polars.containers.column import NamedColumn +from cudf_polars.containers import Column from cudf_polars.utils import dtypes if TYPE_CHECKING: - from collections.abc import Mapping, Sequence, Set + from collections.abc import Iterable, Mapping, Sequence, Set from typing_extensions import Self - from cudf_polars.containers import Column - __all__: list[str] = ["DataFrame"] +# Pacify the type checker. DataFrame init asserts that all the columns +# have a string name, so let's narrow the type. +class NamedColumn(Column): + name: str + + class DataFrame: """A representation of a dataframe.""" - columns: list[NamedColumn] + column_map: dict[str, Column] table: plc.Table + columns: list[NamedColumn] - def __init__(self, columns: Sequence[NamedColumn]) -> None: - self.columns = list(columns) - self._column_map = {c.name: c for c in self.columns} - self.table = plc.Table([c.obj for c in columns]) + def __init__(self, columns: Iterable[Column]) -> None: + columns = list(columns) + if any(c.name is None for c in columns): + raise ValueError("All columns must have a name") + self.columns = [cast(NamedColumn, c) for c in columns] + self.column_map = {c.name: c for c in self.columns} + self.table = plc.Table([c.obj for c in self.columns]) def copy(self) -> Self: """Return a shallow copy of self.""" - return type(self)([c.copy() for c in self.columns]) + return type(self)(c.copy() for c in self.columns) def to_polars(self) -> pl.DataFrame: """Convert to a polars DataFrame.""" @@ -51,42 +58,38 @@ def to_polars(self) -> pl.DataFrame: # https://github.com/pola-rs/polars/issues/11632 # To guarantee we produce correct names, we therefore # serialise with names we control and rename with that map. - name_map = {f"column_{i}": c.name for i, c in enumerate(self.columns)} + name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)} table: pa.Table = plc.interop.to_arrow( self.table, [plc.interop.ColumnMetadata(name=name) for name in name_map], ) df: pl.DataFrame = pl.from_arrow(table) return df.rename(name_map).with_columns( - *( - pl.col(c.name).set_sorted( - descending=c.order == plc.types.Order.DESCENDING - ) - if c.is_sorted - else pl.col(c.name) - for c in self.columns - ) + pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING) + if c.is_sorted + else pl.col(c.name) + for c in self.columns ) @cached_property def column_names_set(self) -> frozenset[str]: """Return the column names as a set.""" - return frozenset(c.name for c in self.columns) + return frozenset(self.column_map) @cached_property def column_names(self) -> list[str]: """Return a list of the column names.""" - return [c.name for c in self.columns] + return list(self.column_map) @cached_property def num_columns(self) -> int: """Number of columns.""" - return len(self.columns) + return len(self.column_map) @cached_property def num_rows(self) -> int: """Number of rows.""" - return 0 if len(self.columns) == 0 else self.table.num_rows() + return self.table.num_rows() if self.column_map else 0 @classmethod def from_polars(cls, df: pl.DataFrame) -> Self: @@ -111,12 +114,8 @@ def from_polars(cls, df: pl.DataFrame) -> Self: # No-op if the schema is unchanged. d_table = plc.interop.from_arrow(table.cast(schema)) return cls( - [ - NamedColumn(column, h_col.name).copy_metadata(h_col) - for column, h_col in zip( - d_table.columns(), df.iter_columns(), strict=True - ) - ] + Column(column).copy_metadata(h_col) + for column, h_col in zip(d_table.columns(), df.iter_columns(), strict=True) ) @classmethod @@ -144,17 +143,14 @@ def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self: if table.num_columns() != len(names): raise ValueError("Mismatching name and table length.") return cls( - [ - NamedColumn(c, name) - for c, name in zip(table.columns(), names, strict=True) - ] + Column(c, name=name) for c, name in zip(table.columns(), names, strict=True) ) def sorted_like( self, like: DataFrame, /, *, subset: Set[str] | None = None ) -> Self: """ - Copy sortedness from a dataframe onto self. + Return a shallow copy with sortedness copied from like. Parameters ---------- @@ -165,7 +161,7 @@ def sorted_like( Returns ------- - Self with metadata set. + Shallow copy of self with metadata set. Raises ------ @@ -175,13 +171,12 @@ def sorted_like( if like.column_names != self.column_names: raise ValueError("Can only copy from identically named frame") subset = self.column_names_set if subset is None else subset - self.columns = [ + return type(self)( c.sorted_like(other) if c.name in subset else c for c, other in zip(self.columns, like.columns, strict=True) - ] - return self + ) - def with_columns(self, columns: Sequence[NamedColumn]) -> Self: + def with_columns(self, columns: Iterable[Column], *, replace_only=False) -> Self: """ Return a new dataframe with extra columns. @@ -189,6 +184,8 @@ def with_columns(self, columns: Sequence[NamedColumn]) -> Self: ---------- columns Columns to add + replace_only + If true, then only replacements are allowed (matching by name). Returns ------- @@ -196,36 +193,30 @@ def with_columns(self, columns: Sequence[NamedColumn]) -> Self: Notes ----- - If column names overlap, newer names replace older ones. + If column names overlap, newer names replace older ones, and + appear in the same order as the original frame. """ - columns = list( - {c.name: c for c in itertools.chain(self.columns, columns)}.values() - ) - return type(self)(columns) + new = {c.name: c for c in columns} + if replace_only and not self.column_names_set.issuperset(new.keys()): + raise ValueError("Cannot replace with non-existing names") + return type(self)((self.column_map | new).values()) def discard_columns(self, names: Set[str]) -> Self: """Drop columns by name.""" - return type(self)([c for c in self.columns if c.name not in names]) + return type(self)(column for column in self.columns if column.name not in names) def select(self, names: Sequence[str]) -> Self: """Select columns by name returning DataFrame.""" - want = set(names) - if not want.issubset(self.column_names_set): - raise ValueError("Can't select missing names") - return type(self)([self._column_map[name] for name in names]) - - def replace_columns(self, *columns: NamedColumn) -> Self: - """Return a new dataframe with columns replaced by name.""" - new = {c.name: c for c in columns} - if not set(new).issubset(self.column_names_set): - raise ValueError("Cannot replace with non-existing names") - return type(self)([new.get(c.name, c) for c in self.columns]) + try: + return type(self)(self.column_map[name] for name in names) + except KeyError as e: + raise ValueError("Can't select missing names") from e def rename_columns(self, mapping: Mapping[str, str]) -> Self: """Rename some columns.""" - return type(self)([c.copy(new_name=mapping.get(c.name)) for c in self.columns]) + return type(self)(c.rename(mapping.get(c.name, c.name)) for c in self.columns) - def select_columns(self, names: Set[str]) -> list[NamedColumn]: + def select_columns(self, names: Set[str]) -> list[Column]: """Select columns by name.""" return [c for c in self.columns if c.name in names] diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index a418560b31c..f7775ceb238 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -27,7 +27,7 @@ from polars.exceptions import InvalidOperationError from polars.polars import _expr_nodes as pl_expr -from cudf_polars.containers import Column, NamedColumn +from cudf_polars.containers import Column from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: @@ -313,7 +313,7 @@ def evaluate( *, context: ExecutionContext = ExecutionContext.FRAME, mapping: Mapping[Expr, Column] | None = None, - ) -> NamedColumn: + ) -> Column: """ Evaluate this expression given a dataframe for context. @@ -328,20 +328,15 @@ def evaluate( Returns ------- - NamedColumn attaching a name to an evaluated Column + Evaluated Column with name attached. See Also -------- :meth:`Expr.evaluate` for details, this function just adds the name to a column produced from an expression. """ - obj = self.value.evaluate(df, context=context, mapping=mapping) - return NamedColumn( - obj.obj, - self.name, - is_sorted=obj.is_sorted, - order=obj.order, - null_order=obj.null_order, + return self.value.evaluate(df, context=context, mapping=mapping).rename( + self.name ) def collect_agg(self, *, depth: int) -> AggInfo: @@ -428,7 +423,9 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - return df._column_map[self.name] + # Deliberately remove the name here so that we guarantee + # evaluation of the IR produces names. + return df.column_map[self.name].rename(None) def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 1c61075be22..e319c363a23 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -26,7 +26,7 @@ import polars as pl import cudf_polars.dsl.expr as expr -from cudf_polars.containers import DataFrame, NamedColumn +from cudf_polars.containers import Column, DataFrame from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: @@ -57,9 +57,7 @@ ] -def broadcast( - *columns: NamedColumn, target_length: int | None = None -) -> list[NamedColumn]: +def broadcast(*columns: Column, target_length: int | None = None) -> list[Column]: """ Broadcast a sequence of columns to a common length. @@ -112,12 +110,12 @@ def broadcast( return [ column if column.obj.size() != 1 - else NamedColumn( + else Column( plc.Column.from_scalar(column.obj_scalar, nrows), - column.name, is_sorted=plc.types.Sorted.YES, order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE, + name=column.name, ) for column in columns ] @@ -385,15 +383,17 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: init = plc.interop.from_arrow( pa.scalar(offset, type=plc.interop.to_arrow(dtype)) ) - index = NamedColumn( + index = Column( plc.filling.sequence(df.num_rows, init, step), - name, is_sorted=plc.types.Sorted.YES, order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.AFTER, + name=name, ) df = DataFrame([index, *df.columns]) - assert all(c.obj.type() == self.schema[c.name] for c in df.columns) + assert all( + c.obj.type() == self.schema[name] for name, c in df.column_map.items() + ) if self.predicate is None: return df else: @@ -588,15 +588,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: requests.append(plc.groupby.GroupByRequest(col, [req])) replacements.append(rep) group_keys, raw_tables = grouper.aggregate(requests) - # TODO: names - raw_columns: list[NamedColumn] = [] + raw_columns: list[Column] = [] for i, table in enumerate(raw_tables): (column,) = table.columns() - raw_columns.append(NamedColumn(column, f"tmp{i}")) + raw_columns.append(Column(column, name=f"tmp{i}")) mapping = dict(zip(replacements, raw_columns, strict=True)) result_keys = [ - NamedColumn(gk, k.name) - for gk, k in zip(group_keys.columns(), keys, strict=True) + Column(grouped_key, name=key.name) + for key, grouped_key in zip(keys, group_keys.columns(), strict=True) ] result_subs = DataFrame(raw_columns) results = [ @@ -639,8 +638,8 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: plc.copying.OutOfBoundsPolicy.DONT_CHECK, ) broadcasted = [ - NamedColumn(reordered, b.name) - for reordered, b in zip( + Column(reordered, name=old.name) + for reordered, old in zip( ordered_table.columns(), broadcasted, strict=True ) ] @@ -787,20 +786,20 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: # result, not the gather maps columns = plc.join.cross_join(left.table, right.table).columns() left_cols = [ - NamedColumn(new, old.name).sorted_like(old) + Column(new, name=old.name).sorted_like(old) for new, old in zip( columns[: left.num_columns], left.columns, strict=True ) ] right_cols = [ - NamedColumn( + Column( new, - old.name - if old.name not in left.column_names_set - else f"{old.name}{suffix}", + name=name + if name not in left.column_names_set + else f"{name}{suffix}", ) - for new, old in zip( - columns[left.num_columns :], right.columns, strict=True + for new, name in zip( + columns[left.num_columns :], right.column_names, strict=True ) ] return DataFrame([*left_cols, *right_cols]) @@ -838,18 +837,19 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: plc.copying.gather(right.table, rg, right_policy), right.column_names ) if coalesce and how != "inner": - left = left.replace_columns( - *( - NamedColumn( + left = left.with_columns( + ( + Column( plc.replace.replace_nulls(left_col.obj, right_col.obj), - left_col.name, + name=left_col.name, ) for left_col, right_col in zip( left.select_columns(left_on.column_names_set), right.select_columns(right_on.column_names_set), strict=True, ) - ) + ), + replace_only=True, ) right = right.discard_columns(right_on.column_names_set) if how == "right": @@ -931,9 +931,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: df = self.df.evaluate(cache=cache) if self.subset is None: indices = list(range(df.num_columns)) + keys_sorted = all(c.is_sorted for c in df.column_map.values()) else: indices = [i for i, k in enumerate(df.column_names) if k in self.subset] - keys_sorted = all(df.columns[i].is_sorted for i in indices) + keys_sorted = all(df.column_map[name].is_sorted for name in self.subset) if keys_sorted: table = plc.stream_compaction.unique( df.table, @@ -954,10 +955,11 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: plc.types.NullEquality.EQUAL, plc.types.NanEquality.ALL_EQUAL, ) + # TODO: Is this sortedness setting correct result = DataFrame( [ - NamedColumn(c, old.name).sorted_like(old) - for c, old in zip(table.columns(), df.columns, strict=True) + Column(new, name=old.name).sorted_like(old) + for new, old in zip(table.columns(), df.columns, strict=True) ] ) if keys_sorted or self.stable: @@ -1008,30 +1010,30 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: sort_keys = broadcast( *(k.evaluate(df) for k in self.by), target_length=df.num_rows ) - names = {c.name: i for i, c in enumerate(df.columns)} # TODO: More robust identification here. - keys_in_result = [ - i - for k in sort_keys - if (i := names.get(k.name)) is not None and k.obj is df.columns[i].obj - ] + keys_in_result = { + k.name: i + for i, k in enumerate(sort_keys) + if k.name in df.column_map and k.obj is df.column_map[k.name].obj + } table = self.do_sort( df.table, plc.Table([k.obj for k in sort_keys]), self.order, self.null_order, ) - columns = [ - NamedColumn(c, old.name) - for c, old in zip(table.columns(), df.columns, strict=True) - ] - # If a sort key is in the result table, set the sortedness property - for k, i in enumerate(keys_in_result): - columns[i] = columns[i].set_sorted( - is_sorted=plc.types.Sorted.YES, - order=self.order[k], - null_order=self.null_order[k], - ) + columns: list[Column] = [] + for name, c in zip(df.column_map, table.columns(), strict=True): + column = Column(c, name=name) + # If a sort key is in the result table, set the sortedness property + if name in keys_in_result: + i = keys_in_result[name] + column = column.set_sorted( + is_sorted=plc.types.Sorted.YES, + order=self.order[i], + null_order=self.null_order[i], + ) + columns.append(column) return DataFrame(columns).slice(self.zlice) @@ -1080,7 +1082,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: df = self.df.evaluate(cache=cache) # This can reorder things. columns = broadcast( - *df.select(list(self.schema.keys())).columns, target_length=df.num_rows + *(df.column_map[name] for name in self.schema), target_length=df.num_rows ) return DataFrame(columns) @@ -1125,7 +1127,7 @@ def __post_init__(self) -> None: old, new, _ = self.options # TODO: perhaps polars should validate renaming in the IR? if len(new) != len(set(new)) or ( - set(new) & (set(self.df.schema.keys() - set(old))) + set(new) & (set(self.df.schema.keys()) - set(old)) ): raise NotImplementedError("Duplicate new names in rename.") elif self.name == "unpivot": @@ -1170,7 +1172,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: npiv = len(pivotees) df = self.df.evaluate(cache=cache) index_columns = [ - NamedColumn(col, name) + Column(col, name=name) for col, name in zip( plc.reshape.tile(df.select(indices).table, npiv).columns(), indices, @@ -1191,13 +1193,16 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: df.num_rows, ).columns() value_column = plc.concatenate.concatenate( - [c.astype(self.schema[value_name]) for c in df.select(pivotees).columns] + [ + df.column_map[pivotee].astype(self.schema[value_name]).obj + for pivotee in pivotees + ] ) return DataFrame( [ *index_columns, - NamedColumn(variable_column, variable_name), - NamedColumn(value_column, value_name), + Column(variable_column, name=variable_name), + Column(value_column, name=value_name), ] ) else: @@ -1278,6 +1283,4 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: ) for df in dfs ] - return DataFrame( - list(itertools.chain.from_iterable(df.columns for df in dfs)), - ) + return DataFrame(itertools.chain.from_iterable(df.columns for df in dfs)) diff --git a/python/cudf_polars/docs/overview.md b/python/cudf_polars/docs/overview.md index bff44af1468..7837a275f20 100644 --- a/python/cudf_polars/docs/overview.md +++ b/python/cudf_polars/docs/overview.md @@ -201,21 +201,21 @@ the logical plan in any case, so is reasonably natural. # Containers Containers should be constructed as relatively lightweight objects -around their pylibcudf counterparts. We have four (in +around their pylibcudf counterparts. We have three (in `cudf_polars/containers/`): 1. `Scalar` (a wrapper around a pylibcudf `Scalar`) 2. `Column` (a wrapper around a pylibcudf `Column`) -3. `NamedColumn` (a `Column` with an additional name) -4. `DataFrame` (a wrapper around a pylibcudf `Table`) +3. `DataFrame` (a wrapper around a pylibcudf `Table`) The interfaces offered by these are somewhat in flux, but broadly -speaking, a `DataFrame` is just a list of `NamedColumn`s which each -hold a `Column` plus a string `name`. `NamedColumn`s are only ever -constructed via `NamedExpr`s, which are the top-level expression node -that lives inside an `IR` node. This means that the expression -evaluator never has to concern itself with column names: columns are -only ever decorated with names when constructing a `DataFrame`. +speaking, a `DataFrame` is just a mapping from string `name`s to +`Column`s, and thus also holds a pylibcudf `Table`. Names are only +attached to `Column`s and hence inserted into `DataFrames` via +`NamedExpr`s, which are the top-level expression nodes that live +inside an `IR` node. This means that the expression evaluator never +has to concern itself with column names: columns are only ever +decorated with names when constructing a `DataFrame`. The columns keep track of metadata (for example, whether or not they are sorted). We could imagine tracking more metadata, like minimum and diff --git a/python/cudf_polars/tests/containers/test_column.py b/python/cudf_polars/tests/containers/test_column.py index 19919877f84..1f26ab1af9f 100644 --- a/python/cudf_polars/tests/containers/test_column.py +++ b/python/cudf_polars/tests/containers/test_column.py @@ -3,13 +3,11 @@ from __future__ import annotations -from functools import partial - import pyarrow import pylibcudf as plc import pytest -from cudf_polars.containers import Column, NamedColumn +from cudf_polars.containers import Column def test_non_scalar_access_raises(): @@ -55,11 +53,10 @@ def test_shallow_copy(): @pytest.mark.parametrize("typeid", [plc.TypeId.INT8, plc.TypeId.FLOAT32]) -@pytest.mark.parametrize("constructor", [Column, partial(NamedColumn, name="name")]) -def test_mask_nans(typeid, constructor): +def test_mask_nans(typeid): dtype = plc.DataType(typeid) values = pyarrow.array([0, 0, 0], type=plc.interop.to_arrow(dtype)) - column = constructor(plc.interop.from_arrow(values)) + column = Column(plc.interop.from_arrow(values)) masked = column.mask_nans() assert column.obj.null_count() == masked.obj.null_count() diff --git a/python/cudf_polars/tests/containers/test_dataframe.py b/python/cudf_polars/tests/containers/test_dataframe.py index 39fb44d55a5..5c68fb8f0aa 100644 --- a/python/cudf_polars/tests/containers/test_dataframe.py +++ b/python/cudf_polars/tests/containers/test_dataframe.py @@ -8,18 +8,18 @@ import polars as pl -from cudf_polars.containers import DataFrame, NamedColumn +from cudf_polars.containers import Column, DataFrame from cudf_polars.testing.asserts import assert_gpu_result_equal def test_select_missing_raises(): df = DataFrame( [ - NamedColumn( + Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID ), - "a", + name="a", ) ] ) @@ -30,17 +30,17 @@ def test_select_missing_raises(): def test_replace_missing_raises(): df = DataFrame( [ - NamedColumn( + Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID ), - "a", + name="a", ) ] ) - replacement = df.columns[0].copy(new_name="b") + replacement = df.column_map["a"].copy().rename("b") with pytest.raises(ValueError): - df.replace_columns(replacement) + df.with_columns([replacement], replace_only=True) def test_from_table_wrong_names(): @@ -55,14 +55,23 @@ def test_from_table_wrong_names(): DataFrame.from_table(table, ["a", "b"]) +def test_unnamed_column_raise(): + payload = plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 0, plc.MaskState.ALL_VALID + ) + + with pytest.raises(ValueError): + DataFrame([Column(payload, name="a"), Column(payload)]) + + def test_sorted_like_raises_mismatching_names(): df = DataFrame( [ - NamedColumn( + Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID ), - "a", + name="a", ) ] ) @@ -72,11 +81,11 @@ def test_sorted_like_raises_mismatching_names(): def test_shallow_copy(): - column = NamedColumn( + column = Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID ), - "a", + name="a", ) column.set_sorted( is_sorted=plc.types.Sorted.YES, @@ -85,13 +94,13 @@ def test_shallow_copy(): ) df = DataFrame([column]) copy = df.copy() - copy.columns[0].set_sorted( + copy.column_map["a"].set_sorted( is_sorted=plc.types.Sorted.NO, order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.AFTER, ) - assert df.columns[0].is_sorted == plc.types.Sorted.YES - assert copy.columns[0].is_sorted == plc.types.Sorted.NO + assert df.column_map["a"].is_sorted == plc.types.Sorted.YES + assert copy.column_map["a"].is_sorted == plc.types.Sorted.NO def test_sorted_flags_preserved_empty(): @@ -100,7 +109,7 @@ def test_sorted_flags_preserved_empty(): gf = DataFrame.from_polars(df) - (a,) = gf.columns + a = gf.column_map["a"] assert a.is_sorted == plc.types.Sorted.YES diff --git a/python/cudf_polars/tests/expressions/test_sort.py b/python/cudf_polars/tests/expressions/test_sort.py index 76c7648813a..2a37683478b 100644 --- a/python/cudf_polars/tests/expressions/test_sort.py +++ b/python/cudf_polars/tests/expressions/test_sort.py @@ -69,7 +69,7 @@ def test_setsorted(descending, nulls_last, with_nulls): df = translate_ir(q._ldf.visit()).evaluate(cache={}) - (a,) = df.columns + a = df.column_map["a"] assert a.is_sorted == plc.types.Sorted.YES null_order = ( diff --git a/python/cudf_polars/tests/utils/test_broadcast.py b/python/cudf_polars/tests/utils/test_broadcast.py index 35aaef44e1f..e7770bfadac 100644 --- a/python/cudf_polars/tests/utils/test_broadcast.py +++ b/python/cudf_polars/tests/utils/test_broadcast.py @@ -6,34 +6,35 @@ import pylibcudf as plc import pytest -from cudf_polars.containers import NamedColumn +from cudf_polars.containers import Column from cudf_polars.dsl.ir import broadcast @pytest.mark.parametrize("target", [4, None]) def test_broadcast_all_scalar(target): columns = [ - NamedColumn( + Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), 1, plc.MaskState.ALL_VALID ), - f"col{i}", + name=f"col{i}", ) for i in range(3) ] result = broadcast(*columns, target_length=target) expected = 1 if target is None else target + assert [c.name for c in result] == [f"col{i}" for i in range(3)] assert all(column.obj.size() == expected for column in result) def test_invalid_target_length(): columns = [ - NamedColumn( + Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), 4, plc.MaskState.ALL_VALID ), - f"col{i}", + name=f"col{i}", ) for i in range(3) ] @@ -43,11 +44,11 @@ def test_invalid_target_length(): def test_broadcast_mismatching_column_lengths(): columns = [ - NamedColumn( + Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), i + 1, plc.MaskState.ALL_VALID ), - f"col{i}", + name=f"col{i}", ) for i in range(3) ] @@ -58,16 +59,17 @@ def test_broadcast_mismatching_column_lengths(): @pytest.mark.parametrize("nrows", [0, 5]) def test_broadcast_with_scalars(nrows): columns = [ - NamedColumn( + Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), nrows if i == 0 else 1, plc.MaskState.ALL_VALID, ), - f"col{i}", + name=f"col{i}", ) for i in range(3) ] result = broadcast(*columns) + assert [c.name for c in result] == [f"col{i}" for i in range(3)] assert all(column.obj.size() == nrows for column in result)