diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 76816ee0a61..ea964ec7ae7 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -174,7 +174,8 @@ def execute_with_cudf( device = config.device memory_resource = config.memory_resource raise_on_fail = config.config.get("raise_on_fail", False) - if unsupported := (config.config.keys() - {"raise_on_fail"}): + debug_mode = config.config.get("debug_mode", False) + if unsupported := (config.config.keys() - {"raise_on_fail", "debug_mode"}): raise ValueError( f"Engine configuration contains unsupported settings {unsupported}" ) @@ -183,7 +184,10 @@ def execute_with_cudf( nt.set_udf( partial( _callback, - translate_ir(nt), + translate_ir( + nt, + debug_mode=1 if debug_mode else 0, + ), device=device, memory_resource=memory_resource, ) @@ -197,3 +201,5 @@ def execute_with_cudf( ) if raise_on_fail: raise + if debug_mode: + raise diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 1035ef076d9..2d6f188df5d 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -163,8 +163,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: @dataclasses.dataclass class ErrorNode(IR): - def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: - return pl.DataFrame() + error: str @dataclasses.dataclass diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 0c7acfe619d..431773afac4 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -20,7 +20,7 @@ from cudf_polars.dsl import expr, ir from cudf_polars.typing import NodeTraverser -from cudf_polars.utils import dtypes, other +from cudf_polars.utils import dtypes __all__ = ["translate_ir", "translate_named_expr"] @@ -28,11 +28,10 @@ def debug(func): def wrapper(*args, **kwargs): try: - print(args, kwargs) return func(*args, **kwargs) - except NotImplementedError: - if other._env_get_bool("CUDF_POLARS_DEBUG_MODE", default=False): - return ir.ErrorNode(args[0].get_schema()) + except NotImplementedError as e: + if kwargs.get("debug_mode", False): + return ir.ErrorNode(args[0].get_schema(), e) raise return wrapper @@ -77,19 +76,24 @@ def __exit__(self, *args: Any) -> None: @singledispatch def _translate_ir( - node: Any, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: Any, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: - if other._env_get_bool("CUDF_POLARS_DEBUG_MODE", default=False): - return ir.ErrorNode(schema) - raise NotImplementedError( - f"Translation for {type(node).__name__}" - ) # pragma: no cover + e = f"Translation for {type(node).__name__}" + if debug_mode: + return ir.ErrorNode(schema, e) + raise NotImplementedError(e) # pragma: no cover @debug @_translate_ir.register def _( - node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.PythonScan, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: scan_fn, with_columns, source_type, predicate, nrows = node.options options = (scan_fn, with_columns, source_type, nrows) @@ -102,7 +106,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Scan, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: typ, *options = node.scan_type if typ == "ndjson": @@ -140,7 +147,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Cache, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Cache, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: return ir.Cache(schema, node.id_, translate_ir(visitor, n=node.input)) @@ -148,7 +158,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.DataFrameScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.DataFrameScan, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: return ir.DataFrameScan( schema, @@ -163,7 +176,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Select, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Select, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) @@ -174,7 +190,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.GroupBy, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.GroupBy, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) @@ -193,7 +212,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Join, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: # Join key dtypes are dependent on the schema of the left and # right inputs, so these must be translated with the relevant @@ -210,7 +232,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.HStack, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.HStack, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) @@ -221,7 +246,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Reduce, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Reduce, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) @@ -232,7 +260,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Distinct, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: return ir.Distinct( schema, @@ -244,7 +275,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Sort, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Sort, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) @@ -255,7 +289,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Slice, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: return ir.Slice(schema, translate_ir(visitor, n=node.input), node.offset, node.len) @@ -263,7 +300,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Filter, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Filter, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) @@ -277,6 +317,7 @@ def _( node: pl_ir.SimpleProjection, visitor: NodeTraverser, schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: return ir.Projection(schema, translate_ir(visitor, n=node.input)) @@ -284,7 +325,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.MapFunction, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.MapFunction, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: name, *options = node.function return ir.MapFunction( @@ -299,7 +343,10 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.Union, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: return ir.Union( schema, [translate_ir(visitor, n=n) for n in node.inputs], node.options @@ -309,13 +356,21 @@ def _( @debug @_translate_ir.register def _( - node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType] + node: pl_ir.HConcat, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], + debug_mode: int = 0, ) -> ir.IR: return ir.HConcat(schema, [translate_ir(visitor, n=n) for n in node.inputs]) @debug -def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: +def translate_ir( + visitor: NodeTraverser, + *, + n: int | None = None, + debug_mode: int = 0, +) -> ir.IR: """ Translate a polars-internal IR node to our representation. @@ -326,6 +381,9 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: n Optional node to start traversing from, if not provided uses current polars-internal node. + debug_mode + Optional: If true returns an ErrorNode in the IR that is used to + report unsupported operations in the query Returns ------- @@ -352,7 +410,7 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: polars_schema = visitor.get_schema() node = visitor.view_current_node() schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()} - result = _translate_ir(node, visitor, schema) + result = _translate_ir(node, visitor, schema, debug_mode=debug_mode) if any( isinstance(dtype, pl.Null) for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values())