Skip to content

Commit

Permalink
Adapt to more targetted expression rewriting scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Apr 23, 2024
1 parent 17f0d58 commit 11170d7
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions python/cudf_polars/cudf_polars/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

if TYPE_CHECKING:
from typing_extensions import Self
from collections.abc import Sequence

from cudf_polars.typing import ColumnType, Expr, Visitor

Expand All @@ -46,37 +46,55 @@ def __init__(self, visitor: Visitor):
self.visitor = visitor
self.in_groupby = False

def __call__(self, node: int, context: DataFrame) -> ColumnType:
def add_expressions(
self, expressions: Sequence[Expr]
) -> tuple[list[int], int]:
"""
Return the evaluation of an expression node in a context.
Add expressions to the expression graph.
Parameters
----------
node
The node to evaluate
context
The dataframe providing context
expressions
List of expressions to add
Returns
-------
New column as the evaluation of the expression.
tuple of list of node ids and the total number of node ids in the
expression graph after adding the expressions.
"""
return evaluate_expr(self.visitor.view_expression(node), context, self)
return self.visitor.add_expressions(expressions)

def with_replacements(self, mapping: list[tuple[int, Expr]]) -> Self:
def set_mapping(self, mapping: list[int]):
"""
Return a new visitor with nodes replaced by new ones.
Set the node mapping for rewiring the expression graph.
Parameters
----------
mapping
List of pairs mapping node ids to their replacement expression.
List mapping old expression ids to new ones.
"""
self.visitor.set_expr_mapping(mapping)

def unset_mapping(self):
"""Unset the node mapping."""
self.visitor.unset_expr_mapping()

def __call__(self, node: int, context: DataFrame) -> ColumnType:
"""
Return the evaluation of an expression node in a context.
Parameters
----------
node
The node to evaluate
context
The dataframe providing context
Returns
-------
New node visitor with replaced expressions.
New column as the evaluation of the expression.
"""
return type(self)(self.visitor.replace_expressions(mapping))
return evaluate_expr(self.visitor.view_expression(node), context, self)


@singledispatch
Expand Down Expand Up @@ -622,8 +640,15 @@ def _post_aggregate(
for agg in agg_expr:
mapping.append((agg, newcol))
context = DataFrame(context)
v = visitor.with_replacements(mapping)
return [v(agg, context) for agg in aggs]
old_nodes, new_cols = zip(*mapping)
(new_nodes, num_nodes) = visitor.add_expressions(new_cols)
aggmap = list(range(num_nodes))
for old, new in zip(old_nodes, new_nodes):
aggmap[old] = new
visitor.set_mapping(aggmap)
result = [visitor(agg, context) for agg in aggs]
visitor.unset_mapping()
return result


def _rolling(
Expand Down

0 comments on commit 11170d7

Please sign in to comment.