diff --git a/python/cudf_polars/cudf_polars/expressions.py b/python/cudf_polars/cudf_polars/expressions.py index c13e6b6d309..9ccdffdcbae 100644 --- a/python/cudf_polars/cudf_polars/expressions.py +++ b/python/cudf_polars/cudf_polars/expressions.py @@ -30,7 +30,7 @@ ) if TYPE_CHECKING: - from typing_extensions import Self + from collections.abc import Sequence from cudf_polars.typing import ColumnType, Expr, Visitor @@ -46,37 +46,55 @@ def __init__(self, visitor: Visitor): self.visitor = visitor self.in_groupby = False - def __call__(self, node: int, context: DataFrame) -> ColumnType: + def add_expressions( + self, expressions: Sequence[Expr] + ) -> tuple[list[int], int]: """ - Return the evaluation of an expression node in a context. + Add expressions to the expression graph. Parameters ---------- - node - The node to evaluate - context - The dataframe providing context + expressions + List of expressions to add Returns ------- - New column as the evaluation of the expression. + tuple of list of node ids and the total number of node ids in the + expression graph after adding the expressions. """ - return evaluate_expr(self.visitor.view_expression(node), context, self) + return self.visitor.add_expressions(expressions) - def with_replacements(self, mapping: list[tuple[int, Expr]]) -> Self: + def set_mapping(self, mapping: list[int]): """ - Return a new visitor with nodes replaced by new ones. + Set the node mapping for rewiring the expression graph. Parameters ---------- mapping - List of pairs mapping node ids to their replacement expression. + List mapping old expression ids to new ones. + """ + self.visitor.set_expr_mapping(mapping) + + def unset_mapping(self): + """Unset the node mapping.""" + self.visitor.unset_expr_mapping() + + def __call__(self, node: int, context: DataFrame) -> ColumnType: + """ + Return the evaluation of an expression node in a context. + + Parameters + ---------- + node + The node to evaluate + context + The dataframe providing context Returns ------- - New node visitor with replaced expressions. + New column as the evaluation of the expression. """ - return type(self)(self.visitor.replace_expressions(mapping)) + return evaluate_expr(self.visitor.view_expression(node), context, self) @singledispatch @@ -622,8 +640,15 @@ def _post_aggregate( for agg in agg_expr: mapping.append((agg, newcol)) context = DataFrame(context) - v = visitor.with_replacements(mapping) - return [v(agg, context) for agg in aggs] + old_nodes, new_cols = zip(*mapping) + (new_nodes, num_nodes) = visitor.add_expressions(new_cols) + aggmap = list(range(num_nodes)) + for old, new in zip(old_nodes, new_nodes): + aggmap[old] = new + visitor.set_mapping(aggmap) + result = [visitor(agg, context) for agg in aggs] + visitor.unset_mapping() + return result def _rolling(