diff --git a/python/cudf_polars/cudf_polars/experimental/__init__.py b/python/cudf_polars/cudf_polars/experimental/__init__.py index 6fd93bf5157..4a2aef68e18 100644 --- a/python/cudf_polars/cudf_polars/experimental/__init__.py +++ b/python/cudf_polars/cudf_polars/experimental/__init__.py @@ -6,3 +6,8 @@ from __future__ import annotations __all__: list[str] = [] + +from cudf_polars.experimental.parallel import _register + +# Register multi-partition dispatch functions +_register() diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py new file mode 100644 index 00000000000..d321283ebaa --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-partition IO Logic.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +from cudf_polars.dsl.ir import DataFrameScan, Union +from cudf_polars.experimental.parallel import lower_ir_node + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from cudf_polars.dsl.ir import IR + from cudf_polars.experimental.parallel import LowerIRTransformer, PartitionInfo + + +## +## DataFrameScan +## + + +@lower_ir_node.register(DataFrameScan) +def _( + ir: DataFrameScan, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + rows_per_partition = ir.config_options.get("executor_options", {}).get( + "num_rows_threshold", 1_000_000 + ) + + nrows = max(ir.df.shape()[0], 1) + count = math.ceil(nrows / rows_per_partition) + + if count > 1: + length = math.ceil(nrows / count) + slices = [ + DataFrameScan( + ir.schema, + ir.df.slice(offset, length), + ir.projection, + ir.predicate, + ir.config_options, + ) + for offset in range(0, nrows, length) + ] + return rec(Union(ir.schema, None, *slices)) + + return rec.state["default_mapper"](ir) diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 8a8a6e336d4..091ba6d0791 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -4,12 +4,11 @@ from __future__ import annotations -import math import operator from functools import reduce, singledispatch from typing import TYPE_CHECKING, Any -from cudf_polars.dsl.ir import IR, DataFrameScan, Union +from cudf_polars.dsl.ir import IR, Union from cudf_polars.dsl.traversal import traversal if TYPE_CHECKING: @@ -263,39 +262,6 @@ def _concat(dfs: Sequence[DataFrame]) -> DataFrame: return Union.do_evaluate(None, *dfs) -## -## DataFrameScan -## - - -@lower_ir_node.register(DataFrameScan) -def _( - ir: DataFrameScan, rec: LowerIRTransformer -) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: - rows_per_partition = ir.config_options.get("executor_options", {}).get( - "num_rows_threshold", 1_000_000 - ) - - nrows = max(ir.df.shape()[0], 1) - count = math.ceil(nrows / rows_per_partition) - - if count > 1: - length = math.ceil(nrows / count) - slices = [ - DataFrameScan( - ir.schema, - ir.df.slice(offset, length), - ir.projection, - ir.predicate, - ir.config_options, - ) - for offset in range(0, nrows, length) - ] - return rec(Union(ir.schema, None, *slices)) - - return rec.state["default_mapper"](ir) - - ## ## Union ## @@ -334,3 +300,13 @@ def _( graph[(key_name, part_out)] = (get_key_name(child), i) part_out += 1 return graph + + +## +## Other +## + + +def _register(): + """Register multi-partition dispatch functions.""" + import cudf_polars.experimental.io # noqa: F401