From 58b7dc9f186c1860d4f9df80188bf21214381b1b Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:01:25 -1000 Subject: [PATCH 1/6] interpolate returns new column if no values are interpolated (#16158) While cleaning up the `interpolate` implementation, I noticed that a interpolation no-op did not return a new column. Authors: - Matthew Roeschke (https://github.com/mroeschke) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/16158 --- python/cudf/cudf/core/algorithms.py | 61 ++++++++-------------- python/cudf/cudf/core/indexed_frame.py | 14 +++-- python/cudf/cudf/core/multiindex.py | 4 +- python/cudf/cudf/tests/test_interpolate.py | 6 +++ 4 files changed, 39 insertions(+), 46 deletions(-) diff --git a/python/cudf/cudf/core/algorithms.py b/python/cudf/cudf/core/algorithms.py index e8b82ff60c2..6c69fbd2637 100644 --- a/python/cudf/cudf/core/algorithms.py +++ b/python/cudf/cudf/core/algorithms.py @@ -1,17 +1,22 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING import cupy as cp import numpy as np from cudf.core.column import as_column -from cudf.core.copy_types import BooleanMask from cudf.core.index import RangeIndex, ensure_index -from cudf.core.indexed_frame import IndexedFrame from cudf.core.scalar import Scalar from cudf.options import get_option from cudf.utils.dtypes import can_convert_to_column +if TYPE_CHECKING: + from cudf.core.column.column import ColumnBase + from cudf.core.index import BaseIndex + def factorize(values, sort=False, use_na_sentinel=True, size_hint=None): """Encode the input values as integer labels @@ -110,55 +115,31 @@ def factorize(values, sort=False, use_na_sentinel=True, size_hint=None): return labels, cats.values if return_cupy_array else ensure_index(cats) -def _linear_interpolation(column, index=None): - """ - Interpolate over a float column. Implicitly assumes that values are - evenly spaced with respect to the x-axis, for example the data - [1.0, NaN, 3.0] will be interpolated assuming the NaN is half way - between the two valid values, yielding [1.0, 2.0, 3.0] - """ - - index = RangeIndex(start=0, stop=len(column), step=1) - return _index_or_values_interpolation(column, index=index) - - -def _index_or_values_interpolation(column, index=None): +def _interpolation(column: ColumnBase, index: BaseIndex) -> ColumnBase: """ Interpolate over a float column. assumes a linear interpolation strategy using the index of the data to denote spacing of the x values. For example the data and index [1.0, NaN, 4.0], [1, 3, 4] - would result in [1.0, 3.0, 4.0] + would result in [1.0, 3.0, 4.0]. """ # figure out where the nans are - mask = cp.isnan(column) + mask = column.isnull() # trivial cases, all nan or no nans - num_nan = mask.sum() - if num_nan == 0 or num_nan == len(column): - return column + if not mask.any() or mask.all(): + return column.copy() - to_interp = IndexedFrame(data={None: column}, index=index) - known_x_and_y = to_interp._apply_boolean_mask( - BooleanMask(~mask, len(to_interp)) - ) - - known_x = known_x_and_y.index.to_cupy() - known_y = known_x_and_y._data.columns[0].values + valid_locs = ~mask + if isinstance(index, RangeIndex): + # Each point is evenly spaced, index values don't matter + known_x = cp.flatnonzero(valid_locs.values) + else: + known_x = index._column.apply_boolean_mask(valid_locs).values # type: ignore[attr-defined] + known_y = column.apply_boolean_mask(valid_locs).values result = cp.interp(index.to_cupy(), known_x, known_y) # find the first nan - first_nan_idx = (mask == 0).argmax().item() + first_nan_idx = valid_locs.values.argmax().item() result[:first_nan_idx] = np.nan - return result - - -def get_column_interpolator(method): - interpolator = { - "linear": _linear_interpolation, - "index": _index_or_values_interpolation, - "values": _index_or_values_interpolation, - }.get(method, None) - if not interpolator: - raise ValueError(f"Interpolation method `{method}` not found") - return interpolator + return as_column(result) diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index ff10051c52d..63fa96d0db0 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -26,6 +26,8 @@ import cudf import cudf._lib as libcudf +import cudf.core +import cudf.core.algorithms from cudf.api.extensions import no_default from cudf.api.types import ( _is_non_decimal_numeric_dtype, @@ -1987,6 +1989,8 @@ def interpolate( "Use obj.ffill() or obj.bfill() instead.", FutureWarning, ) + elif method not in {"linear", "values", "index"}: + raise ValueError(f"Interpolation method `{method}` not found") data = self @@ -2000,7 +2004,10 @@ def interpolate( ) ) - interpolator = cudf.core.algorithms.get_column_interpolator(method) + if method == "linear": + interp_index = RangeIndex(self._num_rows) + else: + interp_index = data.index columns = [] for col in data._columns: if isinstance(col, cudf.core.column.StringColumn): @@ -2012,8 +2019,9 @@ def interpolate( if col.nullable: col = col.astype("float64").fillna(np.nan) - # Interpolation methods may or may not need the index - columns.append(interpolator(col, index=data.index)) + columns.append( + cudf.core.algorithms._interpolation(col, index=interp_index) + ) result = self._from_data_like_self( self._data._from_columns_like_self(columns) diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index 9cbe863142b..dbbd1eab6c8 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -23,6 +23,7 @@ from cudf.api.types import is_integer, is_list_like, is_object_dtype from cudf.core import column from cudf.core._base_index import _return_get_indexer_result +from cudf.core.algorithms import factorize from cudf.core.column_accessor import ColumnAccessor from cudf.core.frame import Frame from cudf.core.index import ( @@ -1373,9 +1374,6 @@ def from_arrays( (2, 'blue')], names=['number', 'color']) """ - # Imported here due to circular import - from cudf.core.algorithms import factorize - error_msg = "Input must be a list / sequence of array-likes." if not is_list_like(arrays): raise TypeError(error_msg) diff --git a/python/cudf/cudf/tests/test_interpolate.py b/python/cudf/cudf/tests/test_interpolate.py index 4a0dc331e1a..a4f0b9fc97e 100644 --- a/python/cudf/cudf/tests/test_interpolate.py +++ b/python/cudf/cudf/tests/test_interpolate.py @@ -135,3 +135,9 @@ def test_interpolate_dataframe_error_cases(data, kwargs): lfunc_args_and_kwargs=([], kwargs), rfunc_args_and_kwargs=([], kwargs), ) + + +def test_interpolate_noop_new_column(): + ser = cudf.Series([1.0, 2.0, 3.0]) + result = ser.interpolate() + assert ser._column is not result._column From cf88f8e045b279cbe5caa2e19ffadc7c6400aa58 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:04:51 -1000 Subject: [PATCH 2/6] Defer copying in Column.astype(copy=True) (#16095) Avoids: 1. Copying `self` when the `astype` would already produce a new column with its own data 2. Copying `self` when the `astype` would raise an Exception Also cleans up some `as_categorical_column` logic. Authors: - Matthew Roeschke (https://github.com/mroeschke) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/16095 --- python/cudf/cudf/core/column/categorical.py | 20 ++--- python/cudf/cudf/core/column/column.py | 91 ++++++++++----------- 2 files changed, 51 insertions(+), 60 deletions(-) diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index 231af30c06d..cec7d5e6663 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -1113,24 +1113,18 @@ def is_monotonic_decreasing(self) -> bool: def as_categorical_column(self, dtype: Dtype) -> CategoricalColumn: if isinstance(dtype, str) and dtype == "category": return self + if isinstance(dtype, pd.CategoricalDtype): + dtype = cudf.CategoricalDtype.from_pandas(dtype) if ( - isinstance( - dtype, (cudf.core.dtypes.CategoricalDtype, pd.CategoricalDtype) - ) - and (dtype.categories is None) - and (dtype.ordered is None) + isinstance(dtype, cudf.CategoricalDtype) + and dtype.categories is None + and dtype.ordered is None ): return self - - if isinstance(dtype, pd.CategoricalDtype): - dtype = CategoricalDtype( - categories=dtype.categories, ordered=dtype.ordered - ) - - if not isinstance(dtype, CategoricalDtype): + elif not isinstance(dtype, CategoricalDtype): raise ValueError("dtype must be CategoricalDtype") - if not isinstance(self.categories, type(dtype.categories._values)): + if not isinstance(self.categories, type(dtype.categories._column)): # If both categories are of different Column types, # return a column full of Nulls. return _create_empty_categorical_column(self, dtype) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index e7a2863da8c..adc783c20c4 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -962,59 +962,59 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase: if len(self) == 0: dtype = cudf.dtype(dtype) if self.dtype == dtype: - if copy: - return self.copy() - else: - return self + result = self else: - return column_empty(0, dtype=dtype, masked=self.nullable) - if copy: - col = self.copy() - else: - col = self - if dtype == "category": + result = column_empty(0, dtype=dtype, masked=self.nullable) + elif dtype == "category": # TODO: Figure out why `cudf.dtype("category")` # astype's different than just the string - return col.as_categorical_column(dtype) + result = self.as_categorical_column(dtype) elif ( isinstance(dtype, str) and dtype == "interval" and isinstance(self.dtype, cudf.IntervalDtype) ): # astype("interval") (the string only) should no-op - return col - was_object = dtype == object or dtype == np.dtype(object) - dtype = cudf.dtype(dtype) - if self.dtype == dtype: - return col - elif isinstance(dtype, CategoricalDtype): - return col.as_categorical_column(dtype) - elif isinstance(dtype, IntervalDtype): - return col.as_interval_column(dtype) - elif isinstance(dtype, (ListDtype, StructDtype)): - if not col.dtype == dtype: - raise NotImplementedError( - f"Casting {self.dtype} columns not currently supported" - ) - return col - elif isinstance(dtype, cudf.core.dtypes.DecimalDtype): - return col.as_decimal_column(dtype) - elif dtype.kind == "M": - return col.as_datetime_column(dtype) - elif dtype.kind == "m": - return col.as_timedelta_column(dtype) - elif dtype.kind == "O": - if cudf.get_option("mode.pandas_compatible") and was_object: - raise ValueError( - f"Casting to {dtype} is not supported, use " - "`.astype('str')` instead." - ) - return col.as_string_column(dtype) + result = self else: - return col.as_numerical_column(dtype) + was_object = dtype == object or dtype == np.dtype(object) + dtype = cudf.dtype(dtype) + if self.dtype == dtype: + result = self + elif isinstance(dtype, CategoricalDtype): + result = self.as_categorical_column(dtype) + elif isinstance(dtype, IntervalDtype): + result = self.as_interval_column(dtype) + elif isinstance(dtype, (ListDtype, StructDtype)): + if not self.dtype == dtype: + raise NotImplementedError( + f"Casting {self.dtype} columns not currently supported" + ) + result = self + elif isinstance(dtype, cudf.core.dtypes.DecimalDtype): + result = self.as_decimal_column(dtype) + elif dtype.kind == "M": + result = self.as_datetime_column(dtype) + elif dtype.kind == "m": + result = self.as_timedelta_column(dtype) + elif dtype.kind == "O": + if cudf.get_option("mode.pandas_compatible") and was_object: + raise ValueError( + f"Casting to {dtype} is not supported, use " + "`.astype('str')` instead." + ) + result = self.as_string_column(dtype) + else: + result = self.as_numerical_column(dtype) + + if copy and result is self: + return result.copy() + return result def as_categorical_column(self, dtype) -> ColumnBase: - if isinstance(dtype, (cudf.CategoricalDtype, pd.CategoricalDtype)): + if isinstance(dtype, pd.CategoricalDtype): + dtype = cudf.CategoricalDtype.from_pandas(dtype) + if isinstance(dtype, cudf.CategoricalDtype): ordered = dtype.ordered else: ordered = False @@ -1023,14 +1023,11 @@ def as_categorical_column(self, dtype) -> ColumnBase: if ( isinstance(dtype, cudf.CategoricalDtype) and dtype._categories is not None - ) or ( - isinstance(dtype, pd.CategoricalDtype) - and dtype.categories is not None ): - labels = self._label_encoding(cats=as_column(dtype.categories)) - + cat_col = dtype._categories + labels = self._label_encoding(cats=cat_col) return build_categorical_column( - categories=as_column(dtype.categories), + categories=cat_col, codes=labels, mask=self.mask, ordered=dtype.ordered, From 65e4e99d702aedbbfd489840d112faecfaeb43b9 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Mon, 8 Jul 2024 23:10:23 -0500 Subject: [PATCH 3/6] Remove CCCL patch for PR 211. (#16207) While upgrading CCCL, we ran into a test failure in cuSpatial. We added a patch to revert some changes from CCCL but the root cause was a bug in cuSpatial. I have fixed that bug here: https://github.com/rapidsai/cuspatial/pull/1402 Once that PR is merged, we can remove this CCCL patch. See also: - rapids-cmake patch removal: https://github.com/rapidsai/rapids-cmake/pull/640 - Original rapids-cmake patch: https://github.com/rapidsai/rapids-cmake/pull/511 - CCCL epic to remove RAPIDS patches: https://github.com/NVIDIA/cccl/issues/1939 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Robert Maynard (https://github.com/robertmaynard) URL: https://github.com/rapidsai/cudf/pull/16207 --- cpp/cmake/thirdparty/patches/cccl_override.json | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cpp/cmake/thirdparty/patches/cccl_override.json b/cpp/cmake/thirdparty/patches/cccl_override.json index e61102dffac..2f29578f7ae 100644 --- a/cpp/cmake/thirdparty/patches/cccl_override.json +++ b/cpp/cmake/thirdparty/patches/cccl_override.json @@ -3,11 +3,6 @@ "packages" : { "CCCL" : { "patches" : [ - { - "file" : "cccl/revert_pr_211.diff", - "issue" : "thrust::copy introduced a change in behavior that causes failures with cudaErrorInvalidValue.", - "fixed_in" : "" - }, { "file" : "${current_json_dir}/thrust_disable_64bit_dispatching.diff", "issue" : "Remove 64bit dispatching as not needed by libcudf and results in compiling twice as many kernels [https://github.com/rapidsai/cudf/pull/11437]", From b693e79b1813276700f70c2cb251d6fef71851a1 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 9 Jul 2024 13:22:35 +0100 Subject: [PATCH 4/6] Handler csv reader options in cudf-polars (#16211) Previously we were just relying on the default cudf read_csv options which doesn't do the right thing if the user has configured things. Now that polars passes through the information to us, we can handle things properly, and raise for unsupported cases. While here, update to new polars release and adapt tests to bug fixes that have been made upstream. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - James Lamb (https://github.com/jameslamb) - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/16211 --- python/cudf/cudf/_lib/csv.pyx | 2 +- python/cudf_polars/cudf_polars/dsl/expr.py | 4 +- python/cudf_polars/cudf_polars/dsl/ir.py | 104 +++++++++++++++-- .../cudf_polars/cudf_polars/dsl/translate.py | 12 +- python/cudf_polars/tests/test_scan.py | 107 ++++++++++++++++-- 5 files changed, 206 insertions(+), 23 deletions(-) diff --git a/python/cudf/cudf/_lib/csv.pyx b/python/cudf/cudf/_lib/csv.pyx index c706351a683..9fecff5f5f6 100644 --- a/python/cudf/cudf/_lib/csv.pyx +++ b/python/cudf/cudf/_lib/csv.pyx @@ -450,7 +450,7 @@ def read_csv( col_name = df._data.names[index] df._data[col_name] = df._data[col_name].astype(col_dtype) - if names is not None and isinstance(names[0], (int)): + if names is not None and len(names) and isinstance(names[0], (int)): df.columns = [int(x) for x in df._data] # Set index if the index_col parameter is passed diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 93cb9db7cbd..f83d9e82d30 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence - import polars.polars as plrs + import polars as pl import polars.type_aliases as pl_types from cudf_polars.containers import DataFrame @@ -377,7 +377,7 @@ class LiteralColumn(Expr): value: pa.Array[Any, Any] children: tuple[()] - def __init__(self, dtype: plc.DataType, value: plrs.PySeries) -> None: + def __init__(self, dtype: plc.DataType, value: pl.Series) -> None: super().__init__(dtype) data = value.to_arrow() self.value = data.cast(dtypes.downcast_arrow_lists(data.type)) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 6b552642e88..b32fa9c273e 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -15,9 +15,9 @@ import dataclasses import itertools -import json import types from functools import cache +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, ClassVar import pyarrow as pa @@ -185,8 +185,10 @@ class Scan(IR): typ: str """What type of file are we reading? Parquet, CSV, etc...""" - options: tuple[Any, ...] - """Type specific options, as json-encoded strings.""" + reader_options: dict[str, Any] + """Reader-specific options, as dictionary.""" + cloud_options: dict[str, Any] | None + """Cloud-related authentication options, currently ignored.""" paths: list[str] """List of paths to read from.""" file_options: Any @@ -206,9 +208,33 @@ def __post_init__(self) -> None: if self.file_options.n_rows is not None: raise NotImplementedError("row limit in scan") if self.typ not in ("csv", "parquet"): + raise NotImplementedError(f"Unhandled scan type: {self.typ}") + if self.cloud_options is not None and any( + self.cloud_options[k] is not None for k in ("aws", "azure", "gcp") + ): raise NotImplementedError( - f"Unhandled scan type: {self.typ}" - ) # pragma: no cover; polars raises on the rust side for now + "Read from cloud storage" + ) # pragma: no cover; no test yet + if self.typ == "csv": + if self.reader_options["skip_rows_after_header"] != 0: + raise NotImplementedError("Skipping rows after header in CSV reader") + parse_options = self.reader_options["parse_options"] + if ( + null_values := parse_options["null_values"] + ) is not None and "Named" in null_values: + raise NotImplementedError( + "Per column null value specification not supported for CSV reader" + ) + if ( + comment := parse_options["comment_prefix"] + ) is not None and "Multi" in comment: + raise NotImplementedError( + "Multi-character comment prefix not supported for CSV reader" + ) + if not self.reader_options["has_header"]: + # Need to do some file introspection to get the number + # of columns so that column projection works right. + raise NotImplementedError("Reading CSV without header") def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" @@ -216,14 +242,70 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: with_columns = options.with_columns row_index = options.row_index if self.typ == "csv": - opts, cloud_opts = map(json.loads, self.options) - df = DataFrame.from_cudf( - cudf.concat( - [cudf.read_csv(p, usecols=with_columns) for p in self.paths] + dtype_map = { + name: cudf._lib.types.PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES[typ.id()] + for name, typ in self.schema.items() + } + parse_options = self.reader_options["parse_options"] + sep = chr(parse_options["separator"]) + quote = chr(parse_options["quote_char"]) + eol = chr(parse_options["eol_char"]) + if self.reader_options["schema"] is not None: + # Reader schema provides names + column_names = list(self.reader_options["schema"]["inner"].keys()) + else: + # file provides column names + column_names = None + usecols = with_columns + # TODO: support has_header=False + header = 0 + + # polars defaults to no null recognition + null_values = [""] + if parse_options["null_values"] is not None: + ((typ, nulls),) = parse_options["null_values"].items() + if typ == "AllColumnsSingle": + # Single value + null_values.append(nulls) + else: + # List of values + null_values.extend(nulls) + if parse_options["comment_prefix"] is not None: + comment = chr(parse_options["comment_prefix"]["Single"]) + else: + comment = None + decimal = "," if parse_options["decimal_comma"] else "." + + # polars skips blank lines at the beginning of the file + pieces = [] + for p in self.paths: + skiprows = self.reader_options["skip_rows"] + # TODO: read_csv expands globs which we should not do, + # because polars will already have handled them. + path = Path(p) + with path.open() as f: + while f.readline() == "\n": + skiprows += 1 + pieces.append( + cudf.read_csv( + path, + sep=sep, + quotechar=quote, + lineterminator=eol, + names=column_names, + header=header, + usecols=usecols, + na_filter=True, + na_values=null_values, + keep_default_na=False, + skiprows=skiprows, + comment=comment, + decimal=decimal, + dtype=dtype_map, + ) ) - ) + df = DataFrame.from_cudf(cudf.concat(pieces)) elif self.typ == "parquet": - opts, cloud_opts = map(json.loads, self.options) cdf = cudf.read_parquet(self.paths, columns=with_columns) assert isinstance(cdf, cudf.DataFrame) df = DataFrame.from_cudf(cdf) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 5a1e682abe7..dec45679c75 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -5,6 +5,7 @@ from __future__ import annotations +import json from contextlib import AbstractContextManager, nullcontext from functools import singledispatch from typing import Any @@ -12,6 +13,7 @@ import pyarrow as pa from typing_extensions import assert_never +import polars as pl import polars.polars as plrs from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir @@ -88,10 +90,16 @@ def _( node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: typ, *options = node.scan_type + if typ == "ndjson": + (reader_options,) = map(json.loads, options) + cloud_options = None + else: + reader_options, cloud_options = map(json.loads, options) return ir.Scan( schema, typ, - tuple(options), + reader_options, + cloud_options, node.paths, node.file_options, translate_named_expr(visitor, n=node.predicate) @@ -402,7 +410,7 @@ def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr @_translate_expr.register def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: if isinstance(node.value, plrs.PySeries): - return expr.LiteralColumn(dtype, node.value) + return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value)) value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype)) return expr.Literal(dtype, value) diff --git a/python/cudf_polars/tests/test_scan.py b/python/cudf_polars/tests/test_scan.py index f129cc7ca32..c41a94da14b 100644 --- a/python/cudf_polars/tests/test_scan.py +++ b/python/cudf_polars/tests/test_scan.py @@ -22,22 +22,22 @@ def row_index(request): @pytest.fixture( params=[ - (None, 0), + None, pytest.param( - (2, 1), marks=pytest.mark.xfail(reason="No handling of row limit in scan") + 2, marks=pytest.mark.xfail(reason="No handling of row limit in scan") ), pytest.param( - (3, 0), marks=pytest.mark.xfail(reason="No handling of row limit in scan") + 3, marks=pytest.mark.xfail(reason="No handling of row limit in scan") ), ], ids=["all-rows", "n_rows-with-skip", "n_rows-no-skip"], ) -def n_rows_skip_rows(request): +def n_rows(request): return request.param @pytest.fixture(params=["csv", "parquet"]) -def df(request, tmp_path, row_index, n_rows_skip_rows): +def df(request, tmp_path, row_index, n_rows): df = pl.DataFrame( { "a": [1, 2, 3, None], @@ -46,14 +46,12 @@ def df(request, tmp_path, row_index, n_rows_skip_rows): } ) name, offset = row_index - n_rows, skip_rows = n_rows_skip_rows if request.param == "csv": df.write_csv(tmp_path / "file.csv") return pl.scan_csv( tmp_path / "file.csv", row_index_name=name, row_index_offset=offset, - skip_rows_after_header=skip_rows, n_rows=n_rows, ) else: @@ -97,3 +95,98 @@ def test_scan_unsupported_raises(tmp_path): df.write_ndjson(tmp_path / "df.json") q = pl.scan_ndjson(tmp_path / "df.json") assert_ir_translation_raises(q, NotImplementedError) + + +def test_scan_row_index_projected_out(tmp_path): + df = pl.DataFrame({"a": [1, 2, 3]}) + + df.write_parquet(tmp_path / "df.pq") + + q = pl.scan_parquet(tmp_path / "df.pq").with_row_index().select(pl.col("a")) + + assert_gpu_result_equal(q) + + +def test_scan_csv_column_renames_projection_schema(tmp_path): + with (tmp_path / "test.csv").open("w") as f: + f.write("""foo,bar,baz\n1,2\n3,4,5""") + + q = pl.scan_csv( + tmp_path / "test.csv", + with_column_names=lambda names: [f"{n}_suffix" for n in names], + schema_overrides={ + "foo_suffix": pl.String(), + "bar_suffix": pl.Int8(), + "baz_suffix": pl.UInt16(), + }, + ) + + assert_gpu_result_equal(q) + + +def test_scan_csv_skip_after_header_not_implemented(tmp_path): + with (tmp_path / "test.csv").open("w") as f: + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") + + q = pl.scan_csv(tmp_path / "test.csv", skip_rows_after_header=1) + + assert_ir_translation_raises(q, NotImplementedError) + + +def test_scan_csv_null_values_per_column_not_implemented(tmp_path): + with (tmp_path / "test.csv").open("w") as f: + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") + + q = pl.scan_csv(tmp_path / "test.csv", null_values={"foo": "1", "baz": "5"}) + + assert_ir_translation_raises(q, NotImplementedError) + + +def test_scan_csv_comment_str_not_implemented(tmp_path): + with (tmp_path / "test.csv").open("w") as f: + f.write("""foo,bar,baz\n// 1,2,3\n3,4,5""") + + q = pl.scan_csv(tmp_path / "test.csv", comment_prefix="// ") + + assert_ir_translation_raises(q, NotImplementedError) + + +def test_scan_csv_comment_char(tmp_path): + with (tmp_path / "test.csv").open("w") as f: + f.write("""foo,bar,baz\n# 1,2,3\n3,4,5""") + + q = pl.scan_csv(tmp_path / "test.csv", comment_prefix="#") + + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("nulls", [None, "3", ["3", "5"]]) +def test_scan_csv_null_values(tmp_path, nulls): + with (tmp_path / "test.csv").open("w") as f: + f.write("""foo,bar,baz\n1,2,3\n3,4,5\n5,,2""") + + q = pl.scan_csv(tmp_path / "test.csv", null_values=nulls) + + assert_gpu_result_equal(q) + + +def test_scan_csv_decimal_comma(tmp_path): + with (tmp_path / "test.csv").open("w") as f: + f.write("""foo|bar|baz\n1,23|2,34|3,56\n1""") + + q = pl.scan_csv(tmp_path / "test.csv", separator="|", decimal_comma=True) + + assert_gpu_result_equal(q) + + +def test_scan_csv_skip_initial_empty_rows(tmp_path): + with (tmp_path / "test.csv").open("w") as f: + f.write("""\n\n\n\nfoo|bar|baz\n1|2|3\n1""") + + q = pl.scan_csv(tmp_path / "test.csv", separator="|", skip_rows=1, has_header=False) + + assert_ir_translation_raises(q, NotImplementedError) + + q = pl.scan_csv(tmp_path / "test.csv", separator="|", skip_rows=1) + + assert_gpu_result_equal(q) From 75966deef548754a5a7f5fb49f1cf5b1be991363 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 9 Jul 2024 06:59:56 -0700 Subject: [PATCH 5/6] Publish cudf-polars nightlies (#16213) Publish nightlies for cudf-polars. Authors: - Thomas Li (https://github.com/lithomas1) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/cudf/pull/16213 --- .github/workflows/build.yaml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c5679cc5141..2e5959338b0 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -108,6 +108,28 @@ jobs: sha: ${{ inputs.sha }} date: ${{ inputs.date }} package-name: dask_cudf + wheel-build-cudf-polars: + needs: wheel-publish-cudf + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.08 + with: + # This selects "ARCH=amd64 + the latest supported Python + CUDA". + matrix_filter: map(select(.ARCH == "amd64")) | group_by(.CUDA_VER|split(".")|map(tonumber)|.[0]) | map(max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))])) + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + script: ci/build_wheel_cudf_polars.sh + wheel-publish-cudf-polars: + needs: wheel-build-cudf-polars + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.08 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: cudf_polars trigger-pandas-tests: if: inputs.build_type == 'nightly' needs: wheel-build-cudf From 433e959deab26ccf1eb9b75b8ea3e21659da4f0a Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Tue, 9 Jul 2024 10:45:05 -0400 Subject: [PATCH 6/6] Free temp memory no longer needed in multibyte_split processing (#16091) Updates the `multibyte_split` logic to free temporary memory once the chars and offsets have been resolved. This gives room to the remaining processing if more temp memory is required. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Bradley Dice (https://github.com/bdice) - https://github.com/nvdbaranec URL: https://github.com/rapidsai/cudf/pull/16091 --- cpp/src/io/text/multibyte_split.cu | 324 ++++++++++++++--------------- 1 file changed, 162 insertions(+), 162 deletions(-) diff --git a/cpp/src/io/text/multibyte_split.cu b/cpp/src/io/text/multibyte_split.cu index 51dc0ca90af..be2e2b9a79c 100644 --- a/cpp/src/io/text/multibyte_split.cu +++ b/cpp/src/io/text/multibyte_split.cu @@ -55,6 +55,8 @@ #include #include +namespace cudf::io::text { +namespace detail { namespace { using cudf::io::text::detail::multistate; @@ -299,11 +301,6 @@ CUDF_KERNEL __launch_bounds__(THREADS_PER_TILE) void byte_split_kernel( } // namespace -namespace cudf { -namespace io { -namespace text { -namespace detail { - std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source const& source, std::string const& delimiter, byte_range_info byte_range, @@ -336,173 +333,181 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source CUDF_EXPECTS(delimiter.size() < multistate::max_segment_value, "delimiter contains too many total tokens to produce a deterministic result."); - auto const concurrency = 2; - - // must be at least 32 when using warp-reduce on partials - // must be at least 1 more than max possible concurrent tiles - // best when at least 32 more than max possible concurrent tiles, due to rolling `invalid`s - auto num_tile_states = std::max(32, TILES_PER_CHUNK * concurrency + 32); - auto tile_multistates = - scan_tile_state(num_tile_states, stream, rmm::mr::get_current_device_resource()); - auto tile_offsets = - scan_tile_state(num_tile_states, stream, rmm::mr::get_current_device_resource()); - - multibyte_split_init_kernel<<>>( // - -TILES_PER_CHUNK, - TILES_PER_CHUNK, - tile_multistates, - tile_offsets, - cudf::io::text::detail::scan_tile_status::oob); - - auto multistate_seed = multistate(); - multistate_seed.enqueue(0, 0); // this represents the first state in the pattern. - - // Seeding the tile state with an identity value allows the 0th tile to follow the same logic as - // the Nth tile, assuming it can look up an inclusive prefix. Without this seed, the 0th block - // would have to follow separate logic. - cudf::detail::device_single_thread( - [tm = scan_tile_state_view(tile_multistates), - to = scan_tile_state_view(tile_offsets), - multistate_seed] __device__() mutable { - tm.set_inclusive_prefix(-1, multistate_seed); - to.set_inclusive_prefix(-1, 0); - }, - stream); - - auto reader = source.create_reader(); - auto chunk_offset = std::max(0, byte_range.offset() - delimiter.size()); - auto const byte_range_end = byte_range.offset() + byte_range.size(); - reader->skip_bytes(chunk_offset); - // amortize output chunk allocations over 8 worst-case outputs. This limits the overallocation - constexpr auto max_growth = 8; - output_builder row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream); - output_builder char_storage(ITEMS_PER_CHUNK, max_growth, stream); - - auto streams = cudf::detail::fork_streams(stream, concurrency); - - cudaEvent_t last_launch_event; - CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event)); - - auto& read_stream = streams[0]; - auto& scan_stream = streams[1]; - auto chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream); - int64_t base_tile_idx = 0; + auto chunk_offset = std::max(0, byte_range.offset() - delimiter.size()); std::optional first_row_offset; - std::optional last_row_offset; - bool found_last_offset = false; if (byte_range.offset() == 0) { first_row_offset = 0; } - std::swap(read_stream, scan_stream); - - while (chunk->size() > 0) { - // if we found the last delimiter, or didn't find delimiters inside the byte range at all: abort - if (last_row_offset.has_value() or - (not first_row_offset.has_value() and chunk_offset >= byte_range_end)) { - break; - } - - auto tiles_in_launch = - cudf::util::div_rounding_up_safe(chunk->size(), static_cast(ITEMS_PER_TILE)); - - auto row_offsets = row_offset_storage.next_output(scan_stream); + std::optional last_row_offset; - // reset the next chunk of tile state - multibyte_split_init_kernel<<(num_tile_states, stream, rmm::mr::get_current_device_resource()); + auto tile_offsets = scan_tile_state( + num_tile_states, stream, rmm::mr::get_current_device_resource()); + + multibyte_split_init_kernel<<>>( // - base_tile_idx, - tiles_in_launch, + stream.value()>>>( // + -TILES_PER_CHUNK, + TILES_PER_CHUNK, tile_multistates, - tile_offsets); + tile_offsets, + cudf::io::text::detail::scan_tile_status::oob); - CUDF_CUDA_TRY(cudaStreamWaitEvent(scan_stream.value(), last_launch_event)); + auto multistate_seed = multistate(); + multistate_seed.enqueue(0, 0); // this represents the first state in the pattern. - if (delimiter.size() == 1) { - // the single-byte case allows for a much more efficient kernel, so we special-case it - byte_split_kernel<<>>( // - base_tile_idx, - chunk_offset, - row_offset_storage.size(), - tile_offsets, - delimiter[0], - *chunk, - row_offsets); - } else { - multibyte_split_kernel<<>>( // + // Seeding the tile state with an identity value allows the 0th tile to follow the same logic as + // the Nth tile, assuming it can look up an inclusive prefix. Without this seed, the 0th block + // would have to follow separate logic. + cudf::detail::device_single_thread( + [tm = scan_tile_state_view(tile_multistates), + to = scan_tile_state_view(tile_offsets), + multistate_seed] __device__() mutable { + tm.set_inclusive_prefix(-1, multistate_seed); + to.set_inclusive_prefix(-1, 0); + }, + stream); + + auto reader = source.create_reader(); + auto const byte_range_end = byte_range.offset() + byte_range.size(); + reader->skip_bytes(chunk_offset); + // amortize output chunk allocations over 8 worst-case outputs. This limits the overallocation + constexpr auto max_growth = 8; + output_builder row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream); + output_builder char_storage(ITEMS_PER_CHUNK, max_growth, stream); + + auto streams = cudf::detail::fork_streams(stream, concurrency); + + cudaEvent_t last_launch_event; + CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event)); + + auto& read_stream = streams[0]; + auto& scan_stream = streams[1]; + auto chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream); + int64_t base_tile_idx = 0; + bool found_last_offset = false; + std::swap(read_stream, scan_stream); + + while (chunk->size() > 0) { + // if we found the last delimiter, or didn't find delimiters inside the byte range at all: + // abort + if (last_row_offset.has_value() or + (not first_row_offset.has_value() and chunk_offset >= byte_range_end)) { + break; + } + + auto tiles_in_launch = + cudf::util::div_rounding_up_safe(chunk->size(), static_cast(ITEMS_PER_TILE)); + + auto row_offsets = row_offset_storage.next_output(scan_stream); + + // reset the next chunk of tile state + multibyte_split_init_kernel<<>>( // base_tile_idx, - chunk_offset, - row_offset_storage.size(), + tiles_in_launch, tile_multistates, - tile_offsets, - {device_delim.data(), static_cast(device_delim.size())}, - *chunk, - row_offsets); - } + tile_offsets); + + CUDF_CUDA_TRY(cudaStreamWaitEvent(scan_stream.value(), last_launch_event)); + + if (delimiter.size() == 1) { + // the single-byte case allows for a much more efficient kernel, so we special-case it + byte_split_kernel<<>>( // + base_tile_idx, + chunk_offset, + row_offset_storage.size(), + tile_offsets, + delimiter[0], + *chunk, + row_offsets); + } else { + multibyte_split_kernel<<>>( // + base_tile_idx, + chunk_offset, + row_offset_storage.size(), + tile_multistates, + tile_offsets, + {device_delim.data(), static_cast(device_delim.size())}, + *chunk, + row_offsets); + } - // load the next chunk - auto next_chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream); - // while that is running, determine how many offsets we output (synchronizes) - auto const new_offsets = [&] { - auto const new_offsets_unclamped = - tile_offsets.get_inclusive_prefix(base_tile_idx + tiles_in_launch - 1, scan_stream) - - static_cast(row_offset_storage.size()); - // if we are not in the last chunk, we can use all offsets - if (chunk_offset + static_cast(chunk->size()) < byte_range_end) { - return new_offsets_unclamped; + // load the next chunk + auto next_chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream); + // while that is running, determine how many offsets we output (synchronizes) + auto const new_offsets = [&] { + auto const new_offsets_unclamped = + tile_offsets.get_inclusive_prefix(base_tile_idx + tiles_in_launch - 1, scan_stream) - + static_cast(row_offset_storage.size()); + // if we are not in the last chunk, we can use all offsets + if (chunk_offset + static_cast(chunk->size()) < byte_range_end) { + return new_offsets_unclamped; + } + // if we are in the last chunk, we need to find the first out-of-bounds offset + auto const it = thrust::make_counting_iterator(output_offset{}); + auto const end_loc = + *thrust::find_if(rmm::exec_policy_nosync(scan_stream), + it, + it + new_offsets_unclamped, + [row_offsets, byte_range_end] __device__(output_offset i) { + return row_offsets[i] >= byte_range_end; + }); + // if we had no out-of-bounds offset, we copy all offsets + if (end_loc == new_offsets_unclamped) { return end_loc; } + // otherwise we copy only up to (including) the first out-of-bounds delimiter + found_last_offset = true; + return end_loc + 1; + }(); + row_offset_storage.advance_output(new_offsets, scan_stream); + // determine if we found the first or last field offset for the byte range + if (new_offsets > 0 and not first_row_offset) { + first_row_offset = row_offset_storage.front_element(scan_stream); + } + if (found_last_offset) { last_row_offset = row_offset_storage.back_element(scan_stream); } + // copy over the characters we need, if we already encountered the first field delimiter + if (first_row_offset.has_value()) { + auto const begin = + chunk->data() + std::max(0, *first_row_offset - chunk_offset); + auto const sentinel = last_row_offset.value_or(std::numeric_limits::max()); + auto const end = + chunk->data() + std::min(sentinel - chunk_offset, chunk->size()); + auto const output_size = end - begin; + auto char_output = char_storage.next_output(scan_stream); + thrust::copy(rmm::exec_policy_nosync(scan_stream), begin, end, char_output.begin()); + char_storage.advance_output(output_size, scan_stream); } - // if we are in the last chunk, we need to find the first out-of-bounds offset - auto const it = thrust::make_counting_iterator(output_offset{}); - auto const end_loc = - *thrust::find_if(rmm::exec_policy_nosync(scan_stream), - it, - it + new_offsets_unclamped, - [row_offsets, byte_range_end] __device__(output_offset i) { - return row_offsets[i] >= byte_range_end; - }); - // if we had no out-of-bounds offset, we copy all offsets - if (end_loc == new_offsets_unclamped) { return end_loc; } - // otherwise we copy only up to (including) the first out-of-bounds delimiter - found_last_offset = true; - return end_loc + 1; - }(); - row_offset_storage.advance_output(new_offsets, scan_stream); - // determine if we found the first or last field offset for the byte range - if (new_offsets > 0 and not first_row_offset) { - first_row_offset = row_offset_storage.front_element(scan_stream); - } - if (found_last_offset) { last_row_offset = row_offset_storage.back_element(scan_stream); } - // copy over the characters we need, if we already encountered the first field delimiter - if (first_row_offset.has_value()) { - auto const begin = chunk->data() + std::max(0, *first_row_offset - chunk_offset); - auto const sentinel = last_row_offset.value_or(std::numeric_limits::max()); - auto const end = - chunk->data() + std::min(sentinel - chunk_offset, chunk->size()); - auto const output_size = end - begin; - auto char_output = char_storage.next_output(scan_stream); - thrust::copy(rmm::exec_policy_nosync(scan_stream), begin, end, char_output.begin()); - char_storage.advance_output(output_size, scan_stream); - } - CUDF_CUDA_TRY(cudaEventRecord(last_launch_event, scan_stream.value())); + CUDF_CUDA_TRY(cudaEventRecord(last_launch_event, scan_stream.value())); - std::swap(read_stream, scan_stream); - base_tile_idx += tiles_in_launch; - chunk_offset += chunk->size(); - chunk = std::move(next_chunk); - } + std::swap(read_stream, scan_stream); + base_tile_idx += tiles_in_launch; + chunk_offset += chunk->size(); + chunk = std::move(next_chunk); + } + + CUDF_CUDA_TRY(cudaEventDestroy(last_launch_event)); - CUDF_CUDA_TRY(cudaEventDestroy(last_launch_event)); + cudf::detail::join_streams(streams, stream); - cudf::detail::join_streams(streams, stream); + auto chars = char_storage.gather(stream, mr); + auto global_offsets = row_offset_storage.gather(stream, mr); + return std::pair{std::move(global_offsets), std::move(chars)}; + }(); // if the input was empty, we didn't find a delimiter at all, // or the first delimiter was also the last: empty output @@ -511,9 +516,6 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source return make_empty_column(type_id::STRING); } - auto chars = char_storage.gather(stream, mr); - auto global_offsets = row_offset_storage.gather(stream, mr); - // insert an offset at the beginning if we started at the beginning of the input bool const insert_begin = first_row_offset.value_or(0) == 0; // insert an offset at the end if we have not terminated the last row @@ -591,6 +593,4 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source return result; } -} // namespace text -} // namespace io -} // namespace cudf +} // namespace cudf::io::text