diff --git a/build.sh b/build.sh index 211e1db9fbf..69d6481af42 100755 --- a/build.sh +++ b/build.sh @@ -239,11 +239,6 @@ if hasArg --pydevelop; then PYTHON_ARGS_FOR_INSTALL="${PYTHON_ARGS_FOR_INSTALL} -e" fi -# Append `-DFIND_CUDF_CPP=ON` to EXTRA_CMAKE_ARGS unless a user specified the option. -if [[ "${EXTRA_CMAKE_ARGS}" != *"DFIND_CUDF_CPP"* ]]; then - EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DFIND_CUDF_CPP=ON" -fi - if hasArg --disable_large_strings; then BUILD_DISABLE_LARGE_STRINGS="ON" fi diff --git a/ci/cudf_pandas_scripts/run_tests.sh b/ci/cudf_pandas_scripts/run_tests.sh index c6228a4ef33..f6bdc6f9484 100755 --- a/ci/cudf_pandas_scripts/run_tests.sh +++ b/ci/cudf_pandas_scripts/run_tests.sh @@ -56,10 +56,10 @@ else echo "" > ./constraints.txt if [[ $RAPIDS_DEPENDENCIES == "oldest" ]]; then - # `test_python` constraints are for `[test]` not `[cudf-pandas-tests]` + # `test_python_cudf_pandas` constraints are for `[test]` not `[cudf-pandas-tests]` rapids-dependency-file-generator \ --output requirements \ - --file-key test_python \ + --file-key test_python_cudf_pandas \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION};dependencies=${RAPIDS_DEPENDENCIES}" \ | tee ./constraints.txt fi diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index b0346327319..f73e88bc0c8 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -25,9 +25,9 @@ NEXT_PATCH=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[3]}') NEXT_SHORT_TAG=${NEXT_MAJOR}.${NEXT_MINOR} # Need to distutils-normalize the versions for some use cases -CURRENT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${CURRENT_SHORT_TAG}'))") -NEXT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_SHORT_TAG}'))") -PATCH_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_PATCH}'))") +CURRENT_SHORT_TAG_PEP440=$(python -c "from packaging.version import Version; print(Version('${CURRENT_SHORT_TAG}'))") +NEXT_SHORT_TAG_PEP440=$(python -c "from packaging.version import Version; print(Version('${NEXT_SHORT_TAG}'))") +PATCH_PEP440=$(python -c "from packaging.version import Version; print(Version('${NEXT_PATCH}'))") echo "Preparing release $CURRENT_TAG => $NEXT_FULL_TAG" diff --git a/ci/test_python_common.sh b/ci/test_python_common.sh index d0675b0431a..dc70661a17a 100755 --- a/ci/test_python_common.sh +++ b/ci/test_python_common.sh @@ -10,10 +10,10 @@ set -euo pipefail rapids-logger "Generate Python testing dependencies" ENV_YAML_DIR="$(mktemp -d)" - +FILE_KEY=$1 rapids-dependency-file-generator \ --output conda \ - --file-key test_python \ + --file-key ${FILE_KEY} \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION};dependencies=${RAPIDS_DEPENDENCIES}" \ | tee "${ENV_YAML_DIR}/env.yaml" diff --git a/ci/test_python_cudf.sh b/ci/test_python_cudf.sh index ae34047e87f..2386414b32e 100755 --- a/ci/test_python_cudf.sh +++ b/ci/test_python_cudf.sh @@ -5,7 +5,7 @@ cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../; # Common setup steps shared by Python test jobs -source ./ci/test_python_common.sh +source ./ci/test_python_common.sh test_python_cudf rapids-logger "Check GPU usage" nvidia-smi diff --git a/ci/test_python_other.sh b/ci/test_python_other.sh index 06a24773cae..67c97ad29a5 100755 --- a/ci/test_python_other.sh +++ b/ci/test_python_other.sh @@ -5,7 +5,7 @@ cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../ # Common setup steps shared by Python test jobs -source ./ci/test_python_common.sh +source ./ci/test_python_common.sh test_python_other rapids-mamba-retry install \ --channel "${CPP_CHANNEL}" \ diff --git a/dependencies.yaml b/dependencies.yaml index 620dac95c28..9c95b9f399f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -43,15 +43,28 @@ files: includes: - cuda_version - test_cpp - test_python: + test_python_cudf_pandas: output: none includes: - cuda_version - py_version - test_python_common - test_python_cudf - - test_python_dask_cudf - test_python_cudf_pandas + test_python_cudf: + output: none + includes: + - cuda_version + - py_version + - test_python_common + - test_python_cudf + test_python_other: + output: none + includes: + - cuda_version + - py_version + - test_python_common + - test_python_dask_cudf test_java: output: none includes: @@ -707,9 +720,7 @@ dependencies: - matrix: {dependencies: "oldest"} packages: - numba==0.57.* - - numpy==1.23.* - pandas==2.0.* - - pyarrow==14.0.0 - matrix: packages: - output_types: conda @@ -764,6 +775,14 @@ dependencies: - &transformers transformers==4.39.3 - tzdata specific: + - output_types: [conda, requirements] + matrices: + - matrix: {dependencies: "oldest"} + packages: + - numpy==1.23.* + - pyarrow==14.0.0 + - matrix: + packages: - output_types: conda matrices: - matrix: @@ -783,6 +802,15 @@ dependencies: packages: - dask-cuda==24.12.*,>=0.0.0a0 - *numba + specific: + - output_types: [conda, requirements] + matrices: + - matrix: {dependencies: "oldest"} + packages: + - numpy==1.24.* + - pyarrow==14.0.1 + - matrix: + packages: depends_on_libcudf: common: - output_types: conda diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/findall.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/findall.rst new file mode 100644 index 00000000000..9850ee10098 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/findall.rst @@ -0,0 +1,6 @@ +==== +find +==== + +.. automodule:: pylibcudf.strings.findall + :members: diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst index 003e7c0c35e..9b1a6b72a88 100644 --- a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst @@ -9,6 +9,7 @@ strings contains extract find + findall regex_flags regex_program repeat diff --git a/java/src/main/java/ai/rapids/cudf/JSONOptions.java b/java/src/main/java/ai/rapids/cudf/JSONOptions.java index 17b497be5ee..2bb74c3e3b1 100644 --- a/java/src/main/java/ai/rapids/cudf/JSONOptions.java +++ b/java/src/main/java/ai/rapids/cudf/JSONOptions.java @@ -38,6 +38,7 @@ public final class JSONOptions extends ColumnFilterOptions { private final boolean allowLeadingZeros; private final boolean allowNonNumericNumbers; private final boolean allowUnquotedControlChars; + private final boolean cudfPruneSchema; private final byte lineDelimiter; private JSONOptions(Builder builder) { @@ -53,9 +54,14 @@ private JSONOptions(Builder builder) { allowLeadingZeros = builder.allowLeadingZeros; allowNonNumericNumbers = builder.allowNonNumericNumbers; allowUnquotedControlChars = builder.allowUnquotedControlChars; + cudfPruneSchema = builder.cudfPruneSchema; lineDelimiter = builder.lineDelimiter; } + public boolean shouldCudfPruneSchema() { + return cudfPruneSchema; + } + public byte getLineDelimiter() { return lineDelimiter; } @@ -129,8 +135,14 @@ public static final class Builder extends ColumnFilterOptions.Builder Byte.MAX_VALUE) { throw new IllegalArgumentException("Only basic ASCII values are supported as line delimiters " + delimiter); diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 19c72809cea..6d370ca27b2 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -259,6 +259,7 @@ private static native long readJSON(int[] numChildren, String[] columnNames, boolean allowLeadingZeros, boolean allowNonNumericNumbers, boolean allowUnquotedControl, + boolean pruneColumns, byte lineDelimiter) throws CudfException; private static native long readJSONFromDataSource(int[] numChildren, String[] columnNames, @@ -273,6 +274,7 @@ private static native long readJSONFromDataSource(int[] numChildren, String[] co boolean allowLeadingZeros, boolean allowNonNumericNumbers, boolean allowUnquotedControl, + boolean pruneColumns, byte lineDelimiter, long dsHandle) throws CudfException; @@ -1312,6 +1314,10 @@ private static Table gatherJSONColumns(Schema schema, TableWithMeta twm, int emp * @return the file parsed as a table on the GPU. */ public static Table readJSON(Schema schema, JSONOptions opts, File path) { + // only prune the schema if one is provided + boolean cudfPruneSchema = schema.getColumnNames() != null && + schema.getColumnNames().length != 0 && + opts.shouldCudfPruneSchema(); try (TableWithMeta twm = new TableWithMeta( readJSON(schema.getFlattenedNumChildren(), schema.getFlattenedColumnNames(), schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), @@ -1326,6 +1332,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, File path) { opts.leadingZerosAllowed(), opts.nonNumericNumbersAllowed(), opts.unquotedControlChars(), + cudfPruneSchema, opts.getLineDelimiter()))) { return gatherJSONColumns(schema, twm, -1); @@ -1472,6 +1479,10 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b assert len > 0; assert len <= buffer.length - offset; assert offset >= 0 && offset < buffer.length; + // only prune the schema if one is provided + boolean cudfPruneSchema = schema.getColumnNames() != null && + schema.getColumnNames().length != 0 && + opts.shouldCudfPruneSchema(); try (TableWithMeta twm = new TableWithMeta(readJSON( schema.getFlattenedNumChildren(), schema.getFlattenedColumnNames(), schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), null, @@ -1487,6 +1498,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b opts.leadingZerosAllowed(), opts.nonNumericNumbersAllowed(), opts.unquotedControlChars(), + cudfPruneSchema, opts.getLineDelimiter()))) { return gatherJSONColumns(schema, twm, emptyRowCount); } @@ -1513,6 +1525,10 @@ public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds) { */ public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds, int emptyRowCount) { long dsHandle = DataSourceHelper.createWrapperDataSource(ds); + // only prune the schema if one is provided + boolean cudfPruneSchema = schema.getColumnNames() != null && + schema.getColumnNames().length != 0 && + opts.shouldCudfPruneSchema(); try (TableWithMeta twm = new TableWithMeta(readJSONFromDataSource(schema.getFlattenedNumChildren(), schema.getFlattenedColumnNames(), schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), opts.isDayFirst(), @@ -1526,6 +1542,7 @@ public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds, int opts.leadingZerosAllowed(), opts.nonNumericNumbersAllowed(), opts.unquotedControlChars(), + cudfPruneSchema, opts.getLineDelimiter(), dsHandle))) { return gatherJSONColumns(schema, twm, emptyRowCount); diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 96d4c2c4eeb..0f77da54152 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -1649,7 +1649,8 @@ Java_ai_rapids_cudf_Table_readAndInferJSONFromDataSource(JNIEnv* env, .mixed_types_as_string(mixed_types_as_string) .delimiter(static_cast(line_delimiter)) .strict_validation(strict_validation) - .keep_quotes(keep_quotes); + .keep_quotes(keep_quotes) + .prune_columns(false); if (strict_validation) { opts.numeric_leading_zeros(allow_leading_zeros) .nonnumeric_numbers(allow_nonnumeric_numbers) @@ -1703,6 +1704,7 @@ Java_ai_rapids_cudf_Table_readAndInferJSON(JNIEnv* env, .normalize_whitespace(static_cast(normalize_whitespace)) .strict_validation(strict_validation) .mixed_types_as_string(mixed_types_as_string) + .prune_columns(false) .delimiter(static_cast(line_delimiter)) .keep_quotes(keep_quotes); if (strict_validation) { @@ -1818,6 +1820,7 @@ Java_ai_rapids_cudf_Table_readJSONFromDataSource(JNIEnv* env, jboolean allow_leading_zeros, jboolean allow_nonnumeric_numbers, jboolean allow_unquoted_control, + jboolean prune_columns, jbyte line_delimiter, jlong ds_handle) { @@ -1855,7 +1858,8 @@ Java_ai_rapids_cudf_Table_readJSONFromDataSource(JNIEnv* env, .mixed_types_as_string(mixed_types_as_string) .delimiter(static_cast(line_delimiter)) .strict_validation(strict_validation) - .keep_quotes(keep_quotes); + .keep_quotes(keep_quotes) + .prune_columns(prune_columns); if (strict_validation) { opts.numeric_leading_zeros(allow_leading_zeros) .nonnumeric_numbers(allow_nonnumeric_numbers) @@ -1915,6 +1919,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON(JNIEnv* env, jboolean allow_leading_zeros, jboolean allow_nonnumeric_numbers, jboolean allow_unquoted_control, + jboolean prune_columns, jbyte line_delimiter) { bool read_buffer = true; @@ -1966,7 +1971,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON(JNIEnv* env, .mixed_types_as_string(mixed_types_as_string) .delimiter(static_cast(line_delimiter)) .strict_validation(strict_validation) - .keep_quotes(keep_quotes); + .keep_quotes(keep_quotes) + .prune_columns(prune_columns); if (strict_validation) { opts.numeric_leading_zeros(allow_leading_zeros) .nonnumeric_numbers(allow_nonnumeric_numbers) diff --git a/python/cudf/cudf/_lib/strings/findall.pyx b/python/cudf/cudf/_lib/strings/findall.pyx index 1db0fc89490..c1125d1ebb7 100644 --- a/python/cudf/cudf/_lib/strings/findall.pyx +++ b/python/cudf/cudf/_lib/strings/findall.pyx @@ -1,13 +1,10 @@ # Copyright (c) 2019-2024, NVIDIA CORPORATION. -from cython.operator cimport dereference from libc.stdint cimport uint32_t -from libcpp.memory cimport unique_ptr -from libcpp.string cimport string -from libcpp.utility cimport move from cudf.core.buffer import acquire_spill_lock +<<<<<<< HEAD from pylibcudf.libcudf.column.column cimport column from pylibcudf.libcudf.column.column_view cimport column_view from pylibcudf.libcudf.strings.findall cimport ( @@ -17,8 +14,12 @@ from pylibcudf.libcudf.strings.findall cimport ( from pylibcudf.libcudf.strings.regex_flags cimport regex_flags from pylibcudf.libcudf.strings.regex_program cimport regex_program +======= +>>>>>>> branch-24.12 from cudf._lib.column cimport Column +import pylibcudf as plc + @acquire_spill_lock() def findall(Column source_strings, object pattern, uint32_t flags): @@ -26,21 +27,14 @@ def findall(Column source_strings, object pattern, uint32_t flags): Returns data with all non-overlapping matches of `pattern` in each string of `source_strings` as a lists column. """ - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef string pattern_string = str(pattern).encode() - cdef regex_flags c_flags = flags - cdef unique_ptr[regex_program] c_prog - - with nogil: - c_prog = move(regex_program.create(pattern_string, c_flags)) - c_result = move(cpp_findall( - source_view, - dereference(c_prog) - )) - - return Column.from_unique_ptr(move(c_result)) + prog = plc.strings.regex_program.RegexProgram.create( + str(pattern), flags + ) + plc_result = plc.strings.findall.findall( + source_strings.to_pylibcudf(mode="read"), + prog, + ) + return Column.from_pylibcudf(plc_result) @acquire_spill_lock() @@ -49,18 +43,11 @@ def find_re(Column source_strings, object pattern, uint32_t flags): Returns character positions where the pattern first matches the elements in source_strings. """ - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef string pattern_string = str(pattern).encode() - cdef regex_flags c_flags = flags - cdef unique_ptr[regex_program] c_prog - - with nogil: - c_prog = move(regex_program.create(pattern_string, c_flags)) - c_result = move(cpp_find_re( - source_view, - dereference(c_prog) - )) - - return Column.from_unique_ptr(move(c_result)) + prog = plc.strings.regex_program.RegexProgram.create( + str(pattern), flags + ) + plc_result = plc.strings.findall.find_re( + source_strings.to_pylibcudf(mode="read"), + prog, + ) + return Column.from_pylibcudf(plc_result) diff --git a/python/cudf/cudf/pandas/fast_slow_proxy.py b/python/cudf/cudf/pandas/fast_slow_proxy.py index bf2ee6ae624..0c1cda8810b 100644 --- a/python/cudf/cudf/pandas/fast_slow_proxy.py +++ b/python/cudf/cudf/pandas/fast_slow_proxy.py @@ -881,6 +881,12 @@ def _assert_fast_slow_eq(left, right): assert_eq(left, right) +class ProxyFallbackError(Exception): + """Raised when fallback occurs""" + + pass + + def _fast_function_call(): """ Placeholder fast function for pytest profiling purposes. @@ -957,6 +963,10 @@ def _fast_slow_function_call( f"The exception was {e}." ) except Exception as err: + if _env_get_bool("CUDF_PANDAS_FAIL_ON_FALLBACK", False): + raise ProxyFallbackError( + f"The operation failed with cuDF, the reason was {type(err)}: {err}." + ) from err with nvtx.annotate( "EXECUTE_SLOW", color=_CUDF_PANDAS_NVTX_COLORS["EXECUTE_SLOW"], diff --git a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py index c4ab4b0a853..2bbed40e34e 100644 --- a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py +++ b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py @@ -26,7 +26,11 @@ from cudf.core._compat import PANDAS_GE_220 from cudf.pandas import LOADED, Profiler -from cudf.pandas.fast_slow_proxy import _Unusable, is_proxy_object +from cudf.pandas.fast_slow_proxy import ( + ProxyFallbackError, + _Unusable, + is_proxy_object, +) from cudf.testing import assert_eq if not LOADED: @@ -1738,3 +1742,13 @@ def add_one_ufunc(a): return a + 1 assert_eq(cp.asarray(add_one_ufunc(arr1)), cp.asarray(add_one_ufunc(arr2))) + + +@pytest.mark.xfail( + reason="Fallback expected because casting to object is not supported", +) +def test_fallback_raises_error(monkeypatch): + with monkeypatch.context() as monkeycontext: + monkeycontext.setenv("CUDF_PANDAS_FAIL_ON_FALLBACK", "True") + with pytest.raises(ProxyFallbackError): + pd.Series(range(2)).astype(object) diff --git a/python/cudf/cudf_pandas_tests/test_cudf_pandas_no_fallback.py b/python/cudf/cudf_pandas_tests/test_cudf_pandas_no_fallback.py new file mode 100644 index 00000000000..896256bf6d7 --- /dev/null +++ b/python/cudf/cudf_pandas_tests/test_cudf_pandas_no_fallback.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from cudf.pandas import LOADED + +if not LOADED: + raise ImportError("These tests must be run with cudf.pandas loaded") + +import numpy as np +import pandas as pd + + +@pytest.fixture(autouse=True) +def fail_on_fallback(monkeypatch): + monkeypatch.setenv("CUDF_PANDAS_FAIL_ON_FALLBACK", "True") + + +@pytest.fixture +def dataframe(): + df = pd.DataFrame( + { + "a": [1, 1, 1, 2, 3], + "b": [1, 2, 3, 4, 5], + "c": [1.2, 1.3, 1.5, 1.7, 1.11], + } + ) + return df + + +@pytest.fixture +def series(dataframe): + return dataframe["a"] + + +@pytest.fixture +def array(series): + return series.values + + +@pytest.mark.parametrize( + "op", + [ + "sum", + "min", + "max", + "mean", + "std", + "var", + "prod", + "median", + ], +) +def test_no_fallback_in_reduction_ops(series, op): + s = series + getattr(s, op)() + + +def test_groupby(dataframe): + df = dataframe + df.groupby("a", sort=True).max() + + +def test_no_fallback_in_binops(dataframe): + df = dataframe + df + df + df - df + df * df + df**df + df[["a", "b"]] & df[["a", "b"]] + df <= df + + +def test_no_fallback_in_groupby_rolling_sum(dataframe): + df = dataframe + df.groupby("a").rolling(2).sum() + + +def test_no_fallback_in_concat(dataframe): + df = dataframe + pd.concat([df, df]) + + +def test_no_fallback_in_get_shape(dataframe): + df = dataframe + df.shape + + +def test_no_fallback_in_array_ufunc_op(array): + np.add(array, array) + + +def test_no_fallback_in_merge(dataframe): + df = dataframe + pd.merge(df * df, df + df, how="inner") + pd.merge(df * df, df + df, how="outer") + pd.merge(df * df, df + df, how="left") + pd.merge(df * df, df + df, how="right") diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 97e1dffc65b..907abaa2bfc 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -15,6 +15,7 @@ from dask import config from dask.dataframe.core import is_dataframe_like +from dask.typing import no_default import cudf @@ -90,6 +91,17 @@ def var( ) ) + def rename_axis( + self, mapper=no_default, index=no_default, columns=no_default, axis=0 + ): + from dask_cudf.expr._expr import RenameAxisCudf + + return new_collection( + RenameAxisCudf( + self, mapper=mapper, index=index, columns=columns, axis=axis + ) + ) + class DataFrame(DXDataFrame, CudfFrameBase): @classmethod @@ -202,27 +214,58 @@ class Index(DXIndex, CudfFrameBase): ## -try: - from dask_expr._backends import create_array_collection - - @get_collection_type.register_lazy("cupy") - def _register_cupy(): - import cupy - - @get_collection_type.register(cupy.ndarray) - def get_collection_type_cupy_array(_): - return create_array_collection - - @get_collection_type.register_lazy("cupyx") - def _register_cupyx(): - # Needed for cuml - from cupyx.scipy.sparse import spmatrix - - @get_collection_type.register(spmatrix) - def get_collection_type_csr_matrix(_): - return create_array_collection - -except ImportError: - # Older version of dask-expr. - # Implicit conversion to array wont work. - pass +def _create_array_collection_with_meta(expr): + # NOTE: This is the GPU compatible version of + # `new_dd_object` for DataFrame -> Array conversion. + # This can be removed if dask#11017 is resolved + # (See: https://github.com/dask/dask/issues/11017) + import numpy as np + + import dask.array as da + from dask.blockwise import Blockwise + from dask.highlevelgraph import HighLevelGraph + + result = expr.optimize() + dsk = result.__dask_graph__() + name = result._name + meta = result._meta + divisions = result.divisions + chunks = ((np.nan,) * (len(divisions) - 1),) + tuple( + (d,) for d in meta.shape[1:] + ) + if len(chunks) > 1: + if isinstance(dsk, HighLevelGraph): + layer = dsk.layers[name] + else: + # dask-expr provides a dict only + layer = dsk + if isinstance(layer, Blockwise): + layer.new_axes["j"] = chunks[1][0] + layer.output_indices = layer.output_indices + ("j",) + else: + suffix = (0,) * (len(chunks) - 1) + for i in range(len(chunks[0])): + layer[(name, i) + suffix] = layer.pop((name, i)) + + return da.Array(dsk, name=name, chunks=chunks, meta=meta) + + +@get_collection_type.register_lazy("cupy") +def _register_cupy(): + import cupy + + get_collection_type.register( + cupy.ndarray, + lambda _: _create_array_collection_with_meta, + ) + + +@get_collection_type.register_lazy("cupyx") +def _register_cupyx(): + # Needed for cuml + from cupyx.scipy.sparse import spmatrix + + get_collection_type.register( + spmatrix, + lambda _: _create_array_collection_with_meta, + ) diff --git a/python/dask_cudf/dask_cudf/expr/_expr.py b/python/dask_cudf/dask_cudf/expr/_expr.py index 8a2c50d3fe7..b284ab3774d 100644 --- a/python/dask_cudf/dask_cudf/expr/_expr.py +++ b/python/dask_cudf/dask_cudf/expr/_expr.py @@ -4,11 +4,12 @@ import dask_expr._shuffle as _shuffle_module from dask_expr import new_collection from dask_expr._cumulative import CumulativeBlockwise -from dask_expr._expr import Elemwise, Expr, VarColumns +from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns from dask_expr._reductions import Reduction, Var from dask.dataframe.core import is_dataframe_like, make_meta, meta_nonempty from dask.dataframe.dispatch import is_categorical_dtype +from dask.typing import no_default import cudf @@ -17,6 +18,19 @@ ## +class RenameAxisCudf(RenameAxis): + # TODO: Remove this after rename_axis is supported in cudf + # (See: https://github.com/rapidsai/cudf/issues/16895) + @staticmethod + def operation(df, index=no_default, **kwargs): + if index != no_default: + df.index.name = index + return df + raise NotImplementedError( + "Only `index` is supported for the cudf backend" + ) + + class ToCudfBackend(Elemwise): # TODO: Inherit from ToBackend when rapids-dask-dependency # is pinned to dask>=2024.8.1 diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 7aa0f6320f2..5f0fae86691 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -16,6 +16,7 @@ import dask_cudf from dask_cudf.tests.utils import ( + QUERY_PLANNING_ON, require_dask_expr, skip_dask_expr, xfail_dask_expr, @@ -950,12 +951,16 @@ def test_implicit_array_conversion_cupy(): def func(x): return x.values - # Need to compute the dask collection for now. - # See: https://github.com/dask/dask/issues/11017 - result = ds.map_partitions(func, meta=s.values).compute() - expect = func(s) + result = ds.map_partitions(func, meta=s.values) - dask.array.assert_eq(result, expect) + if QUERY_PLANNING_ON: + # Check Array and round-tripped DataFrame + dask.array.assert_eq(result, func(s)) + dd.assert_eq(result.to_dask_dataframe(), s, check_index=False) + else: + # Legacy version still carries numpy metadata + # See: https://github.com/dask/dask/issues/11017 + dask.array.assert_eq(result.compute(), func(s)) def test_implicit_array_conversion_cupy_sparse(): @@ -967,8 +972,6 @@ def test_implicit_array_conversion_cupy_sparse(): def func(x): return cupyx.scipy.sparse.csr_matrix(x.values) - # Need to compute the dask collection for now. - # See: https://github.com/dask/dask/issues/11017 result = ds.map_partitions(func, meta=s.values).compute() expect = func(s) @@ -1024,3 +1027,15 @@ def test_cov_corr(op, numeric_only): # (See: https://github.com/rapidsai/cudf/issues/12626) expect = getattr(df.to_pandas(), op)(numeric_only=numeric_only) dd.assert_eq(res, expect) + + +def test_rename_axis_after_join(): + df1 = cudf.DataFrame(index=["a", "b", "c"], data=dict(a=[1, 2, 3])) + df1.index.name = "test" + ddf1 = dd.from_pandas(df1, 2) + + df2 = cudf.DataFrame(index=["a", "b", "d"], data=dict(b=[1, 2, 3])) + ddf2 = dd.from_pandas(df2, 2) + result = ddf1.join(ddf2, how="outer") + expected = df1.join(df2, how="outer") + dd.assert_eq(result, expected, check_index=False) diff --git a/python/pylibcudf/pylibcudf/strings/CMakeLists.txt b/python/pylibcudf/pylibcudf/strings/CMakeLists.txt index 8b4fbb1932f..77f20b0b917 100644 --- a/python/pylibcudf/pylibcudf/strings/CMakeLists.txt +++ b/python/pylibcudf/pylibcudf/strings/CMakeLists.txt @@ -13,8 +13,8 @@ # ============================================================================= set(cython_sources - capitalize.pyx case.pyx char_types.pyx contains.pyx extract.pyx find.pyx regex_flags.pyx - regex_program.pyx repeat.pyx replace.pyx side_type.pyx slice.pyx strip.pyx + capitalize.pyx case.pyx char_types.pyx contains.pyx extract.pyx find.pyx findall.pyx + regex_flags.pyx regex_program.pyx repeat.pyx replace.pyx side_type.pyx slice.pyx strip.pyx ) set(linked_libraries cudf::cudf) diff --git a/python/pylibcudf/pylibcudf/strings/__init__.pxd b/python/pylibcudf/pylibcudf/strings/__init__.pxd index 4867d944dc7..91d884b294b 100644 --- a/python/pylibcudf/pylibcudf/strings/__init__.pxd +++ b/python/pylibcudf/pylibcudf/strings/__init__.pxd @@ -8,6 +8,7 @@ from . cimport ( convert, extract, find, + findall, regex_flags, regex_program, replace, diff --git a/python/pylibcudf/pylibcudf/strings/__init__.py b/python/pylibcudf/pylibcudf/strings/__init__.py index a3bef64d19f..b4856784390 100644 --- a/python/pylibcudf/pylibcudf/strings/__init__.py +++ b/python/pylibcudf/pylibcudf/strings/__init__.py @@ -8,6 +8,7 @@ convert, extract, find, + findall, regex_flags, regex_program, repeat, diff --git a/python/pylibcudf/pylibcudf/strings/findall.pxd b/python/pylibcudf/pylibcudf/strings/findall.pxd new file mode 100644 index 00000000000..54afa088141 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/findall.pxd @@ -0,0 +1,7 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from pylibcudf.column cimport Column +from pylibcudf.strings.regex_program cimport RegexProgram + + +cpdef Column findall(Column input, RegexProgram pattern) diff --git a/python/pylibcudf/pylibcudf/strings/findall.pyx b/python/pylibcudf/pylibcudf/strings/findall.pyx new file mode 100644 index 00000000000..03ecb13a50e --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/findall.pyx @@ -0,0 +1,40 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move +from pylibcudf.column cimport Column +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.strings cimport findall as cpp_findall +from pylibcudf.strings.regex_program cimport RegexProgram + + +cpdef Column findall(Column input, RegexProgram pattern): + """ + Returns a lists column of strings for each matching occurrence using + the regex_program pattern within each string. + + For details, see For details, see :cpp:func:`cudf::strings::findall`. + + Parameters + ---------- + input : Column + Strings instance for this operation + pattern : RegexProgram + Regex pattern + + Returns + ------- + Column + New lists column of strings + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_findall.findall( + input.view(), + pattern.c_obj.get()[0] + ) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_findall.py b/python/pylibcudf/pylibcudf/tests/test_string_findall.py new file mode 100644 index 00000000000..994552fa276 --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_string_findall.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import re + +import pyarrow as pa +import pylibcudf as plc +from utils import assert_column_eq + + +def test_findall(): + arr = pa.array(["bunny", "rabbit", "hare", "dog"]) + pattern = "[ab]" + result = plc.strings.findall.findall( + plc.interop.from_arrow(arr), + plc.strings.regex_program.RegexProgram.create( + pattern, plc.strings.regex_flags.RegexFlags.DEFAULT + ), + ) + pa_result = plc.interop.to_arrow(result) + expected = pa.array( + [re.findall(pattern, elem) for elem in arr.to_pylist()], + type=pa_result.type, + ) + assert_column_eq(result, expected)