Skip to content

Commit

Permalink
split logic into dedicated io.py file
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Nov 27, 2024
1 parent 1be7228 commit 36f59f1
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 35 deletions.
5 changes: 5 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@
from __future__ import annotations

__all__: list[str] = []

from cudf_polars.experimental.parallel import _register

# Register multi-partition dispatch functions
_register()
50 changes: 50 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Multi-partition IO Logic."""

from __future__ import annotations

import math
from typing import TYPE_CHECKING

from cudf_polars.dsl.ir import DataFrameScan, Union
from cudf_polars.experimental.parallel import lower_ir_node

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.parallel import LowerIRTransformer, PartitionInfo


##
## DataFrameScan
##


@lower_ir_node.register(DataFrameScan)
def _(
ir: DataFrameScan, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
rows_per_partition = ir.config_options.get("executor_options", {}).get(
"num_rows_threshold", 1_000_000
)

nrows = max(ir.df.shape()[0], 1)
count = math.ceil(nrows / rows_per_partition)

if count > 1:
length = math.ceil(nrows / count)
slices = [
DataFrameScan(
ir.schema,
ir.df.slice(offset, length),
ir.projection,
ir.predicate,
ir.config_options,
)
for offset in range(0, nrows, length)
]
return rec(Union(ir.schema, None, *slices))

return rec.state["default_mapper"](ir)
46 changes: 11 additions & 35 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@

from __future__ import annotations

import math
import operator
from functools import reduce, singledispatch
from typing import TYPE_CHECKING, Any

from cudf_polars.dsl.ir import IR, DataFrameScan, Union
from cudf_polars.dsl.ir import IR, Union
from cudf_polars.dsl.traversal import traversal

if TYPE_CHECKING:
Expand Down Expand Up @@ -263,39 +262,6 @@ def _concat(dfs: Sequence[DataFrame]) -> DataFrame:
return Union.do_evaluate(None, *dfs)


##
## DataFrameScan
##


@lower_ir_node.register(DataFrameScan)
def _(
ir: DataFrameScan, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
rows_per_partition = ir.config_options.get("executor_options", {}).get(
"num_rows_threshold", 1_000_000
)

nrows = max(ir.df.shape()[0], 1)
count = math.ceil(nrows / rows_per_partition)

if count > 1:
length = math.ceil(nrows / count)
slices = [
DataFrameScan(
ir.schema,
ir.df.slice(offset, length),
ir.projection,
ir.predicate,
ir.config_options,
)
for offset in range(0, nrows, length)
]
return rec(Union(ir.schema, None, *slices))

return rec.state["default_mapper"](ir)


##
## Union
##
Expand Down Expand Up @@ -334,3 +300,13 @@ def _(
graph[(key_name, part_out)] = (get_key_name(child), i)
part_out += 1
return graph


##
## Other
##


def _register():
"""Register multi-partition dispatch functions."""
import cudf_polars.experimental.io # noqa: F401

0 comments on commit 36f59f1

Please sign in to comment.