diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 289ec91dfa7..8a8a6e336d4 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -35,7 +35,7 @@ def __init__(self, count: int): LowerIRTransformer: TypeAlias = ( - "GenericTransformer[IR, MutableMapping[IR, PartitionInfo]]" + "GenericTransformer[IR, tuple[IR, MutableMapping[IR, PartitionInfo]]]" ) """Protocol for Lowering IR nodes.""" @@ -278,22 +278,22 @@ def _( nrows = max(ir.df.shape()[0], 1) count = math.ceil(nrows / rows_per_partition) - length = math.ceil(nrows / count) - - slices = [ - DataFrameScan( - ir.schema, - ir.df.slice(offset, length), - ir.projection, - ir.predicate, - ir.config_options, - ) - for offset in range(0, nrows, length) - ] - new_node = Union(ir.schema, None, *slices) - return new_node, {slice: PartitionInfo(count=1) for slice in slices} | { - new_node: PartitionInfo(count=count) - } + + if count > 1: + length = math.ceil(nrows / count) + slices = [ + DataFrameScan( + ir.schema, + ir.df.slice(offset, length), + ir.projection, + ir.predicate, + ir.config_options, + ) + for offset in range(0, nrows, length) + ] + return rec(Union(ir.schema, None, *slices)) + + return rec.state["default_mapper"](ir) ## @@ -307,7 +307,7 @@ def _( ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: # zlice must be None if ir.zlice is not None: - return rec.state["default_mapper"](ir) + return rec.state["default_mapper"](ir) # pragma: no cover # Lower children children, _partition_info = zip(*(rec(c) for c in ir.children), strict=False)