Skip to content

Commit

Permalink
Adapt to updates and start building out tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Apr 30, 2024
1 parent 5af9881 commit 76b644c
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 58 deletions.
103 changes: 72 additions & 31 deletions python/cudf_polars/cudf_polars/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

from collections import defaultdict
from enum import IntEnum, auto
from functools import singledispatch
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -35,16 +36,58 @@
from cudf_polars.typing import ColumnType, Expr, Visitor


class ExecutionContext(IntEnum):
"""Tag for the current execution context."""

GROUPBY = auto()
"Executing inside a group_by expression."
ROLLING = auto()
"Executing inside a rolling expression."
DATAFRAME = auto()
"Executing on the whole dataframe."


class ExprVisitor:
"""Object holding rust visitor and utility methods."""

__slots__ = ("visitor", "in_groupby")
__slots__ = ("visitor", "context", "node_stack")
visitor: Visitor
in_groupby: bool
context: ExecutionContext
node_stack: list[int]

class _with_context:
def __init__(self, context: ExecutionContext, visitor: ExprVisitor):
self.context = context
self.visitor = visitor

def __enter__(self):
self.visitor.context, self.context = (
self.context,
self.visitor.context,
)

def __exit__(self, *args):
self.visitor.context = self.context

def __init__(self, visitor: Visitor):
self.visitor = visitor
self.in_groupby = False
self.context = ExecutionContext.DATAFRAME
self.node_stack = []

def with_context(self, context: ExecutionContext):
"""
Context manager for setting the execution context of the visitor.
Parameters
----------
context
New execution context
Returns
-------
context manager that sets and restores the execution context.
"""
return self._with_context(context, self)

def add_expressions(
self, expressions: Sequence[Expr]
Expand Down Expand Up @@ -94,7 +137,17 @@ def __call__(self, node: int, context: DataFrame) -> ColumnType:
-------
New column as the evaluation of the expression.
"""
return evaluate_expr(self.visitor.view_expression(node), context, self)
self.node_stack.append(node)
result = evaluate_expr(
self.visitor.view_expression(node), context, self
)
self.node_stack.pop()
return result

@property
def dtype(self):
"""Return the datatype of the current expression node."""
return self.visitor.get_dtype(self.node_stack[-1])


@singledispatch
Expand Down Expand Up @@ -310,11 +363,6 @@ def _expr_function(
# TODO: tracking sortedness
(column,) = arguments
return column
# (name,) = data.keys()
# (flag,) = fargs
# return data.set_sorted(
# {name: getattr(DataFrame.IsSorted, flag.upper())}
# )
elif fname in BOOLEAN_FUNCTIONS:
return boolean_function(fname, arguments, fargs)
else:
Expand Down Expand Up @@ -365,46 +413,40 @@ def _literal(

@evaluate_expr.register
def _sort(expr: expr_nodes.Sort, context: DataFrame, visitor: ExprVisitor):
if visitor.in_groupby:
raise NotImplementedError("sort inside groupby")
if visitor.context is not ExecutionContext.DATAFRAME:
raise NotImplementedError("sort inside groupby/rolling")
to_sort = visitor(expr.expr, context)
(stable, nulls_last, descending) = expr.options
descending, column_order, null_precedence = sort_order(
[descending], nulls_last=nulls_last, num_keys=1
)
do_sort = plc.sorting.stable_sort if stable else plc.sorting.sort
result = do_sort(to_sort.to_pylibcudf(), column_order, null_precedence)
(result,) = do_sort(
plc.Table([to_sort]), column_order, null_precedence
).columns()
return result
# TODO: track sortedness
# flag = (
# DataFrame.IsSorted.DESCENDING
# if descending
# else DataFrame.IsSorted.ASCENDING
# )
# return DataFrame.from_pylibcudf(to_sort.names(), result).set_sorted(
# {name: flag}
# )


@evaluate_expr.register
def _sort_by(
expr: expr_nodes.SortBy, context: DataFrame, visitor: ExprVisitor
):
if visitor.in_groupby:
raise NotImplementedError("sort_by inside groupby")
if visitor.context is not ExecutionContext.DATAFRAME:
raise NotImplementedError("sort_by inside groupby/rolling")
to_sort = visitor(expr.expr, context)
descending = expr.descending
sort_keys = [visitor(e, context) for e in expr.by]
# TODO: no stable to sort_by in polars
descending, column_order, null_precedence = sort_order(
descending, nulls_last=True, num_keys=len(sort_keys)
)
return plc.sorting.sort_by_key(
(result,) = plc.sorting.sort_by_key(
plc.Table([to_sort]),
plc.Table(sort_keys),
column_order,
null_precedence,
)
return result


@evaluate_expr.register
Expand Down Expand Up @@ -432,8 +474,8 @@ def _gather(expr: expr_nodes.Gather, context: DataFrame, visitor: ExprVisitor):

@evaluate_expr.register
def _filter(expr: expr_nodes.Filter, context: DataFrame, visitor: ExprVisitor):
if visitor.in_groupby:
raise NotImplementedError("filter inside groupby")
if visitor.context is not ExecutionContext.DATAFRAME:
raise NotImplementedError("filter inside groupby/rolling")
result = visitor(expr.input, context)
mask = visitor(expr.by, context)
(column,) = plc.stream_compaction.apply_boolean_mask(
Expand All @@ -458,8 +500,8 @@ def _column(expr: expr_nodes.Column, context: DataFrame, visitor: ExprVisitor):

@evaluate_expr.register
def _agg(expr: expr_nodes.Agg, context: DataFrame, visitor: ExprVisitor):
if visitor.in_groupby:
raise NotImplementedError("nested agg in group_by")
if visitor.context is not ExecutionContext.DATAFRAME:
raise NotImplementedError("nested agg in groupby/rolling")
name = expr.name
column = visitor(expr.arguments, context)
# TODO: handle options
Expand Down Expand Up @@ -711,9 +753,8 @@ def collect_agg(
return [*lcol, *rcol], [*lreq, *rreq]
else:
# TODO: Ugly non-local method of saying "we're in a groupby, disallow"
visitor.in_groupby = True
column = evaluate_expr(agg, context, visitor)
visitor.in_groupby = False
with visitor.with_context(ExecutionContext.GROUPBY):
column = evaluate_expr(agg, context, visitor)
return [column], [(plc.aggregation.collect_list(), node)]
elif isinstance(agg, expr_nodes.Literal):
# Scalar value, constant across the groups
Expand Down
30 changes: 16 additions & 14 deletions python/cudf_polars/cudf_polars/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,17 @@ def __call__(self, n: int | None = None) -> DataFrame:
Node to evaluate (optional), if not provided uses the internal
visitor's "current" node.
Notes
-----
Side-effectfully modifies the visitor to set the current node.
Returns
-------
New dataframe giving the evaluation of the plan.
"""
if n is None:
node = self.visitor.view_current_node()
else:
node = self.visitor.view_node(n)
if n is not None:
self.visitor.set_node(n)
node = self.visitor.view_current_node()
return _execute_plan(node, self)

def record(self, name: str):
Expand Down Expand Up @@ -200,8 +203,6 @@ def _python_scan(plan: nodes.PythonScan, visitor: PlanVisitor):
with visitor.record("PythonScan"):
(
scan_fn,
schema,
output_schema,
with_columns,
is_pyarrow,
predicate,
Expand All @@ -211,6 +212,8 @@ def _python_scan(plan: nodes.PythonScan, visitor: PlanVisitor):
if is_pyarrow:
raise NotImplementedError("Don't know what to do here")
context = scan_fn(with_columns, predicate, nrows)
if not isinstance(context, DataFrame):
raise TypeError(f"Don't know how to handle a {type(context)}")
if predicate is not None:
mask = visitor.expr_visitor(predicate.node, context)
return context.filter(mask)
Expand All @@ -227,7 +230,7 @@ def _scan(plan: nodes.Scan, visitor: PlanVisitor):
n_rows = options.n_rows
with_columns = options.with_columns
row_index = options.row_index
schema = plan.output_schema
schema = visitor.visitor.get_schema()
# TODO: Send all the options through to the libcudf readers where appropriate
if n_rows is not None:
# TODO: read_csv supports n_rows, but if we have more than one
Expand Down Expand Up @@ -422,7 +425,7 @@ def _join(plan: nodes.Join, visitor: PlanVisitor):
right_on = plc.Table(
[visitor.expr_visitor(e.node, right) for e in plan.right_on]
)
how, join_nulls, zlice, suffix = plan.options
how, join_nulls, zlice, suffix, coalesce_key_columns = plan.options
null_equality = (
plc.types.NullEquality.EQUAL
if join_nulls
Expand All @@ -431,9 +434,7 @@ def _join(plan: nodes.Join, visitor: PlanVisitor):
suffix = "_right" if suffix is None else suffix
if how == "cross":
raise NotImplementedError("cross join not implemented")
coalesce_key_columns = True
if how == "outer":
coalesce_key_columns = False
if how == "outer" and not coalesce_key_columns:
raise NotImplementedError("Non-coalescing outer join")
elif how == "outer_coalesce":
how = "outer"
Expand Down Expand Up @@ -576,7 +577,8 @@ def _sort(plan: nodes.Sort, visitor: PlanVisitor):
sort_keys = [
visitor.expr_visitor(e.node, result) for e in plan.by_column
]
(stable, nulls_last, descending, zlice) = plan.args
(descending, nulls_last, stable) = plan.sort_options
zlice = plan.slice
descending, column_order, null_precedence = sort_order(
descending, nulls_last=nulls_last, num_keys=len(sort_keys)
)
Expand Down Expand Up @@ -632,8 +634,8 @@ def _filter(plan: nodes.Filter, visitor: PlanVisitor):

@_execute_plan.register
def _simple_projection(plan: nodes.SimpleProjection, visitor: PlanVisitor):
schema = visitor.visitor.get_schema()
result = visitor(plan.input)
schema = plan.columns
with visitor.record("simple_projection"):
return DataFrame({name: result[name] for name in schema})

Expand Down Expand Up @@ -708,7 +710,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor):
elif typ == "explode":
context = visitor(plan.input)
with profiler:
column_names, schema = args
(column_names,) = args
if len(column_names) > 1:
# TODO: straightforward, but need to error check
# polars requires that all to-explode columns have the
Expand Down
4 changes: 4 additions & 0 deletions python/cudf_polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ ignore = [
"W191",
]

[tool.ruff.lint.per-file-ignores]
# No need for docstrings on tests
"tests/**.py" = ["D"]

[tool.ruff.lint.pycodestyle]
max-doc-length = 85

Expand Down
14 changes: 1 addition & 13 deletions python/cudf_polars/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,8 @@
import numpy as np
import polars as pl
import pytest
from polars.testing.asserts import assert_frame_equal

from cudf_polars.patch import _WAS_PATCHED

if not _WAS_PATCHED:
# We could also just patch in the test, but this approach provides a canary for
# failures with patching that we might observe in trying this with other tests.
raise RuntimeError("Patch was not applied")
from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture()
Expand Down Expand Up @@ -54,12 +48,6 @@ def ldf(df):
return df.lazy()


def assert_gpu_result_equal(lazydf, **kwargs):
expect = lazydf.collect(use_gpu=False)
got = lazydf.collect(use_gpu=True, cpu_fallback=False)
assert_frame_equal(expect, got, **kwargs)


@pytest.mark.parametrize("dtype", ["int32", "int64", "float32", "float64"])
@pytest.mark.parametrize(
"op", [operator.add, operator.sub, operator.mul, operator.truediv]
Expand Down
25 changes: 25 additions & 0 deletions python/cudf_polars/tests/test_distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

import polars as pl
import pytest

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize("subset", [None, ["a"], ["a", "b"], ["b", "c"]])
@pytest.mark.parametrize("keep", ["any", "none", "first", "last"])
@pytest.mark.parametrize(
"maintain_order", [False, True], ids=["unstable", "stable"]
)
def test_distinct(subset, keep, maintain_order):
ldf = pl.DataFrame(
{
"a": [1, 2, 1, 3, 5, None, None],
"b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3],
"c": [True, True, True, True, False, False, True],
}
).lazy()

query = ldf.unique(subset=subset, keep=keep, maintain_order=maintain_order)
assert_gpu_result_equal(query, check_row_order=maintain_order)
42 changes: 42 additions & 0 deletions python/cudf_polars/tests/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

import polars as pl
import pytest

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize(
"sort_keys",
[
(pl.col("a"),),
pytest.param(
(pl.col("d").abs(),),
marks=pytest.mark.xfail(reason="abs not yet implemented"),
),
(pl.col("a"), pl.col("d")),
(pl.col("b"),),
],
)
@pytest.mark.parametrize("nulls_last", [False, True])
@pytest.mark.parametrize(
"maintain_order", [False, True], ids=["unstable", "stable"]
)
def test_sort(sort_keys, nulls_last, maintain_order):
ldf = pl.DataFrame(
{
"a": [1, 2, 1, 3, 5, None, None],
"b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3],
"c": [True, True, True, True, False, False, True],
"d": [1, 2, -1, 10, 6, -1, -7],
}
).lazy()

query = ldf.sort(
*sort_keys,
descending=True,
nulls_last=nulls_last,
maintain_order=maintain_order,
)
assert_gpu_result_equal(query, check_row_order=maintain_order)

0 comments on commit 76b644c

Please sign in to comment.