From 09c5217a4a54d2ba4cf0b320f0a8d3a84b22606a Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 21 Nov 2024 10:12:03 -0800 Subject: [PATCH] use general StateInfo --- .../cudf_polars/experimental/parallel.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index e8237c1eb4d..1131fb2acf6 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -39,6 +39,17 @@ def __init__(self, count: int): """Protocol for Lowering IR nodes.""" +class StateInfo: + """Bag of arbitrary state information.""" + + def __init__(self, *, parts_info: MutableMapping[IR, PartitionInfo] | None = None): + self.__parts_info = parts_info or {} + + def parts(self, ir: IR) -> PartitionInfo: + """Return partitioning information for an IR node.""" + return self.__parts_info[ir] + + def get_key_name(node: Node) -> str: """Generate the key name for a Node.""" return f"{type(node).__name__.lower()}-{hash(node)}" @@ -87,9 +98,7 @@ def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: @singledispatch -def generate_ir_tasks( - ir: IR, partition_info: MutableMapping[IR, PartitionInfo] -) -> MutableMapping[Any, Any]: +def generate_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]: """ Generate tasks for an IR node. @@ -100,12 +109,10 @@ def generate_ir_tasks( @generate_ir_tasks.register(IR) -def _default_ir_tasks( - ir: IR, partition_info: MutableMapping[IR, PartitionInfo] -) -> MutableMapping[Any, Any]: +def _default_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]: # Single-partition default behavior. # This is used by `generate_ir_tasks` for all unregistered IR sub-types. - if partition_info[ir].count > 1: + if state.parts(ir).count > 1: raise NotImplementedError( f"Failed to generate multiple output tasks for {ir}." ) # pragma: no cover @@ -113,7 +120,7 @@ def _default_ir_tasks( child_names = [] for child in ir.children: child_names.append(get_key_name(child)) - if partition_info[child].count > 1: + if state.parts(child).count > 1: raise NotImplementedError( f"Failed to generate tasks for {ir} with child {child}." ) # pragma: no cover @@ -128,13 +135,11 @@ def _default_ir_tasks( } -def task_graph( - ir: IR, partition_info: MutableMapping[IR, PartitionInfo] -) -> tuple[MutableMapping[str, Any], str]: +def task_graph(ir: IR, state: StateInfo) -> tuple[MutableMapping[str, Any], str]: """Construct a Dask-compatible task graph.""" graph = reduce( operator.or_, - [generate_ir_tasks(node, partition_info) for node in traversal(ir)], + [generate_ir_tasks(node, state) for node in traversal(ir)], ) key_name = get_key_name(ir) graph[key_name] = (key_name, 0) @@ -146,7 +151,7 @@ def evaluate_dask(ir: IR) -> DataFrame: """Evaluate an IR graph with Dask.""" from dask import get - ir, partition_info = lower_ir_graph(ir) + ir, parts_info = lower_ir_graph(ir) - graph, key = task_graph(ir, partition_info) + graph, key = task_graph(ir, StateInfo(parts_info=parts_info)) return get(graph, key)