Skip to content

Commit

Permalink
use general StateInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Nov 21, 2024
1 parent aeecd4d commit 09c5217
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def __init__(self, count: int):
"""Protocol for Lowering IR nodes."""


class StateInfo:
"""Bag of arbitrary state information."""

def __init__(self, *, parts_info: MutableMapping[IR, PartitionInfo] | None = None):
self.__parts_info = parts_info or {}

def parts(self, ir: IR) -> PartitionInfo:
"""Return partitioning information for an IR node."""
return self.__parts_info[ir]


def get_key_name(node: Node) -> str:
"""Generate the key name for a Node."""
return f"{type(node).__name__.lower()}-{hash(node)}"
Expand Down Expand Up @@ -87,9 +98,7 @@ def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:


@singledispatch
def generate_ir_tasks(
ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
def generate_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]:
"""
Generate tasks for an IR node.
Expand All @@ -100,20 +109,18 @@ def generate_ir_tasks(


@generate_ir_tasks.register(IR)
def _default_ir_tasks(
ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
def _default_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]:
# Single-partition default behavior.
# This is used by `generate_ir_tasks` for all unregistered IR sub-types.
if partition_info[ir].count > 1:
if state.parts(ir).count > 1:
raise NotImplementedError(
f"Failed to generate multiple output tasks for {ir}."
) # pragma: no cover

child_names = []
for child in ir.children:
child_names.append(get_key_name(child))
if partition_info[child].count > 1:
if state.parts(child).count > 1:
raise NotImplementedError(
f"Failed to generate tasks for {ir} with child {child}."
) # pragma: no cover
Expand All @@ -128,13 +135,11 @@ def _default_ir_tasks(
}


def task_graph(
ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
) -> tuple[MutableMapping[str, Any], str]:
def task_graph(ir: IR, state: StateInfo) -> tuple[MutableMapping[str, Any], str]:
"""Construct a Dask-compatible task graph."""
graph = reduce(
operator.or_,
[generate_ir_tasks(node, partition_info) for node in traversal(ir)],
[generate_ir_tasks(node, state) for node in traversal(ir)],
)
key_name = get_key_name(ir)
graph[key_name] = (key_name, 0)
Expand All @@ -146,7 +151,7 @@ def evaluate_dask(ir: IR) -> DataFrame:
"""Evaluate an IR graph with Dask."""
from dask import get

ir, partition_info = lower_ir_graph(ir)
ir, parts_info = lower_ir_graph(ir)

graph, key = task_graph(ir, partition_info)
graph, key = task_graph(ir, StateInfo(parts_info=parts_info))
return get(graph, key)

0 comments on commit 09c5217

Please sign in to comment.