Skip to content

Commit

Permalink
Some agg simplification, sometimes handle Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Apr 22, 2024
1 parent d86957c commit 17f0d58
Showing 1 changed file with 37 additions and 54 deletions.
91 changes: 37 additions & 54 deletions python/cudf_polars/cudf_polars/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from collections import defaultdict
from functools import singledispatch
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING

import cudf
import cudf._lib as libcudf
Expand Down Expand Up @@ -35,10 +35,16 @@
from cudf_polars.typing import ColumnType, Expr, Visitor


class ExprVisitor(NamedTuple):
class ExprVisitor:
"""Object holding rust visitor and utility methods."""

__slots__ = ("visitor", "in_groupby")
visitor: Visitor
in_groupby: bool

def __init__(self, visitor: Visitor):
self.visitor = visitor
self.in_groupby = False

def __call__(self, node: int, context: DataFrame) -> ColumnType:
"""
Expand Down Expand Up @@ -175,6 +181,8 @@ def _literal(

@evaluate_expr.register
def _sort(expr: expr_nodes.Sort, context: DataFrame, visitor: ExprVisitor):
if visitor.in_groupby:
raise NotImplementedError("sort inside groupby")
to_sort = visitor(expr.expr, context)
(stable, nulls_last, descending) = expr.options
descending, column_order, null_precedence = sort_order(
Expand All @@ -198,6 +206,8 @@ def _sort(expr: expr_nodes.Sort, context: DataFrame, visitor: ExprVisitor):
def _sort_by(
expr: expr_nodes.SortBy, context: DataFrame, visitor: ExprVisitor
):
if visitor.in_groupby:
raise NotImplementedError("sort_by inside groupby")
to_sort = visitor(expr.expr, context)
descending = expr.descending
sort_keys = [visitor(e, context) for e in expr.by]
Expand Down Expand Up @@ -238,6 +248,8 @@ def _gather(expr: expr_nodes.Gather, context: DataFrame, visitor: ExprVisitor):

@evaluate_expr.register
def _filter(expr: expr_nodes.Filter, context: DataFrame, visitor: ExprVisitor):
if visitor.in_groupby:
raise NotImplementedError("filter inside groupby")
result = visitor(expr.input, context)
mask = visitor(expr.by, context)
(column,) = plc.stream_compaction.apply_boolean_mask(
Expand All @@ -262,6 +274,8 @@ def _column(expr: expr_nodes.Column, context: DataFrame, visitor: ExprVisitor):

@evaluate_expr.register
def _agg(expr: expr_nodes.Agg, context: DataFrame, visitor: ExprVisitor):
if visitor.in_groupby:
raise NotImplementedError("nested agg in group_by")
name = expr.name
column = visitor(expr.arguments, context)
# TODO: handle options
Expand Down Expand Up @@ -425,56 +439,6 @@ def _binop(
return plc.binaryop.binary_operation(_as_plc(lop), _as_plc(rop), op, dtype)


# Aggregations, need to be shared between plan and expression
# evaluation, but circular dep, so we put them here.
# TODO: document approach here properly
def agg_depth(agg, visitor: ExprVisitor) -> int:
"""
Determine the depth of aggregations in an expression.
Parameters
----------
agg
Expression containing aggregations
visitor
Callback visitor
Returns
-------
Depth in the expression tree that an aggregation request was observed.
Raises
------
NotImplementedError
If an aggregation request is nested inside another aggregation
request, or an unhandled expression is seen.
"""
agg = visitor.visitor.view_expression(agg)
if isinstance(agg, expr_nodes.Column):
return 0
elif isinstance(agg, expr_nodes.Alias):
return agg_depth(agg.expr, visitor)
elif isinstance(agg, expr_nodes.BinaryExpr):
ldepth = agg_depth(agg.left, visitor)
rdepth = agg_depth(agg.right, visitor)
maxdepth = max(ldepth, rdepth)
assert ldepth == rdepth
return maxdepth
elif isinstance(agg, expr_nodes.Len):
return 1
elif isinstance(agg, expr_nodes.Agg):
# TODO: currently only singleton arguments (that's all that
# the Agg object has right now)
depth = agg_depth(agg.arguments, visitor)
if depth >= 1:
raise NotImplementedError("Nesting aggregations not supported")
return depth + 1
else:
raise NotImplementedError(f"Unhandled agg {agg=}")


# TODO: would love to find a way to avoid the multiple traversal
# right now, we must run agg_depth first.
def collect_agg(
node: int, context: DataFrame, depth: int, visitor: ExprVisitor
) -> tuple[
Expand All @@ -483,6 +447,17 @@ def collect_agg(
"""
Collect the aggregation requirements of a single aggregation request.
Parameters
----------
node
Node representing aggregation to collect
context
DataFrame context
depth
Depth of the aggregation in tree
visitor
Visitor for translating nodes
Returns
-------
tuple of list of columns, list of (libcudf-agg-name, agg-expression) pairs,
Expand Down Expand Up @@ -551,10 +526,14 @@ def collect_agg(
rcol, rreq = collect_agg(agg.right, context, depth, visitor)
return [*lcol, *rcol], [*lreq, *rreq]
else:
# TODO: Inside an aggregation, this needs to disallow (for now)
# seeing an aggregation request.
# TODO: Ugly non-local method of saying "we're in a groupby, disallow"
visitor.in_groupby = True
column = evaluate_expr(agg, context, visitor)
visitor.in_groupby = False
return [column], [(plc.aggregation.collect_list(), node)]
elif isinstance(agg, expr_nodes.Literal):
# Scalar value, constant across the groups
return [], []
else:
raise NotImplementedError

Expand Down Expand Up @@ -602,6 +581,10 @@ def collect_aggs(
_, column_requests, to_replace = groups.setdefault(
colid, (column, [], defaultdict(list))
)
if request is None:
# Literals, which don't produce requests since they must be
# uniform across the group
continue
# We're only going to ask libcudf for unique aggregation requests
if request not in column_requests:
column_requests.append(request)
Expand Down

0 comments on commit 17f0d58

Please sign in to comment.