From 12c77f32eee3b1aa0ba5592d9f25b4664104bd04 Mon Sep 17 00:00:00 2001 From: tequilayu <48981002+tequilayu@users.noreply.github.com> Date: Tue, 3 Dec 2024 08:57:51 +0800 Subject: [PATCH 1/5] add comment to Series.tolist method (#17350) closes #15767 This PR adds comment to `Series.tolist` method. It mentions that the method will raise a `TypeError` when it's called and suggest alternatives. Authors: - https://github.com/tequilayu - Michael Wang (https://github.com/isVoid) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/17350 --- python/cudf/cudf/core/series.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 928f3c3d666..58cefc6554e 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -943,6 +943,19 @@ def drop( ) def tolist(self): + """Conversion to host memory lists is currently unsupported + + Raises + ------ + TypeError + If this method is called + + Notes + ----- + cuDF currently does not support implicity conversion from GPU stored series to + host stored lists. A `TypeError` is raised when this method is called. + Consider calling `.to_arrow().to_pylist()` to construct a Python list. + """ raise TypeError( "cuDF does not support conversion to host memory " "via the `tolist()` method. Consider using " From 3785a48eb81be23b44b895624f21acbfc1a828c5 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Tue, 3 Dec 2024 11:17:04 -0600 Subject: [PATCH 2/5] Add multi-partition `DataFrameScan` support to cuDF-Polars (#17441) Follow-up to https://github.com/rapidsai/cudf/pull/17262 Adds support for parallel `DataFrameScan` operations. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/17441 --- python/cudf_polars/cudf_polars/callback.py | 14 +- python/cudf_polars/cudf_polars/dsl/ir.py | 17 +- .../cudf_polars/cudf_polars/dsl/translate.py | 1 + .../cudf_polars/experimental/base.py | 43 +++ .../cudf_polars/experimental/dispatch.py | 84 ++++++ .../cudf_polars/experimental/io.py | 49 ++++ .../cudf_polars/experimental/parallel.py | 245 +++++++++--------- .../cudf_polars/tests/dsl/test_traversal.py | 12 +- .../tests/experimental/test_dataframescan.py | 53 ++++ python/cudf_polars/tests/test_executors.py | 16 ++ 10 files changed, 411 insertions(+), 123 deletions(-) create mode 100644 python/cudf_polars/cudf_polars/experimental/base.py create mode 100644 python/cudf_polars/cudf_polars/experimental/dispatch.py create mode 100644 python/cudf_polars/cudf_polars/experimental/io.py create mode 100644 python/cudf_polars/tests/experimental/test_dataframescan.py diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 95527028aa9..29d3dc4ae79 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -217,7 +217,8 @@ def validate_config_options(config: dict) -> None: If the configuration contains unsupported options. """ if unsupported := ( - config.keys() - {"raise_on_fail", "parquet_options", "executor"} + config.keys() + - {"raise_on_fail", "parquet_options", "executor", "executor_options"} ): raise ValueError( f"Engine configuration contains unsupported settings: {unsupported}" @@ -226,6 +227,17 @@ def validate_config_options(config: dict) -> None: config.get("parquet_options", {}) ) + # Validate executor_options + executor = config.get("executor", "pylibcudf") + if executor == "dask-experimental": + unsupported = config.get("executor_options", {}).keys() - { + "max_rows_per_partition" + } + else: + unsupported = config.get("executor_options", {}).keys() + if unsupported: + raise ValueError(f"Unsupported executor_options for {executor}: {unsupported}") + def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None: """ diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index a28b4cf25b2..1faa778ccf6 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -688,14 +688,16 @@ class DataFrameScan(IR): This typically arises from ``q.collect().lazy()`` """ - __slots__ = ("df", "predicate", "projection") - _non_child = ("schema", "df", "projection", "predicate") + __slots__ = ("config_options", "df", "predicate", "projection") + _non_child = ("schema", "df", "projection", "predicate", "config_options") df: Any """Polars LazyFrame object.""" projection: tuple[str, ...] | None """List of columns to project out.""" predicate: expr.NamedExpr | None """Mask to apply.""" + config_options: dict[str, Any] + """GPU-specific configuration options""" def __init__( self, @@ -703,11 +705,13 @@ def __init__( df: Any, projection: Sequence[str] | None, predicate: expr.NamedExpr | None, + config_options: dict[str, Any], ): self.schema = schema self.df = df self.projection = tuple(projection) if projection is not None else None self.predicate = predicate + self.config_options = config_options self._non_child_args = (schema, df, self.projection, predicate) self.children = () @@ -719,7 +723,14 @@ def get_hashable(self) -> Hashable: not stable across runs, or repeat instances of the same equal dataframes. """ schema_hash = tuple(self.schema.items()) - return (type(self), schema_hash, id(self.df), self.projection, self.predicate) + return ( + type(self), + schema_hash, + id(self.df), + self.projection, + self.predicate, + json.dumps(self.config_options), + ) @classmethod def do_evaluate( diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index b1e2de63ba6..37cf36dc4dd 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -263,6 +263,7 @@ def _( translate_named_expr(translator, n=node.selection) if node.selection is not None else None, + translator.config.config.copy(), ) diff --git a/python/cudf_polars/cudf_polars/experimental/base.py b/python/cudf_polars/cudf_polars/experimental/base.py new file mode 100644 index 00000000000..8f660632df2 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/base.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-partition base classes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from cudf_polars.dsl.ir import Union + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from cudf_polars.containers import DataFrame + from cudf_polars.dsl.nodebase import Node + + +class PartitionInfo: + """ + Partitioning information. + + This class only tracks the partition count (for now). + """ + + __slots__ = ("count",) + + def __init__(self, count: int): + self.count = count + + def keys(self, node: Node) -> Iterator[tuple[str, int]]: + """Return the partitioned keys for a given node.""" + name = get_key_name(node) + yield from ((name, i) for i in range(self.count)) + + +def get_key_name(node: Node) -> str: + """Generate the key name for a Node.""" + return f"{type(node).__name__.lower()}-{hash(node)}" + + +def _concat(dfs: Sequence[DataFrame]) -> DataFrame: + # Concatenate a sequence of DataFrames vertically + return Union.do_evaluate(None, *dfs) diff --git a/python/cudf_polars/cudf_polars/experimental/dispatch.py b/python/cudf_polars/cudf_polars/experimental/dispatch.py new file mode 100644 index 00000000000..79a52ff3cde --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/dispatch.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-partition dispatch functions.""" + +from __future__ import annotations + +from functools import singledispatch +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import MutableMapping + from typing import TypeAlias + + from cudf_polars.dsl.ir import IR + from cudf_polars.experimental.base import PartitionInfo + from cudf_polars.typing import GenericTransformer + + +LowerIRTransformer: TypeAlias = ( + "GenericTransformer[IR, tuple[IR, MutableMapping[IR, PartitionInfo]]]" +) +"""Protocol for Lowering IR nodes.""" + + +@singledispatch +def lower_ir_node( + ir: IR, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + """ + Rewrite an IR node and extract partitioning information. + + Parameters + ---------- + ir + IR node to rewrite. + rec + Recursive LowerIRTransformer callable. + + Returns + ------- + new_ir, partition_info + The rewritten node, and a mapping from unique nodes in + the full IR graph to associated partitioning information. + + Notes + ----- + This function is used by `lower_ir_graph`. + + See Also + -------- + lower_ir_graph + """ + raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + + +@singledispatch +def generate_ir_tasks( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + """ + Generate a task graph for evaluation of an IR node. + + Parameters + ---------- + ir + IR node to generate tasks for. + partition_info + Partitioning information, obtained from :func:`lower_ir_graph`. + + Returns + ------- + mapping + A (partial) dask task graph for the evaluation of an ir node. + + Notes + ----- + Task generation should only produce the tasks for the current node, + referring to child tasks by name. + + See Also + -------- + task_graph + """ + raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover diff --git a/python/cudf_polars/cudf_polars/experimental/io.py b/python/cudf_polars/cudf_polars/experimental/io.py new file mode 100644 index 00000000000..3a1fec36079 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/io.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-partition IO Logic.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +from cudf_polars.dsl.ir import DataFrameScan, Union +from cudf_polars.experimental.base import PartitionInfo +from cudf_polars.experimental.dispatch import lower_ir_node + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from cudf_polars.dsl.ir import IR + from cudf_polars.experimental.dispatch import LowerIRTransformer + + +@lower_ir_node.register(DataFrameScan) +def _( + ir: DataFrameScan, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + rows_per_partition = ir.config_options.get("executor_options", {}).get( + "max_rows_per_partition", 1_000_000 + ) + + nrows = max(ir.df.shape()[0], 1) + count = math.ceil(nrows / rows_per_partition) + + if count > 1: + length = math.ceil(nrows / count) + slices = [ + DataFrameScan( + ir.schema, + ir.df.slice(offset, length), + ir.projection, + ir.predicate, + ir.config_options, + ) + for offset in range(0, nrows, length) + ] + new_node = Union(ir.schema, None, *slices) + return new_node, {slice: PartitionInfo(count=1) for slice in slices} | { + new_node: PartitionInfo(count=count) + } + + return ir, {ir: PartitionInfo(count=1)} diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index 6518dd60c7d..e5884f1c574 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -1,93 +1,46 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 -"""Partitioned LogicalPlan nodes.""" +"""Multi-partition Dask execution.""" from __future__ import annotations +import itertools import operator -from functools import reduce, singledispatch +from functools import reduce from typing import TYPE_CHECKING, Any -from cudf_polars.dsl.ir import IR -from cudf_polars.dsl.traversal import traversal +import cudf_polars.experimental.io # noqa: F401 +from cudf_polars.dsl.ir import IR, Cache, Projection, Union +from cudf_polars.dsl.traversal import CachingVisitor, traversal +from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name +from cudf_polars.experimental.dispatch import ( + generate_ir_tasks, + lower_ir_node, +) if TYPE_CHECKING: from collections.abc import MutableMapping - from typing import TypeAlias from cudf_polars.containers import DataFrame - from cudf_polars.dsl.nodebase import Node - from cudf_polars.typing import GenericTransformer - - -class PartitionInfo: - """ - Partitioning information. - - This class only tracks the partition count (for now). - """ - - __slots__ = ("count",) - - def __init__(self, count: int): - self.count = count - - -LowerIRTransformer: TypeAlias = ( - "GenericTransformer[IR, MutableMapping[IR, PartitionInfo]]" -) -"""Protocol for Lowering IR nodes.""" - - -def get_key_name(node: Node) -> str: - """Generate the key name for a Node.""" - return f"{type(node).__name__.lower()}-{hash(node)}" - - -@singledispatch -def lower_ir_node( - ir: IR, rec: LowerIRTransformer -) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: - """ - Rewrite an IR node and extract partitioning information. - - Parameters - ---------- - ir - IR node to rewrite. - rec - Recursive LowerIRTransformer callable. - - Returns - ------- - new_ir, partition_info - The rewritten node, and a mapping from unique nodes in - the full IR graph to associated partitioning information. - - Notes - ----- - This function is used by `lower_ir_graph`. - - See Also - -------- - lower_ir_graph - """ - raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + from cudf_polars.experimental.dispatch import LowerIRTransformer @lower_ir_node.register(IR) def _(ir: IR, rec: LowerIRTransformer) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + # Default logic - Requires single partition + if len(ir.children) == 0: # Default leaf node has single partition - return ir, {ir: PartitionInfo(count=1)} + return ir, { + ir: PartitionInfo(count=1) + } # pragma: no cover; Missed by pylibcudf executor # Lower children - children, _partition_info = zip(*(rec(c) for c in ir.children), strict=False) + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) partition_info = reduce(operator.or_, _partition_info) # Check that child partitioning is supported - count = max(partition_info[c].count for c in children) - if count > 1: + if any(partition_info[c].count > 1 for c in children): raise NotImplementedError( f"Class {type(ir)} does not support multiple partitions." ) # pragma: no cover @@ -123,41 +76,62 @@ def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: -------- lower_ir_node """ - from cudf_polars.dsl.traversal import CachingVisitor - mapper = CachingVisitor(lower_ir_node) return mapper(ir) -@singledispatch -def generate_ir_tasks( +def task_graph( ir: IR, partition_info: MutableMapping[IR, PartitionInfo] -) -> MutableMapping[Any, Any]: +) -> tuple[MutableMapping[Any, Any], str | tuple[str, int]]: """ - Generate a task graph for evaluation of an IR node. + Construct a task graph for evaluation of an IR graph. Parameters ---------- ir - IR node to generate tasks for. + Root of the graph to rewrite. partition_info - Partitioning information, obtained from :func:`lower_ir_graph`. + A mapping from all unique IR nodes to the + associated partitioning information. Returns ------- - mapping - A (partial) dask task graph for the evaluation of an ir node. + graph + A Dask-compatible task graph for the entire + IR graph with root `ir`. Notes ----- - Task generation should only produce the tasks for the current node, - referring to child tasks by name. + This function traverses the unique nodes of the + graph with root `ir`, and extracts the tasks for + each node with :func:`generate_ir_tasks`. See Also -------- - task_graph + generate_ir_tasks """ - raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + graph = reduce( + operator.or_, + (generate_ir_tasks(node, partition_info) for node in traversal(ir)), + ) + + key_name = get_key_name(ir) + partition_count = partition_info[ir].count + if partition_count > 1: + graph[key_name] = (_concat, list(partition_info[ir].keys(ir))) + return graph, key_name + else: + return graph, (key_name, 0) + + +def evaluate_dask(ir: IR) -> DataFrame: + """Evaluate an IR graph with Dask.""" + from dask import get + + ir, partition_info = lower_ir_graph(ir) + + graph, key = task_graph(ir, partition_info) + return get(graph, key) @generate_ir_tasks.register(IR) @@ -189,48 +163,85 @@ def _( } -def task_graph( - ir: IR, partition_info: MutableMapping[IR, PartitionInfo] -) -> tuple[MutableMapping[Any, Any], str | tuple[str, int]]: - """ - Construct a task graph for evaluation of an IR graph. +@lower_ir_node.register(Union) +def _( + ir: Union, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + # Lower children + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) + partition_info = reduce(operator.or_, _partition_info) - Parameters - ---------- - ir - Root of the graph to rewrite. - partition_info - A mapping from all unique IR nodes to the - associated partitioning information. + # Check zlice + if ir.zlice is not None: # pragma: no cover + if any(p[c].count > 1 for p, c in zip(children, _partition_info, strict=False)): + raise NotImplementedError("zlice is not supported for multiple partitions.") + new_node = ir.reconstruct(children) + partition_info[new_node] = PartitionInfo(count=1) + return new_node, partition_info - Returns - ------- - graph - A Dask-compatible task graph for the entire - IR graph with root `ir`. + # Partition count is the sum of all child partitions + count = sum(partition_info[c].count for c in children) - Notes - ----- - This function traverses the unique nodes of the - graph with root `ir`, and extracts the tasks for - each node with :func:`generate_ir_tasks`. + # Return reconstructed node and partition-info dict + new_node = ir.reconstruct(children) + partition_info[new_node] = PartitionInfo(count=count) + return new_node, partition_info - See Also - -------- - generate_ir_tasks - """ - graph = reduce( - operator.or_, - (generate_ir_tasks(node, partition_info) for node in traversal(ir)), - ) - return graph, (get_key_name(ir), 0) +@generate_ir_tasks.register(Union) +def _( + ir: Union, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + key_name = get_key_name(ir) + partition = itertools.count() + return { + (key_name, next(partition)): child_key + for child in ir.children + for child_key in partition_info[child].keys(child) + } -def evaluate_dask(ir: IR) -> DataFrame: - """Evaluate an IR graph with Dask.""" - from dask import get - ir, partition_info = lower_ir_graph(ir) +def _lower_ir_pwise( + ir: IR, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + # Lower a partition-wise (i.e. embarrassingly-parallel) IR node - graph, key = task_graph(ir, partition_info) - return get(graph, key) + # Lower children + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) + partition_info = reduce(operator.or_, _partition_info) + counts = {partition_info[c].count for c in children} + + # Check that child partitioning is supported + if len(counts) > 1: + raise NotImplementedError( + f"Class {type(ir)} does not support unbalanced partitions." + ) # pragma: no cover + + # Return reconstructed node and partition-info dict + partition = PartitionInfo(count=max(counts)) + new_node = ir.reconstruct(children) + partition_info[new_node] = partition + return new_node, partition_info + + +lower_ir_node.register(Projection, _lower_ir_pwise) +lower_ir_node.register(Cache, _lower_ir_pwise) + + +def _generate_ir_tasks_pwise( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + # Generate partition-wise (i.e. embarrassingly-parallel) tasks + child_names = [get_key_name(c) for c in ir.children] + return { + key: ( + ir.do_evaluate, + *ir._non_child_args, + *[(child_name, i) for child_name in child_names], + ) + for i, key in enumerate(partition_info[ir].keys(ir)) + } + + +generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise) +generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise) diff --git a/python/cudf_polars/tests/dsl/test_traversal.py b/python/cudf_polars/tests/dsl/test_traversal.py index 2f4df9289f8..9755994c419 100644 --- a/python/cudf_polars/tests/dsl/test_traversal.py +++ b/python/cudf_polars/tests/dsl/test_traversal.py @@ -116,7 +116,11 @@ def test_rewrite_ir_node(): def replace_df(node, rec): if isinstance(node, ir.DataFrameScan): return ir.DataFrameScan( - node.schema, new_df._df, node.projection, node.predicate + node.schema, + new_df._df, + node.projection, + node.predicate, + node.config_options, ) return reuse_if_unchanged(node, rec) @@ -144,7 +148,11 @@ def test_rewrite_scan_node(tmp_path): def replace_scan(node, rec): if isinstance(node, ir.Scan): return ir.DataFrameScan( - node.schema, right._df, node.with_columns, node.predicate + node.schema, + right._df, + node.with_columns, + node.predicate, + node.config_options, ) return reuse_if_unchanged(node, rec) diff --git a/python/cudf_polars/tests/experimental/test_dataframescan.py b/python/cudf_polars/tests/experimental/test_dataframescan.py new file mode 100644 index 00000000000..77c7bf0c503 --- /dev/null +++ b/python/cudf_polars/tests/experimental/test_dataframescan.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars import Translator +from cudf_polars.experimental.parallel import lower_ir_graph +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture(scope="module") +def df(): + return pl.LazyFrame( + { + "x": range(30_000), + "y": ["cat", "dog", "fish"] * 10_000, + "z": [1.0, 2.0, 3.0, 4.0, 5.0] * 6_000, + } + ) + + +@pytest.mark.parametrize("max_rows_per_partition", [1_000, 1_000_000]) +def test_parallel_dataframescan(df, max_rows_per_partition): + total_row_count = len(df.collect()) + engine = pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"max_rows_per_partition": max_rows_per_partition}, + ) + assert_gpu_result_equal(df, engine=engine) + + # Check partitioning + qir = Translator(df._ldf.visit(), engine).translate_ir() + ir, info = lower_ir_graph(qir) + count = info[ir].count + if max_rows_per_partition < total_row_count: + assert count > 1 + else: + assert count == 1 + + +def test_dataframescan_concat(df): + engine = pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"max_rows_per_partition": 1_000}, + ) + df2 = pl.concat([df, df]) + assert_gpu_result_equal(df2, engine=engine) diff --git a/python/cudf_polars/tests/test_executors.py b/python/cudf_polars/tests/test_executors.py index 3eaea2ec9ea..b8c0bb926ab 100644 --- a/python/cudf_polars/tests/test_executors.py +++ b/python/cudf_polars/tests/test_executors.py @@ -66,3 +66,19 @@ def test_unknown_executor(): match="ValueError: Unknown executor 'unknown-executor'", ): assert_gpu_result_equal(df, executor="unknown-executor") + + +@pytest.mark.parametrize("executor", [None, "pylibcudf", "dask-experimental"]) +def test_unknown_executor_options(executor): + df = pl.LazyFrame({}) + + with pytest.raises( + pl.exceptions.ComputeError, + match="Unsupported executor_options", + ): + df.collect( + engine=pl.GPUEngine( + executor=executor, + executor_options={"foo": None}, + ) + ) From 4696bbf91ca37ab6960b606d1f7763487ee03ef6 Mon Sep 17 00:00:00 2001 From: Matthew Murray <41342305+Matt711@users.noreply.github.com> Date: Tue, 3 Dec 2024 12:58:35 -0500 Subject: [PATCH 3/5] Revert "Temporarily skip tests due to dask/distributed#8953" (#17492) Reverts rapidsai/cudf#17472 The new dask nightly has resolved https://github.com/dask/distributed/issues/8953 --- .../custreamz/tests/test_dataframes.py | 56 +++---------------- 1 file changed, 7 insertions(+), 49 deletions(-) diff --git a/python/custreamz/custreamz/tests/test_dataframes.py b/python/custreamz/custreamz/tests/test_dataframes.py index 6905044039c..8c0130d2818 100644 --- a/python/custreamz/custreamz/tests/test_dataframes.py +++ b/python/custreamz/custreamz/tests/test_dataframes.py @@ -216,13 +216,7 @@ def test_set_index(): assert_eq(b[0], df.set_index(df.y + 1)) -def test_binary_stream_operators(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_binary_stream_operators(stream): df = cudf.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) expected = df.x + df.y @@ -248,13 +242,7 @@ def test_index(stream): assert_eq(L[1], df.index + 5) -def test_pair_arithmetic(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_pair_arithmetic(stream): df = cudf.DataFrame({"x": list(range(10)), "y": [1] * 10}) a = DataFrame(example=df.iloc[:0], stream=stream) @@ -267,13 +255,7 @@ def test_pair_arithmetic(request, stream): assert_eq(cudf.concat(L), (df.x + df.y) * 2) -def test_getitem(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_getitem(stream): df = cudf.DataFrame({"x": list(range(10)), "y": [1] * 10}) a = DataFrame(example=df.iloc[:0], stream=stream) @@ -350,13 +332,7 @@ def test_repr_html(stream): assert "1" in html -def test_setitem(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_setitem(stream): df = cudf.DataFrame({"x": list(range(10)), "y": [1] * 10}) sdf = DataFrame(example=df.iloc[:0], stream=stream) @@ -380,13 +356,7 @@ def test_setitem(request, stream): assert_eq(L[-1], df.mean()) -def test_setitem_overwrites(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_setitem_overwrites(stream): df = cudf.DataFrame({"x": list(range(10))}) sdf = DataFrame(example=df.iloc[:0], stream=stream) stream = sdf.stream @@ -443,14 +413,8 @@ def test_setitem_overwrites(request, stream): ], ) def test_rolling_count_aggregations( - request, op, window, m, pre_get, post_get, kwargs, stream + op, window, m, pre_get, post_get, kwargs, stream ): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream) and len(kwargs) == 0, - reason="https://github.com/dask/distributed/issues/8953", - ) - ) index = pd.DatetimeIndex( pd.date_range("2000-01-01", "2000-01-03", freq="1h") ) @@ -844,13 +808,7 @@ def test_reductions_with_start_state(stream): assert output2[0] == 360 -def test_rolling_aggs_with_start_state(request, stream): - request.applymarker( - pytest.mark.xfail( - isinstance(stream, DaskStream), - reason="https://github.com/dask/distributed/issues/8953", - ) - ) +def test_rolling_aggs_with_start_state(stream): example = cudf.DataFrame({"name": [], "amount": []}, dtype="float64") sdf = DataFrame(stream, example=example) output0 = ( From d3e94d458ddeaced5ba34a825ab0af5275b73dbe Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Tue, 3 Dec 2024 10:03:29 -0800 Subject: [PATCH 4/5] Apply clang-tidy autofixes from new rules (#17431) This PR contains all of clang-tidy's autofixes for the rules outlined in https://github.com/rapidsai/cudf/issues/17410. In the process I simplified the process of performing autofixes locally. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - David Wendt (https://github.com/davidwendt) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/17431 --- ci/cpp_linters.sh | 2 +- cpp/CMakeLists.txt | 14 +- cpp/src/bitmask/is_element_valid.cpp | 4 +- cpp/src/column/column_view.cpp | 97 +++--- cpp/src/copying/copy.cpp | 12 +- cpp/src/copying/pack.cpp | 81 +++-- cpp/src/datetime/timezone.cpp | 2 +- cpp/src/groupby/sort/aggregate.cpp | 96 +++--- cpp/src/interop/dlpack.cpp | 4 +- cpp/src/interop/to_arrow_schema.cpp | 4 +- cpp/src/io/avro/avro.cpp | 12 +- cpp/src/io/comp/comp.cpp | 8 +- cpp/src/io/comp/nvcomp_adapter.cpp | 280 +++++++++--------- cpp/src/io/comp/uncomp.cpp | 40 +-- cpp/src/io/functions.cpp | 63 ++-- cpp/src/io/json/parser_features.cpp | 139 ++++----- cpp/src/io/parquet/arrow_schema_writer.cpp | 2 +- .../io/parquet/compact_protocol_reader.cpp | 131 ++++---- .../io/parquet/compact_protocol_writer.cpp | 2 +- cpp/src/io/parquet/predicate_pushdown.cpp | 5 +- cpp/src/io/parquet/reader_impl.cpp | 6 +- cpp/src/io/parquet/reader_impl_helpers.cpp | 18 +- cpp/src/io/text/bgzip_utils.cpp | 2 +- cpp/src/io/utilities/base64_utilities.cpp | 6 +- cpp/src/io/utilities/data_sink.cpp | 4 +- cpp/src/io/utilities/datasource.cpp | 8 +- cpp/src/io/utilities/file_io_utilities.cpp | 41 +-- cpp/src/jit/cache.cpp | 12 +- cpp/src/jit/parser.cpp | 56 ++-- .../quantiles/tdigest/tdigest_column_view.cpp | 8 +- cpp/src/reductions/scan/scan.cpp | 3 +- cpp/src/reductions/segmented/reductions.cpp | 3 + .../detail/optimized_unbounded_window.cpp | 54 ++-- cpp/src/strings/regex/regcomp.cpp | 14 +- cpp/src/strings/regex/regexec.cpp | 6 +- cpp/src/structs/utilities.cpp | 2 +- cpp/src/table/table_view.cpp | 33 ++- cpp/src/transform/transform.cpp | 7 +- cpp/src/utilities/prefetch.cpp | 4 +- cpp/src/utilities/stream_pool.cpp | 112 +++---- 40 files changed, 722 insertions(+), 675 deletions(-) diff --git a/ci/cpp_linters.sh b/ci/cpp_linters.sh index 4d5b62ba280..9702b055512 100755 --- a/ci/cpp_linters.sh +++ b/ci/cpp_linters.sh @@ -27,7 +27,7 @@ source rapids-configure-sccache # Run the build via CMake, which will run clang-tidy when CUDF_STATIC_LINTERS is enabled. iwyu_flag="" -if [[ "${RAPIDS_BUILD_TYPE}" == "nightly" ]]; then +if [[ "${RAPIDS_BUILD_TYPE:-}" == "nightly" ]]; then iwyu_flag="-DCUDF_IWYU=ON" fi cmake -S cpp -B cpp/build -DCMAKE_BUILD_TYPE=Release -DCUDF_CLANG_TIDY=ON ${iwyu_flag} -DBUILD_TESTS=OFF -GNinja diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f25b46a52cd..12e6826f301 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -93,6 +93,7 @@ option( mark_as_advanced(CUDF_BUILD_STREAMS_TEST_UTIL) option(CUDF_CLANG_TIDY "Enable clang-tidy during compilation" OFF) option(CUDF_IWYU "Enable IWYU during compilation" OFF) +option(CUDF_CLANG_TIDY_AUTOFIX "Enable clang-tidy autofixes" OFF) option( CUDF_KVIKIO_REMOTE_IO @@ -205,9 +206,16 @@ function(enable_static_checkers target) if(_LINT_CLANG_TIDY) # clang will complain about unused link libraries on the compile line unless we specify # -Qunused-arguments. - set_target_properties( - ${target} PROPERTIES CXX_CLANG_TIDY "${CLANG_TIDY_EXE};--extra-arg=-Qunused-arguments" - ) + if(CUDF_CLANG_TIDY_AUTOFIX) + set_target_properties( + ${target} PROPERTIES CXX_CLANG_TIDY + "${CLANG_TIDY_EXE};--extra-arg=-Qunused-arguments;--fix" + ) + else() + set_target_properties( + ${target} PROPERTIES CXX_CLANG_TIDY "${CLANG_TIDY_EXE};--extra-arg=-Qunused-arguments" + ) + endif() endif() if(_LINT_IWYU) # A few extra warnings pop up when building with IWYU. I'm not sure why, but they are not diff --git a/cpp/src/bitmask/is_element_valid.cpp b/cpp/src/bitmask/is_element_valid.cpp index 4806c7a94e8..7eb80c4249e 100644 --- a/cpp/src/bitmask/is_element_valid.cpp +++ b/cpp/src/bitmask/is_element_valid.cpp @@ -30,9 +30,9 @@ bool is_element_valid_sync(column_view const& col_view, CUDF_EXPECTS(element_index >= 0 and element_index < col_view.size(), "invalid index."); if (!col_view.nullable()) { return true; } - bitmask_type word; + bitmask_type word = 0; // null_mask() returns device ptr to bitmask without offset - size_type index = element_index + col_view.offset(); + size_type const index = element_index + col_view.offset(); CUDF_CUDA_TRY(cudaMemcpyAsync(&word, col_view.null_mask() + word_index(index), sizeof(bitmask_type), diff --git a/cpp/src/column/column_view.cpp b/cpp/src/column/column_view.cpp index e831aa9645d..ea940676f6a 100644 --- a/cpp/src/column/column_view.cpp +++ b/cpp/src/column/column_view.cpp @@ -41,7 +41,7 @@ void prefetch_col_data(ColumnView& col, void const* data_ptr, std::string_view k cudf::experimental::prefetch::detail::prefetch_noexcept( key, data_ptr, col.size() * size_of(col.type()), cudf::get_default_stream()); } else if (col.type().id() == type_id::STRING) { - strings_column_view scv{col}; + strings_column_view const scv{col}; if (data_ptr == nullptr) { // Do not call chars_size if the data_ptr is nullptr. return; @@ -58,51 +58,6 @@ void prefetch_col_data(ColumnView& col, void const* data_ptr, std::string_view k } } -} // namespace - -column_view_base::column_view_base(data_type type, - size_type size, - void const* data, - bitmask_type const* null_mask, - size_type null_count, - size_type offset) - : _type{type}, - _size{size}, - _data{data}, - _null_mask{null_mask}, - _null_count{null_count}, - _offset{offset} -{ - CUDF_EXPECTS(size >= 0, "Column size cannot be negative."); - - if (type.id() == type_id::EMPTY) { - _null_count = size; - CUDF_EXPECTS(nullptr == data, "EMPTY column should have no data."); - CUDF_EXPECTS(nullptr == null_mask, "EMPTY column should have no null mask."); - } else if (is_compound(type)) { - if (type.id() != type_id::STRING) { - CUDF_EXPECTS(nullptr == data, "Compound (parent) columns cannot have data"); - } - } else if (size > 0) { - CUDF_EXPECTS(nullptr != data, "Null data pointer."); - } - - CUDF_EXPECTS(offset >= 0, "Invalid offset."); - - if ((null_count > 0) and (type.id() != type_id::EMPTY)) { - CUDF_EXPECTS(nullptr != null_mask, "Invalid null mask for non-zero null count."); - } -} - -size_type column_view_base::null_count(size_type begin, size_type end) const -{ - CUDF_EXPECTS((begin >= 0) && (end <= size()) && (begin <= end), "Range is out of bounds."); - return (null_count() == 0) - ? 0 - : cudf::detail::null_count( - null_mask(), offset() + begin, offset() + end, cudf::get_default_stream()); -} - // Struct to use custom hash combine and fold expression struct HashValue { std::size_t hash; @@ -133,8 +88,6 @@ std::size_t shallow_hash_impl(column_view const& c, bool is_parent_empty = false }); } -std::size_t shallow_hash(column_view const& input) { return shallow_hash_impl(input); } - bool shallow_equivalent_impl(column_view const& lhs, column_view const& rhs, bool is_parent_empty = false) @@ -151,11 +104,59 @@ bool shallow_equivalent_impl(column_view const& lhs, return shallow_equivalent_impl(lhs_child, rhs_child, is_empty); }); } + +} // namespace + +column_view_base::column_view_base(data_type type, + size_type size, + void const* data, + bitmask_type const* null_mask, + size_type null_count, + size_type offset) + : _type{type}, + _size{size}, + _data{data}, + _null_mask{null_mask}, + _null_count{null_count}, + _offset{offset} +{ + CUDF_EXPECTS(size >= 0, "Column size cannot be negative."); + + if (type.id() == type_id::EMPTY) { + _null_count = size; + CUDF_EXPECTS(nullptr == data, "EMPTY column should have no data."); + CUDF_EXPECTS(nullptr == null_mask, "EMPTY column should have no null mask."); + } else if (is_compound(type)) { + if (type.id() != type_id::STRING) { + CUDF_EXPECTS(nullptr == data, "Compound (parent) columns cannot have data"); + } + } else if (size > 0) { + CUDF_EXPECTS(nullptr != data, "Null data pointer."); + } + + CUDF_EXPECTS(offset >= 0, "Invalid offset."); + + if ((null_count > 0) and (type.id() != type_id::EMPTY)) { + CUDF_EXPECTS(nullptr != null_mask, "Invalid null mask for non-zero null count."); + } +} + +size_type column_view_base::null_count(size_type begin, size_type end) const +{ + CUDF_EXPECTS((begin >= 0) && (end <= size()) && (begin <= end), "Range is out of bounds."); + return (null_count() == 0) + ? 0 + : cudf::detail::null_count( + null_mask(), offset() + begin, offset() + end, cudf::get_default_stream()); +} + bool is_shallow_equivalent(column_view const& lhs, column_view const& rhs) { return shallow_equivalent_impl(lhs, rhs); } +std::size_t shallow_hash(column_view const& input) { return shallow_hash_impl(input); } + } // namespace detail // Immutable view constructor diff --git a/cpp/src/copying/copy.cpp b/cpp/src/copying/copy.cpp index 5e2065ba844..89d8cc3f4aa 100644 --- a/cpp/src/copying/copy.cpp +++ b/cpp/src/copying/copy.cpp @@ -62,11 +62,12 @@ struct scalar_empty_like_functor_impl { auto ls = static_cast(&input); // TODO: add a manual constructor for lists_column_view. - column_view offsets{cudf::data_type{cudf::type_id::INT32}, 0, nullptr, nullptr, 0}; + column_view const offsets{cudf::data_type{cudf::type_id::INT32}, 0, nullptr, nullptr, 0}; std::vector children; children.push_back(offsets); children.push_back(ls->view()); - column_view lcv{cudf::data_type{cudf::type_id::LIST}, 0, nullptr, nullptr, 0, 0, children}; + column_view const lcv{ + cudf::data_type{cudf::type_id::LIST}, 0, nullptr, nullptr, 0, 0, children}; return empty_like(lcv); } @@ -81,8 +82,9 @@ struct scalar_empty_like_functor_impl { // TODO: add a manual constructor for structs_column_view // TODO: add cudf::get_element() support for structs cudf::table_view tbl = ss->view(); - std::vector children(tbl.begin(), tbl.end()); - column_view scv{cudf::data_type{cudf::type_id::STRUCT}, 0, nullptr, nullptr, 0, 0, children}; + std::vector const children(tbl.begin(), tbl.end()); + column_view const scv{ + cudf::data_type{cudf::type_id::STRUCT}, 0, nullptr, nullptr, 0, 0, children}; return empty_like(scv); } @@ -120,7 +122,7 @@ std::unique_ptr allocate_like(column_view const& input, CUDF_FUNC_RANGE(); CUDF_EXPECTS( is_fixed_width(input.type()), "Expects only fixed-width type column", cudf::data_type_error); - mask_state allocate_mask = should_allocate_mask(mask_alloc, input.nullable()); + mask_state const allocate_mask = should_allocate_mask(mask_alloc, input.nullable()); return std::make_unique(input.type(), size, diff --git a/cpp/src/copying/pack.cpp b/cpp/src/copying/pack.cpp index a001807c82b..42ea28f5961 100644 --- a/cpp/src/copying/pack.cpp +++ b/cpp/src/copying/pack.cpp @@ -48,20 +48,20 @@ struct serialized_column { null_count(_null_count), data_offset(_data_offset), null_mask_offset(_null_mask_offset), - num_children(_num_children), - pad(0) + num_children(_num_children) + { } data_type type; - size_type size; - size_type null_count; - int64_t data_offset; // offset into contiguous data buffer, or -1 if column data is null - int64_t null_mask_offset; // offset into contiguous data buffer, or -1 if column data is null - size_type num_children; + size_type size{}; + size_type null_count{}; + int64_t data_offset{}; // offset into contiguous data buffer, or -1 if column data is null + int64_t null_mask_offset{}; // offset into contiguous data buffer, or -1 if column data is null + size_type num_children{}; // Explicitly pad to avoid uninitialized padding bits, allowing `serialized_column` to be bit-wise // comparable - int pad; + int pad{}; }; /** @@ -137,6 +137,34 @@ void build_column_metadata(metadata_builder& mb, }); } +table_view unpack(uint8_t const* metadata, uint8_t const* gpu_data) +{ + // gpu data can be null if everything is empty but the metadata must always be valid + CUDF_EXPECTS(metadata != nullptr, "Encountered invalid packed column input"); + auto serialized_columns = reinterpret_cast(metadata); + uint8_t const* base_ptr = gpu_data; + // first entry is a stub where size == the total # of top level columns (see pack_metadata above) + auto const num_columns = serialized_columns[0].size; + size_t current_index = 1; + + std::function(size_type)> get_columns; + get_columns = [&serialized_columns, ¤t_index, base_ptr, &get_columns](size_t num_columns) { + std::vector cols; + for (size_t i = 0; i < num_columns; i++) { + auto serial_column = serialized_columns[current_index]; + current_index++; + + std::vector const children = get_columns(serial_column.num_children); + + cols.emplace_back(deserialize_column(serial_column, children, base_ptr)); + } + + return cols; + }; + + return table_view{get_columns(num_columns)}; +} + } // anonymous namespace /** @@ -198,37 +226,6 @@ class metadata_builder_impl { std::vector metadata; }; -/** - * @copydoc cudf::detail::unpack - */ -table_view unpack(uint8_t const* metadata, uint8_t const* gpu_data) -{ - // gpu data can be null if everything is empty but the metadata must always be valid - CUDF_EXPECTS(metadata != nullptr, "Encountered invalid packed column input"); - auto serialized_columns = reinterpret_cast(metadata); - uint8_t const* base_ptr = gpu_data; - // first entry is a stub where size == the total # of top level columns (see pack_metadata above) - auto const num_columns = serialized_columns[0].size; - size_t current_index = 1; - - std::function(size_type)> get_columns; - get_columns = [&serialized_columns, ¤t_index, base_ptr, &get_columns](size_t num_columns) { - std::vector cols; - for (size_t i = 0; i < num_columns; i++) { - auto serial_column = serialized_columns[current_index]; - current_index++; - - std::vector children = get_columns(serial_column.num_children); - - cols.emplace_back(deserialize_column(serial_column, children, base_ptr)); - } - - return cols; - }; - - return table_view{get_columns(num_columns)}; -} - metadata_builder::metadata_builder(size_type const num_root_columns) : impl(std::make_unique(num_root_columns + 1 /*one more extra metadata entry as below*/)) @@ -280,9 +277,6 @@ std::vector pack_metadata(table_view const& table, return detail::pack_metadata(table, contiguous_buffer, buffer_size, builder); } -/** - * @copydoc cudf::unpack - */ table_view unpack(packed_columns const& input) { CUDF_FUNC_RANGE(); @@ -292,9 +286,6 @@ table_view unpack(packed_columns const& input) reinterpret_cast(input.gpu_data->data())); } -/** - * @copydoc cudf::unpack(uint8_t const*, uint8_t const* ) - */ table_view unpack(uint8_t const* metadata, uint8_t const* gpu_data) { CUDF_FUNC_RANGE(); diff --git a/cpp/src/datetime/timezone.cpp b/cpp/src/datetime/timezone.cpp index f786624680c..78e4198f60c 100644 --- a/cpp/src/datetime/timezone.cpp +++ b/cpp/src/datetime/timezone.cpp @@ -62,7 +62,7 @@ struct dst_transition_s { #pragma pack(pop) struct timezone_file { - timezone_file_header header; + timezone_file_header header{}; bool is_header_from_64bit = false; std::vector transition_times; diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 3041e261945..7a8a1883ed4 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -45,6 +45,42 @@ namespace cudf { namespace groupby { namespace detail { +namespace { + +/** + * @brief Creates column views with only valid elements in both input column views + * + * @param column_0 The first column + * @param column_1 The second column + * @param stream CUDA stream used for device memory operations and kernel launches + * @return tuple with new null mask (if null masks of input differ) and new column views + */ +auto column_view_with_common_nulls(column_view const& column_0, + column_view const& column_1, + rmm::cuda_stream_view stream) +{ + auto [new_nullmask, null_count] = cudf::bitmask_and(table_view{{column_0, column_1}}, stream); + if (null_count == 0) { return std::make_tuple(std::move(new_nullmask), column_0, column_1); } + auto column_view_with_new_nullmask = [](auto const& col, void* nullmask, auto null_count) { + return column_view(col.type(), + col.size(), + col.head(), + static_cast(nullmask), + null_count, + col.offset(), + std::vector(col.child_begin(), col.child_end())); + }; + auto new_column_0 = null_count == column_0.null_count() + ? column_0 + : column_view_with_new_nullmask(column_0, new_nullmask.data(), null_count); + auto new_column_1 = null_count == column_1.null_count() + ? column_1 + : column_view_with_new_nullmask(column_1, new_nullmask.data(), null_count); + return std::make_tuple(std::move(new_nullmask), new_column_0, new_column_1); +} + +} // namespace + /** * @brief Functor to dispatch aggregation with * @@ -170,13 +206,13 @@ void aggregate_result_functor::operator()(aggregation const& a } else { auto argmin_agg = make_argmin_aggregation(); operator()(*argmin_agg); - column_view argmin_result = cache.get_result(values, *argmin_agg); + column_view const argmin_result = cache.get_result(values, *argmin_agg); // We make a view of ARGMIN result without a null mask and gather using // this mask. The values in data buffer of ARGMIN result corresponding // to null values was initialized to ARGMIN_SENTINEL which is an out of // bounds index value and causes the gathered value to be null. - column_view null_removed_map( + column_view const null_removed_map( data_type(type_to_id()), argmin_result.size(), static_cast(argmin_result.template data()), @@ -212,13 +248,13 @@ void aggregate_result_functor::operator()(aggregation const& a } else { auto argmax_agg = make_argmax_aggregation(); operator()(*argmax_agg); - column_view argmax_result = cache.get_result(values, *argmax_agg); + column_view const argmax_result = cache.get_result(values, *argmax_agg); // We make a view of ARGMAX result without a null mask and gather using // this mask. The values in data buffer of ARGMAX result corresponding // to null values was initialized to ARGMAX_SENTINEL which is an out of // bounds index value and causes the gathered value to be null. - column_view null_removed_map( + column_view const null_removed_map( data_type(type_to_id()), argmax_result.size(), static_cast(argmax_result.template data()), @@ -248,8 +284,8 @@ void aggregate_result_functor::operator()(aggregation const& auto count_agg = make_count_aggregation(); operator()(*sum_agg); operator()(*count_agg); - column_view sum_result = cache.get_result(values, *sum_agg); - column_view count_result = cache.get_result(values, *count_agg); + column_view const sum_result = cache.get_result(values, *sum_agg); + column_view const count_result = cache.get_result(values, *count_agg); // TODO (dm): Special case for timestamp. Add target_type_impl for it. // Blocked until we support operator+ on timestamps @@ -291,8 +327,8 @@ void aggregate_result_functor::operator()(aggregation con auto count_agg = make_count_aggregation(); operator()(*mean_agg); operator()(*count_agg); - column_view mean_result = cache.get_result(values, *mean_agg); - column_view group_sizes = cache.get_result(values, *count_agg); + column_view const mean_result = cache.get_result(values, *mean_agg); + column_view const group_sizes = cache.get_result(values, *count_agg); auto result = detail::group_var(get_grouped_values(), mean_result, @@ -312,7 +348,7 @@ void aggregate_result_functor::operator()(aggregation const& a auto& std_agg = dynamic_cast(agg); auto var_agg = make_variance_aggregation(std_agg._ddof); operator()(*var_agg); - column_view var_result = cache.get_result(values, *var_agg); + column_view const var_result = cache.get_result(values, *var_agg); auto result = cudf::detail::unary_operation(var_result, unary_operator::SQRT, stream, mr); cache.add_result(values, agg, std::move(result)); @@ -325,8 +361,8 @@ void aggregate_result_functor::operator()(aggregation con auto count_agg = make_count_aggregation(); operator()(*count_agg); - column_view group_sizes = cache.get_result(values, *count_agg); - auto& quantile_agg = dynamic_cast(agg); + column_view const group_sizes = cache.get_result(values, *count_agg); + auto& quantile_agg = dynamic_cast(agg); auto result = detail::group_quantiles(get_sorted_values(), group_sizes, @@ -346,7 +382,7 @@ void aggregate_result_functor::operator()(aggregation const auto count_agg = make_count_aggregation(); operator()(*count_agg); - column_view group_sizes = cache.get_result(values, *count_agg); + column_view const group_sizes = cache.get_result(values, *count_agg); auto result = detail::group_quantiles(get_sorted_values(), group_sizes, @@ -391,7 +427,7 @@ void aggregate_result_functor::operator()(aggregation } else { CUDF_FAIL("Wrong count aggregation kind"); } - column_view group_sizes = cache.get_result(values, *count_agg); + column_view const group_sizes = cache.get_result(values, *count_agg); cache.add_result(values, agg, @@ -564,38 +600,6 @@ void aggregate_result_functor::operator()(aggregat get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr)); } -/** - * @brief Creates column views with only valid elements in both input column views - * - * @param column_0 The first column - * @param column_1 The second column - * @param stream CUDA stream used for device memory operations and kernel launches - * @return tuple with new null mask (if null masks of input differ) and new column views - */ -auto column_view_with_common_nulls(column_view const& column_0, - column_view const& column_1, - rmm::cuda_stream_view stream) -{ - auto [new_nullmask, null_count] = cudf::bitmask_and(table_view{{column_0, column_1}}, stream); - if (null_count == 0) { return std::make_tuple(std::move(new_nullmask), column_0, column_1); } - auto column_view_with_new_nullmask = [](auto const& col, void* nullmask, auto null_count) { - return column_view(col.type(), - col.size(), - col.head(), - static_cast(nullmask), - null_count, - col.offset(), - std::vector(col.child_begin(), col.child_end())); - }; - auto new_column_0 = null_count == column_0.null_count() - ? column_0 - : column_view_with_new_nullmask(column_0, new_nullmask.data(), null_count); - auto new_column_1 = null_count == column_1.null_count() - ? column_1 - : column_view_with_new_nullmask(column_1, new_nullmask.data(), null_count); - return std::make_tuple(std::move(new_nullmask), new_column_0, new_column_1); -} - /** * @brief Perform covariance between two child columns of non-nullable struct column. * @@ -734,7 +738,7 @@ void aggregate_result_functor::operator()(aggregation cons auto count_agg = make_count_aggregation(); operator()(*count_agg); - column_view valid_counts = cache.get_result(values, *count_agg); + column_view const valid_counts = cache.get_result(values, *count_agg); cache.add_result(values, agg, diff --git a/cpp/src/interop/dlpack.cpp b/cpp/src/interop/dlpack.cpp index b5cc4cbba0d..fee767255c2 100644 --- a/cpp/src/interop/dlpack.cpp +++ b/cpp/src/interop/dlpack.cpp @@ -115,8 +115,8 @@ DLDataType data_type_to_DLDataType(data_type type) // Context object to own memory allocated for DLManagedTensor struct dltensor_context { - int64_t shape[2]; // NOLINT - int64_t strides[2]; // NOLINT + int64_t shape[2]{}; // NOLINT + int64_t strides[2]{}; // NOLINT rmm::device_buffer buffer; static void deleter(DLManagedTensor* arg) diff --git a/cpp/src/interop/to_arrow_schema.cpp b/cpp/src/interop/to_arrow_schema.cpp index 5afed772656..5dd8d77c261 100644 --- a/cpp/src/interop/to_arrow_schema.cpp +++ b/cpp/src/interop/to_arrow_schema.cpp @@ -44,7 +44,7 @@ struct dispatch_to_arrow_type { template ())> int operator()(column_view input_view, column_metadata const&, ArrowSchema* out) { - cudf::type_id id = input_view.type().id(); + cudf::type_id const id = input_view.type().id(); switch (id) { case cudf::type_id::TIMESTAMP_SECONDS: return ArrowSchemaSetTypeDateTime( @@ -186,7 +186,7 @@ int dispatch_to_arrow_type::operator()(column_view input, column_metadata const& metadata, ArrowSchema* out) { - cudf::dictionary_column_view dview{input}; + cudf::dictionary_column_view const dview{input}; NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(out, id_to_arrow_type(dview.indices().type().id()))); NANOARROW_RETURN_NOT_OK(ArrowSchemaAllocateDictionary(out)); diff --git a/cpp/src/io/avro/avro.cpp b/cpp/src/io/avro/avro.cpp index b3fcca62314..c3a7f0f3053 100644 --- a/cpp/src/io/avro/avro.cpp +++ b/cpp/src/io/avro/avro.cpp @@ -200,7 +200,7 @@ bool container::parse(file_metadata* md, size_t max_num_rows, size_t first_row) // encountered. If they don't, we have to assume the data is corrupted, // and thus, we terminate processing immediately. std::array const sync_marker = {get_raw(), get_raw()}; - bool valid_sync_markers = + bool const valid_sync_markers = ((sync_marker[0] == md->sync_marker[0]) && (sync_marker[1] == md->sync_marker[1])); if (!valid_sync_markers) { return false; } } @@ -218,10 +218,10 @@ bool container::parse(file_metadata* md, size_t max_num_rows, size_t first_row) md->selected_data_size = m_cur - m_start; // Extract columns for (size_t i = 0; i < md->schema.size(); i++) { - type_kind_e kind = md->schema[i].kind; - logicaltype_kind_e logical_kind = md->schema[i].logical_kind; + type_kind_e const kind = md->schema[i].kind; + logicaltype_kind_e const logical_kind = md->schema[i].logical_kind; - bool is_supported_kind = ((kind > type_null) && (kind < type_record)); + bool const is_supported_kind = ((kind > type_null) && (kind < type_record)); if (is_supported_logical_type(logical_kind) || is_supported_kind) { column_desc col; int parent_idx = md->schema[i].parent_idx; @@ -302,7 +302,7 @@ bool schema_parser::parse(std::vector& schema, std::string const& // Empty schema if (json_str == "[]") return true; - std::array depthbuf; + std::array depthbuf{}; int depth = 0, parent_idx = -1, entry_idx = -1; json_state_e state = state_attrname; std::string str; @@ -341,7 +341,7 @@ bool schema_parser::parse(std::vector& schema, std::string const& m_cur = m_base; m_end = m_base + json_str.length(); while (more_data()) { - int c = *m_cur++; + int const c = *m_cur++; switch (c) { case '"': str = get_str(); diff --git a/cpp/src/io/comp/comp.cpp b/cpp/src/io/comp/comp.cpp index b26a6292806..2dda2287e09 100644 --- a/cpp/src/io/comp/comp.cpp +++ b/cpp/src/io/comp/comp.cpp @@ -48,13 +48,13 @@ std::vector compress_gzip(host_span src) zs.avail_out = 0; zs.next_out = nullptr; - int windowbits = 15; - int gzip_encoding = 16; - int ret = deflateInit2( + constexpr int windowbits = 15; + constexpr int gzip_encoding = 16; + int ret = deflateInit2( &zs, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowbits | gzip_encoding, 8, Z_DEFAULT_STRATEGY); CUDF_EXPECTS(ret == Z_OK, "GZIP DEFLATE compression initialization failed."); - uint32_t estcomplen = deflateBound(&zs, src.size()); + uint32_t const estcomplen = deflateBound(&zs, src.size()); dst.resize(estcomplen); zs.avail_out = estcomplen; zs.next_out = dst.data(); diff --git a/cpp/src/io/comp/nvcomp_adapter.cpp b/cpp/src/io/comp/nvcomp_adapter.cpp index c3187f73a95..b8bf8be6d2d 100644 --- a/cpp/src/io/comp/nvcomp_adapter.cpp +++ b/cpp/src/io/comp/nvcomp_adapter.cpp @@ -31,6 +31,7 @@ #include namespace cudf::io::nvcomp { +namespace { // Dispatcher for nvcompBatchedDecompressGetTempSizeEx template @@ -50,19 +51,6 @@ auto batched_decompress_get_temp_size_ex(compression_type compression, Args&&... default: CUDF_FAIL("Unsupported compression type"); } } -size_t batched_decompress_temp_size(compression_type compression, - size_t num_chunks, - size_t max_uncomp_chunk_size, - size_t max_total_uncomp_size) -{ - size_t temp_size = 0; - nvcompStatus_t nvcomp_status = batched_decompress_get_temp_size_ex( - compression, num_chunks, max_uncomp_chunk_size, &temp_size, max_total_uncomp_size); - - CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, - "Unable to get scratch size for decompression"); - return temp_size; -} // Dispatcher for nvcompBatchedDecompressAsync template @@ -94,40 +82,6 @@ std::string compression_type_name(compression_type compression) return "compression_type(" + std::to_string(static_cast(compression)) + ")"; } -void batched_decompress(compression_type compression, - device_span const> inputs, - device_span const> outputs, - device_span results, - size_t max_uncomp_chunk_size, - size_t max_total_uncomp_size, - rmm::cuda_stream_view stream) -{ - auto const num_chunks = inputs.size(); - - // cuDF inflate inputs converted to nvcomp inputs - auto const nvcomp_args = create_batched_nvcomp_args(inputs, outputs, stream); - rmm::device_uvector actual_uncompressed_data_sizes(num_chunks, stream); - rmm::device_uvector nvcomp_statuses(num_chunks, stream); - // Temporary space required for decompression - auto const temp_size = batched_decompress_temp_size( - compression, num_chunks, max_uncomp_chunk_size, max_total_uncomp_size); - rmm::device_buffer scratch(temp_size, stream); - auto const nvcomp_status = batched_decompress_async(compression, - nvcomp_args.input_data_ptrs.data(), - nvcomp_args.input_data_sizes.data(), - nvcomp_args.output_data_sizes.data(), - actual_uncompressed_data_sizes.data(), - num_chunks, - scratch.data(), - scratch.size(), - nvcomp_args.output_data_ptrs.data(), - nvcomp_statuses.data(), - stream.value()); - CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, "unable to perform decompression"); - - update_compression_results(nvcomp_statuses, actual_uncompressed_data_sizes, results, stream); -} - size_t batched_compress_temp_size(compression_type compression, size_t batch_size, size_t max_uncompressed_chunk_bytes, @@ -172,52 +126,17 @@ size_t batched_compress_temp_size(compression_type compression, return temp_size; } -// Wrapper for nvcompBatchedCompressGetMaxOutputChunkSize -size_t compress_max_output_chunk_size(compression_type compression, - uint32_t max_uncompressed_chunk_bytes) -{ - auto const capped_uncomp_bytes = std::min( - compress_max_allowed_chunk_size(compression).value_or(max_uncompressed_chunk_bytes), - max_uncompressed_chunk_bytes); - - size_t max_comp_chunk_size = 0; - nvcompStatus_t status = nvcompStatus_t::nvcompSuccess; - switch (compression) { - case compression_type::SNAPPY: - status = nvcompBatchedSnappyCompressGetMaxOutputChunkSize( - capped_uncomp_bytes, nvcompBatchedSnappyDefaultOpts, &max_comp_chunk_size); - break; - case compression_type::DEFLATE: - status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize( - capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size); - break; - case compression_type::ZSTD: - status = nvcompBatchedZstdCompressGetMaxOutputChunkSize( - capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size); - break; - case compression_type::LZ4: - status = nvcompBatchedLZ4CompressGetMaxOutputChunkSize( - capped_uncomp_bytes, nvcompBatchedLZ4DefaultOpts, &max_comp_chunk_size); - break; - default: CUDF_FAIL("Unsupported compression type"); - } - - CUDF_EXPECTS(status == nvcompStatus_t::nvcompSuccess, - "failed to get max uncompressed chunk size"); - return max_comp_chunk_size; -} - // Dispatcher for nvcompBatchedCompressAsync -static void batched_compress_async(compression_type compression, - void const* const* device_uncompressed_ptrs, - size_t const* device_uncompressed_bytes, - size_t max_uncompressed_chunk_bytes, - size_t batch_size, - void* device_temp_ptr, - size_t temp_bytes, - void* const* device_compressed_ptrs, - size_t* device_compressed_bytes, - rmm::cuda_stream_view stream) +void batched_compress_async(compression_type compression, + void const* const* device_uncompressed_ptrs, + size_t const* device_uncompressed_bytes, + size_t max_uncompressed_chunk_bytes, + size_t batch_size, + void* device_temp_ptr, + size_t temp_bytes, + void* const* device_compressed_ptrs, + size_t* device_compressed_bytes, + rmm::cuda_stream_view stream) { nvcompStatus_t nvcomp_status = nvcompStatus_t::nvcompSuccess; switch (compression) { @@ -279,6 +198,137 @@ bool is_aligned(void const* ptr, std::uintptr_t alignment) noexcept return (reinterpret_cast(ptr) % alignment) == 0; } +std::optional is_compression_disabled_impl(compression_type compression, + feature_status_parameters params) +{ + switch (compression) { + case compression_type::DEFLATE: { + if (not params.are_all_integrations_enabled) { + return "DEFLATE compression is experimental, you can enable it through " + "`LIBCUDF_NVCOMP_POLICY` environment variable."; + } + return std::nullopt; + } + case compression_type::LZ4: + case compression_type::SNAPPY: + case compression_type::ZSTD: + if (not params.are_stable_integrations_enabled) { + return "nvCOMP use is disabled through the `LIBCUDF_NVCOMP_POLICY` environment variable."; + } + return std::nullopt; + default: return "Unsupported compression type"; + } +} + +std::optional is_decompression_disabled_impl(compression_type compression, + feature_status_parameters params) +{ + switch (compression) { + case compression_type::DEFLATE: + case compression_type::GZIP: { + if (not params.are_all_integrations_enabled) { + return "DEFLATE decompression is experimental, you can enable it through " + "`LIBCUDF_NVCOMP_POLICY` environment variable."; + } + return std::nullopt; + } + case compression_type::LZ4: + case compression_type::SNAPPY: + case compression_type::ZSTD: { + if (not params.are_stable_integrations_enabled) { + return "nvCOMP use is disabled through the `LIBCUDF_NVCOMP_POLICY` environment variable."; + } + return std::nullopt; + } + } + return "Unsupported compression type"; +} + +} // namespace + +size_t batched_decompress_temp_size(compression_type compression, + size_t num_chunks, + size_t max_uncomp_chunk_size, + size_t max_total_uncomp_size) +{ + size_t temp_size = 0; + nvcompStatus_t const nvcomp_status = batched_decompress_get_temp_size_ex( + compression, num_chunks, max_uncomp_chunk_size, &temp_size, max_total_uncomp_size); + + CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, + "Unable to get scratch size for decompression"); + return temp_size; +} + +void batched_decompress(compression_type compression, + device_span const> inputs, + device_span const> outputs, + device_span results, + size_t max_uncomp_chunk_size, + size_t max_total_uncomp_size, + rmm::cuda_stream_view stream) +{ + auto const num_chunks = inputs.size(); + + // cuDF inflate inputs converted to nvcomp inputs + auto const nvcomp_args = create_batched_nvcomp_args(inputs, outputs, stream); + rmm::device_uvector actual_uncompressed_data_sizes(num_chunks, stream); + rmm::device_uvector nvcomp_statuses(num_chunks, stream); + // Temporary space required for decompression + auto const temp_size = batched_decompress_temp_size( + compression, num_chunks, max_uncomp_chunk_size, max_total_uncomp_size); + rmm::device_buffer scratch(temp_size, stream); + auto const nvcomp_status = batched_decompress_async(compression, + nvcomp_args.input_data_ptrs.data(), + nvcomp_args.input_data_sizes.data(), + nvcomp_args.output_data_sizes.data(), + actual_uncompressed_data_sizes.data(), + num_chunks, + scratch.data(), + scratch.size(), + nvcomp_args.output_data_ptrs.data(), + nvcomp_statuses.data(), + stream.value()); + CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, "unable to perform decompression"); + + update_compression_results(nvcomp_statuses, actual_uncompressed_data_sizes, results, stream); +} + +// Wrapper for nvcompBatchedCompressGetMaxOutputChunkSize +size_t compress_max_output_chunk_size(compression_type compression, + uint32_t max_uncompressed_chunk_bytes) +{ + auto const capped_uncomp_bytes = std::min( + compress_max_allowed_chunk_size(compression).value_or(max_uncompressed_chunk_bytes), + max_uncompressed_chunk_bytes); + + size_t max_comp_chunk_size = 0; + nvcompStatus_t status = nvcompStatus_t::nvcompSuccess; + switch (compression) { + case compression_type::SNAPPY: + status = nvcompBatchedSnappyCompressGetMaxOutputChunkSize( + capped_uncomp_bytes, nvcompBatchedSnappyDefaultOpts, &max_comp_chunk_size); + break; + case compression_type::DEFLATE: + status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize( + capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size); + break; + case compression_type::ZSTD: + status = nvcompBatchedZstdCompressGetMaxOutputChunkSize( + capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size); + break; + case compression_type::LZ4: + status = nvcompBatchedLZ4CompressGetMaxOutputChunkSize( + capped_uncomp_bytes, nvcompBatchedLZ4DefaultOpts, &max_comp_chunk_size); + break; + default: CUDF_FAIL("Unsupported compression type"); + } + + CUDF_EXPECTS(status == nvcompStatus_t::nvcompSuccess, + "failed to get max uncompressed chunk size"); + return max_comp_chunk_size; +} + void batched_compress(compression_type compression, device_span const> inputs, device_span const> outputs, @@ -347,28 +397,6 @@ struct hash_feature_status_inputs { using feature_status_memo_map = std::unordered_map, hash_feature_status_inputs>; -std::optional is_compression_disabled_impl(compression_type compression, - feature_status_parameters params) -{ - switch (compression) { - case compression_type::DEFLATE: { - if (not params.are_all_integrations_enabled) { - return "DEFLATE compression is experimental, you can enable it through " - "`LIBCUDF_NVCOMP_POLICY` environment variable."; - } - return std::nullopt; - } - case compression_type::LZ4: - case compression_type::SNAPPY: - case compression_type::ZSTD: - if (not params.are_stable_integrations_enabled) { - return "nvCOMP use is disabled through the `LIBCUDF_NVCOMP_POLICY` environment variable."; - } - return std::nullopt; - default: return "Unsupported compression type"; - } -} - std::optional is_compression_disabled(compression_type compression, feature_status_parameters params) { @@ -398,30 +426,6 @@ std::optional is_compression_disabled(compression_type compression, return reason; } -std::optional is_decompression_disabled_impl(compression_type compression, - feature_status_parameters params) -{ - switch (compression) { - case compression_type::DEFLATE: - case compression_type::GZIP: { - if (not params.are_all_integrations_enabled) { - return "DEFLATE decompression is experimental, you can enable it through " - "`LIBCUDF_NVCOMP_POLICY` environment variable."; - } - return std::nullopt; - } - case compression_type::LZ4: - case compression_type::SNAPPY: - case compression_type::ZSTD: { - if (not params.are_stable_integrations_enabled) { - return "nvCOMP use is disabled through the `LIBCUDF_NVCOMP_POLICY` environment variable."; - } - return std::nullopt; - } - } - return "Unsupported compression type"; -} - std::optional is_decompression_disabled(compression_type compression, feature_status_parameters params) { diff --git a/cpp/src/io/comp/uncomp.cpp b/cpp/src/io/comp/uncomp.cpp index b3d43fa786a..4ab5174387e 100644 --- a/cpp/src/io/comp/uncomp.cpp +++ b/cpp/src/io/comp/uncomp.cpp @@ -127,7 +127,7 @@ struct zip_archive_s { bool ParseGZArchive(gz_archive_s* dst, uint8_t const* raw, size_t len) { - gz_file_header_s const* fhdr; + gz_file_header_s const* fhdr = nullptr; if (!dst) return false; memset(dst, 0, sizeof(gz_archive_s)); @@ -138,7 +138,7 @@ bool ParseGZArchive(gz_archive_s* dst, uint8_t const* raw, size_t len) raw += sizeof(gz_file_header_s); len -= sizeof(gz_file_header_s); if (fhdr->flags & GZIPHeaderFlag::fextra) { - uint32_t xlen; + uint32_t xlen = 0; if (len < 2) return false; xlen = raw[0] | (raw[1] << 8); @@ -151,8 +151,8 @@ bool ParseGZArchive(gz_archive_s* dst, uint8_t const* raw, size_t len) len -= xlen; } if (fhdr->flags & GZIPHeaderFlag::fname) { - size_t l = 0; - uint8_t c; + size_t l = 0; + uint8_t c = 0; do { if (l >= len) return false; c = raw[l]; @@ -163,8 +163,8 @@ bool ParseGZArchive(gz_archive_s* dst, uint8_t const* raw, size_t len) len -= l; } if (fhdr->flags & GZIPHeaderFlag::fcomment) { - size_t l = 0; - uint8_t c; + size_t l = 0; + uint8_t c = 0; do { if (l >= len) return false; c = raw[l]; @@ -219,7 +219,7 @@ bool OpenZipArchive(zip_archive_s* dst, uint8_t const* raw, size_t len) int cpu_inflate(uint8_t* uncomp_data, size_t* destLen, uint8_t const* comp_data, size_t comp_len) { - int zerr; + int zerr = 0; z_stream strm; memset(&strm, 0, sizeof(strm)); @@ -291,7 +291,7 @@ size_t decompress_zlib(host_span src, host_span dst) */ size_t decompress_gzip(host_span src, host_span dst) { - gz_archive_s gz; + gz_archive_s gz{}; auto const parse_succeeded = ParseGZArchive(&gz, src.data(), src.size()); CUDF_EXPECTS(parse_succeeded, "Failed to parse GZIP header"); return decompress_zlib({gz.comp_data, gz.comp_len}, dst); @@ -303,12 +303,12 @@ size_t decompress_gzip(host_span src, host_span dst) size_t decompress_snappy(host_span src, host_span dst) { CUDF_EXPECTS(not dst.empty() and src.size() >= 1, "invalid Snappy decompress inputs"); - uint32_t uncompressed_size, bytes_left, dst_pos; + uint32_t uncompressed_size = 0, bytes_left = 0, dst_pos = 0; auto cur = src.begin(); auto const end = src.end(); // Read uncompressed length (varint) { - uint32_t l = 0, c; + uint32_t l = 0, c = 0; uncompressed_size = 0; do { c = *cur++; @@ -328,7 +328,7 @@ size_t decompress_snappy(host_span src, host_span dst) if (blen & 3) { // Copy - uint32_t offset; + uint32_t offset = 0; if (blen & 2) { // xxxxxx1x: copy with 6-bit length, 2-byte or 4-byte offset if (cur + 2 > end) break; @@ -441,7 +441,7 @@ source_properties get_source_properties(compression_type compression, host_span< switch (compression) { case compression_type::AUTO: case compression_type::GZIP: { - gz_archive_s gz; + gz_archive_s gz{}; auto const parse_succeeded = ParseGZArchive(&gz, src.data(), src.size()); CUDF_EXPECTS(parse_succeeded, "Failed to parse GZIP header while fetching source properties"); compression = compression_type::GZIP; @@ -452,26 +452,28 @@ source_properties get_source_properties(compression_type compression, host_span< [[fallthrough]]; } case compression_type::ZIP: { - zip_archive_s za; + zip_archive_s za{}; if (OpenZipArchive(&za, raw, src.size())) { size_t cdfh_ofs = 0; for (int i = 0; i < za.eocd->num_entries; i++) { auto const* cdfh = reinterpret_cast( reinterpret_cast(za.cdfh) + cdfh_ofs); - int cdfh_len = sizeof(zip_cdfh_s) + cdfh->fname_len + cdfh->extra_len + cdfh->comment_len; + int const cdfh_len = + sizeof(zip_cdfh_s) + cdfh->fname_len + cdfh->extra_len + cdfh->comment_len; if (cdfh_ofs + cdfh_len > za.eocd->cdir_size || cdfh->sig != 0x0201'4b50) { // Bad cdir break; } // For now, only accept with non-zero file sizes and DEFLATE if (cdfh->comp_method == 8 && cdfh->comp_size > 0 && cdfh->uncomp_size > 0) { - size_t lfh_ofs = cdfh->hdr_ofs; - auto const* lfh = reinterpret_cast(raw + lfh_ofs); + size_t const lfh_ofs = cdfh->hdr_ofs; + auto const* lfh = reinterpret_cast(raw + lfh_ofs); if (lfh_ofs + sizeof(zip_lfh_s) <= src.size() && lfh->sig == 0x0403'4b50 && lfh_ofs + sizeof(zip_lfh_s) + lfh->fname_len + lfh->extra_len <= src.size()) { if (lfh->comp_method == 8 && lfh->comp_size > 0 && lfh->uncomp_size > 0) { - size_t file_start = lfh_ofs + sizeof(zip_lfh_s) + lfh->fname_len + lfh->extra_len; - size_t file_end = file_start + lfh->comp_size; + size_t const file_start = + lfh_ofs + sizeof(zip_lfh_s) + lfh->fname_len + lfh->extra_len; + size_t const file_end = file_start + lfh->comp_size; if (file_end <= src.size()) { // Pick the first valid file of non-zero size (only 1 file expected in archive) compression = compression_type::ZIP; @@ -510,7 +512,7 @@ source_properties get_source_properties(compression_type compression, host_span< auto const end = src.end(); // Read uncompressed length (varint) { - uint32_t l = 0, c; + uint32_t l = 0, c = 0; do { c = *cur++; auto const lo7 = c & 0x7f; diff --git a/cpp/src/io/functions.cpp b/cpp/src/io/functions.cpp index ceaeb5d8f85..88423122e16 100644 --- a/cpp/src/io/functions.cpp +++ b/cpp/src/io/functions.cpp @@ -39,6 +39,38 @@ #include namespace cudf::io { +namespace { + +compression_type infer_compression_type(compression_type compression, source_info const& info) +{ + if (compression != compression_type::AUTO) { return compression; } + + if (info.type() != io_type::FILEPATH) { return compression_type::NONE; } + + auto filepath = info.filepaths()[0]; + + // Attempt to infer from the file extension + auto const pos = filepath.find_last_of('.'); + + if (pos == std::string::npos) { return {}; } + + auto str_tolower = [](auto const& begin, auto const& end) { + std::string out; + std::transform(begin, end, std::back_inserter(out), ::tolower); + return out; + }; + + auto const ext = str_tolower(filepath.begin() + pos + 1, filepath.end()); + + if (ext == "gz") { return compression_type::GZIP; } + if (ext == "zip") { return compression_type::ZIP; } + if (ext == "bz2") { return compression_type::BZIP2; } + if (ext == "xz") { return compression_type::XZ; } + + return compression_type::NONE; +} + +} // namespace // Returns builder for csv_reader_options csv_reader_options_builder csv_reader_options::builder(source_info src) @@ -170,35 +202,6 @@ table_with_metadata read_avro(avro_reader_options const& options, rmm::device_as return avro::read_avro(std::move(datasources[0]), options, cudf::get_default_stream(), mr); } -compression_type infer_compression_type(compression_type compression, source_info const& info) -{ - if (compression != compression_type::AUTO) { return compression; } - - if (info.type() != io_type::FILEPATH) { return compression_type::NONE; } - - auto filepath = info.filepaths()[0]; - - // Attempt to infer from the file extension - auto const pos = filepath.find_last_of('.'); - - if (pos == std::string::npos) { return {}; } - - auto str_tolower = [](auto const& begin, auto const& end) { - std::string out; - std::transform(begin, end, std::back_inserter(out), ::tolower); - return out; - }; - - auto const ext = str_tolower(filepath.begin() + pos + 1, filepath.end()); - - if (ext == "gz") { return compression_type::GZIP; } - if (ext == "zip") { return compression_type::ZIP; } - if (ext == "bz2") { return compression_type::BZIP2; } - if (ext == "xz") { return compression_type::XZ; } - - return compression_type::NONE; -} - table_with_metadata read_json(json_reader_options options, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -287,7 +290,7 @@ raw_orc_statistics read_raw_orc_statistics(source_info const& src_info, CUDF_FAIL("Unsupported source type"); } - orc::metadata metadata(source.get(), stream); + orc::metadata const metadata(source.get(), stream); // Initialize statistics to return raw_orc_statistics result; diff --git a/cpp/src/io/json/parser_features.cpp b/cpp/src/io/json/parser_features.cpp index e795e8e09d8..ced7acb9cde 100644 --- a/cpp/src/io/json/parser_features.cpp +++ b/cpp/src/io/json/parser_features.cpp @@ -68,6 +68,77 @@ void json_reader_options::set_dtypes(schema_element types) } // namespace cudf::io namespace cudf::io::json::detail { +namespace { + +// example schema and its path. +// "a": int {"a", int} +// "a": [ int ] {"a", list}, {"element", int} +// "a": { "b": int} {"a", struct}, {"b", int} +// "a": [ {"b": int }] {"a", list}, {"element", struct}, {"b", int} +// "a": [ null] {"a", list}, {"element", str} +// back() is root. +// front() is leaf. +/** + * @brief Get the path data type of a column by path if present in input schema + * + * @param path path of the json column + * @param root root of input schema element + * @return data type of the column if present, otherwise std::nullopt + */ +std::optional get_path_data_type( + host_span const> path, schema_element const& root) +{ + if (path.empty() || path.size() == 1) { + return root.type; + } else { + if (path.back().second == NC_STRUCT && root.type.id() == type_id::STRUCT) { + auto const child_name = path.first(path.size() - 1).back().first; + auto const child_schema_it = root.child_types.find(child_name); + return (child_schema_it != std::end(root.child_types)) + ? get_path_data_type(path.first(path.size() - 1), child_schema_it->second) + : std::optional{}; + } else if (path.back().second == NC_LIST && root.type.id() == type_id::LIST) { + auto const child_schema_it = root.child_types.find(list_child_name); + return (child_schema_it != std::end(root.child_types)) + ? get_path_data_type(path.first(path.size() - 1), child_schema_it->second) + : std::optional{}; + } + return std::optional{}; + } +} + +std::optional child_schema_element(std::string const& col_name, + cudf::io::json_reader_options const& options) +{ + return std::visit( + cudf::detail::visitor_overload{ + [col_name](std::vector const& user_dtypes) -> std::optional { + auto column_index = atol(col_name.data()); + return (static_cast(column_index) < user_dtypes.size()) + ? std::optional{{user_dtypes[column_index]}} + : std::optional{}; + }, + [col_name]( + std::map const& user_dtypes) -> std::optional { + return (user_dtypes.find(col_name) != std::end(user_dtypes)) + ? std::optional{{user_dtypes.find(col_name)->second}} + : std::optional{}; + }, + [col_name]( + std::map const& user_dtypes) -> std::optional { + return (user_dtypes.find(col_name) != std::end(user_dtypes)) + ? user_dtypes.find(col_name)->second + : std::optional{}; + }, + [col_name](schema_element const& user_dtypes) -> std::optional { + return (user_dtypes.child_types.find(col_name) != std::end(user_dtypes.child_types)) + ? user_dtypes.child_types.find(col_name)->second + : std::optional{}; + }}, + options.get_dtypes()); +} + +} // namespace /// Created an empty column of the specified schema struct empty_column_functor { @@ -211,74 +282,6 @@ column_name_info make_column_name_info(schema_element const& schema, std::string return info; } -std::optional child_schema_element(std::string const& col_name, - cudf::io::json_reader_options const& options) -{ - return std::visit( - cudf::detail::visitor_overload{ - [col_name](std::vector const& user_dtypes) -> std::optional { - auto column_index = atol(col_name.data()); - return (static_cast(column_index) < user_dtypes.size()) - ? std::optional{{user_dtypes[column_index]}} - : std::optional{}; - }, - [col_name]( - std::map const& user_dtypes) -> std::optional { - return (user_dtypes.find(col_name) != std::end(user_dtypes)) - ? std::optional{{user_dtypes.find(col_name)->second}} - : std::optional{}; - }, - [col_name]( - std::map const& user_dtypes) -> std::optional { - return (user_dtypes.find(col_name) != std::end(user_dtypes)) - ? user_dtypes.find(col_name)->second - : std::optional{}; - }, - [col_name](schema_element const& user_dtypes) -> std::optional { - return (user_dtypes.child_types.find(col_name) != std::end(user_dtypes.child_types)) - ? user_dtypes.child_types.find(col_name)->second - : std::optional{}; - }}, - options.get_dtypes()); -} - -// example schema and its path. -// "a": int {"a", int} -// "a": [ int ] {"a", list}, {"element", int} -// "a": { "b": int} {"a", struct}, {"b", int} -// "a": [ {"b": int }] {"a", list}, {"element", struct}, {"b", int} -// "a": [ null] {"a", list}, {"element", str} -// back() is root. -// front() is leaf. -/** - * @brief Get the path data type of a column by path if present in input schema - * - * @param path path of the json column - * @param root root of input schema element - * @return data type of the column if present, otherwise std::nullopt - */ -std::optional get_path_data_type( - host_span const> path, schema_element const& root) -{ - if (path.empty() || path.size() == 1) { - return root.type; - } else { - if (path.back().second == NC_STRUCT && root.type.id() == type_id::STRUCT) { - auto const child_name = path.first(path.size() - 1).back().first; - auto const child_schema_it = root.child_types.find(child_name); - return (child_schema_it != std::end(root.child_types)) - ? get_path_data_type(path.first(path.size() - 1), child_schema_it->second) - : std::optional{}; - } else if (path.back().second == NC_LIST && root.type.id() == type_id::LIST) { - auto const child_schema_it = root.child_types.find(list_child_name); - return (child_schema_it != std::end(root.child_types)) - ? get_path_data_type(path.first(path.size() - 1), child_schema_it->second) - : std::optional{}; - } - return std::optional{}; - } -} - std::optional get_path_data_type( host_span const> path, cudf::io::json_reader_options const& options) diff --git a/cpp/src/io/parquet/arrow_schema_writer.cpp b/cpp/src/io/parquet/arrow_schema_writer.cpp index d15435b2553..a4536ac6a3b 100644 --- a/cpp/src/io/parquet/arrow_schema_writer.cpp +++ b/cpp/src/io/parquet/arrow_schema_writer.cpp @@ -336,7 +336,7 @@ std::string construct_arrow_schema_ipc_message(cudf::detail::LinkedColVector con { // Lambda function to convert int32 to a string of uint8 bytes auto const convert_int32_to_byte_string = [&](int32_t const value) { - std::array buffer; + std::array buffer{}; std::memcpy(buffer.data(), &value, sizeof(int32_t)); return std::string(reinterpret_cast(buffer.data()), buffer.size()); }; diff --git a/cpp/src/io/parquet/compact_protocol_reader.cpp b/cpp/src/io/parquet/compact_protocol_reader.cpp index d276e946a51..f1ecf66c29f 100644 --- a/cpp/src/io/parquet/compact_protocol_reader.cpp +++ b/cpp/src/io/parquet/compact_protocol_reader.cpp @@ -27,23 +27,7 @@ #include namespace cudf::io::parquet::detail { - -/** - * @brief Base class for parquet field functors. - * - * Holds the field value used by all of the specialized functors. - */ -class parquet_field { - private: - int _field_val; - - protected: - parquet_field(int f) : _field_val(f) {} - - public: - virtual ~parquet_field() = default; - [[nodiscard]] int field() const { return _field_val; } -}; +namespace { std::string field_type_string(FieldType type) { @@ -79,6 +63,72 @@ void assert_bool_field_type(int type) "expected bool field, got " + field_type_string(field_type) + " field instead"); } +template +struct FunctionSwitchImpl { + template + static inline void run(CompactProtocolReader* cpr, + int field_type, + int const& field, + std::tuple& ops) + { + if (field == std::get(ops).field()) { + std::get(ops)(cpr, field_type); + } else { + FunctionSwitchImpl::run(cpr, field_type, field, ops); + } + } +}; + +template <> +struct FunctionSwitchImpl<0> { + template + static inline void run(CompactProtocolReader* cpr, + int field_type, + int const& field, + std::tuple& ops) + { + if (field == std::get<0>(ops).field()) { + std::get<0>(ops)(cpr, field_type); + } else { + cpr->skip_struct_field(field_type); + } + } +}; + +template +inline void function_builder(CompactProtocolReader* cpr, std::tuple& op) +{ + constexpr int index = std::tuple_size>::value - 1; + int field = 0; + while (true) { + int const current_byte = cpr->getb(); + if (!current_byte) { break; } + int const field_delta = current_byte >> 4; + int const field_type = current_byte & 0xf; + field = field_delta ? field + field_delta : cpr->get_i16(); + FunctionSwitchImpl::run(cpr, field_type, field, op); + } +} + +} // namespace + +/** + * @brief Base class for parquet field functors. + * + * Holds the field value used by all of the specialized functors. + */ +class parquet_field { + private: + int _field_val; + + protected: + parquet_field(int f) : _field_val(f) {} + + public: + virtual ~parquet_field() = default; + [[nodiscard]] int field() const { return _field_val; } +}; + /** * @brief Abstract base class for list functors. */ @@ -494,53 +544,6 @@ void CompactProtocolReader::skip_struct_field(int t, int depth) } } -template -struct FunctionSwitchImpl { - template - static inline void run(CompactProtocolReader* cpr, - int field_type, - int const& field, - std::tuple& ops) - { - if (field == std::get(ops).field()) { - std::get(ops)(cpr, field_type); - } else { - FunctionSwitchImpl::run(cpr, field_type, field, ops); - } - } -}; - -template <> -struct FunctionSwitchImpl<0> { - template - static inline void run(CompactProtocolReader* cpr, - int field_type, - int const& field, - std::tuple& ops) - { - if (field == std::get<0>(ops).field()) { - std::get<0>(ops)(cpr, field_type); - } else { - cpr->skip_struct_field(field_type); - } - } -}; - -template -inline void function_builder(CompactProtocolReader* cpr, std::tuple& op) -{ - constexpr int index = std::tuple_size>::value - 1; - int field = 0; - while (true) { - int const current_byte = cpr->getb(); - if (!current_byte) { break; } - int const field_delta = current_byte >> 4; - int const field_type = current_byte & 0xf; - field = field_delta ? field + field_delta : cpr->get_i16(); - FunctionSwitchImpl::run(cpr, field_type, field, op); - } -} - void CompactProtocolReader::read(FileMetaData* f) { using optional_list_column_order = diff --git a/cpp/src/io/parquet/compact_protocol_writer.cpp b/cpp/src/io/parquet/compact_protocol_writer.cpp index 14c99f728de..bf2db013118 100644 --- a/cpp/src/io/parquet/compact_protocol_writer.cpp +++ b/cpp/src/io/parquet/compact_protocol_writer.cpp @@ -291,7 +291,7 @@ uint32_t CompactProtocolFieldWriter::put_uint(uint64_t v) uint32_t CompactProtocolFieldWriter::put_int(int64_t v) { - int64_t s = (v < 0); + int64_t const s = (v < 0); return put_uint(((v ^ -s) << 1) + s); } diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index cd3dcd2bce4..b0cbabf1c12 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -426,7 +426,7 @@ std::optional>> aggregate_reader_metadata::fi // where min(col[i]) = columns[i*2], max(col[i])=columns[i*2+1] // For each column, it contains #sources * #column_chunks_per_src rows. std::vector> columns; - stats_caster stats_col{total_row_groups, per_file_metadata, input_row_group_indices}; + stats_caster const stats_col{total_row_groups, per_file_metadata, input_row_group_indices}; for (size_t col_idx = 0; col_idx < output_dtypes.size(); col_idx++) { auto const schema_idx = output_column_schemas[col_idx]; auto const& dtype = output_dtypes[col_idx]; @@ -447,7 +447,8 @@ std::optional>> aggregate_reader_metadata::fi auto stats_table = cudf::table(std::move(columns)); // Converts AST to StatsAST with reference to min, max columns in above `stats_table`. - stats_expression_converter stats_expr{filter.get(), static_cast(output_dtypes.size())}; + stats_expression_converter const stats_expr{filter.get(), + static_cast(output_dtypes.size())}; auto stats_ast = stats_expr.get_stats_expr(); auto predicate_col = cudf::detail::compute_column(stats_table, stats_ast.get(), stream, mr); auto predicate = predicate_col->view(); diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index d74ae83b635..c48ff896e33 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -148,7 +148,7 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num CUDF_EXPECTS(input_col.schema_idx == pass.chunks[c].src_col_schema, "Column/page schema index mismatch"); - size_t max_depth = _metadata->get_output_nesting_depth(pass.chunks[c].src_col_schema); + size_t const max_depth = _metadata->get_output_nesting_depth(pass.chunks[c].src_col_schema); chunk_offsets.push_back(chunk_off); // get a slice of size `nesting depth` from `chunk_nested_valids` to store an array of pointers @@ -203,7 +203,7 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num auto& out_buf = (*cols)[input_col.nesting[idx]]; cols = &out_buf.children; - int owning_schema = out_buf.user_data & PARQUET_COLUMN_BUFFER_SCHEMA_MASK; + int const owning_schema = out_buf.user_data & PARQUET_COLUMN_BUFFER_SCHEMA_MASK; if (owning_schema == 0 || owning_schema == input_col.schema_idx) { valids[idx] = out_buf.null_mask(); data[idx] = out_buf.data(); @@ -435,7 +435,7 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num ColumnChunkDesc* col = &pass.chunks[pi->chunk_idx]; input_column_info const& input_col = _input_columns[col->src_col_index]; - int index = pi->nesting_decode - page_nesting_decode.device_ptr(); + int const index = pi->nesting_decode - page_nesting_decode.device_ptr(); PageNestingDecodeInfo* pndi = &page_nesting_decode[index]; auto* cols = &_output_buffers; diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index a6562d33de2..bfd0cc992cf 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -833,7 +833,7 @@ std::optional aggregate_reader_metadata::decode_ipc_message( // Lambda function to read and return 4 bytes as int32_t from the ipc message buffer and update // buffer pointer and size auto read_int32_from_ipc_message = [&]() { - int32_t bytes; + int32_t bytes = 0; std::memcpy(&bytes, message_buf, sizeof(int32_t)); // Offset the message buf and reduce remaining size message_buf += sizeof(int32_t); @@ -991,7 +991,7 @@ std::string aggregate_reader_metadata::get_pandas_index() const // One-liner regex: // "index_columns"\s*:\s*\[\s*((?:"(?:|(?:.*?(?![^\\]")).?)[^\\]?",?\s*)*)\] // Documented below. - std::regex index_columns_expr{ + std::regex const index_columns_expr{ R"("index_columns"\s*:\s*\[\s*)" // match preamble, opening square bracket, whitespace R"(()" // Open first capturing group R"((?:")" // Open non-capturing group match opening quote @@ -1013,12 +1013,12 @@ std::vector aggregate_reader_metadata::get_pandas_index_names() con std::vector names; auto str = get_pandas_index(); if (str.length() != 0) { - std::regex index_name_expr{R"(\"((?:\\.|[^\"])*)\")"}; + std::regex const index_name_expr{R"(\"((?:\\.|[^\"])*)\")"}; std::smatch sm; while (std::regex_search(str, sm, index_name_expr)) { if (sm.size() == 2) { // 2 = whole match, first item if (std::find(names.begin(), names.end(), sm[1].str()) == names.end()) { - std::regex esc_quote{R"(\\")"}; + std::regex const esc_quote{R"(\\")"}; names.emplace_back(std::regex_replace(sm[1].str(), esc_quote, R"(")")); } } @@ -1362,8 +1362,8 @@ aggregate_reader_metadata::select_columns( std::vector all_paths; std::function add_path = [&](std::string path_till_now, int schema_idx) { - auto const& schema_elem = get_schema(schema_idx); - std::string curr_path = path_till_now + schema_elem.name; + auto const& schema_elem = get_schema(schema_idx); + std::string const curr_path = path_till_now + schema_elem.name; all_paths.push_back({curr_path, schema_idx}); for (auto const& child_idx : schema_elem.children_idx) { add_path(curr_path + ".", child_idx); @@ -1376,7 +1376,7 @@ aggregate_reader_metadata::select_columns( // Find which of the selected paths are valid and get their schema index std::vector valid_selected_paths; // vector reference pushback (*use_names). If filter names passed. - std::vector const>> column_names{ + std::vector const>> const column_names{ *use_names, *filter_columns_names}; for (auto const& used_column_names : column_names) { for (auto const& selected_path : used_column_names.get()) { @@ -1408,7 +1408,7 @@ aggregate_reader_metadata::select_columns( std::vector selected_columns; if (include_index) { - std::vector index_names = get_pandas_index_names(); + std::vector const index_names = get_pandas_index_names(); std::transform(index_names.cbegin(), index_names.cend(), std::back_inserter(selected_columns), @@ -1457,7 +1457,7 @@ aggregate_reader_metadata::select_columns( } for (auto& col : selected_columns) { auto const& top_level_col_schema_idx = find_schema_child(root, col.name); - bool valid_column = build_column(&col, top_level_col_schema_idx, output_columns, false); + bool const valid_column = build_column(&col, top_level_col_schema_idx, output_columns, false); if (valid_column) { output_column_schemas.push_back(top_level_col_schema_idx); diff --git a/cpp/src/io/text/bgzip_utils.cpp b/cpp/src/io/text/bgzip_utils.cpp index cb412828e2d..77da2a44c7c 100644 --- a/cpp/src/io/text/bgzip_utils.cpp +++ b/cpp/src/io/text/bgzip_utils.cpp @@ -40,7 +40,7 @@ IntType read_int(char* data) template void write_int(std::ostream& output_stream, T val) { - std::array bytes; + std::array bytes{}; // we assume little-endian std::memcpy(&bytes[0], &val, sizeof(T)); output_stream.write(bytes.data(), bytes.size()); diff --git a/cpp/src/io/utilities/base64_utilities.cpp b/cpp/src/io/utilities/base64_utilities.cpp index 856c29599a7..2a2a07afc8d 100644 --- a/cpp/src/io/utilities/base64_utilities.cpp +++ b/cpp/src/io/utilities/base64_utilities.cpp @@ -86,7 +86,7 @@ std::string base64_encode(std::string_view string_to_encode) num_iterations += (input_length % 3) ? 1 : 0; std::string encoded; - size_t encoded_length = (input_length + 2) / 3 * 4; + size_t const encoded_length = (input_length + 2) / 3 * 4; encoded.reserve(encoded_length); // altered: modify base64 encoder loop using STL and Thrust. @@ -135,7 +135,7 @@ std::string base64_decode(std::string_view encoded_string) return std::string{}; } - size_t input_length = encoded_string.length(); + size_t const input_length = encoded_string.length(); std::string decoded; // altered: compute number of decoding iterations = floor (multiple of 4) @@ -147,7 +147,7 @@ std::string base64_decode(std::string_view encoded_string) // two bytes smaller, depending on the amount of trailing equal signs // in the encoded string. This approximation is needed to reserve // enough space in the string to be returned. - size_t approx_decoded_length = input_length / 4 * 3; + size_t const approx_decoded_length = input_length / 4 * 3; decoded.reserve(approx_decoded_length); // diff --git a/cpp/src/io/utilities/data_sink.cpp b/cpp/src/io/utilities/data_sink.cpp index b37a5ac900a..bed03869b34 100644 --- a/cpp/src/io/utilities/data_sink.cpp +++ b/cpp/src/io/utilities/data_sink.cpp @@ -86,7 +86,7 @@ class file_sink : public data_sink { { if (!supports_device_write()) CUDF_FAIL("Device writes are not supported for this file."); - size_t offset = _bytes_written; + size_t const offset = _bytes_written; _bytes_written += size; if (!_kvikio_file.closed()) { @@ -170,7 +170,7 @@ class void_sink : public data_sink { size_t bytes_written() override { return _bytes_written; } private: - size_t _bytes_written; + size_t _bytes_written{}; }; class user_sink_wrapper : public data_sink { diff --git a/cpp/src/io/utilities/datasource.cpp b/cpp/src/io/utilities/datasource.cpp index 10814eea458..62ef7c7a794 100644 --- a/cpp/src/io/utilities/datasource.cpp +++ b/cpp/src/io/utilities/datasource.cpp @@ -128,7 +128,8 @@ class file_source : public datasource { rmm::cuda_stream_view stream) override { rmm::device_buffer out_data(size, stream); - size_t read = device_read(offset, size, reinterpret_cast(out_data.data()), stream); + size_t const read = + device_read(offset, size, reinterpret_cast(out_data.data()), stream); out_data.resize(read, stream); return datasource::buffer::create(std::move(out_data)); } @@ -444,7 +445,8 @@ class remote_file_source : public datasource { rmm::cuda_stream_view stream) override { rmm::device_buffer out_data(size, stream); - size_t read = device_read(offset, size, reinterpret_cast(out_data.data()), stream); + size_t const read = + device_read(offset, size, reinterpret_cast(out_data.data()), stream); out_data.resize(read, stream); return datasource::buffer::create(std::move(out_data)); } @@ -471,7 +473,7 @@ class remote_file_source : public datasource { static bool is_supported_remote_url(std::string const& url) { // Regular expression to match "s3://" - static std::regex pattern{R"(^s3://)", std::regex_constants::icase}; + static std::regex const pattern{R"(^s3://)", std::regex_constants::icase}; return std::regex_search(url, pattern); } diff --git a/cpp/src/io/utilities/file_io_utilities.cpp b/cpp/src/io/utilities/file_io_utilities.cpp index f9750e4a505..9b17e7f6d55 100644 --- a/cpp/src/io/utilities/file_io_utilities.cpp +++ b/cpp/src/io/utilities/file_io_utilities.cpp @@ -33,6 +33,24 @@ namespace cudf { namespace io { namespace detail { +namespace { + +[[nodiscard]] int open_file_checked(std::string const& filepath, int flags, mode_t mode) +{ + auto const fd = open(filepath.c_str(), flags, mode); + if (fd == -1) { throw_on_file_open_failure(filepath, flags & O_CREAT); } + + return fd; +} + +[[nodiscard]] size_t get_file_size(int file_descriptor) +{ + struct stat st {}; + CUDF_EXPECTS(fstat(file_descriptor, &st) != -1, "Cannot query file size"); + return static_cast(st.st_size); +} + +} // namespace void force_init_cuda_context() { @@ -55,26 +73,11 @@ void force_init_cuda_context() CUDF_EXPECTS(std::filesystem::exists(path), "Cannot open file; it does not exist"); } - std::array error_msg_buffer; + std::array error_msg_buffer{}; auto const error_msg = strerror_r(err, error_msg_buffer.data(), 1024); CUDF_FAIL("Cannot open file; failed with errno: " + std::string{error_msg}); } -[[nodiscard]] int open_file_checked(std::string const& filepath, int flags, mode_t mode) -{ - auto const fd = open(filepath.c_str(), flags, mode); - if (fd == -1) { throw_on_file_open_failure(filepath, flags & O_CREAT); } - - return fd; -} - -[[nodiscard]] size_t get_file_size(int file_descriptor) -{ - struct stat st; - CUDF_EXPECTS(fstat(file_descriptor, &st) != -1, "Cannot query file size"); - return static_cast(st.st_size); -} - file_wrapper::file_wrapper(std::string const& filepath, int flags, mode_t mode) : fd(open_file_checked(filepath.c_str(), flags, mode)), _size{get_file_size(fd)} { @@ -125,7 +128,7 @@ class cufile_shim { void cufile_shim::modify_cufile_json() const { std::string const json_path_env_var = "CUFILE_ENV_PATH_JSON"; - static temp_directory tmp_config_dir{"cudf_cufile_config"}; + static temp_directory const tmp_config_dir{"cudf_cufile_config"}; // Modify the config file based on the policy auto const config_file_path = getenv_or(json_path_env_var, "/etc/cufile.json"); @@ -253,7 +256,7 @@ std::future cufile_input_impl::read_async(size_t offset, uint8_t* dst, rmm::cuda_stream_view stream) { - int device; + int device = 0; CUDF_CUDA_TRY(cudaGetDevice(&device)); auto read_slice = [device, gds_read = shim->read, file_handle = cf_file.handle()]( @@ -285,7 +288,7 @@ cufile_output_impl::cufile_output_impl(std::string const& filepath) std::future cufile_output_impl::write_async(void const* data, size_t offset, size_t size) { - int device; + int device = 0; CUDF_CUDA_TRY(cudaGetDevice(&device)); auto write_slice = [device, gds_write = shim->write, file_handle = cf_file.handle()]( diff --git a/cpp/src/jit/cache.cpp b/cpp/src/jit/cache.cpp index 34a0bdce124..49f92756e43 100644 --- a/cpp/src/jit/cache.cpp +++ b/cpp/src/jit/cache.cpp @@ -22,6 +22,7 @@ namespace cudf { namespace jit { +namespace { // Get the directory in home to use for storing the cache std::filesystem::path get_user_home_cache_dir() @@ -72,13 +73,13 @@ std::filesystem::path get_cache_dir() // Make per device cache based on compute capability. This is to avoid multiple devices of // different compute capability to access the same kernel cache. - int device; - int cc_major; - int cc_minor; + int device = 0; + int cc_major = 0; + int cc_minor = 0; CUDF_CUDA_TRY(cudaGetDevice(&device)); CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device)); CUDF_CUDA_TRY(cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device)); - int cc = cc_major * 10 + cc_minor; + int const cc = cc_major * 10 + cc_minor; kernel_cache_path /= std::to_string(cc); @@ -107,13 +108,14 @@ std::size_t try_parse_numeric_env_var(char const* const env_name, std::size_t de auto const value = std::getenv(env_name); return value != nullptr ? std::stoull(value) : default_val; } +} // namespace jitify2::ProgramCache<>& get_program_cache(jitify2::PreprocessedProgramData preprog) { static std::mutex caches_mutex{}; static std::unordered_map>> caches{}; - std::lock_guard caches_lock(caches_mutex); + std::lock_guard const caches_lock(caches_mutex); auto existing_cache = caches.find(preprog.name()); diff --git a/cpp/src/jit/parser.cpp b/cpp/src/jit/parser.cpp index 519ac2d1a2e..c79ba4347bf 100644 --- a/cpp/src/jit/parser.cpp +++ b/cpp/src/jit/parser.cpp @@ -26,10 +26,37 @@ namespace cudf { namespace jit { -constexpr char percent_escape[] = "_"; // NOLINT +namespace { inline bool is_white(char const c) { return c == ' ' || c == '\n' || c == '\r' || c == '\t'; } +std::string remove_comments(std::string const& src) +{ + std::string output; + auto f = src.cbegin(); + while (f < src.cend()) { + auto l = std::find(f, src.cend(), '/'); + output.append(f, l); // push chunk instead of 1 char at a time + f = std::next(l); // skip over '/' + if (l < src.cend()) { + char const n = f < src.cend() ? *f : '?'; + if (n == '/') { // found "//" + f = std::find(f, src.cend(), '\n'); // skip to end of line + } else if (n == '*') { // found "/*" + auto term = std::string("*/"); // skip to end of next "*/" + f = std::search(std::next(f), src.cend(), term.cbegin(), term.cend()) + term.size(); + } else { + output.push_back('/'); // lone '/' should be pushed into output + } + } + } + return output; +} + +} // namespace + +constexpr char percent_escape[] = "_"; // NOLINT + std::string ptx_parser::escape_percent(std::string const& src) { // b/c we're transforming into inline ptx we aren't allowed to have register names starting with % @@ -106,7 +133,7 @@ std::string ptx_parser::parse_instruction(std::string const& src) std::string output; std::string suffix; - std::string original_code = "\n /** " + src + " */\n"; + std::string const original_code = "\n /** " + src + " */\n"; int piece_count = 0; @@ -316,33 +343,10 @@ std::string ptx_parser::parse_function_header(std::string const& src) return "\n__device__ __inline__ void " + function_name + "(" + input_arg + "){" + "\n"; } -std::string remove_comments(std::string const& src) -{ - std::string output; - auto f = src.cbegin(); - while (f < src.cend()) { - auto l = std::find(f, src.cend(), '/'); - output.append(f, l); // push chunk instead of 1 char at a time - f = std::next(l); // skip over '/' - if (l < src.cend()) { - char n = f < src.cend() ? *f : '?'; - if (n == '/') { // found "//" - f = std::find(f, src.cend(), '\n'); // skip to end of line - } else if (n == '*') { // found "/*" - auto term = std::string("*/"); // skip to end of next "*/" - f = std::search(std::next(f), src.cend(), term.cbegin(), term.cend()) + term.size(); - } else { - output.push_back('/'); // lone '/' should be pushed into output - } - } - } - return output; -} - // The interface std::string ptx_parser::parse() { - std::string no_comments = remove_comments(ptx); + std::string const no_comments = remove_comments(ptx); input_arg_list.clear(); auto const _func = std::string(".func"); // Go directly to the .func mark diff --git a/cpp/src/quantiles/tdigest/tdigest_column_view.cpp b/cpp/src/quantiles/tdigest/tdigest_column_view.cpp index 17844b6bb0a..933ef1bfcbd 100644 --- a/cpp/src/quantiles/tdigest/tdigest_column_view.cpp +++ b/cpp/src/quantiles/tdigest/tdigest_column_view.cpp @@ -29,14 +29,14 @@ tdigest_column_view::tdigest_column_view(column_view const& col) : column_view(c CUDF_EXPECTS(col.offset() == 0, "Encountered a sliced tdigest column"); CUDF_EXPECTS(not col.nullable(), "Encountered nullable tdigest column"); - structs_column_view scv(col); + structs_column_view const scv(col); CUDF_EXPECTS(scv.num_children() == 3, "Encountered invalid tdigest column"); CUDF_EXPECTS(scv.child(min_column_index).type().id() == type_id::FLOAT64, "Encountered invalid tdigest column"); CUDF_EXPECTS(scv.child(max_column_index).type().id() == type_id::FLOAT64, "Encountered invalid tdigest column"); - lists_column_view lcv(scv.child(centroid_column_index)); + lists_column_view const lcv(scv.child(centroid_column_index)); auto data = lcv.child(); CUDF_EXPECTS(data.type().id() == type_id::STRUCT, "Encountered invalid tdigest column"); CUDF_EXPECTS(data.num_children() == 2, @@ -52,14 +52,14 @@ lists_column_view tdigest_column_view::centroids() const { return child(centroid column_view tdigest_column_view::means() const { auto c = centroids(); - structs_column_view inner(c.parent().child(lists_column_view::child_column_index)); + structs_column_view const inner(c.parent().child(lists_column_view::child_column_index)); return inner.child(mean_column_index); } column_view tdigest_column_view::weights() const { auto c = centroids(); - structs_column_view inner(c.parent().child(lists_column_view::child_column_index)); + structs_column_view const inner(c.parent().child(lists_column_view::child_column_index)); return inner.child(weight_column_index); } diff --git a/cpp/src/reductions/scan/scan.cpp b/cpp/src/reductions/scan/scan.cpp index b91ae19b51a..7afd3ba3c00 100644 --- a/cpp/src/reductions/scan/scan.cpp +++ b/cpp/src/reductions/scan/scan.cpp @@ -20,8 +20,8 @@ #include namespace cudf { - namespace detail { +namespace { std::unique_ptr scan(column_view const& input, scan_aggregation const& agg, scan_type inclusive, @@ -50,6 +50,7 @@ std::unique_ptr scan(column_view const& input, : detail::scan_inclusive(input, agg, null_handling, stream, mr); } +} // namespace } // namespace detail std::unique_ptr scan(column_view const& input, diff --git a/cpp/src/reductions/segmented/reductions.cpp b/cpp/src/reductions/segmented/reductions.cpp index c4f6c135dde..dedfc4b0734 100644 --- a/cpp/src/reductions/segmented/reductions.cpp +++ b/cpp/src/reductions/segmented/reductions.cpp @@ -26,6 +26,8 @@ namespace cudf { namespace reduction { namespace detail { +namespace { + struct segmented_reduce_dispatch_functor { column_view const& col; device_span offsets; @@ -126,6 +128,7 @@ std::unique_ptr segmented_reduce(column_view const& segmented_values, segmented_values, offsets, output_dtype, null_handling, init, stream, mr}, agg); } +} // namespace } // namespace detail } // namespace reduction diff --git a/cpp/src/rolling/detail/optimized_unbounded_window.cpp b/cpp/src/rolling/detail/optimized_unbounded_window.cpp index 7cad31c0658..9c22c27144d 100644 --- a/cpp/src/rolling/detail/optimized_unbounded_window.cpp +++ b/cpp/src/rolling/detail/optimized_unbounded_window.cpp @@ -25,32 +25,7 @@ #include namespace cudf::detail { - -bool can_optimize_unbounded_window(bool unbounded_preceding, - bool unbounded_following, - size_type min_periods, - rolling_aggregation const& agg) -{ - auto is_supported = [](auto const& agg) { - switch (agg.kind) { - case cudf::aggregation::Kind::COUNT_ALL: [[fallthrough]]; - case cudf::aggregation::Kind::COUNT_VALID: [[fallthrough]]; - case cudf::aggregation::Kind::SUM: [[fallthrough]]; - case cudf::aggregation::Kind::MIN: [[fallthrough]]; - case cudf::aggregation::Kind::MAX: return true; - default: - // COLLECT_LIST and COLLECT_SET can be added at a later date. - // Other aggregations do not fit into the [UNBOUNDED, UNBOUNDED] - // category. For instance: - // 1. Ranking functions (ROW_NUMBER, RANK, DENSE_RANK, PERCENT_RANK) - // use [UNBOUNDED PRECEDING, CURRENT ROW]. - // 2. LEAD/LAG are defined on finite row boundaries. - return false; - } - }; - - return unbounded_preceding && unbounded_following && (min_periods == 1) && is_supported(agg); -} +namespace { /// Converts rolling_aggregation to corresponding reduce/groupby_aggregation. template @@ -145,6 +120,33 @@ std::unique_ptr reduction_based_rolling_window(column_view const& input, // Blow up results into separate column. return cudf::make_column_from_scalar(*reduce_results, input.size(), stream, mr); } +} // namespace + +bool can_optimize_unbounded_window(bool unbounded_preceding, + bool unbounded_following, + size_type min_periods, + rolling_aggregation const& agg) +{ + auto is_supported = [](auto const& agg) { + switch (agg.kind) { + case cudf::aggregation::Kind::COUNT_ALL: [[fallthrough]]; + case cudf::aggregation::Kind::COUNT_VALID: [[fallthrough]]; + case cudf::aggregation::Kind::SUM: [[fallthrough]]; + case cudf::aggregation::Kind::MIN: [[fallthrough]]; + case cudf::aggregation::Kind::MAX: return true; + default: + // COLLECT_LIST and COLLECT_SET can be added at a later date. + // Other aggregations do not fit into the [UNBOUNDED, UNBOUNDED] + // category. For instance: + // 1. Ranking functions (ROW_NUMBER, RANK, DENSE_RANK, PERCENT_RANK) + // use [UNBOUNDED PRECEDING, CURRENT ROW]. + // 2. LEAD/LAG are defined on finite row boundaries. + return false; + } + }; + + return unbounded_preceding && unbounded_following && (min_periods == 1) && is_supported(agg); +} std::unique_ptr optimized_unbounded_window(table_view const& group_keys, column_view const& input, diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index b923a301f84..b7b1338dd89 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -80,8 +80,8 @@ std::array const escapable_chars{ */ std::vector string_to_char32_vector(std::string_view pattern) { - auto size = static_cast(pattern.size()); - size_type count = std::count_if(pattern.cbegin(), pattern.cend(), [](char ch) { + auto size = static_cast(pattern.size()); + size_type const count = std::count_if(pattern.cbegin(), pattern.cend(), [](char ch) { return is_begin_utf8_char(static_cast(ch)); }); std::vector result(count + 1); @@ -89,7 +89,7 @@ std::vector string_to_char32_vector(std::string_view pattern) char const* input_ptr = pattern.data(); for (size_type idx = 0; idx < size; ++idx) { char_utf8 output_character = 0; - size_type ch_width = to_char_utf8(input_ptr, output_character); + size_type const ch_width = to_char_utf8(input_ptr, output_character); input_ptr += ch_width; idx += ch_width - 1; *output_ptr++ = output_character; @@ -102,7 +102,7 @@ std::vector string_to_char32_vector(std::string_view pattern) int32_t reprog::add_inst(int32_t t) { - reinst inst; + reinst inst{}; inst.type = t; inst.u2.left_id = 0; inst.u1.right_id = 0; @@ -968,7 +968,7 @@ class regex_compiler { } if (token != RBRA) { push_operator(token, subid); } - static std::vector tokens{STAR, STAR_LAZY, QUEST, QUEST_LAZY, PLUS, PLUS_LAZY, RBRA}; + static std::vector const tokens{STAR, STAR_LAZY, QUEST, QUEST_LAZY, PLUS, PLUS_LAZY, RBRA}; _last_was_and = std::any_of(tokens.cbegin(), tokens.cend(), [token](auto t) { return t == token; }); } @@ -1046,7 +1046,7 @@ reprog reprog::create_from(std::string_view pattern, { reprog rtn; auto pattern32 = string_to_char32_vector(pattern); - regex_compiler compiler(pattern32.data(), flags, capture, rtn); + regex_compiler const compiler(pattern32.data(), flags, capture, rtn); // for debugging, it can be helpful to call rtn.print(flags) here to dump // out the instructions that have been created from the given pattern return rtn; @@ -1114,7 +1114,7 @@ void reprog::build_start_ids() std::stack ids; ids.push(_startinst_id); while (!ids.empty()) { - int id = ids.top(); + int const id = ids.top(); ids.pop(); reinst const& inst = _insts[id]; if (inst.type == OR) { diff --git a/cpp/src/strings/regex/regexec.cpp b/cpp/src/strings/regex/regexec.cpp index 60ad714dfec..3d11b641b3f 100644 --- a/cpp/src/strings/regex/regexec.cpp +++ b/cpp/src/strings/regex/regexec.cpp @@ -99,9 +99,9 @@ std::unique_ptr> reprog_devic // place each class and append the variable length data for (int32_t idx = 0; idx < classes_count; ++idx) { auto const& h_class = h_prog.class_at(idx); - reclass_device d_class{h_class.builtins, - static_cast(h_class.literals.size()), - reinterpret_cast(d_end)}; + reclass_device const d_class{h_class.builtins, + static_cast(h_class.literals.size()), + reinterpret_cast(d_end)}; *classes++ = d_class; memcpy(h_end, h_class.literals.data(), h_class.literals.size() * sizeof(reclass_range)); h_end += h_class.literals.size() * sizeof(reclass_range); diff --git a/cpp/src/structs/utilities.cpp b/cpp/src/structs/utilities.cpp index 4012ee3d21c..22328726c0e 100644 --- a/cpp/src/structs/utilities.cpp +++ b/cpp/src/structs/utilities.cpp @@ -47,7 +47,7 @@ std::vector> extract_ordered_struct_children( std::vector children; children.reserve(num_cols); for (size_type col_index = 0; col_index < num_cols; col_index++) { - structs_column_view scv(struct_cols[col_index]); + structs_column_view const scv(struct_cols[col_index]); // all inputs must have the same # of children and they must all be of the // same type. diff --git a/cpp/src/table/table_view.cpp b/cpp/src/table/table_view.cpp index 659beb749af..ee7136d8f5e 100644 --- a/cpp/src/table/table_view.cpp +++ b/cpp/src/table/table_view.cpp @@ -25,6 +25,21 @@ namespace cudf { namespace detail { +namespace { + +template +auto concatenate_column_views(std::vector const& views) +{ + using ColumnView = typename ViewType::ColumnView; + std::vector concat_cols; + for (auto& view : views) { + concat_cols.insert(concat_cols.end(), view.begin(), view.end()); + } + return concat_cols; +} + +} // namespace + template table_view_base::table_view_base(std::vector const& cols) : _columns{cols} { @@ -38,17 +53,6 @@ table_view_base::table_view_base(std::vector const& cols } } -template -auto concatenate_column_views(std::vector const& views) -{ - using ColumnView = typename ViewType::ColumnView; - std::vector concat_cols; - for (auto& view : views) { - concat_cols.insert(concat_cols.end(), view.begin(), view.end()); - } - return concat_cols; -} - // Explicit instantiation for a table of `column_view`s template class table_view_base; @@ -65,17 +69,16 @@ table_view table_view::select(std::vector const& column_indices) cons // Convert mutable view to immutable view mutable_table_view::operator table_view() { - std::vector cols{begin(), end()}; - return table_view{cols}; + return table_view{std::vector{begin(), end()}}; } table_view::table_view(std::vector const& views) - : table_view{concatenate_column_views(views)} + : table_view{detail::concatenate_column_views(views)} { } mutable_table_view::mutable_table_view(std::vector const& views) - : mutable_table_view{concatenate_column_views(views)} + : mutable_table_view{detail::concatenate_column_views(views)} { } diff --git a/cpp/src/transform/transform.cpp b/cpp/src/transform/transform.cpp index b919ac16956..4a383bfba47 100644 --- a/cpp/src/transform/transform.cpp +++ b/cpp/src/transform/transform.cpp @@ -33,7 +33,7 @@ namespace cudf { namespace transformation { namespace jit { - +namespace { void unary_operation(mutable_column_view output, column_view input, std::string const& udf, @@ -41,7 +41,7 @@ void unary_operation(mutable_column_view output, bool is_ptx, rmm::cuda_stream_view stream) { - std::string kernel_name = + std::string const kernel_name = jitify2::reflection::Template("cudf::transformation::jit::kernel") // .instantiate(cudf::type_to_name(output.type()), // list of template arguments cudf::type_to_name(input.type())); @@ -62,6 +62,7 @@ void unary_operation(mutable_column_view output, cudf::jit::get_data_ptr(output), cudf::jit::get_data_ptr(input)); } +} // namespace } // namespace jit } // namespace transformation @@ -81,7 +82,7 @@ std::unique_ptr transform(column_view const& input, if (input.is_empty()) { return output; } - mutable_column_view output_view = *output; + mutable_column_view const output_view = *output; // transform transformation::jit::unary_operation(output_view, input, unary_udf, output_type, is_ptx, stream); diff --git a/cpp/src/utilities/prefetch.cpp b/cpp/src/utilities/prefetch.cpp index 000526723c4..6c9f677afb3 100644 --- a/cpp/src/utilities/prefetch.cpp +++ b/cpp/src/utilities/prefetch.cpp @@ -33,14 +33,14 @@ prefetch_config& prefetch_config::instance() bool prefetch_config::get(std::string_view key) { - std::shared_lock lock(config_mtx); + std::shared_lock const lock(config_mtx); auto const it = config_values.find(key.data()); return it == config_values.end() ? false : it->second; // default to not prefetching } void prefetch_config::set(std::string_view key, bool value) { - std::lock_guard lock(config_mtx); + std::lock_guard const lock(config_mtx); config_values[key.data()] = value; } diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 7069b59be26..9d1bebd1937 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -55,6 +55,63 @@ std::size_t constexpr STREAM_POOL_SIZE = 32; } while (0) #endif +/** + * @brief RAII struct to wrap a cuda event and ensure its proper destruction. + */ +struct cuda_event { + cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); } + virtual ~cuda_event() { CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e_)); } + + // Moveable but not copyable. + cuda_event(const cuda_event&) = delete; + cuda_event& operator=(const cuda_event&) = delete; + + cuda_event(cuda_event&&) = default; + cuda_event& operator=(cuda_event&&) = default; + + operator cudaEvent_t() { return e_; } + + private: + cudaEvent_t e_{}; +}; + +namespace { + +// FIXME: these will be available in rmm soon +inline int get_num_cuda_devices() +{ + rmm::cuda_device_id::value_type num_dev{}; + CUDF_CUDA_TRY(cudaGetDeviceCount(&num_dev)); + return num_dev; +} + +rmm::cuda_device_id get_current_cuda_device() +{ + int device_id = 0; + CUDF_CUDA_TRY(cudaGetDevice(&device_id)); + return rmm::cuda_device_id{device_id}; +} + +/** + * @brief Returns a cudaEvent_t for the current thread. + * + * The returned event is valid for the current device. + * + * @return A cudaEvent_t unique to the current thread and valid on the current device. + */ +cudaEvent_t event_for_thread() +{ + // The program may crash if this function is called from the main thread and user application + // subsequently calls cudaDeviceReset(). + // As a workaround, here we intentionally disable RAII and leak cudaEvent_t. + thread_local static std::vector thread_events(get_num_cuda_devices()); + auto const device_id = get_current_cuda_device(); + if (not thread_events[device_id.value()]) { thread_events[device_id.value()] = new cuda_event(); } + return *thread_events[device_id.value()]; +} + +} // namespace + /** * @brief Implementation of `cuda_stream_pool` that wraps an `rmm::cuda_stram_pool`. */ @@ -109,59 +166,6 @@ cuda_stream_pool* create_global_cuda_stream_pool() return new rmm_cuda_stream_pool(); } -// FIXME: these will be available in rmm soon -inline int get_num_cuda_devices() -{ - rmm::cuda_device_id::value_type num_dev{}; - CUDF_CUDA_TRY(cudaGetDeviceCount(&num_dev)); - return num_dev; -} - -rmm::cuda_device_id get_current_cuda_device() -{ - int device_id; - CUDF_CUDA_TRY(cudaGetDevice(&device_id)); - return rmm::cuda_device_id{device_id}; -} - -/** - * @brief RAII struct to wrap a cuda event and ensure its proper destruction. - */ -struct cuda_event { - cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); } - virtual ~cuda_event() { CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e_)); } - - // Moveable but not copyable. - cuda_event(const cuda_event&) = delete; - cuda_event& operator=(const cuda_event&) = delete; - - cuda_event(cuda_event&&) = default; - cuda_event& operator=(cuda_event&&) = default; - - operator cudaEvent_t() { return e_; } - - private: - cudaEvent_t e_; -}; - -/** - * @brief Returns a cudaEvent_t for the current thread. - * - * The returned event is valid for the current device. - * - * @return A cudaEvent_t unique to the current thread and valid on the current device. - */ -cudaEvent_t event_for_thread() -{ - // The program may crash if this function is called from the main thread and user application - // subsequently calls cudaDeviceReset(). - // As a workaround, here we intentionally disable RAII and leak cudaEvent_t. - thread_local std::vector thread_events(get_num_cuda_devices()); - auto const device_id = get_current_cuda_device(); - if (not thread_events[device_id.value()]) { thread_events[device_id.value()] = new cuda_event(); } - return *thread_events[device_id.value()]; -} - /** * @brief Returns a reference to the global stream pool for the current device. * @return `cuda_stream_pool` valid on the current device. @@ -174,7 +178,7 @@ cuda_stream_pool& global_cuda_stream_pool() static std::mutex mutex; auto const device_id = get_current_cuda_device(); - std::lock_guard lock(mutex); + std::lock_guard const lock(mutex); if (pools[device_id.value()] == nullptr) { pools[device_id.value()] = create_global_cuda_stream_pool(); } From beb42960a7fbf2b0c1da17c943bb66050539b39c Mon Sep 17 00:00:00 2001 From: Vukasin Milovanovic Date: Tue, 3 Dec 2024 10:05:24 -0800 Subject: [PATCH 5/5] Workaround for a misaligned access in `read_csv` on some CUDA versions (#17477) Use a global array instead of a shared memory array in the `gather_row_offsets_gpu` kernel. Impact on the kernel performance is less than 5%, and this kernel takes very little portion of the total read_csv execution time - impact on the performance is negligible. Also modified functions that take this array to take a `device_span` instead on a plain pointer. Authors: - Vukasin Milovanovic (https://github.com/vuule) Approvers: - Bradley Dice (https://github.com/bdice) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/17477 --- cpp/src/io/csv/csv_gpu.cu | 40 +++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/cpp/src/io/csv/csv_gpu.cu b/cpp/src/io/csv/csv_gpu.cu index 273e82edf8b..e2bc75d4bab 100644 --- a/cpp/src/io/csv/csv_gpu.cu +++ b/cpp/src/io/csv/csv_gpu.cu @@ -495,7 +495,7 @@ inline __device__ uint32_t select_rowmap(uint4 ctx_map, uint32_t ctxid) * @param t thread id (leaf node id) */ template -inline __device__ void ctx_merge(uint64_t* ctxtree, packed_rowctx_t* ctxb, uint32_t t) +inline __device__ void ctx_merge(device_span ctxtree, packed_rowctx_t* ctxb, uint32_t t) { uint64_t tmp = shuffle_xor(*ctxb, lanemask); if (!(t & tmask)) { @@ -518,7 +518,7 @@ inline __device__ void ctx_merge(uint64_t* ctxtree, packed_rowctx_t* ctxb, uint3 */ template inline __device__ void ctx_unmerge( - uint32_t base, uint64_t* ctxtree, uint32_t* ctx, uint32_t* brow4, uint32_t t) + uint32_t base, device_span ctxtree, uint32_t* ctx, uint32_t* brow4, uint32_t t) { rowctx32_t ctxb_left, ctxb_right, ctxb_sum; ctxb_sum = get_row_context(ctxtree[base], *ctx); @@ -550,7 +550,7 @@ inline __device__ void ctx_unmerge( * @param[in] ctxb packed row context for the current character block * @param t thread id (leaf node id) */ -static inline __device__ void rowctx_merge_transform(uint64_t ctxtree[1024], +static inline __device__ void rowctx_merge_transform(device_span ctxtree, packed_rowctx_t ctxb, uint32_t t) { @@ -584,8 +584,8 @@ static inline __device__ void rowctx_merge_transform(uint64_t ctxtree[1024], * * @return Final row context and count (row_position*4 + context_id format) */ -static inline __device__ rowctx32_t rowctx_inverse_merge_transform(uint64_t ctxtree[1024], - uint32_t t) +static inline __device__ rowctx32_t +rowctx_inverse_merge_transform(device_span ctxtree, uint32_t t) { uint32_t ctx = ctxtree[0] & 3; // Starting input context rowctx32_t brow4 = 0; // output row in block *4 @@ -603,6 +603,8 @@ static inline __device__ rowctx32_t rowctx_inverse_merge_transform(uint64_t ctxt return brow4 + ctx; } +constexpr auto bk_ctxtree_size = rowofs_block_dim * 2; + /** * @brief Gather row offsets from CSV character data split into 16KB chunks * @@ -634,6 +636,7 @@ static inline __device__ rowctx32_t rowctx_inverse_merge_transform(uint64_t ctxt */ CUDF_KERNEL void __launch_bounds__(rowofs_block_dim) gather_row_offsets_gpu(uint64_t* row_ctx, + device_span ctxtree, device_span offsets_out, device_span const data, size_t chunk_size, @@ -649,12 +652,8 @@ CUDF_KERNEL void __launch_bounds__(rowofs_block_dim) int escapechar, int commentchar) { - auto start = data.begin(); - using block_reduce = typename cub::BlockReduce; - __shared__ union { - typename block_reduce::TempStorage bk_storage; - __align__(8) uint64_t ctxtree[rowofs_block_dim * 2]; - } temp_storage; + auto start = data.begin(); + auto const bk_ctxtree = ctxtree.subspan(blockIdx.x * bk_ctxtree_size, bk_ctxtree_size); char const* end = start + (min(parse_pos + chunk_size, data_size) - start_offset); uint32_t t = threadIdx.x; @@ -723,16 +722,16 @@ CUDF_KERNEL void __launch_bounds__(rowofs_block_dim) // Convert the long-form {rowmap,outctx}[inctx] version into packed version // {rowcount,ouctx}[inctx], then merge the row contexts of the 32-character blocks into // a single 16K-character block context - rowctx_merge_transform(temp_storage.ctxtree, pack_rowmaps(ctx_map), t); + rowctx_merge_transform(bk_ctxtree, pack_rowmaps(ctx_map), t); // If this is the second phase, get the block's initial parser state and row counter if (offsets_out.data()) { - if (t == 0) { temp_storage.ctxtree[0] = row_ctx[blockIdx.x]; } + if (t == 0) { bk_ctxtree[0] = row_ctx[blockIdx.x]; } __syncthreads(); // Walk back the transform tree with the known initial parser state - rowctx32_t ctx = rowctx_inverse_merge_transform(temp_storage.ctxtree, t); - uint64_t row = (temp_storage.ctxtree[0] >> 2) + (ctx >> 2); + rowctx32_t ctx = rowctx_inverse_merge_transform(bk_ctxtree, t); + uint64_t row = (bk_ctxtree[0] >> 2) + (ctx >> 2); uint32_t rows_out_of_range = 0; uint32_t rowmap = select_rowmap(ctx_map, ctx & 3); // Output row positions @@ -749,11 +748,14 @@ CUDF_KERNEL void __launch_bounds__(rowofs_block_dim) } __syncthreads(); // Return the number of rows out of range - rows_out_of_range = block_reduce(temp_storage.bk_storage).Sum(rows_out_of_range); + + using block_reduce = typename cub::BlockReduce; + __shared__ typename block_reduce::TempStorage bk_storage; + rows_out_of_range = block_reduce(bk_storage).Sum(rows_out_of_range); if (t == 0) { row_ctx[blockIdx.x] = rows_out_of_range; } } else { // Just store the row counts and output contexts - if (t == 0) { row_ctx[blockIdx.x] = temp_storage.ctxtree[1]; } + if (t == 0) { row_ctx[blockIdx.x] = bk_ctxtree[1]; } } } @@ -829,7 +831,7 @@ void decode_row_column_data(cudf::io::parse_options_view const& options, // Calculate actual block count to use based on records count auto const block_size = csvparse_block_dim; auto const num_rows = row_offsets.size() - 1; - auto const grid_size = (num_rows + block_size - 1) / block_size; + auto const grid_size = cudf::util::div_rounding_up_safe(num_rows, block_size); convert_csv_to_cudf<<>>( options, data, column_flags, row_offsets, dtypes, columns, valids, valid_counts); @@ -849,9 +851,11 @@ uint32_t __host__ gather_row_offsets(parse_options_view const& options, rmm::cuda_stream_view stream) { uint32_t dim_grid = 1 + (chunk_size / rowofs_block_bytes); + auto ctxtree = rmm::device_uvector(dim_grid * bk_ctxtree_size, stream); gather_row_offsets_gpu<<>>( row_ctx, + ctxtree, offsets_out, data, chunk_size,