Skip to content

Commit

Permalink
revert (for now)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Nov 21, 2024
1 parent 09c5217 commit 62f10bc
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,6 @@ def __init__(self, count: int):
"""Protocol for Lowering IR nodes."""


class StateInfo:
"""Bag of arbitrary state information."""

def __init__(self, *, parts_info: MutableMapping[IR, PartitionInfo] | None = None):
self.__parts_info = parts_info or {}

def parts(self, ir: IR) -> PartitionInfo:
"""Return partitioning information for an IR node."""
return self.__parts_info[ir]


def get_key_name(node: Node) -> str:
"""Generate the key name for a Node."""
return f"{type(node).__name__.lower()}-{hash(node)}"
Expand Down Expand Up @@ -98,7 +87,9 @@ def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:


@singledispatch
def generate_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]:
def generate_ir_tasks(
ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
"""
Generate tasks for an IR node.
Expand All @@ -109,18 +100,20 @@ def generate_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]:


@generate_ir_tasks.register(IR)
def _default_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]:
def _default_ir_tasks(
ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
# Single-partition default behavior.
# This is used by `generate_ir_tasks` for all unregistered IR sub-types.
if state.parts(ir).count > 1:
if partition_info[ir].count > 1:
raise NotImplementedError(
f"Failed to generate multiple output tasks for {ir}."
) # pragma: no cover

child_names = []
for child in ir.children:
child_names.append(get_key_name(child))
if state.parts(child).count > 1:
if partition_info[child].count > 1:
raise NotImplementedError(
f"Failed to generate tasks for {ir} with child {child}."
) # pragma: no cover
Expand All @@ -135,11 +128,13 @@ def _default_ir_tasks(ir: IR, state: StateInfo) -> MutableMapping[Any, Any]:
}


def task_graph(ir: IR, state: StateInfo) -> tuple[MutableMapping[str, Any], str]:
def task_graph(
ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
) -> tuple[MutableMapping[str, Any], str]:
"""Construct a Dask-compatible task graph."""
graph = reduce(
operator.or_,
[generate_ir_tasks(node, state) for node in traversal(ir)],
[generate_ir_tasks(node, partition_info) for node in traversal(ir)],
)
key_name = get_key_name(ir)
graph[key_name] = (key_name, 0)
Expand All @@ -151,7 +146,7 @@ def evaluate_dask(ir: IR) -> DataFrame:
"""Evaluate an IR graph with Dask."""
from dask import get

ir, parts_info = lower_ir_graph(ir)
ir, partition_info = lower_ir_graph(ir)

graph, key = task_graph(ir, StateInfo(parts_info=parts_info))
graph, key = task_graph(ir, partition_info)
return get(graph, key)

0 comments on commit 62f10bc

Please sign in to comment.