diff --git a/python/cudf_polars/cudf_polars/expressions.py b/python/cudf_polars/cudf_polars/expressions.py index 0b3872bace9..c13e6b6d309 100644 --- a/python/cudf_polars/cudf_polars/expressions.py +++ b/python/cudf_polars/cudf_polars/expressions.py @@ -5,7 +5,7 @@ from collections import defaultdict from functools import singledispatch -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING import cudf import cudf._lib as libcudf @@ -35,10 +35,16 @@ from cudf_polars.typing import ColumnType, Expr, Visitor -class ExprVisitor(NamedTuple): +class ExprVisitor: """Object holding rust visitor and utility methods.""" + __slots__ = ("visitor", "in_groupby") visitor: Visitor + in_groupby: bool + + def __init__(self, visitor: Visitor): + self.visitor = visitor + self.in_groupby = False def __call__(self, node: int, context: DataFrame) -> ColumnType: """ @@ -175,6 +181,8 @@ def _literal( @evaluate_expr.register def _sort(expr: expr_nodes.Sort, context: DataFrame, visitor: ExprVisitor): + if visitor.in_groupby: + raise NotImplementedError("sort inside groupby") to_sort = visitor(expr.expr, context) (stable, nulls_last, descending) = expr.options descending, column_order, null_precedence = sort_order( @@ -198,6 +206,8 @@ def _sort(expr: expr_nodes.Sort, context: DataFrame, visitor: ExprVisitor): def _sort_by( expr: expr_nodes.SortBy, context: DataFrame, visitor: ExprVisitor ): + if visitor.in_groupby: + raise NotImplementedError("sort_by inside groupby") to_sort = visitor(expr.expr, context) descending = expr.descending sort_keys = [visitor(e, context) for e in expr.by] @@ -238,6 +248,8 @@ def _gather(expr: expr_nodes.Gather, context: DataFrame, visitor: ExprVisitor): @evaluate_expr.register def _filter(expr: expr_nodes.Filter, context: DataFrame, visitor: ExprVisitor): + if visitor.in_groupby: + raise NotImplementedError("filter inside groupby") result = visitor(expr.input, context) mask = visitor(expr.by, context) (column,) = plc.stream_compaction.apply_boolean_mask( @@ -262,6 +274,8 @@ def _column(expr: expr_nodes.Column, context: DataFrame, visitor: ExprVisitor): @evaluate_expr.register def _agg(expr: expr_nodes.Agg, context: DataFrame, visitor: ExprVisitor): + if visitor.in_groupby: + raise NotImplementedError("nested agg in group_by") name = expr.name column = visitor(expr.arguments, context) # TODO: handle options @@ -425,56 +439,6 @@ def _binop( return plc.binaryop.binary_operation(_as_plc(lop), _as_plc(rop), op, dtype) -# Aggregations, need to be shared between plan and expression -# evaluation, but circular dep, so we put them here. -# TODO: document approach here properly -def agg_depth(agg, visitor: ExprVisitor) -> int: - """ - Determine the depth of aggregations in an expression. - - Parameters - ---------- - agg - Expression containing aggregations - visitor - Callback visitor - - Returns - ------- - Depth in the expression tree that an aggregation request was observed. - - Raises - ------ - NotImplementedError - If an aggregation request is nested inside another aggregation - request, or an unhandled expression is seen. - """ - agg = visitor.visitor.view_expression(agg) - if isinstance(agg, expr_nodes.Column): - return 0 - elif isinstance(agg, expr_nodes.Alias): - return agg_depth(agg.expr, visitor) - elif isinstance(agg, expr_nodes.BinaryExpr): - ldepth = agg_depth(agg.left, visitor) - rdepth = agg_depth(agg.right, visitor) - maxdepth = max(ldepth, rdepth) - assert ldepth == rdepth - return maxdepth - elif isinstance(agg, expr_nodes.Len): - return 1 - elif isinstance(agg, expr_nodes.Agg): - # TODO: currently only singleton arguments (that's all that - # the Agg object has right now) - depth = agg_depth(agg.arguments, visitor) - if depth >= 1: - raise NotImplementedError("Nesting aggregations not supported") - return depth + 1 - else: - raise NotImplementedError(f"Unhandled agg {agg=}") - - -# TODO: would love to find a way to avoid the multiple traversal -# right now, we must run agg_depth first. def collect_agg( node: int, context: DataFrame, depth: int, visitor: ExprVisitor ) -> tuple[ @@ -483,6 +447,17 @@ def collect_agg( """ Collect the aggregation requirements of a single aggregation request. + Parameters + ---------- + node + Node representing aggregation to collect + context + DataFrame context + depth + Depth of the aggregation in tree + visitor + Visitor for translating nodes + Returns ------- tuple of list of columns, list of (libcudf-agg-name, agg-expression) pairs, @@ -551,10 +526,14 @@ def collect_agg( rcol, rreq = collect_agg(agg.right, context, depth, visitor) return [*lcol, *rcol], [*lreq, *rreq] else: - # TODO: Inside an aggregation, this needs to disallow (for now) - # seeing an aggregation request. + # TODO: Ugly non-local method of saying "we're in a groupby, disallow" + visitor.in_groupby = True column = evaluate_expr(agg, context, visitor) + visitor.in_groupby = False return [column], [(plc.aggregation.collect_list(), node)] + elif isinstance(agg, expr_nodes.Literal): + # Scalar value, constant across the groups + return [], [] else: raise NotImplementedError @@ -602,6 +581,10 @@ def collect_aggs( _, column_requests, to_replace = groups.setdefault( colid, (column, [], defaultdict(list)) ) + if request is None: + # Literals, which don't produce requests since they must be + # uniform across the group + continue # We're only going to ask libcudf for unique aggregation requests if request not in column_requests: column_requests.append(request)