From 62f10bc90058fc1dc0b361dd860137f2fbcf00df Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 21 Nov 2024 10:30:26 -0800 Subject: [PATCH] revert (for now) --- .../cudf_polars/experimental/parallel.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 1131fb2acf6..e8237c1eb4d 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -39,17 +39,6 @@ def __init__(self, count: int): """Protocol for Lowering IR nodes.""" -class StateInfo: - """Bag of arbitrary state information.""" - - def __init__(self, *, parts_info: MutableMapping[IR, PartitionInfo] | None = None): - self.__parts_info = parts_info or {} - - def parts(self, ir: IR) -> PartitionInfo: - """Return partitioning information for an IR node.""" - return self.__parts_info[ir] - - def get_key_name(node: Node) -> str: """Generate the key name for a Node.""" return f"{type(node).__name__.lower()}-{hash(node)}" @@ -98,7 +87,9 @@ def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: @singledispatch -def generate_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]: +def generate_ir_tasks( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: """ Generate tasks for an IR node. @@ -109,10 +100,12 @@ def generate_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]: @generate_ir_tasks.register(IR) -def _default_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]: +def _default_ir_tasks( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: # Single-partition default behavior. # This is used by `generate_ir_tasks` for all unregistered IR sub-types. - if state.parts(ir).count > 1: + if partition_info[ir].count > 1: raise NotImplementedError( f"Failed to generate multiple output tasks for {ir}." ) # pragma: no cover @@ -120,7 +113,7 @@ def _default_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]: child_names = [] for child in ir.children: child_names.append(get_key_name(child)) - if state.parts(child).count > 1: + if partition_info[child].count > 1: raise NotImplementedError( f"Failed to generate tasks for {ir} with child {child}." ) # pragma: no cover @@ -135,11 +128,13 @@ def _default_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]: } -def task_graph(ir: IR, state: StateInfo) -> tuple[MutableMapping[str, Any], str]: +def task_graph( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> tuple[MutableMapping[str, Any], str]: """Construct a Dask-compatible task graph.""" graph = reduce( operator.or_, - [generate_ir_tasks(node, state) for node in traversal(ir)], + [generate_ir_tasks(node, partition_info) for node in traversal(ir)], ) key_name = get_key_name(ir) graph[key_name] = (key_name, 0) @@ -151,7 +146,7 @@ def evaluate_dask(ir: IR) -> DataFrame: """Evaluate an IR graph with Dask.""" from dask import get - ir, parts_info = lower_ir_graph(ir) + ir, partition_info = lower_ir_graph(ir) - graph, key = task_graph(ir, StateInfo(parts_info=parts_info)) + graph, key = task_graph(ir, partition_info) return get(graph, key)