diff --git a/ci/build_wheel_libcudf.sh b/ci/build_wheel_libcudf.sh index 8975381ceba..91bc071583e 100755 --- a/ci/build_wheel_libcudf.sh +++ b/ci/build_wheel_libcudf.sh @@ -5,11 +5,15 @@ set -euo pipefail package_dir="python/libcudf" +export SKBUILD_CMAKE_ARGS="-DUSE_NVCOMP_RUNTIME_WHEEL=ON" ./ci/build_wheel.sh ${package_dir} RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" mkdir -p ${package_dir}/final_dist -python -m auditwheel repair -w ${package_dir}/final_dist ${package_dir}/dist/* +python -m auditwheel repair \ + --exclude libnvcomp.so.4 \ + -w ${package_dir}/final_dist \ + ${package_dir}/dist/* RAPIDS_PY_WHEEL_NAME="libcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 cpp ${package_dir}/final_dist diff --git a/ci/test_cudf_polars_polars_tests.sh b/ci/test_cudf_polars_polars_tests.sh index 55399d0371a..f5bcdc62604 100755 --- a/ci/test_cudf_polars_polars_tests.sh +++ b/ci/test_cudf_polars_polars_tests.sh @@ -24,14 +24,17 @@ rapids-logger "Download wheels" RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" RAPIDS_PY_WHEEL_NAME="cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}" RAPIDS_PY_WHEEL_PURE="1" rapids-download-wheels-from-s3 ./dist -# Download the pylibcudf built in the previous step -RAPIDS_PY_WHEEL_NAME="pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-pylibcudf-dep +# Download libcudf and pylibcudf built in the previous step +RAPIDS_PY_WHEEL_NAME="libcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 cpp ./local-libcudf-dep +RAPIDS_PY_WHEEL_NAME="pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 python ./local-pylibcudf-dep -rapids-logger "Install pylibcudf" -python -m pip install ./local-pylibcudf-dep/pylibcudf*.whl +rapids-logger "Install libcudf, pylibcudf and cudf_polars" +python -m pip install \ + -v \ + "$(echo ./dist/cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}*.whl)[test]" \ + "$(echo ./local-libcudf-dep/libcudf_${RAPIDS_PY_CUDA_SUFFIX}*.whl)" \ + "$(echo ./local-pylibcudf-dep/pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}*.whl)" -rapids-logger "Install cudf_polars" -python -m pip install $(echo ./dist/cudf_polars*.whl) TAG=$(python -c 'import polars; print(f"py-{polars.__version__}")') rapids-logger "Clone polars to ${TAG}" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 136f43ee706..f7a5dd2f2fb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -52,6 +52,7 @@ option(JITIFY_USE_CACHE "Use a file cache for JIT compiled kernels" ON) option(CUDF_BUILD_TESTUTIL "Whether to build the test utilities contained in libcudf" ON) mark_as_advanced(CUDF_BUILD_TESTUTIL) option(CUDF_USE_PROPRIETARY_NVCOMP "Download and use NVCOMP with proprietary extensions" ON) +option(CUDF_EXPORT_NVCOMP "Export NVCOMP as a dependency" ON) option(CUDF_LARGE_STRINGS_DISABLED "Build with large string support disabled" OFF) mark_as_advanced(CUDF_LARGE_STRINGS_DISABLED) option( diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 15d12ab2766..02519afb89d 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -392,11 +392,6 @@ ConfigureNVBench(JSON_READER_NVBENCH io/json/nested_json.cpp io/json/json_reader ConfigureNVBench(JSON_READER_OPTION_NVBENCH io/json/json_reader_option.cpp) ConfigureNVBench(JSON_WRITER_NVBENCH io/json/json_writer.cpp) -# ################################################################################################## -# * multi buffer memset benchmark -# ---------------------------------------------------------------------- -ConfigureNVBench(BATCHED_MEMSET_BENCH io/utilities/batched_memset_bench.cpp) - # ################################################################################################## # * io benchmark --------------------------------------------------------------------- ConfigureNVBench(MULTIBYTE_SPLIT_NVBENCH io/text/multibyte_split.cpp) diff --git a/cpp/benchmarks/io/utilities/batched_memset_bench.cpp b/cpp/benchmarks/io/utilities/batched_memset_bench.cpp deleted file mode 100644 index 2905895a63b..00000000000 --- a/cpp/benchmarks/io/utilities/batched_memset_bench.cpp +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include -#include - -#include - -// Size of the data in the benchmark dataframe; chosen to be low enough to allow benchmarks to -// run on most GPUs, but large enough to allow highest throughput -constexpr size_t data_size = 512 << 20; - -void parquet_read_common(cudf::size_type num_rows_to_read, - cudf::size_type num_cols_to_read, - cuio_source_sink_pair& source_sink, - nvbench::state& state) -{ - cudf::io::parquet_reader_options read_opts = - cudf::io::parquet_reader_options::builder(source_sink.make_source_info()); - - auto mem_stats_logger = cudf::memory_stats_logger(); - state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value())); - state.exec( - nvbench::exec_tag::sync | nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { - try_drop_l3_cache(); - - timer.start(); - auto const result = cudf::io::read_parquet(read_opts); - timer.stop(); - - CUDF_EXPECTS(result.tbl->num_columns() == num_cols_to_read, "Unexpected number of columns"); - CUDF_EXPECTS(result.tbl->num_rows() == num_rows_to_read, "Unexpected number of rows"); - }); - - auto const time = state.get_summary("nv/cold/time/gpu/mean").get_float64("value"); - state.add_element_count(static_cast(data_size) / time, "bytes_per_second"); - state.add_buffer_size( - mem_stats_logger.peak_memory_usage(), "peak_memory_usage", "peak_memory_usage"); - state.add_buffer_size(source_sink.size(), "encoded_file_size", "encoded_file_size"); -} - -template -void bench_batched_memset(nvbench::state& state, nvbench::type_list>) -{ - auto const d_type = get_type_or_group(static_cast(DataType)); - auto const num_cols = static_cast(state.get_int64("num_cols")); - auto const cardinality = static_cast(state.get_int64("cardinality")); - auto const run_length = static_cast(state.get_int64("run_length")); - auto const source_type = retrieve_io_type_enum(state.get_string("io_type")); - auto const compression = cudf::io::compression_type::NONE; - cuio_source_sink_pair source_sink(source_type); - auto const tbl = - create_random_table(cycle_dtypes(d_type, num_cols), - table_size_bytes{data_size}, - data_profile_builder().cardinality(cardinality).avg_run_length(run_length)); - auto const view = tbl->view(); - - cudf::io::parquet_writer_options write_opts = - cudf::io::parquet_writer_options::builder(source_sink.make_sink_info(), view) - .compression(compression); - cudf::io::write_parquet(write_opts); - auto const num_rows = view.num_rows(); - - parquet_read_common(num_rows, num_cols, source_sink, state); -} - -using d_type_list = nvbench::enum_type_list; - -NVBENCH_BENCH_TYPES(bench_batched_memset, NVBENCH_TYPE_AXES(d_type_list)) - .set_name("batched_memset") - .set_type_axes_names({"data_type"}) - .add_int64_axis("num_cols", {1000}) - .add_string_axis("io_type", {"DEVICE_BUFFER"}) - .set_min_samples(4) - .add_int64_axis("cardinality", {0, 1000}) - .add_int64_axis("run_length", {1, 32}); diff --git a/cpp/cmake/thirdparty/get_nvcomp.cmake b/cpp/cmake/thirdparty/get_nvcomp.cmake index 41bbf44abc8..33b1b45fb44 100644 --- a/cpp/cmake/thirdparty/get_nvcomp.cmake +++ b/cpp/cmake/thirdparty/get_nvcomp.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -16,11 +16,11 @@ function(find_and_configure_nvcomp) include(${rapids-cmake-dir}/cpm/nvcomp.cmake) - rapids_cpm_nvcomp( - BUILD_EXPORT_SET cudf-exports - INSTALL_EXPORT_SET cudf-exports - USE_PROPRIETARY_BINARY ${CUDF_USE_PROPRIETARY_NVCOMP} - ) + set(export_args) + if(CUDF_EXPORT_NVCOMP) + set(export_args BUILD_EXPORT_SET cudf-exports INSTALL_EXPORT_SET cudf-exports) + endif() + rapids_cpm_nvcomp(${export_args} USE_PROPRIETARY_BINARY ${CUDF_USE_PROPRIETARY_NVCOMP}) # Per-thread default stream if(TARGET nvcomp AND CUDF_USE_PER_THREAD_DEFAULT_STREAM) diff --git a/cpp/doxygen/regex.md b/cpp/doxygen/regex.md index 6d1c91a5752..6902b1948bd 100644 --- a/cpp/doxygen/regex.md +++ b/cpp/doxygen/regex.md @@ -8,6 +8,7 @@ This page specifies which regular expression (regex) features are currently supp - cudf::strings::extract() - cudf::strings::extract_all_record() - cudf::strings::findall() +- cudf::strings::find_re() - cudf::strings::replace_re() - cudf::strings::replace_with_backrefs() - cudf::strings::split_re() diff --git a/cpp/include/cudf/detail/utilities/batched_memcpy.hpp b/cpp/include/cudf/detail/utilities/batched_memcpy.hpp new file mode 100644 index 00000000000..ed0ab9e6e5b --- /dev/null +++ b/cpp/include/cudf/detail/utilities/batched_memcpy.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include +#include +#include + +namespace CUDF_EXPORT cudf { +namespace detail { + +/** + * @brief A helper function that copies a vector of vectors from source to destination addresses in + * a batched manner. + * + * @tparam SrcIterator **[inferred]** The type of device-accessible source addresses iterator + * @tparam DstIterator **[inferred]** The type of device-accessible destination address iterator + * @tparam SizeIterator **[inferred]** The type of device-accessible buffer size iterator + * + * @param src_iter Device-accessible iterator to source addresses + * @param dst_iter Device-accessible iterator to destination addresses + * @param size_iter Device-accessible iterator to the buffer sizes (in bytes) + * @param num_buffs Number of buffers to be copied + * @param stream CUDA stream to use + */ +template +void batched_memcpy_async(SrcIterator src_iter, + DstIterator dst_iter, + SizeIterator size_iter, + size_t num_buffs, + rmm::cuda_stream_view stream) +{ + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, num_buffs, stream.value()); + + rmm::device_buffer d_temp_storage{temp_storage_bytes, stream.value()}; + + cub::DeviceMemcpy::Batched(d_temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + num_buffs, + stream.value()); +} + +} // namespace detail +} // namespace CUDF_EXPORT cudf diff --git a/cpp/include/cudf/io/detail/batched_memset.hpp b/cpp/include/cudf/detail/utilities/batched_memset.hpp similarity index 98% rename from cpp/include/cudf/io/detail/batched_memset.hpp rename to cpp/include/cudf/detail/utilities/batched_memset.hpp index 1c74be4a9fe..75f738f7529 100644 --- a/cpp/include/cudf/io/detail/batched_memset.hpp +++ b/cpp/include/cudf/detail/utilities/batched_memset.hpp @@ -28,7 +28,7 @@ #include namespace CUDF_EXPORT cudf { -namespace io::detail { +namespace detail { /** * @brief A helper function that takes in a vector of device spans and memsets them to the @@ -78,5 +78,5 @@ void batched_memset(std::vector> const& bufs, d_temp_storage.data(), temp_storage_bytes, iter_in, iter_out, sizes, num_bufs, stream); } -} // namespace io::detail +} // namespace detail } // namespace CUDF_EXPORT cudf diff --git a/cpp/include/cudf/strings/findall.hpp b/cpp/include/cudf/strings/findall.hpp index c6b9bc7e58a..867764b6d9a 100644 --- a/cpp/include/cudf/strings/findall.hpp +++ b/cpp/include/cudf/strings/findall.hpp @@ -66,6 +66,35 @@ std::unique_ptr findall( rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); +/** + * @brief Returns the starting character index of the first match for the given pattern + * in each row of the input column + * + * @code{.pseudo} + * Example: + * s = ["bunny", "rabbit", "hare", "dog"] + * p = regex_program::create("[be]") + * r = find_re(s, p) + * r is now [0, 2, 3, -1] + * @endcode + * + * A null output row occurs if the corresponding input row is null. + * A -1 is returned for rows that do not contain a match. + * + * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. + * + * @param input Strings instance for this operation + * @param prog Regex program instance + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New column of integers + */ +std::unique_ptr find_re( + strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); + /** @} */ // end of doxygen group } // namespace strings } // namespace CUDF_EXPORT cudf diff --git a/cpp/src/io/orc/stripe_enc.cu b/cpp/src/io/orc/stripe_enc.cu index 5c70e35fd2e..ed0b6969154 100644 --- a/cpp/src/io/orc/stripe_enc.cu +++ b/cpp/src/io/orc/stripe_enc.cu @@ -20,6 +20,8 @@ #include "orc_gpu.hpp" #include +#include +#include #include #include #include @@ -1087,37 +1089,42 @@ CUDF_KERNEL void __launch_bounds__(block_size) /** * @brief Merge chunked column data into a single contiguous stream * - * @param[in,out] strm_desc StripeStream device array [stripe][stream] - * @param[in,out] streams List of encoder chunk streams [column][rowgroup] + * @param[in] strm_desc StripeStream device array [stripe][stream] + * @param[in] streams List of encoder chunk streams [column][rowgroup] + * @param[out] srcs List of source encoder chunk stream data addresses + * @param[out] dsts List of destination StripeStream data addresses + * @param[out] sizes List of stream sizes in bytes */ // blockDim {compact_streams_block_size,1,1} CUDF_KERNEL void __launch_bounds__(compact_streams_block_size) - gpuCompactOrcDataStreams(device_2dspan strm_desc, - device_2dspan streams) + gpuInitBatchedMemcpy(device_2dspan strm_desc, + device_2dspan streams, + device_span srcs, + device_span dsts, + device_span sizes) { - __shared__ __align__(16) StripeStream ss; - - auto const stripe_id = blockIdx.x; + auto const stripe_id = cudf::detail::grid_1d::global_thread_id(); auto const stream_id = blockIdx.y; - auto const t = threadIdx.x; + if (stripe_id >= strm_desc.size().first) { return; } - if (t == 0) { ss = strm_desc[stripe_id][stream_id]; } - __syncthreads(); + auto const out_id = stream_id * strm_desc.size().first + stripe_id; + StripeStream ss = strm_desc[stripe_id][stream_id]; if (ss.data_ptr == nullptr) { return; } auto const cid = ss.stream_type; auto dst_ptr = ss.data_ptr; for (auto group = ss.first_chunk_id; group < ss.first_chunk_id + ss.num_chunks; ++group) { + auto const out_id = stream_id * streams.size().second + group; + srcs[out_id] = streams[ss.column_id][group].data_ptrs[cid]; + dsts[out_id] = dst_ptr; + + // Also update the stream here, data will be copied in a separate kernel + streams[ss.column_id][group].data_ptrs[cid] = dst_ptr; + auto const len = streams[ss.column_id][group].lengths[cid]; - if (len > 0) { - auto const src_ptr = streams[ss.column_id][group].data_ptrs[cid]; - for (uint32_t i = t; i < len; i += blockDim.x) { - dst_ptr[i] = src_ptr[i]; - } - __syncthreads(); - } - if (t == 0) { streams[ss.column_id][group].data_ptrs[cid] = dst_ptr; } + // len is the size (in bytes) of the current stream. + sizes[out_id] = len; dst_ptr += len; } } @@ -1325,9 +1332,26 @@ void CompactOrcDataStreams(device_2dspan strm_desc, device_2dspan enc_streams, rmm::cuda_stream_view stream) { + auto const num_rowgroups = enc_streams.size().second; + auto const num_streams = strm_desc.size().second; + auto const num_stripes = strm_desc.size().first; + auto const num_chunks = num_rowgroups * num_streams; + auto srcs = cudf::detail::make_zeroed_device_uvector_async( + num_chunks, stream, rmm::mr::get_current_device_resource()); + auto dsts = cudf::detail::make_zeroed_device_uvector_async( + num_chunks, stream, rmm::mr::get_current_device_resource()); + auto lengths = cudf::detail::make_zeroed_device_uvector_async( + num_chunks, stream, rmm::mr::get_current_device_resource()); + dim3 dim_block(compact_streams_block_size, 1); - dim3 dim_grid(strm_desc.size().first, strm_desc.size().second); - gpuCompactOrcDataStreams<<>>(strm_desc, enc_streams); + dim3 dim_grid(cudf::util::div_rounding_up_unsafe(num_stripes, compact_streams_block_size), + strm_desc.size().second); + gpuInitBatchedMemcpy<<>>( + strm_desc, enc_streams, srcs, dsts, lengths); + + // Copy streams in a batched manner. + cudf::detail::batched_memcpy_async( + srcs.begin(), dsts.begin(), lengths.begin(), lengths.size(), stream); } std::optional CompressOrcDataStreams( diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index e0d50d7ccf9..b3276c81c1f 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -17,6 +17,8 @@ #include "page_data.cuh" #include "page_decode.cuh" +#include + #include #include @@ -466,4 +468,28 @@ void __host__ DecodeSplitPageData(cudf::detail::hostdevice_span pages, } } +void WriteFinalOffsets(host_span offsets, + host_span buff_addrs, + rmm::cuda_stream_view stream) +{ + // Copy offsets to device and create an iterator + auto d_src_data = cudf::detail::make_device_uvector_async( + offsets, stream, cudf::get_current_device_resource_ref()); + // Iterator for the source (scalar) data + auto src_iter = cudf::detail::make_counting_transform_iterator( + static_cast(0), + cuda::proclaim_return_type( + [src = d_src_data.begin()] __device__(std::size_t i) { return src + i; })); + + // Copy buffer addresses to device and create an iterator + auto d_dst_addrs = cudf::detail::make_device_uvector_async( + buff_addrs, stream, cudf::get_current_device_resource_ref()); + // size_iter is simply a constant iterator of sizeof(size_type) bytes. + auto size_iter = thrust::make_constant_iterator(sizeof(size_type)); + + // Copy offsets to buffers in batched manner. + cudf::detail::batched_memcpy_async( + src_iter, d_dst_addrs.begin(), size_iter, offsets.size(), stream); +} + } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index e631e12119d..a8ba3a969ce 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -797,6 +797,18 @@ void DecodeSplitPageData(cudf::detail::hostdevice_span pages, kernel_error::pointer error_code, rmm::cuda_stream_view stream); +/** + * @brief Writes the final offsets to the corresponding list and string buffer end addresses in a + * batched manner. + * + * @param offsets Host span of final offsets + * @param buff_addrs Host span of corresponding output col buffer end addresses + * @param stream CUDA stream to use + */ +void WriteFinalOffsets(host_span offsets, + host_span buff_addrs, + rmm::cuda_stream_view stream); + /** * @brief Launches kernel for reading the string column data stored in the pages * diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 7d817bde7af..1b69ccb7742 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -371,13 +371,15 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num CUDF_FAIL("Parquet data decode failed with code(s) " + kernel_error::to_string(error)); } - // for list columns, add the final offset to every offset buffer. - // TODO : make this happen in more efficiently. Maybe use thrust::for_each - // on each buffer. + // For list and string columns, add the final offset to every offset buffer. // Note : the reason we are doing this here instead of in the decode kernel is // that it is difficult/impossible for a given page to know that it is writing the very // last value that should then be followed by a terminator (because rows can span // page boundaries). + std::vector out_buffers; + std::vector final_offsets; + out_buffers.reserve(_input_columns.size()); + final_offsets.reserve(_input_columns.size()); for (size_t idx = 0; idx < _input_columns.size(); idx++) { input_column_info const& input_col = _input_columns[idx]; @@ -393,25 +395,21 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num // the final offset for a list at level N is the size of it's child size_type const offset = child.type.id() == type_id::LIST ? child.size - 1 : child.size; - CUDF_CUDA_TRY(cudaMemcpyAsync(static_cast(out_buf.data()) + (out_buf.size - 1), - &offset, - sizeof(size_type), - cudaMemcpyDefault, - _stream.value())); + out_buffers.emplace_back(static_cast(out_buf.data()) + (out_buf.size - 1)); + final_offsets.emplace_back(offset); out_buf.user_data |= PARQUET_COLUMN_BUFFER_FLAG_LIST_TERMINATED; } else if (out_buf.type.id() == type_id::STRING) { // need to cap off the string offsets column auto const sz = static_cast(col_string_sizes[idx]); if (sz <= strings::detail::get_offset64_threshold()) { - CUDF_CUDA_TRY(cudaMemcpyAsync(static_cast(out_buf.data()) + out_buf.size, - &sz, - sizeof(size_type), - cudaMemcpyDefault, - _stream.value())); + out_buffers.emplace_back(static_cast(out_buf.data()) + out_buf.size); + final_offsets.emplace_back(sz); } } } } + // Write the final offsets for list and string columns in a batched manner + WriteFinalOffsets(final_offsets, out_buffers, _stream); // update null counts in the final column buffers for (size_t idx = 0; idx < subpass.pages.size(); idx++) { diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 3763c2e8e6d..8cab68ea721 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -19,9 +19,9 @@ #include #include +#include #include #include -#include #include #include @@ -1656,9 +1656,9 @@ void reader::impl::allocate_columns(read_mode mode, size_t skip_rows, size_t num } } - cudf::io::detail::batched_memset(memset_bufs, static_cast(0), _stream); + cudf::detail::batched_memset(memset_bufs, static_cast(0), _stream); // Need to set null mask bufs to all high bits - cudf::io::detail::batched_memset( + cudf::detail::batched_memset( nullmask_bufs, std::numeric_limits::max(), _stream); } diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index d8c1b50a94b..21708e48a25 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -126,6 +126,43 @@ std::unique_ptr findall(strings_column_view const& input, mr); } +namespace { +struct find_re_fn { + column_device_view d_strings; + + __device__ size_type operator()(size_type const idx, + reprog_device const prog, + int32_t const thread_idx) const + { + if (d_strings.is_null(idx)) { return 0; } + auto const d_str = d_strings.element(idx); + + auto const result = prog.find(thread_idx, d_str, d_str.begin()); + return result.has_value() ? result.value().first : -1; + } +}; +} // namespace + +std::unique_ptr find_re(strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto results = make_numeric_column(data_type{type_to_id()}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + if (input.is_empty()) { return results; } + + auto d_results = results->mutable_view().data(); + auto d_prog = regex_device_builder::create_prog_device(prog, stream); + auto const d_strings = column_device_view::create(input.parent(), stream); + launch_transform_kernel(find_re_fn{*d_strings}, *d_prog, d_results, input.size(), stream); + + return results; +} } // namespace detail // external API @@ -139,5 +176,14 @@ std::unique_ptr findall(strings_column_view const& input, return detail::findall(input, prog, stream, mr); } +std::unique_ptr find_re(strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + return detail::find_re(input, prog, stream, mr); +} + } // namespace strings } // namespace cudf diff --git a/cpp/src/text/generate_ngrams.cu b/cpp/src/text/generate_ngrams.cu index a87ecb81b9d..997b0278fe2 100644 --- a/cpp/src/text/generate_ngrams.cu +++ b/cpp/src/text/generate_ngrams.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,9 @@ namespace nvtext { namespace detail { namespace { +// long strings threshold found with benchmarking +constexpr cudf::size_type AVG_CHAR_BYTES_THRESHOLD = 64; + /** * @brief Generate ngrams from strings column. * @@ -173,33 +177,39 @@ constexpr cudf::thread_index_type bytes_per_thread = 4; /** * @brief Counts the number of ngrams in each row of the given strings column * - * Each warp processes a single string. + * Each warp/thread processes a single string. * Formula is `count = max(0,str.length() - ngrams + 1)` * If a string has less than ngrams characters, its count is 0. */ CUDF_KERNEL void count_char_ngrams_kernel(cudf::column_device_view const d_strings, cudf::size_type ngrams, + cudf::size_type tile_size, cudf::size_type* d_counts) { auto const idx = cudf::detail::grid_1d::global_thread_id(); - auto const str_idx = idx / cudf::detail::warp_size; + auto const str_idx = idx / tile_size; if (str_idx >= d_strings.size()) { return; } if (d_strings.is_null(str_idx)) { d_counts[str_idx] = 0; return; } + auto const d_str = d_strings.element(str_idx); + if (tile_size == 1) { + d_counts[str_idx] = cuda::std::max(0, (d_str.length() + 1 - ngrams)); + return; + } + namespace cg = cooperative_groups; auto const warp = cg::tiled_partition(cg::this_thread_block()); - auto const d_str = d_strings.element(str_idx); - auto const end = d_str.data() + d_str.size_bytes(); + auto const end = d_str.data() + d_str.size_bytes(); auto const lane_idx = warp.thread_rank(); cudf::size_type count = 0; for (auto itr = d_str.data() + (lane_idx * bytes_per_thread); itr < end; - itr += cudf::detail::warp_size * bytes_per_thread) { + itr += tile_size * bytes_per_thread) { for (auto s = itr; (s < (itr + bytes_per_thread)) && (s < end); ++s) { count += static_cast(cudf::strings::detail::is_begin_utf8_char(*s)); } @@ -256,19 +266,27 @@ std::unique_ptr generate_character_ngrams(cudf::strings_column_vie "Parameter ngrams should be an integer value of 2 or greater", std::invalid_argument); - auto const strings_count = input.size(); - if (strings_count == 0) { // if no strings, return an empty column - return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + if (input.is_empty()) { // if no strings, return an empty column + return cudf::lists::detail::make_empty_lists_column( + cudf::data_type{cudf::type_id::STRING}, stream, mr); + } + if (input.size() == input.null_count()) { + return cudf::lists::detail::make_all_nulls_lists_column( + input.size(), cudf::data_type{cudf::type_id::STRING}, stream, mr); } auto const d_strings = cudf::column_device_view::create(input.parent(), stream); auto [offsets, total_ngrams] = [&] { - auto counts = rmm::device_uvector(input.size(), stream); - auto const num_blocks = cudf::util::div_rounding_up_safe( - static_cast(input.size()) * cudf::detail::warp_size, block_size); - count_char_ngrams_kernel<<>>( - *d_strings, ngrams, counts.data()); + auto counts = rmm::device_uvector(input.size(), stream); + auto const avg_char_bytes = (input.chars_size(stream) / (input.size() - input.null_count())); + auto const tile_size = (avg_char_bytes < AVG_CHAR_BYTES_THRESHOLD) + ? 1 // thread per row + : cudf::detail::warp_size; // warp per row + auto const grid = cudf::detail::grid_1d( + static_cast(input.size()) * tile_size, block_size); + count_char_ngrams_kernel<<>>( + *d_strings, ngrams, tile_size, counts.data()); return cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr); }(); auto d_offsets = offsets->view().data(); @@ -277,8 +295,8 @@ std::unique_ptr generate_character_ngrams(cudf::strings_column_vie "Insufficient number of characters in each string to generate ngrams"); character_ngram_generator_fn generator{*d_strings, ngrams, d_offsets}; - auto [offsets_column, chars] = cudf::strings::detail::make_strings_children( - generator, strings_count, total_ngrams, stream, mr); + auto [offsets_column, chars] = + cudf::strings::detail::make_strings_children(generator, input.size(), total_ngrams, stream, mr); auto output = cudf::make_strings_column( total_ngrams, std::move(offsets_column), chars.release(), 0, rmm::device_buffer{}); @@ -368,7 +386,7 @@ std::unique_ptr hash_character_ngrams(cudf::strings_column_view co auto [offsets, total_ngrams] = [&] { auto counts = rmm::device_uvector(input.size(), stream); count_char_ngrams_kernel<<>>( - *d_strings, ngrams, counts.data()); + *d_strings, ngrams, cudf::detail::warp_size, counts.data()); return cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr); }(); auto d_offsets = offsets->view().data(); diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b67d922d377..4596ec65ce7 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -385,6 +385,8 @@ ConfigureTest( # * utilities tests ------------------------------------------------------------------------------- ConfigureTest( UTILITIES_TEST + utilities_tests/batched_memcpy_tests.cu + utilities_tests/batched_memset_tests.cu utilities_tests/column_debug_tests.cpp utilities_tests/column_utilities_tests.cpp utilities_tests/column_wrapper_tests.cpp @@ -395,7 +397,6 @@ ConfigureTest( utilities_tests/pinned_memory_tests.cpp utilities_tests/type_check_tests.cpp utilities_tests/type_list_tests.cpp - utilities_tests/batched_memset_tests.cu ) # ################################################################################################## diff --git a/cpp/tests/streams/strings/find_test.cpp b/cpp/tests/streams/strings/find_test.cpp index 52839c6fc9f..e5a1ee0988c 100644 --- a/cpp/tests/streams/strings/find_test.cpp +++ b/cpp/tests/streams/strings/find_test.cpp @@ -46,4 +46,5 @@ TEST_F(StringsFindTest, Find) auto const pattern = std::string("[a-z]"); auto const prog = cudf::strings::regex_program::create(pattern); cudf::strings::findall(view, *prog, cudf::test::get_default_stream()); + cudf::strings::find_re(view, *prog, cudf::test::get_default_stream()); } diff --git a/cpp/tests/strings/findall_tests.cpp b/cpp/tests/strings/findall_tests.cpp index 73da4d081e2..4821a7fa999 100644 --- a/cpp/tests/strings/findall_tests.cpp +++ b/cpp/tests/strings/findall_tests.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -149,6 +150,22 @@ TEST_F(StringsFindallTests, LargeRegex) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); } +TEST_F(StringsFindallTests, FindTest) +{ + auto const valids = cudf::test::iterators::null_at(5); + cudf::test::strings_column_wrapper input( + {"3A", "May4", "Jan2021", "March", "A9BC", "", "", "abcdef ghijklm 12345"}, valids); + auto sv = cudf::strings_column_view(input); + + auto pattern = std::string("\\d+"); + + auto prog = cudf::strings::regex_program::create(pattern); + auto results = cudf::strings::find_re(sv, *prog); + auto expected = + cudf::test::fixed_width_column_wrapper({0, 3, 3, -1, 1, 0, -1, 15}, valids); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); +} + TEST_F(StringsFindallTests, NoMatches) { cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"}); @@ -169,10 +186,16 @@ TEST_F(StringsFindallTests, EmptyTest) auto prog = cudf::strings::regex_program::create(pattern); cudf::test::strings_column_wrapper input; - auto sv = cudf::strings_column_view(input); - auto results = cudf::strings::findall(sv, *prog); - - using LCW = cudf::test::lists_column_wrapper; - LCW expected; - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + auto sv = cudf::strings_column_view(input); + { + auto results = cudf::strings::findall(sv, *prog); + using LCW = cudf::test::lists_column_wrapper; + LCW expected; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + } + { + auto results = cudf::strings::find_re(sv, *prog); + auto expected = cudf::test::fixed_width_column_wrapper{}; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + } } diff --git a/cpp/tests/utilities_tests/batched_memcpy_tests.cu b/cpp/tests/utilities_tests/batched_memcpy_tests.cu new file mode 100644 index 00000000000..98657f8e224 --- /dev/null +++ b/cpp/tests/utilities_tests/batched_memcpy_tests.cu @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +template +struct BatchedMemcpyTest : public cudf::test::BaseFixture {}; + +TEST(BatchedMemcpyTest, BasicTest) +{ + using T1 = int64_t; + + // Device init + auto stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + + // Buffer lengths (in number of elements) + std::vector const h_lens{ + 50000, 4, 1000, 0, 250000, 1, 100, 8000, 0, 1, 100, 1000, 10000, 100000, 0, 1, 100000}; + + // Total number of buffers + auto const num_buffs = h_lens.size(); + + // Exclusive sum of buffer lengths for pointers + std::vector h_lens_excl_sum(num_buffs); + std::exclusive_scan(h_lens.begin(), h_lens.end(), h_lens_excl_sum.begin(), 0); + + // Corresponding buffer sizes (in bytes) + std::vector h_sizes_bytes; + h_sizes_bytes.reserve(num_buffs); + std::transform( + h_lens.cbegin(), h_lens.cend(), std::back_inserter(h_sizes_bytes), [&](auto& size) { + return size * sizeof(T1); + }); + + // Initialize random engine + auto constexpr seed = 0xcead; + std::mt19937 engine{seed}; + using uniform_distribution = + typename std::conditional_t, + std::bernoulli_distribution, + std::conditional_t, + std::uniform_real_distribution, + std::uniform_int_distribution>>; + uniform_distribution dist{}; + + // Generate a src vector of random data vectors + std::vector> h_sources; + h_sources.reserve(num_buffs); + std::transform(h_lens.begin(), h_lens.end(), std::back_inserter(h_sources), [&](auto size) { + std::vector data(size); + std::generate_n(data.begin(), size, [&]() { return T1{dist(engine)}; }); + return data; + }); + // Copy the vectors to device + std::vector> h_device_vecs; + h_device_vecs.reserve(h_sources.size()); + std::transform( + h_sources.begin(), h_sources.end(), std::back_inserter(h_device_vecs), [stream, mr](auto& vec) { + return cudf::detail::make_device_uvector_async(vec, stream, mr); + }); + // Pointers to the source vectors + std::vector h_src_ptrs; + h_src_ptrs.reserve(h_sources.size()); + std::transform( + h_device_vecs.begin(), h_device_vecs.end(), std::back_inserter(h_src_ptrs), [](auto& vec) { + return static_cast(vec.data()); + }); + // Copy the source data pointers to device + auto d_src_ptrs = cudf::detail::make_device_uvector_async(h_src_ptrs, stream, mr); + + // Total number of elements in all buffers + auto const total_buff_len = std::accumulate(h_lens.cbegin(), h_lens.cend(), 0); + + // Create one giant buffer for destination + auto d_dst_data = cudf::detail::make_zeroed_device_uvector_async(total_buff_len, stream, mr); + // Pointers to destination buffers within the giant destination buffer + std::vector h_dst_ptrs(num_buffs); + std::for_each(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(num_buffs), + [&](auto i) { return h_dst_ptrs[i] = d_dst_data.data() + h_lens_excl_sum[i]; }); + // Copy destination data pointers to device + auto d_dst_ptrs = cudf::detail::make_device_uvector_async(h_dst_ptrs, stream, mr); + + // Copy buffer size iterators (in bytes) to device + auto d_sizes_bytes = cudf::detail::make_device_uvector_async(h_sizes_bytes, stream, mr); + + // Run the batched memcpy + cudf::detail::batched_memcpy_async( + d_src_ptrs.begin(), d_dst_ptrs.begin(), d_sizes_bytes.begin(), num_buffs, stream); + + // Expected giant destination buffer after the memcpy + std::vector expected_buffer; + expected_buffer.reserve(total_buff_len); + std::for_each(h_sources.cbegin(), h_sources.cend(), [&expected_buffer](auto& source) { + expected_buffer.insert(expected_buffer.end(), source.begin(), source.end()); + }); + + // Copy over the result destination buffer to host and synchronize the stream + auto result_dst_buffer = + cudf::detail::make_std_vector_sync(cudf::device_span(d_dst_data), stream); + + // Check if both vectors are equal + EXPECT_TRUE( + std::equal(expected_buffer.begin(), expected_buffer.end(), result_dst_buffer.begin())); +} diff --git a/cpp/tests/utilities_tests/batched_memset_tests.cu b/cpp/tests/utilities_tests/batched_memset_tests.cu index bed0f40d70e..0eeb7b95318 100644 --- a/cpp/tests/utilities_tests/batched_memset_tests.cu +++ b/cpp/tests/utilities_tests/batched_memset_tests.cu @@ -18,8 +18,8 @@ #include #include +#include #include -#include #include #include #include @@ -78,7 +78,7 @@ TEST(MultiBufferTestIntegral, BasicTest1) }); // Function Call - cudf::io::detail::batched_memset(memset_bufs, uint64_t{0}, stream); + cudf::detail::batched_memset(memset_bufs, uint64_t{0}, stream); // Set all buffer regions to 0 for expected comparison std::for_each( diff --git a/dependencies.yaml b/dependencies.yaml index ed36a23e5c3..b192158c4ea 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -15,6 +15,7 @@ files: - depends_on_cupy - depends_on_libkvikio - depends_on_librmm + - depends_on_nvcomp - depends_on_rmm - develop - docs @@ -152,6 +153,13 @@ files: - build_cpp - depends_on_libkvikio - depends_on_librmm + py_run_libcudf: + output: pyproject + pyproject_dir: python/libcudf + extras: + table: project + includes: + - depends_on_nvcomp py_build_pylibcudf: output: pyproject pyproject_dir: python/pylibcudf @@ -367,9 +375,27 @@ dependencies: - fmt>=11.0.2,<12 - flatbuffers==24.3.25 - librdkafka>=2.5.0,<2.6.0a0 + - spdlog>=1.14.1,<1.15 + depends_on_nvcomp: + common: + - output_types: conda + packages: # Align nvcomp version with rapids-cmake - nvcomp==4.0.1 - - spdlog>=1.14.1,<1.15 + specific: + - output_types: [requirements, pyproject] + matrices: + - matrix: + cuda: "12.*" + packages: + - nvidia-nvcomp-cu12==4.0.1 + - matrix: + cuda: "11.*" + packages: + - nvidia-nvcomp-cu11==4.0.1 + - matrix: + packages: + - nvidia-nvcomp==4.0.1 rapids_build_skbuild: common: - output_types: [conda, requirements, pyproject] diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst index e73ea3370ec..48dc8a13c3e 100644 --- a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst @@ -11,10 +11,13 @@ strings find find_multiple findall + padding regex_flags regex_program repeat replace + side_type slice split strip + wrap diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/padding.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/padding.rst new file mode 100644 index 00000000000..5b417024fd5 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/padding.rst @@ -0,0 +1,6 @@ +======= +padding +======= + +.. automodule:: pylibcudf.strings.padding + :members: diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/side_type.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/side_type.rst new file mode 100644 index 00000000000..d5aef9c4f75 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/side_type.rst @@ -0,0 +1,6 @@ +========= +side_type +========= + +.. automodule:: pylibcudf.strings.side_type + :members: diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/wrap.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/wrap.rst new file mode 100644 index 00000000000..bd825f78568 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/wrap.rst @@ -0,0 +1,6 @@ +==== +wrap +==== + +.. automodule:: pylibcudf.strings.wrap + :members: diff --git a/python/cudf/cudf/_lib/strings/__init__.py b/python/cudf/cudf/_lib/strings/__init__.py index 4bf8a9b1a8f..e712937f816 100644 --- a/python/cudf/cudf/_lib/strings/__init__.py +++ b/python/cudf/cudf/_lib/strings/__init__.py @@ -71,16 +71,9 @@ startswith_multiple, ) from cudf._lib.strings.find_multiple import find_multiple -from cudf._lib.strings.findall import findall +from cudf._lib.strings.findall import find_re, findall from cudf._lib.strings.json import GetJsonObjectOptions, get_json_object -from cudf._lib.strings.padding import ( - SideType, - center, - ljust, - pad, - rjust, - zfill, -) +from cudf._lib.strings.padding import center, ljust, pad, rjust, zfill from cudf._lib.strings.repeat import repeat_scalar, repeat_sequence from cudf._lib.strings.replace import ( insert, diff --git a/python/cudf/cudf/_lib/strings/findall.pyx b/python/cudf/cudf/_lib/strings/findall.pyx index 0e758d5b322..3e7a504d535 100644 --- a/python/cudf/cudf/_lib/strings/findall.pyx +++ b/python/cudf/cudf/_lib/strings/findall.pyx @@ -23,3 +23,19 @@ def findall(Column source_strings, object pattern, uint32_t flags): prog, ) return Column.from_pylibcudf(plc_result) + + +@acquire_spill_lock() +def find_re(Column source_strings, object pattern, uint32_t flags): + """ + Returns character positions where the pattern first matches + the elements in source_strings. + """ + prog = plc.strings.regex_program.RegexProgram.create( + str(pattern), flags + ) + plc_result = plc.strings.findall.find_re( + source_strings.to_pylibcudf(mode="read"), + prog, + ) + return Column.from_pylibcudf(plc_result) diff --git a/python/cudf/cudf/_lib/strings/padding.pyx b/python/cudf/cudf/_lib/strings/padding.pyx index d0239e91ec3..015a2ebab8a 100644 --- a/python/cudf/cudf/_lib/strings/padding.pyx +++ b/python/cudf/cudf/_lib/strings/padding.pyx @@ -1,64 +1,31 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. - -from libcpp.memory cimport unique_ptr -from libcpp.string cimport string -from libcpp.utility cimport move - from cudf.core.buffer import acquire_spill_lock -from pylibcudf.libcudf.column.column_view cimport column_view from pylibcudf.libcudf.types cimport size_type from cudf._lib.column cimport Column -from enum import IntEnum - -from pylibcudf.libcudf.column.column cimport column -from pylibcudf.libcudf.strings.padding cimport ( - pad as cpp_pad, - zfill as cpp_zfill, -) -from pylibcudf.libcudf.strings.side_type cimport ( - side_type, - underlying_type_t_side_type, -) - - -class SideType(IntEnum): - LEFT = side_type.LEFT - RIGHT = side_type.RIGHT - BOTH = side_type.BOTH +import pylibcudf as plc @acquire_spill_lock() def pad(Column source_strings, size_type width, fill_char, - side=SideType.LEFT): + side=plc.strings.side_type.SideType.LEFT): """ Returns a Column by padding strings in `source_strings` up to the given `width`. Direction of padding is to be specified by `side`. The additional characters being filled can be changed by specifying `fill_char`. """ - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef string f_char = str(fill_char).encode() - - cdef side_type pad_direction = ( - side + plc_result = plc.strings.padding.pad( + source_strings.to_pylibcudf(mode="read"), + width, + side, + fill_char, ) - - with nogil: - c_result = move(cpp_pad( - source_view, - width, - pad_direction, - f_char - )) - - return Column.from_unique_ptr(move(c_result)) + return Column.from_pylibcudf(plc_result) @acquire_spill_lock() @@ -68,19 +35,13 @@ def zfill(Column source_strings, Returns a Column by prepending strings in `source_strings` with '0' characters up to the given `width`. """ - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - with nogil: - c_result = move(cpp_zfill( - source_view, - width - )) - - return Column.from_unique_ptr(move(c_result)) + plc_result = plc.strings.padding.zfill( + source_strings.to_pylibcudf(mode="read"), + width + ) + return Column.from_pylibcudf(plc_result) -@acquire_spill_lock() def center(Column source_strings, size_type width, fill_char): @@ -89,23 +50,9 @@ def center(Column source_strings, in `source_strings` with additional character, `fill_char` up to the given `width`. """ - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef string f_char = str(fill_char).encode() - - with nogil: - c_result = move(cpp_pad( - source_view, - width, - side_type.BOTH, - f_char - )) + return pad(source_strings, width, fill_char, plc.strings.side_type.SideType.BOTH) - return Column.from_unique_ptr(move(c_result)) - -@acquire_spill_lock() def ljust(Column source_strings, size_type width, fill_char): @@ -113,23 +60,9 @@ def ljust(Column source_strings, Returns a Column by filling right side of strings in `source_strings` with additional character, `fill_char` up to the given `width`. """ - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef string f_char = str(fill_char).encode() + return pad(source_strings, width, fill_char, plc.strings.side_type.SideType.RIGHT) - with nogil: - c_result = move(cpp_pad( - source_view, - width, - side_type.RIGHT, - f_char - )) - return Column.from_unique_ptr(move(c_result)) - - -@acquire_spill_lock() def rjust(Column source_strings, size_type width, fill_char): @@ -137,17 +70,4 @@ def rjust(Column source_strings, Returns a Column by filling left side of strings in `source_strings` with additional character, `fill_char` up to the given `width`. """ - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef string f_char = str(fill_char).encode() - - with nogil: - c_result = move(cpp_pad( - source_view, - width, - side_type.LEFT, - f_char - )) - - return Column.from_unique_ptr(move(c_result)) + return pad(source_strings, width, fill_char, plc.strings.side_type.SideType.LEFT) diff --git a/python/cudf/cudf/_lib/strings/strip.pyx b/python/cudf/cudf/_lib/strings/strip.pyx index 38ecb21a94c..982c5a600e7 100644 --- a/python/cudf/cudf/_lib/strings/strip.pyx +++ b/python/cudf/cudf/_lib/strings/strip.pyx @@ -1,18 +1,8 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. -from libcpp.memory cimport unique_ptr -from libcpp.utility cimport move - from cudf.core.buffer import acquire_spill_lock -from pylibcudf.libcudf.column.column cimport column -from pylibcudf.libcudf.column.column_view cimport column_view -from pylibcudf.libcudf.scalar.scalar cimport string_scalar -from pylibcudf.libcudf.strings.side_type cimport side_type -from pylibcudf.libcudf.strings.strip cimport strip as cpp_strip - from cudf._lib.column cimport Column -from cudf._lib.scalar cimport DeviceScalar import pylibcudf as plc @@ -24,15 +14,12 @@ def strip(Column source_strings, The set of characters need be stripped from left and right side can be specified by `py_repl`. """ - - cdef DeviceScalar repl = py_repl.device_value - return Column.from_pylibcudf( - plc.strings.strip.strip( - source_strings.to_pylibcudf(mode="read"), - plc.strings.SideType.BOTH, - repl.c_value - ) + plc_result = plc.strings.strip.strip( + source_strings.to_pylibcudf(mode="read"), + plc.strings.side_type.SideType.BOTH, + py_repl.device_value.c_value, ) + return Column.from_pylibcudf(plc_result) @acquire_spill_lock() @@ -43,24 +30,12 @@ def lstrip(Column source_strings, The set of characters need be stripped from left side can be specified by `py_repl`. """ - - cdef DeviceScalar repl = py_repl.device_value - - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef const string_scalar* scalar_str = ( - repl.get_raw_ptr() + plc_result = plc.strings.strip.strip( + source_strings.to_pylibcudf(mode="read"), + plc.strings.side_type.SideType.LEFT, + py_repl.device_value.c_value, ) - - with nogil: - c_result = move(cpp_strip( - source_view, - side_type.LEFT, - scalar_str[0] - )) - - return Column.from_unique_ptr(move(c_result)) + return Column.from_pylibcudf(plc_result) @acquire_spill_lock() @@ -71,21 +46,9 @@ def rstrip(Column source_strings, The set of characters need be stripped from right side can be specified by `py_repl`. """ - - cdef DeviceScalar repl = py_repl.device_value - - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef const string_scalar* scalar_str = ( - repl.get_raw_ptr() + plc_result = plc.strings.strip.strip( + source_strings.to_pylibcudf(mode="read"), + plc.strings.side_type.SideType.RIGHT, + py_repl.device_value.c_value, ) - - with nogil: - c_result = move(cpp_strip( - source_view, - side_type.RIGHT, - scalar_str[0] - )) - - return Column.from_unique_ptr(move(c_result)) + return Column.from_pylibcudf(plc_result) diff --git a/python/cudf/cudf/_lib/strings/wrap.pyx b/python/cudf/cudf/_lib/strings/wrap.pyx index eed5cf33b10..2b40f01f818 100644 --- a/python/cudf/cudf/_lib/strings/wrap.pyx +++ b/python/cudf/cudf/_lib/strings/wrap.pyx @@ -1,17 +1,13 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. -from libcpp.memory cimport unique_ptr -from libcpp.utility cimport move - from cudf.core.buffer import acquire_spill_lock -from pylibcudf.libcudf.column.column cimport column -from pylibcudf.libcudf.column.column_view cimport column_view -from pylibcudf.libcudf.strings.wrap cimport wrap as cpp_wrap from pylibcudf.libcudf.types cimport size_type from cudf._lib.column cimport Column +import pylibcudf as plc + @acquire_spill_lock() def wrap(Column source_strings, @@ -21,14 +17,8 @@ def wrap(Column source_strings, in the Column to be formatted in paragraphs with length less than a given `width`. """ - - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - with nogil: - c_result = move(cpp_wrap( - source_view, - width - )) - - return Column.from_unique_ptr(move(c_result)) + plc_result = plc.strings.wrap.wrap( + source_strings.to_pylibcudf(mode="read"), + width + ) + return Column.from_pylibcudf(plc_result) diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index d0ea4612a1b..2c9b0baa9b6 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -480,6 +480,11 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn: if dtype == self.dtype: return self + elif isinstance(dtype, pd.DatetimeTZDtype): + raise TypeError( + "Cannot use .astype to convert from timezone-naive dtype to timezone-aware dtype. " + "Use tz_localize instead." + ) return libcudf.unary.cast(self, dtype=dtype) def as_timedelta_column(self, dtype: Dtype) -> None: # type: ignore[override] @@ -940,6 +945,16 @@ def strftime(self, format: str) -> cudf.core.column.StringColumn: def as_string_column(self) -> cudf.core.column.StringColumn: return self._local_time.as_string_column() + def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn: + if isinstance(dtype, pd.DatetimeTZDtype) and dtype != self.dtype: + if dtype.unit != self.time_unit: + # TODO: Doesn't check that new unit is valid. + casted = self._with_type_metadata(dtype) + else: + casted = self + return casted.tz_convert(str(dtype.tz)) + return super().as_datetime_column(dtype) + def get_dt_field(self, field: str) -> ColumnBase: return libcudf.datetime.extract_datetime_component( self._local_time, field diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index da422db5eae..b50e23bd52e 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -11,6 +11,8 @@ import pandas as pd import pyarrow as pa +import pylibcudf as plc + import cudf import cudf.api.types from cudf import _lib as libcudf @@ -2966,7 +2968,7 @@ def pad( raise TypeError(msg) try: - side = libstrings.SideType[side.upper()] + side = plc.strings.side_type.SideType[side.upper()] except KeyError: raise ValueError( "side has to be either one of {'left', 'right', 'both'}" @@ -3624,6 +3626,46 @@ def findall(self, pat: str, flags: int = 0) -> SeriesOrIndex: data = libstrings.findall(self._column, pat, flags) return self._return_or_inplace(data) + def find_re(self, pat: str, flags: int = 0) -> SeriesOrIndex: + """ + Find first occurrence of pattern or regular expression in the + Series/Index. + + Parameters + ---------- + pat : str + Pattern or regular expression. + flags : int, default 0 (no flags) + Flags to pass through to the regex engine (e.g. re.MULTILINE) + + Returns + ------- + Series + A Series of position values where the pattern first matches + each string. + + Examples + -------- + >>> import cudf + >>> s = cudf.Series(['Lion', 'Monkey', 'Rabbit', 'Cat']) + >>> s.str.find_re('[ti]') + 0 1 + 1 -1 + 2 4 + 3 2 + dtype: int32 + """ + if isinstance(pat, re.Pattern): + flags = pat.flags & ~re.U + pat = pat.pattern + if not _is_supported_regex_flags(flags): + raise NotImplementedError( + "Unsupported value for `flags` parameter" + ) + + data = libstrings.find_re(self._column, pat, flags) + return self._return_or_inplace(data) + def find_multiple(self, patterns: SeriesOrIndex) -> cudf.Series: """ Find all first occurrences of patterns in the Series/Index. diff --git a/python/cudf/cudf/tests/series/test_datetimelike.py b/python/cudf/cudf/tests/series/test_datetimelike.py index cea86a5499e..691da224f44 100644 --- a/python/cudf/cudf/tests/series/test_datetimelike.py +++ b/python/cudf/cudf/tests/series/test_datetimelike.py @@ -266,3 +266,25 @@ def test_pandas_compatible_non_zoneinfo_raises(klass): with cudf.option_context("mode.pandas_compatible", True): with pytest.raises(NotImplementedError): cudf.from_pandas(pandas_obj) + + +def test_astype_naive_to_aware_raises(): + ser = cudf.Series([datetime.datetime(2020, 1, 1)]) + with pytest.raises(TypeError): + ser.astype("datetime64[ns, UTC]") + with pytest.raises(TypeError): + ser.to_pandas().astype("datetime64[ns, UTC]") + + +@pytest.mark.parametrize("unit", ["ns", "us"]) +def test_astype_aware_to_aware(unit): + ser = cudf.Series( + [datetime.datetime(2020, 1, 1, tzinfo=datetime.timezone.utc)] + ) + result = ser.astype(f"datetime64[{unit}, US/Pacific]") + expected = ser.to_pandas().astype(f"datetime64[{unit}, US/Pacific]") + zoneinfo_type = pd.DatetimeTZDtype( + expected.dtype.unit, zoneinfo.ZoneInfo(str(expected.dtype.tz)) + ) + expected = ser.astype(zoneinfo_type) + assert_eq(result, expected) diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index cc88cc79769..45143211a11 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -1899,6 +1899,26 @@ def test_string_findall(pat, flags): assert_eq(expected, actual) +@pytest.mark.parametrize( + "pat, flags, pos", + [ + ("Monkey", 0, [-1, 0, -1, -1]), + ("on", 0, [2, 1, -1, 1]), + ("bit", 0, [-1, -1, 3, -1]), + ("on$", 0, [2, -1, -1, -1]), + ("on$", re.MULTILINE, [2, -1, -1, 1]), + ("o.*k", re.DOTALL, [-1, 1, -1, 1]), + ], +) +def test_string_find_re(pat, flags, pos): + test_data = ["Lion", "Monkey", "Rabbit", "Don\nkey"] + gs = cudf.Series(test_data) + + expected = pd.Series(pos, dtype=np.int32) + actual = gs.str.find_re(pat, flags) + assert_eq(expected, actual) + + def test_string_replace_multi(): ps = pd.Series(["hello", "goodbye"]) gs = cudf.Series(["hello", "goodbye"]) diff --git a/python/libcudf/CMakeLists.txt b/python/libcudf/CMakeLists.txt index 0a8f5c4807d..5f9a04d3cee 100644 --- a/python/libcudf/CMakeLists.txt +++ b/python/libcudf/CMakeLists.txt @@ -22,6 +22,8 @@ project( LANGUAGES CXX ) +option(USE_NVCOMP_RUNTIME_WHEEL "Use the nvcomp wheel at runtime instead of the system library" OFF) + # Check if cudf is already available. If so, it is the user's responsibility to ensure that the # CMake package is also available at build time of the Python cudf package. find_package(cudf "${RAPIDS_VERSION}") @@ -39,14 +41,20 @@ set(BUILD_TESTS OFF) set(BUILD_BENCHMARKS OFF) set(CUDF_BUILD_TESTUTIL OFF) set(CUDF_BUILD_STREAMS_TEST_UTIL OFF) +if(USE_NVCOMP_RUNTIME_WHEEL) + set(CUDF_EXPORT_NVCOMP OFF) +endif() set(CUDA_STATIC_RUNTIME ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) add_subdirectory(../../cpp cudf-cpp) -# Ensure other libraries needed by libcudf.so get installed alongside it. -include(cmake/Modules/WheelHelpers.cmake) -install_aliased_imported_targets( - TARGETS cudf nvcomp::nvcomp DESTINATION ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} -) +if(USE_NVCOMP_RUNTIME_WHEEL) + set(rpaths "$ORIGIN/../../nvidia/nvcomp") + set_property( + TARGET cudf + PROPERTY INSTALL_RPATH ${rpaths} + APPEND + ) +endif() diff --git a/python/libcudf/pyproject.toml b/python/libcudf/pyproject.toml index 5bffe9fd96c..84660cbc276 100644 --- a/python/libcudf/pyproject.toml +++ b/python/libcudf/pyproject.toml @@ -37,6 +37,9 @@ classifiers = [ "Programming Language :: C++", "Environment :: GPU :: NVIDIA CUDA", ] +dependencies = [ + "nvidia-nvcomp==4.0.1", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls] Homepage = "https://github.com/rapidsai/cudf" diff --git a/python/pylibcudf/LICENSE b/python/pylibcudf/LICENSE new file mode 120000 index 00000000000..30cff7403da --- /dev/null +++ b/python/pylibcudf/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd index e0a8b776465..0d286c36446 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd @@ -11,3 +11,7 @@ cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] findall( column_view input, regex_program prog) except + + + cdef unique_ptr[column] find_re( + column_view input, + regex_program prog) except + diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/padding.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/padding.pxd index 657fe61eb14..875f8cafd14 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/padding.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/padding.pxd @@ -12,11 +12,11 @@ from pylibcudf.libcudf.types cimport size_type cdef extern from "cudf/strings/padding.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] pad( - column_view source_strings, + column_view input, size_type width, side_type side, string fill_char) except + cdef unique_ptr[column] zfill( - column_view source_strings, + column_view input, size_type width) except + diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd index 019ff3f17ba..e92c5dc1d66 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd @@ -1,12 +1,10 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION. -from libc.stdint cimport int32_t +from libcpp cimport int cdef extern from "cudf/strings/side_type.hpp" namespace "cudf::strings" nogil: - cpdef enum class side_type(int32_t): - LEFT 'cudf::strings::side_type::LEFT' - RIGHT 'cudf::strings::side_type::RIGHT' - BOTH 'cudf::strings::side_type::BOTH' - -ctypedef int32_t underlying_type_t_side_type + cpdef enum class side_type(int): + LEFT + RIGHT + BOTH diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/strip.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/strip.pxd index b0ca771762d..dd527a78e7f 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/strip.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/strip.pxd @@ -10,6 +10,6 @@ from pylibcudf.libcudf.strings.side_type cimport side_type cdef extern from "cudf/strings/strip.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] strip( - column_view source_strings, - side_type stype, + column_view input, + side_type side, string_scalar to_strip) except + diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/wrap.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/wrap.pxd index c0053391328..abc1bd43ad2 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/wrap.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/wrap.pxd @@ -9,5 +9,5 @@ from pylibcudf.libcudf.types cimport size_type cdef extern from "cudf/strings/wrap.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] wrap( - column_view source_strings, + column_view input, size_type width) except + diff --git a/python/pylibcudf/pylibcudf/strings/CMakeLists.txt b/python/pylibcudf/pylibcudf/strings/CMakeLists.txt index d92f806efbe..eeb44d19333 100644 --- a/python/pylibcudf/pylibcudf/strings/CMakeLists.txt +++ b/python/pylibcudf/pylibcudf/strings/CMakeLists.txt @@ -22,6 +22,7 @@ set(cython_sources find.pyx find_multiple.pyx findall.pyx + padding.pyx regex_flags.pyx regex_program.pyx repeat.pyx @@ -30,6 +31,7 @@ set(cython_sources slice.pyx strip.pyx translate.pyx + wrap.pyx ) set(linked_libraries cudf::cudf) diff --git a/python/pylibcudf/pylibcudf/strings/__init__.pxd b/python/pylibcudf/pylibcudf/strings/__init__.pxd index 788e2c99ab1..187ef113073 100644 --- a/python/pylibcudf/pylibcudf/strings/__init__.pxd +++ b/python/pylibcudf/pylibcudf/strings/__init__.pxd @@ -11,13 +11,16 @@ from . cimport ( find, find_multiple, findall, + padding, regex_flags, regex_program, replace, + side_type, slice, split, strip, translate, + wrap, ) from .side_type cimport side_type @@ -39,4 +42,5 @@ __all__ = [ "split", "side_type", "translate", + "wrap", ] diff --git a/python/pylibcudf/pylibcudf/strings/__init__.py b/python/pylibcudf/pylibcudf/strings/__init__.py index bcaeb073d0b..6033cea0625 100644 --- a/python/pylibcudf/pylibcudf/strings/__init__.py +++ b/python/pylibcudf/pylibcudf/strings/__init__.py @@ -11,14 +11,17 @@ find, find_multiple, findall, + padding, regex_flags, regex_program, repeat, replace, + side_type, slice, split, strip, translate, + wrap, ) from .side_type import SideType @@ -40,4 +43,5 @@ "split", "SideType", "translate", + "wrap", ] diff --git a/python/pylibcudf/pylibcudf/strings/findall.pxd b/python/pylibcudf/pylibcudf/strings/findall.pxd index 54afa088141..3c35a9c9aa9 100644 --- a/python/pylibcudf/pylibcudf/strings/findall.pxd +++ b/python/pylibcudf/pylibcudf/strings/findall.pxd @@ -4,4 +4,5 @@ from pylibcudf.column cimport Column from pylibcudf.strings.regex_program cimport RegexProgram +cpdef Column find_re(Column input, RegexProgram pattern) cpdef Column findall(Column input, RegexProgram pattern) diff --git a/python/pylibcudf/pylibcudf/strings/findall.pyx b/python/pylibcudf/pylibcudf/strings/findall.pyx index 3a6b87504b3..5212dc4594d 100644 --- a/python/pylibcudf/pylibcudf/strings/findall.pyx +++ b/python/pylibcudf/pylibcudf/strings/findall.pyx @@ -38,3 +38,35 @@ cpdef Column findall(Column input, RegexProgram pattern): ) return Column.from_libcudf(move(c_result)) + + +cpdef Column find_re(Column input, RegexProgram pattern): + """ + Returns character positions where the pattern first matches + the elements in input strings. + + For details, see :cpp:func:`cudf::strings::find_re` + + Parameters + ---------- + input : Column + Strings instance for this operation + pattern : RegexProgram + Regex pattern + + Returns + ------- + Column + New column of integers + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_findall.find_re( + input.view(), + pattern.c_obj.get()[0] + ) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/strings/padding.pxd b/python/pylibcudf/pylibcudf/strings/padding.pxd new file mode 100644 index 00000000000..a035a5ad187 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/padding.pxd @@ -0,0 +1,11 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.string cimport string +from pylibcudf.column cimport Column +from pylibcudf.libcudf.strings.side_type cimport side_type +from pylibcudf.libcudf.types cimport size_type + + +cpdef Column pad(Column input, size_type width, side_type side, str fill_char) + +cpdef Column zfill(Column input, size_type width) diff --git a/python/pylibcudf/pylibcudf/strings/padding.pyx b/python/pylibcudf/pylibcudf/strings/padding.pyx new file mode 100644 index 00000000000..24daaaa3838 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/padding.pyx @@ -0,0 +1,75 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move +from pylibcudf.column cimport Column +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.strings cimport padding as cpp_padding +from pylibcudf.libcudf.strings.side_type cimport side_type + + +cpdef Column pad(Column input, size_type width, side_type side, str fill_char): + """ + Add padding to each string using a provided character. + + For details, see :cpp:func:`cudf::strings::pad`. + + Parameters + ---------- + input : Column + Strings instance for this operation + width : int + The minimum number of characters for each string. + side : SideType + Where to place the padding characters. + fill_char : str + Single UTF-8 character to use for padding + + Returns + ------- + Column + New column with padded strings. + """ + cdef unique_ptr[column] c_result + cdef string c_fill_char = fill_char.encode("utf-8") + + with nogil: + c_result = move( + cpp_padding.pad( + input.view(), + width, + side, + c_fill_char, + ) + ) + + return Column.from_libcudf(move(c_result)) + +cpdef Column zfill(Column input, size_type width): + """ + Add '0' as padding to the left of each string. + + For details, see :cpp:func:`cudf::strings::zfill`. + + Parameters + ---------- + input : Column + Strings instance for this operation + width : int + The minimum number of characters for each string. + + Returns + ------- + Column + New column of strings. + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_padding.zfill( + input.view(), + width, + ) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/strings/side_type.pxd b/python/pylibcudf/pylibcudf/strings/side_type.pxd index 34b7a580380..34b03e9bc27 100644 --- a/python/pylibcudf/pylibcudf/strings/side_type.pxd +++ b/python/pylibcudf/pylibcudf/strings/side_type.pxd @@ -1,3 +1,2 @@ # Copyright (c) 2024, NVIDIA CORPORATION. - from pylibcudf.libcudf.strings.side_type cimport side_type diff --git a/python/pylibcudf/pylibcudf/strings/side_type.pyx b/python/pylibcudf/pylibcudf/strings/side_type.pyx index acdc7d6ff1f..cf0c770cc11 100644 --- a/python/pylibcudf/pylibcudf/strings/side_type.pyx +++ b/python/pylibcudf/pylibcudf/strings/side_type.pyx @@ -1,4 +1,3 @@ # Copyright (c) 2024, NVIDIA CORPORATION. - from pylibcudf.libcudf.strings.side_type import \ side_type as SideType # no-cython-lint diff --git a/python/pylibcudf/pylibcudf/strings/wrap.pxd b/python/pylibcudf/pylibcudf/strings/wrap.pxd new file mode 100644 index 00000000000..fcc86650acf --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/wrap.pxd @@ -0,0 +1,7 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from pylibcudf.column cimport Column +from pylibcudf.libcudf.types cimport size_type + + +cpdef Column wrap(Column input, size_type width) diff --git a/python/pylibcudf/pylibcudf/strings/wrap.pyx b/python/pylibcudf/pylibcudf/strings/wrap.pyx new file mode 100644 index 00000000000..11e31f54eee --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/wrap.pyx @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move +from pylibcudf.column cimport Column +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.strings cimport wrap as cpp_wrap +from pylibcudf.libcudf.types cimport size_type + + +cpdef Column wrap(Column input, size_type width): + """ + Wraps strings onto multiple lines shorter than `width` by + replacing appropriate white space with + new-line characters (ASCII 0x0A). + + For details, see :cpp:func:`cudf::strings::wrap`. + + Parameters + ---------- + input : Column + String column + + width : int + Maximum character width of a line within each string + + Returns + ------- + Column + Column of wrapped strings + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_wrap.wrap( + input.view(), + width, + ) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_findall.py b/python/pylibcudf/pylibcudf/tests/test_string_findall.py index 994552fa276..debfad92d00 100644 --- a/python/pylibcudf/pylibcudf/tests/test_string_findall.py +++ b/python/pylibcudf/pylibcudf/tests/test_string_findall.py @@ -21,3 +21,20 @@ def test_findall(): type=pa_result.type, ) assert_column_eq(result, expected) + + +def test_find_re(): + arr = pa.array(["bunny", "rabbit", "hare", "dog"]) + pattern = "[eb]" + result = plc.strings.findall.find_re( + plc.interop.from_arrow(arr), + plc.strings.regex_program.RegexProgram.create( + pattern, plc.strings.regex_flags.RegexFlags.DEFAULT + ), + ) + pa_result = plc.interop.to_arrow(result) + expected = pa.array( + [0, 2, 3, -1], + type=pa_result.type, + ) + assert_column_eq(result, expected) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_padding.py b/python/pylibcudf/pylibcudf/tests/test_string_padding.py new file mode 100644 index 00000000000..2ba775d17ae --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_string_padding.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import pyarrow as pa +import pyarrow.compute as pc +import pylibcudf as plc + + +def test_pad(): + arr = pa.array(["a", "1", None]) + plc_result = plc.strings.padding.pad( + plc.interop.from_arrow(arr), + 2, + plc.strings.side_type.SideType.LEFT, + "!", + ) + result = plc.interop.to_arrow(plc_result) + expected = pa.chunked_array(pc.utf8_lpad(arr, 2, padding="!")) + assert result.equals(expected) + + +def test_zfill(): + arr = pa.array(["a", "1", None]) + plc_result = plc.strings.padding.zfill(plc.interop.from_arrow(arr), 2) + result = plc.interop.to_arrow(plc_result) + expected = pa.chunked_array(pc.utf8_lpad(arr, 2, padding="0")) + assert result.equals(expected) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_wrap.py b/python/pylibcudf/pylibcudf/tests/test_string_wrap.py new file mode 100644 index 00000000000..85abd3a2bae --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_string_wrap.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import textwrap + +import pyarrow as pa +import pylibcudf as plc +from utils import assert_column_eq + + +def test_wrap(): + pa_array = pa.array( + [ + "the quick brown fox jumped over the lazy brown dog", + "hello, world", + None, + ] + ) + result = plc.strings.wrap.wrap(plc.interop.from_arrow(pa_array), 12) + expected = pa.array( + [ + textwrap.fill(val, 12) if isinstance(val, str) else val + for val in pa_array.to_pylist() + ] + ) + assert_column_eq(expected, result)