From 123605395b9be89ecc2b6d6ea9fff2609240cac5 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 3 Dec 2024 11:35:06 -0800 Subject: [PATCH 1/4] add multi-partition scan support --- python/cudf_polars/cudf_polars/callback.py | 3 +- .../cudf_polars/experimental/io.py | 227 +++++++++++++++++- .../tests/experimental/test_scan.py | 80 ++++++ 3 files changed, 305 insertions(+), 5 deletions(-) create mode 100644 python/cudf_polars/tests/experimental/test_scan.py diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 29d3dc4ae79..074096446fd 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -231,7 +231,8 @@ def validate_config_options(config: dict) -> None: executor = config.get("executor", "pylibcudf") if executor == "dask-experimental": unsupported = config.get("executor_options", {}).keys() - { - "max_rows_per_partition" + "max_rows_per_partition", + "parquet_blocksize", } else: unsupported = config.get("executor_options", {}).keys() diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py index 3a1fec36079..13c2be505cd 100644 --- a/python/cudf_polars/cudf_polars/experimental/io.py +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -5,17 +5,20 @@ from __future__ import annotations import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from cudf_polars.dsl.ir import DataFrameScan, Union +import pylibcudf as plc + +from cudf_polars.dsl.ir import IR, DataFrameScan, Scan, Union from cudf_polars.experimental.base import PartitionInfo from cudf_polars.experimental.dispatch import lower_ir_node if TYPE_CHECKING: - from collections.abc import MutableMapping + from collections.abc import Hashable, MutableMapping - from cudf_polars.dsl.ir import IR + from cudf_polars.dsl.expr import NamedExpr from cudf_polars.experimental.dispatch import LowerIRTransformer + from cudf_polars.typing import Schema @lower_ir_node.register(DataFrameScan) @@ -47,3 +50,219 @@ def _( } return ir, {ir: PartitionInfo(count=1)} + + +class SplitScan(IR): + """Input from a split file.""" + + __slots__ = ( + "base_scan", + "schema", + "split_index", + "total_splits", + ) + _non_child = ( + "base_scan", + "split_index", + "total_splits", + ) + base_scan: Scan + """Scan operation this node is based on.""" + split_index: int + """Index of the current split.""" + total_splits: int + """Total number of splits.""" + + def __init__(self, base_scan: Scan, split_index: int, total_splits: int): + self.schema = base_scan.schema + self.base_scan = base_scan + self.split_index = split_index + self.total_splits = total_splits + self._non_child_args = ( + split_index, + total_splits, + *base_scan._non_child_args, + ) + self.children = () + if base_scan.typ not in ("parquet",): # pragma: no cover + raise NotImplementedError( + f"Unhandled Scan type for file splitting: {base_scan.typ}" + ) + + def get_hashable(self) -> Hashable: + """Hashable representation of node.""" + return ( + hash(self.base_scan), + self.split_index, + self.total_splits, + ) + + @classmethod + def do_evaluate( + cls, + split_index: int, + total_splits: int, + schema: Schema, + typ: str, + reader_options: dict[str, Any], + config_options: dict[str, Any], + paths: list[str], + with_columns: list[str] | None, + skip_rows: int, + n_rows: int, + row_index: tuple[str, int] | None, + predicate: NamedExpr | None, + ): + """Evaluate and return a dataframe.""" + if typ not in ("parquet",): # pragma: no cover + raise NotImplementedError(f"Unhandled Scan type for file splitting: {typ}") + + rowgroup_metadata = plc.io.parquet_metadata.read_parquet_metadata( + plc.io.SourceInfo(paths) + ).rowgroup_metadata() + total_row_groups = len(rowgroup_metadata) + if total_splits > total_row_groups: + # Don't bother aligning on row-groups + total_rows = sum(rg["num_rows"] for rg in rowgroup_metadata) + n_rows = int(total_rows / total_splits) + skip_rows = n_rows * split_index + else: + # Align split with row-groups + rg_stride = int(total_row_groups / total_splits) + skip_rgs = rg_stride * split_index + skip_rows = ( + sum(rg["num_rows"] for rg in rowgroup_metadata[:skip_rgs]) + if skip_rgs + else 0 + ) + n_rows = sum( + rg["num_rows"] + for rg in rowgroup_metadata[skip_rgs : skip_rgs + rg_stride] + ) + + # Last split should always read to end of file + if split_index == (total_splits - 1): + n_rows = -1 + + return Scan.do_evaluate( + schema, + typ, + reader_options, + config_options, + paths, + with_columns, + skip_rows, + n_rows, + row_index, + predicate, + ) + + +def _sample_pq_statistics(ir: Scan) -> dict[str, float]: + import numpy as np + import pyarrow.dataset as pa_ds + + # Use average total_uncompressed_size of three files + # TODO: Use plc.io.parquet_metadata.read_parquet_metadata + n_sample = 3 + column_sizes = {} + ds = pa_ds.dataset(ir.paths[:n_sample], format="parquet") + for i, frag in enumerate(ds.get_fragments()): + md = frag.metadata + for rg in range(md.num_row_groups): + row_group = md.row_group(rg) + for col in range(row_group.num_columns): + column = row_group.column(col) + name = column.path_in_schema + if name not in column_sizes: + column_sizes[name] = np.zeros(n_sample, dtype="int64") + column_sizes[name][i] += column.total_uncompressed_size + + return {name: np.mean(sizes) for name, sizes in column_sizes.items()} + + +def _scan_partitioning(ir: Scan) -> tuple[int, int]: + split, stride = 1, 1 + if ir.typ == "parquet": + file_size: float = 0 + # TODO: Use system info to set default blocksize + parallel_options = ir.config_options.get("executor_options", {}) + blocksize: int = parallel_options.get("parquet_blocksize", 1024**3) + stats = _sample_pq_statistics(ir) + columns: list = ir.with_columns or list(stats.keys()) + for name in columns: + file_size += float(stats[name]) + if file_size > 0: + if file_size > blocksize: + # Split large files + split = math.ceil(file_size / blocksize) + else: + # Aggregate small files + stride = max(int(blocksize / file_size), 1) + + # TODO: Use file sizes for csv and json + return (split, stride) + + +@lower_ir_node.register(Scan) +def _( + ir: Scan, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + partition_info: MutableMapping[IR, PartitionInfo] + if ir.typ in ("csv", "parquet", "ndjson") and ir.n_rows == -1 and ir.skip_rows == 0: + split, stride = _scan_partitioning(ir) + paths = list(ir.paths) + if split > 1: + # Disable chunked reader when splitting files + config_options = ir.config_options.copy() + config_options["parquet_options"] = config_options.get( + "parquet_options", {} + ).copy() + config_options["parquet_options"]["chunked"] = False + + slices: list[SplitScan] = [] + for path in paths: + base_scan = Scan( + ir.schema, + ir.typ, + ir.reader_options, + ir.cloud_options, + config_options, + [path], + ir.with_columns, + ir.skip_rows, + ir.n_rows, + ir.row_index, + ir.predicate, + ) + slices.extend( + SplitScan(base_scan, sindex, split) for sindex in range(split) + ) + new_node = Union(ir.schema, None, *slices) + partition_info = {slice: PartitionInfo(count=1) for slice in slices} | { + new_node: PartitionInfo(count=len(slices)) + } + else: + groups: list[Scan] = [ + Scan( + ir.schema, + ir.typ, + ir.reader_options, + ir.cloud_options, + ir.config_options, + paths[i : i + stride], + ir.with_columns, + ir.skip_rows, + ir.n_rows, + ir.row_index, + ir.predicate, + ) + for i in range(0, len(paths), stride) + ] + new_node = Union(ir.schema, None, *groups) + partition_info = {group: PartitionInfo(count=1) for group in groups} | { + new_node: PartitionInfo(count=len(groups)) + } + return new_node, partition_info + + return ir, {ir: PartitionInfo(count=1)} diff --git a/python/cudf_polars/tests/experimental/test_scan.py b/python/cudf_polars/tests/experimental/test_scan.py new file mode 100644 index 00000000000..a26d751dc86 --- /dev/null +++ b/python/cudf_polars/tests/experimental/test_scan.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars import Translator +from cudf_polars.experimental.parallel import lower_ir_graph +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture(scope="module") +def df(): + return pl.DataFrame( + { + "x": range(3_000), + "y": ["cat", "dog", "fish"] * 1_000, + "z": [1.0, 2.0, 3.0, 4.0, 5.0] * 600, + } + ) + + +def make_source(df, path, fmt, n_files=3): + n_rows = len(df) + stride = int(n_rows / n_files) + for i in range(n_files): + offset = stride * i + part = df.slice(offset, stride) + if fmt == "csv": + part.write_csv(path / f"part.{i}.csv") + elif fmt == "ndjson": + part.write_ndjson(path / f"part.{i}.ndjson") + else: + part.write_parquet( + path / f"part.{i}.parquet", + row_group_size=int(stride / 2), + ) + + +@pytest.mark.parametrize( + "fmt, scan_fn", + [ + ("csv", pl.scan_csv), + ("ndjson", pl.scan_ndjson), + ("parquet", pl.scan_parquet), + ], +) +def test_parallel_scan(tmp_path, df, fmt, scan_fn): + make_source(df, tmp_path, fmt) + q = scan_fn(tmp_path) + engine = pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + ) + assert_gpu_result_equal(q, engine=engine) + + +@pytest.mark.parametrize("blocksize", [1_000, 10_000, 1_000_000]) +def test_parquet_blocksize(tmp_path, df, blocksize): + n_files = 3 + make_source(df, tmp_path, "parquet", n_files) + q = pl.scan_parquet(tmp_path) + engine = pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"parquet_blocksize": blocksize}, + ) + assert_gpu_result_equal(q, engine=engine) + + # Check partitioning + qir = Translator(q._ldf.visit(), engine).translate_ir() + ir, info = lower_ir_graph(qir) + count = info[ir].count + if blocksize <= 12_000: + assert count > n_files + else: + assert count < n_files From ddb5f71e5367ba10a31c2c56a13ea63bfe4a2176 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 4 Dec 2024 07:38:31 -0800 Subject: [PATCH 2/4] update coverage --- python/cudf_polars/cudf_polars/experimental/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py index 13c2be505cd..aa534024ec0 100644 --- a/python/cudf_polars/cudf_polars/experimental/io.py +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -265,4 +265,4 @@ def _( } return new_node, partition_info - return ir, {ir: PartitionInfo(count=1)} + return ir, {ir: PartitionInfo(count=1)} # pragma: no cover From ecbc1044e82317e99186274650d342ded6018f40 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 12 Dec 2024 11:59:06 -0800 Subject: [PATCH 3/4] use ScanPartitionPlan to clarify the logic a bit (maybe) --- .../cudf_polars/experimental/io.py | 143 ++++++++++++------ 1 file changed, 99 insertions(+), 44 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py index aa534024ec0..6315a5315b3 100644 --- a/python/cudf_polars/cudf_polars/experimental/io.py +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -4,7 +4,10 @@ from __future__ import annotations +import enum import math +import random +from enum import IntEnum from typing import TYPE_CHECKING, Any import pylibcudf as plc @@ -52,8 +55,66 @@ def _( return ir, {ir: PartitionInfo(count=1)} +class ScanPartitionFlavor(IntEnum): + """Flavor of Scan partitioning.""" + + SINGLE_FILE = enum.auto() # 1:1 mapping between files and partitions + SPLIT_FILES = enum.auto() # Split each file into >1 partition + FUSED_FILES = enum.auto() # Fuse multiple files into each partition + + +class ScanPartitionPlan: + """Scan partitioning plan.""" + + __slots__ = ("factor", "flavor") + factor: int + flavor: ScanPartitionFlavor + + def __init__(self, factor: int, flavor: ScanPartitionFlavor) -> None: + if ( + flavor == ScanPartitionFlavor.SINGLE_FILE and factor != 1 + ): # pragma: no cover + raise ValueError(f"Expected factor == 1 for {flavor}, got: {factor}") + self.factor = factor + self.flavor = flavor + + @staticmethod + def from_scan(ir: Scan) -> ScanPartitionPlan: + """Extract the partitioning plan of a Scan operation.""" + plan = ScanPartitionPlan(1, ScanPartitionFlavor.SINGLE_FILE) + if ir.typ == "parquet": + # TODO: Use system info to set default blocksize + parallel_options = ir.config_options.get("executor_options", {}) + blocksize: int = parallel_options.get("parquet_blocksize", 1024**3) + stats = _sample_pq_statistics(ir) + file_size = sum(float(stats[column]) for column in ir.schema) + if file_size > 0: + if file_size > blocksize: + # Split large files + plan = ScanPartitionPlan( + math.ceil(file_size / blocksize), + ScanPartitionFlavor.SPLIT_FILES, + ) + else: + # Fuse small files + plan = ScanPartitionPlan( + max(int(blocksize / file_size), 1), + ScanPartitionFlavor.FUSED_FILES, + ) + + # TODO: Use file sizes for csv and json + return plan + + class SplitScan(IR): - """Input from a split file.""" + """ + Input from a split file. + + This class wraps a single-file `Scan` object. At + IO/evaluation time, this class will only perform + a partial read of the underlying file. The range + (skip_rows and n_rows) is calculated at IO time. + """ __slots__ = ( "base_scan", @@ -92,7 +153,7 @@ def __init__(self, base_scan: Scan, split_index: int, total_splits: int): def get_hashable(self) -> Hashable: """Hashable representation of node.""" return ( - hash(self.base_scan), + self.base_scan, self.split_index, self.total_splits, ) @@ -117,33 +178,49 @@ def do_evaluate( if typ not in ("parquet",): # pragma: no cover raise NotImplementedError(f"Unhandled Scan type for file splitting: {typ}") + if len(paths) > 1: # pragma: no cover + raise ValueError(f"Expected a single path, got: {paths}") + + # Parquet logic: + # - We are one of "total_splits" SplitScan nodes + # assigned to the same file. + # - We know our index within this file ("split_index") + # - We can also use parquet metadata to query the + # total number of rows in each row-group of the file. + # - We can use all this information to calculate the + # "skip_rows" and "n_rows" options to use locally. + rowgroup_metadata = plc.io.parquet_metadata.read_parquet_metadata( plc.io.SourceInfo(paths) ).rowgroup_metadata() total_row_groups = len(rowgroup_metadata) - if total_splits > total_row_groups: - # Don't bother aligning on row-groups - total_rows = sum(rg["num_rows"] for rg in rowgroup_metadata) - n_rows = int(total_rows / total_splits) - skip_rows = n_rows * split_index - else: - # Align split with row-groups - rg_stride = int(total_row_groups / total_splits) + if total_splits <= total_row_groups: + # We have enough row-groups in the file to align + # all "total_splits" of our reads with row-group + # boundaries. Calculate which row-groups to include + # in the current read, and use metadata to translate + # the row-group indices to "skip_rows" and "n_rows". + rg_stride = total_row_groups // total_splits skip_rgs = rg_stride * split_index - skip_rows = ( - sum(rg["num_rows"] for rg in rowgroup_metadata[:skip_rgs]) - if skip_rgs - else 0 - ) + skip_rows = sum(rg["num_rows"] for rg in rowgroup_metadata[:skip_rgs]) n_rows = sum( rg["num_rows"] for rg in rowgroup_metadata[skip_rgs : skip_rgs + rg_stride] ) + else: + # There are not enough row-groups to align + # all "total_splits" of our reads with row-group + # boundaries. Use metadata to directly calculate + # "skip_rows" and "n_rows" for the current read. + total_rows = sum(rg["num_rows"] for rg in rowgroup_metadata) + n_rows = total_rows // total_splits + skip_rows = n_rows * split_index # Last split should always read to end of file if split_index == (total_splits - 1): n_rows = -1 + # Perform the partial read return Scan.do_evaluate( schema, typ, @@ -166,7 +243,7 @@ def _sample_pq_statistics(ir: Scan) -> dict[str, float]: # TODO: Use plc.io.parquet_metadata.read_parquet_metadata n_sample = 3 column_sizes = {} - ds = pa_ds.dataset(ir.paths[:n_sample], format="parquet") + ds = pa_ds.dataset(random.sample(ir.paths, n_sample), format="parquet") for i, frag in enumerate(ds.get_fragments()): md = frag.metadata for rg in range(md.num_row_groups): @@ -181,38 +258,15 @@ def _sample_pq_statistics(ir: Scan) -> dict[str, float]: return {name: np.mean(sizes) for name, sizes in column_sizes.items()} -def _scan_partitioning(ir: Scan) -> tuple[int, int]: - split, stride = 1, 1 - if ir.typ == "parquet": - file_size: float = 0 - # TODO: Use system info to set default blocksize - parallel_options = ir.config_options.get("executor_options", {}) - blocksize: int = parallel_options.get("parquet_blocksize", 1024**3) - stats = _sample_pq_statistics(ir) - columns: list = ir.with_columns or list(stats.keys()) - for name in columns: - file_size += float(stats[name]) - if file_size > 0: - if file_size > blocksize: - # Split large files - split = math.ceil(file_size / blocksize) - else: - # Aggregate small files - stride = max(int(blocksize / file_size), 1) - - # TODO: Use file sizes for csv and json - return (split, stride) - - @lower_ir_node.register(Scan) def _( ir: Scan, rec: LowerIRTransformer ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: partition_info: MutableMapping[IR, PartitionInfo] if ir.typ in ("csv", "parquet", "ndjson") and ir.n_rows == -1 and ir.skip_rows == 0: - split, stride = _scan_partitioning(ir) + plan = ScanPartitionPlan.from_scan(ir) paths = list(ir.paths) - if split > 1: + if plan.flavor == ScanPartitionFlavor.SPLIT_FILES: # Disable chunked reader when splitting files config_options = ir.config_options.copy() config_options["parquet_options"] = config_options.get( @@ -236,7 +290,8 @@ def _( ir.predicate, ) slices.extend( - SplitScan(base_scan, sindex, split) for sindex in range(split) + SplitScan(base_scan, sindex, plan.factor) + for sindex in range(plan.factor) ) new_node = Union(ir.schema, None, *slices) partition_info = {slice: PartitionInfo(count=1) for slice in slices} | { @@ -250,14 +305,14 @@ def _( ir.reader_options, ir.cloud_options, ir.config_options, - paths[i : i + stride], + paths[i : i + plan.factor], ir.with_columns, ir.skip_rows, ir.n_rows, ir.row_index, ir.predicate, ) - for i in range(0, len(paths), stride) + for i in range(0, len(paths), plan.factor) ] new_node = Union(ir.schema, None, *groups) partition_info = {group: PartitionInfo(count=1) for group in groups} | { From ecc8443ba7863bb1388cb94073d0f0be2569cbd5 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 19 Dec 2024 10:03:06 -0800 Subject: [PATCH 4/4] address review comments --- .../cudf_polars/experimental/io.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py index 6315a5315b3..2a5b400af4c 100644 --- a/python/cudf_polars/cudf_polars/experimental/io.py +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -17,7 +17,7 @@ from cudf_polars.experimental.dispatch import lower_ir_node if TYPE_CHECKING: - from collections.abc import Hashable, MutableMapping + from collections.abc import MutableMapping from cudf_polars.dsl.expr import NamedExpr from cudf_polars.experimental.dispatch import LowerIRTransformer @@ -64,7 +64,16 @@ class ScanPartitionFlavor(IntEnum): class ScanPartitionPlan: - """Scan partitioning plan.""" + """ + Scan partitioning plan. + + Notes + ----- + The meaning of `factor` depends on the value of `flavor`: + - SINGLE_FILE: `factor` must be `1`. + - SPLIT_FILES: `factor` is the number of partitions per file. + - FUSED_FILES: `factor` is the number of files per partition. + """ __slots__ = ("factor", "flavor") factor: int @@ -81,7 +90,6 @@ def __init__(self, factor: int, flavor: ScanPartitionFlavor) -> None: @staticmethod def from_scan(ir: Scan) -> ScanPartitionPlan: """Extract the partitioning plan of a Scan operation.""" - plan = ScanPartitionPlan(1, ScanPartitionFlavor.SINGLE_FILE) if ir.typ == "parquet": # TODO: Use system info to set default blocksize parallel_options = ir.config_options.get("executor_options", {}) @@ -91,19 +99,19 @@ def from_scan(ir: Scan) -> ScanPartitionPlan: if file_size > 0: if file_size > blocksize: # Split large files - plan = ScanPartitionPlan( + return ScanPartitionPlan( math.ceil(file_size / blocksize), ScanPartitionFlavor.SPLIT_FILES, ) else: # Fuse small files - plan = ScanPartitionPlan( - max(int(blocksize / file_size), 1), + return ScanPartitionPlan( + max(blocksize // int(file_size), 1), ScanPartitionFlavor.FUSED_FILES, ) # TODO: Use file sizes for csv and json - return plan + return ScanPartitionPlan(1, ScanPartitionFlavor.SINGLE_FILE) class SplitScan(IR): @@ -123,6 +131,7 @@ class SplitScan(IR): "total_splits", ) _non_child = ( + "schema", "base_scan", "split_index", "total_splits", @@ -134,8 +143,10 @@ class SplitScan(IR): total_splits: int """Total number of splits.""" - def __init__(self, base_scan: Scan, split_index: int, total_splits: int): - self.schema = base_scan.schema + def __init__( + self, schema: Schema, base_scan: Scan, split_index: int, total_splits: int + ): + self.schema = schema self.base_scan = base_scan self.split_index = split_index self.total_splits = total_splits @@ -150,14 +161,6 @@ def __init__(self, base_scan: Scan, split_index: int, total_splits: int): f"Unhandled Scan type for file splitting: {base_scan.typ}" ) - def get_hashable(self) -> Hashable: - """Hashable representation of node.""" - return ( - self.base_scan, - self.split_index, - self.total_splits, - ) - @classmethod def do_evaluate( cls, @@ -290,7 +293,7 @@ def _( ir.predicate, ) slices.extend( - SplitScan(base_scan, sindex, plan.factor) + SplitScan(ir.schema, base_scan, sindex, plan.factor) for sindex in range(plan.factor) ) new_node = Union(ir.schema, None, *slices)