Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Oct 2, 2024
1 parent 763cb32 commit ab91793
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 33 deletions.
10 changes: 8 additions & 2 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def execute_with_cudf(
device = config.device
memory_resource = config.memory_resource
raise_on_fail = config.config.get("raise_on_fail", False)
if unsupported := (config.config.keys() - {"raise_on_fail"}):
debug_mode = config.config.get("debug_mode", False)
if unsupported := (config.config.keys() - {"raise_on_fail", "debug_mode"}):
raise ValueError(
f"Engine configuration contains unsupported settings {unsupported}"
)
Expand All @@ -183,7 +184,10 @@ def execute_with_cudf(
nt.set_udf(
partial(
_callback,
translate_ir(nt),
translate_ir(
nt,
debug_mode=1 if debug_mode else 0,
),
device=device,
memory_resource=memory_resource,
)
Expand All @@ -197,3 +201,5 @@ def execute_with_cudf(
)
if raise_on_fail:
raise
if debug_mode:
raise
3 changes: 1 addition & 2 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:

@dataclasses.dataclass
class ErrorNode(IR):
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return pl.DataFrame()
error: str


@dataclasses.dataclass
Expand Down
116 changes: 87 additions & 29 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,18 @@

from cudf_polars.dsl import expr, ir
from cudf_polars.typing import NodeTraverser
from cudf_polars.utils import dtypes, other
from cudf_polars.utils import dtypes

__all__ = ["translate_ir", "translate_named_expr"]


def debug(func):
def wrapper(*args, **kwargs):
try:
print(args, kwargs)
return func(*args, **kwargs)
except NotImplementedError:
if other._env_get_bool("CUDF_POLARS_DEBUG_MODE", default=False):
return ir.ErrorNode(args[0].get_schema())
except NotImplementedError as e:
if kwargs.get("debug_mode", False):
return ir.ErrorNode(args[0].get_schema(), e)
raise

return wrapper
Expand Down Expand Up @@ -77,19 +76,24 @@ def __exit__(self, *args: Any) -> None:

@singledispatch
def _translate_ir(
node: Any, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: Any,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
if other._env_get_bool("CUDF_POLARS_DEBUG_MODE", default=False):
return ir.ErrorNode(schema)
raise NotImplementedError(
f"Translation for {type(node).__name__}"
) # pragma: no cover
e = f"Translation for {type(node).__name__}"
if debug_mode:
return ir.ErrorNode(schema, e)
raise NotImplementedError(e) # pragma: no cover


@debug
@_translate_ir.register
def _(
node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.PythonScan,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
scan_fn, with_columns, source_type, predicate, nrows = node.options
options = (scan_fn, with_columns, source_type, nrows)
Expand All @@ -102,7 +106,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Scan,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
typ, *options = node.scan_type
if typ == "ndjson":
Expand Down Expand Up @@ -140,15 +147,21 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Cache, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Cache,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Cache(schema, node.id_, translate_ir(visitor, n=node.input))


@debug
@_translate_ir.register
def _(
node: pl_ir.DataFrameScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.DataFrameScan,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.DataFrameScan(
schema,
Expand All @@ -163,7 +176,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Select, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Select,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
Expand All @@ -174,7 +190,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.GroupBy, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.GroupBy,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
Expand All @@ -193,7 +212,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Join,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
# Join key dtypes are dependent on the schema of the left and
# right inputs, so these must be translated with the relevant
Expand All @@ -210,7 +232,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.HStack, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.HStack,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
Expand All @@ -221,7 +246,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Reduce, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Reduce,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
Expand All @@ -232,7 +260,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Distinct,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Distinct(
schema,
Expand All @@ -244,7 +275,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Sort, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Sort,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
Expand All @@ -255,15 +289,21 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Slice,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Slice(schema, translate_ir(visitor, n=node.input), node.offset, node.len)


@debug
@_translate_ir.register
def _(
node: pl_ir.Filter, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Filter,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
Expand All @@ -277,14 +317,18 @@ def _(
node: pl_ir.SimpleProjection,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Projection(schema, translate_ir(visitor, n=node.input))


@debug
@_translate_ir.register
def _(
node: pl_ir.MapFunction, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.MapFunction,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
name, *options = node.function
return ir.MapFunction(
Expand All @@ -299,7 +343,10 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.Union,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.Union(
schema, [translate_ir(visitor, n=n) for n in node.inputs], node.options
Expand All @@ -309,13 +356,21 @@ def _(
@debug
@_translate_ir.register
def _(
node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType]
node: pl_ir.HConcat,
visitor: NodeTraverser,
schema: dict[str, plc.DataType],
debug_mode: int = 0,
) -> ir.IR:
return ir.HConcat(schema, [translate_ir(visitor, n=n) for n in node.inputs])


@debug
def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
def translate_ir(
visitor: NodeTraverser,
*,
n: int | None = None,
debug_mode: int = 0,
) -> ir.IR:
"""
Translate a polars-internal IR node to our representation.
Expand All @@ -326,6 +381,9 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
n
Optional node to start traversing from, if not provided uses
current polars-internal node.
debug_mode
Optional: If true returns an ErrorNode in the IR that is used to
report unsupported operations in the query
Returns
-------
Expand All @@ -352,7 +410,7 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
polars_schema = visitor.get_schema()
node = visitor.view_current_node()
schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()}
result = _translate_ir(node, visitor, schema)
result = _translate_ir(node, visitor, schema, debug_mode=debug_mode)
if any(
isinstance(dtype, pl.Null)
for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values())
Expand Down

0 comments on commit ab91793

Please sign in to comment.