From 466e37973d3b9aef4d14a7aa0cd48df0b886300d Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Wed, 2 Oct 2024 20:09:21 -0400 Subject: [PATCH 01/12] Fix performance regression for generate_character_ngrams (#16849) Fixes performance regression in `nvtext::generate_character_ngrams` introduced in #16212. Thread-per-row kernel is faster for smaller strings. Authors: - David Wendt (https://github.com/davidwendt) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Bradley Dice (https://github.com/bdice) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/16849 --- cpp/src/text/generate_ngrams.cu | 50 ++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/cpp/src/text/generate_ngrams.cu b/cpp/src/text/generate_ngrams.cu index a87ecb81b9d..997b0278fe2 100644 --- a/cpp/src/text/generate_ngrams.cu +++ b/cpp/src/text/generate_ngrams.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,9 @@ namespace nvtext { namespace detail { namespace { +// long strings threshold found with benchmarking +constexpr cudf::size_type AVG_CHAR_BYTES_THRESHOLD = 64; + /** * @brief Generate ngrams from strings column. * @@ -173,33 +177,39 @@ constexpr cudf::thread_index_type bytes_per_thread = 4; /** * @brief Counts the number of ngrams in each row of the given strings column * - * Each warp processes a single string. + * Each warp/thread processes a single string. * Formula is `count = max(0,str.length() - ngrams + 1)` * If a string has less than ngrams characters, its count is 0. */ CUDF_KERNEL void count_char_ngrams_kernel(cudf::column_device_view const d_strings, cudf::size_type ngrams, + cudf::size_type tile_size, cudf::size_type* d_counts) { auto const idx = cudf::detail::grid_1d::global_thread_id(); - auto const str_idx = idx / cudf::detail::warp_size; + auto const str_idx = idx / tile_size; if (str_idx >= d_strings.size()) { return; } if (d_strings.is_null(str_idx)) { d_counts[str_idx] = 0; return; } + auto const d_str = d_strings.element(str_idx); + if (tile_size == 1) { + d_counts[str_idx] = cuda::std::max(0, (d_str.length() + 1 - ngrams)); + return; + } + namespace cg = cooperative_groups; auto const warp = cg::tiled_partition(cg::this_thread_block()); - auto const d_str = d_strings.element(str_idx); - auto const end = d_str.data() + d_str.size_bytes(); + auto const end = d_str.data() + d_str.size_bytes(); auto const lane_idx = warp.thread_rank(); cudf::size_type count = 0; for (auto itr = d_str.data() + (lane_idx * bytes_per_thread); itr < end; - itr += cudf::detail::warp_size * bytes_per_thread) { + itr += tile_size * bytes_per_thread) { for (auto s = itr; (s < (itr + bytes_per_thread)) && (s < end); ++s) { count += static_cast(cudf::strings::detail::is_begin_utf8_char(*s)); } @@ -256,19 +266,27 @@ std::unique_ptr generate_character_ngrams(cudf::strings_column_vie "Parameter ngrams should be an integer value of 2 or greater", std::invalid_argument); - auto const strings_count = input.size(); - if (strings_count == 0) { // if no strings, return an empty column - return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + if (input.is_empty()) { // if no strings, return an empty column + return cudf::lists::detail::make_empty_lists_column( + cudf::data_type{cudf::type_id::STRING}, stream, mr); + } + if (input.size() == input.null_count()) { + return cudf::lists::detail::make_all_nulls_lists_column( + input.size(), cudf::data_type{cudf::type_id::STRING}, stream, mr); } auto const d_strings = cudf::column_device_view::create(input.parent(), stream); auto [offsets, total_ngrams] = [&] { - auto counts = rmm::device_uvector(input.size(), stream); - auto const num_blocks = cudf::util::div_rounding_up_safe( - static_cast(input.size()) * cudf::detail::warp_size, block_size); - count_char_ngrams_kernel<<>>( - *d_strings, ngrams, counts.data()); + auto counts = rmm::device_uvector(input.size(), stream); + auto const avg_char_bytes = (input.chars_size(stream) / (input.size() - input.null_count())); + auto const tile_size = (avg_char_bytes < AVG_CHAR_BYTES_THRESHOLD) + ? 1 // thread per row + : cudf::detail::warp_size; // warp per row + auto const grid = cudf::detail::grid_1d( + static_cast(input.size()) * tile_size, block_size); + count_char_ngrams_kernel<<>>( + *d_strings, ngrams, tile_size, counts.data()); return cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr); }(); auto d_offsets = offsets->view().data(); @@ -277,8 +295,8 @@ std::unique_ptr generate_character_ngrams(cudf::strings_column_vie "Insufficient number of characters in each string to generate ngrams"); character_ngram_generator_fn generator{*d_strings, ngrams, d_offsets}; - auto [offsets_column, chars] = cudf::strings::detail::make_strings_children( - generator, strings_count, total_ngrams, stream, mr); + auto [offsets_column, chars] = + cudf::strings::detail::make_strings_children(generator, input.size(), total_ngrams, stream, mr); auto output = cudf::make_strings_column( total_ngrams, std::move(offsets_column), chars.release(), 0, rmm::device_buffer{}); @@ -368,7 +386,7 @@ std::unique_ptr hash_character_ngrams(cudf::strings_column_view co auto [offsets, total_ngrams] = [&] { auto counts = rmm::device_uvector(input.size(), stream); count_char_ngrams_kernel<<>>( - *d_strings, ngrams, counts.data()); + *d_strings, ngrams, cudf::detail::warp_size, counts.data()); return cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr); }(); auto d_offsets = offsets->view().data(); From 7ae536031effd31d1c7aab63d1af812b0fc2a291 Mon Sep 17 00:00:00 2001 From: Muhammad Haseeb <14217455+mhaseeb123@users.noreply.github.com> Date: Wed, 2 Oct 2024 20:26:17 -0700 Subject: [PATCH 02/12] Batch memcpy the last offsets for output buffers of str and list cols in PQ reader (#16905) This PR adds the capability to batch memcpy the last offsets for the output buffers of string and list columns in PQ reader. This reduces the overhead from several `cudaMemcpyAsync` calls when reading wide strings and/or list columns tables. This optimization was found as well as ORC changes were contributed by @vuule. See this [comment](https://github.com/rapidsai/cudf/pull/16905#issuecomment-2375532577) for performance improvement data and discussion. Authors: - Muhammad Haseeb (https://github.com/mhaseeb123) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/16905 --- cpp/benchmarks/CMakeLists.txt | 5 - .../io/utilities/batched_memset_bench.cpp | 101 ------------- .../cudf/detail/utilities/batched_memcpy.hpp | 67 +++++++++ .../utilities}/batched_memset.hpp | 4 +- cpp/src/io/orc/stripe_enc.cu | 64 +++++--- cpp/src/io/parquet/page_data.cu | 26 ++++ cpp/src/io/parquet/parquet_gpu.hpp | 12 ++ cpp/src/io/parquet/reader_impl.cpp | 24 ++- cpp/src/io/parquet/reader_impl_preprocess.cu | 6 +- cpp/tests/CMakeLists.txt | 3 +- .../utilities_tests/batched_memcpy_tests.cu | 139 ++++++++++++++++++ .../utilities_tests/batched_memset_tests.cu | 4 +- 12 files changed, 308 insertions(+), 147 deletions(-) delete mode 100644 cpp/benchmarks/io/utilities/batched_memset_bench.cpp create mode 100644 cpp/include/cudf/detail/utilities/batched_memcpy.hpp rename cpp/include/cudf/{io/detail => detail/utilities}/batched_memset.hpp (98%) create mode 100644 cpp/tests/utilities_tests/batched_memcpy_tests.cu diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 4113e38dcf4..110b4557840 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -392,11 +392,6 @@ ConfigureNVBench(JSON_READER_NVBENCH io/json/nested_json.cpp io/json/json_reader ConfigureNVBench(JSON_READER_OPTION_NVBENCH io/json/json_reader_option.cpp) ConfigureNVBench(JSON_WRITER_NVBENCH io/json/json_writer.cpp) -# ################################################################################################## -# * multi buffer memset benchmark -# ---------------------------------------------------------------------- -ConfigureNVBench(BATCHED_MEMSET_BENCH io/utilities/batched_memset_bench.cpp) - # ################################################################################################## # * io benchmark --------------------------------------------------------------------- ConfigureNVBench(MULTIBYTE_SPLIT_NVBENCH io/text/multibyte_split.cpp) diff --git a/cpp/benchmarks/io/utilities/batched_memset_bench.cpp b/cpp/benchmarks/io/utilities/batched_memset_bench.cpp deleted file mode 100644 index 2905895a63b..00000000000 --- a/cpp/benchmarks/io/utilities/batched_memset_bench.cpp +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include -#include - -#include - -// Size of the data in the benchmark dataframe; chosen to be low enough to allow benchmarks to -// run on most GPUs, but large enough to allow highest throughput -constexpr size_t data_size = 512 << 20; - -void parquet_read_common(cudf::size_type num_rows_to_read, - cudf::size_type num_cols_to_read, - cuio_source_sink_pair& source_sink, - nvbench::state& state) -{ - cudf::io::parquet_reader_options read_opts = - cudf::io::parquet_reader_options::builder(source_sink.make_source_info()); - - auto mem_stats_logger = cudf::memory_stats_logger(); - state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value())); - state.exec( - nvbench::exec_tag::sync | nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { - try_drop_l3_cache(); - - timer.start(); - auto const result = cudf::io::read_parquet(read_opts); - timer.stop(); - - CUDF_EXPECTS(result.tbl->num_columns() == num_cols_to_read, "Unexpected number of columns"); - CUDF_EXPECTS(result.tbl->num_rows() == num_rows_to_read, "Unexpected number of rows"); - }); - - auto const time = state.get_summary("nv/cold/time/gpu/mean").get_float64("value"); - state.add_element_count(static_cast(data_size) / time, "bytes_per_second"); - state.add_buffer_size( - mem_stats_logger.peak_memory_usage(), "peak_memory_usage", "peak_memory_usage"); - state.add_buffer_size(source_sink.size(), "encoded_file_size", "encoded_file_size"); -} - -template -void bench_batched_memset(nvbench::state& state, nvbench::type_list>) -{ - auto const d_type = get_type_or_group(static_cast(DataType)); - auto const num_cols = static_cast(state.get_int64("num_cols")); - auto const cardinality = static_cast(state.get_int64("cardinality")); - auto const run_length = static_cast(state.get_int64("run_length")); - auto const source_type = retrieve_io_type_enum(state.get_string("io_type")); - auto const compression = cudf::io::compression_type::NONE; - cuio_source_sink_pair source_sink(source_type); - auto const tbl = - create_random_table(cycle_dtypes(d_type, num_cols), - table_size_bytes{data_size}, - data_profile_builder().cardinality(cardinality).avg_run_length(run_length)); - auto const view = tbl->view(); - - cudf::io::parquet_writer_options write_opts = - cudf::io::parquet_writer_options::builder(source_sink.make_sink_info(), view) - .compression(compression); - cudf::io::write_parquet(write_opts); - auto const num_rows = view.num_rows(); - - parquet_read_common(num_rows, num_cols, source_sink, state); -} - -using d_type_list = nvbench::enum_type_list; - -NVBENCH_BENCH_TYPES(bench_batched_memset, NVBENCH_TYPE_AXES(d_type_list)) - .set_name("batched_memset") - .set_type_axes_names({"data_type"}) - .add_int64_axis("num_cols", {1000}) - .add_string_axis("io_type", {"DEVICE_BUFFER"}) - .set_min_samples(4) - .add_int64_axis("cardinality", {0, 1000}) - .add_int64_axis("run_length", {1, 32}); diff --git a/cpp/include/cudf/detail/utilities/batched_memcpy.hpp b/cpp/include/cudf/detail/utilities/batched_memcpy.hpp new file mode 100644 index 00000000000..ed0ab9e6e5b --- /dev/null +++ b/cpp/include/cudf/detail/utilities/batched_memcpy.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include +#include +#include + +namespace CUDF_EXPORT cudf { +namespace detail { + +/** + * @brief A helper function that copies a vector of vectors from source to destination addresses in + * a batched manner. + * + * @tparam SrcIterator **[inferred]** The type of device-accessible source addresses iterator + * @tparam DstIterator **[inferred]** The type of device-accessible destination address iterator + * @tparam SizeIterator **[inferred]** The type of device-accessible buffer size iterator + * + * @param src_iter Device-accessible iterator to source addresses + * @param dst_iter Device-accessible iterator to destination addresses + * @param size_iter Device-accessible iterator to the buffer sizes (in bytes) + * @param num_buffs Number of buffers to be copied + * @param stream CUDA stream to use + */ +template +void batched_memcpy_async(SrcIterator src_iter, + DstIterator dst_iter, + SizeIterator size_iter, + size_t num_buffs, + rmm::cuda_stream_view stream) +{ + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, num_buffs, stream.value()); + + rmm::device_buffer d_temp_storage{temp_storage_bytes, stream.value()}; + + cub::DeviceMemcpy::Batched(d_temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + num_buffs, + stream.value()); +} + +} // namespace detail +} // namespace CUDF_EXPORT cudf diff --git a/cpp/include/cudf/io/detail/batched_memset.hpp b/cpp/include/cudf/detail/utilities/batched_memset.hpp similarity index 98% rename from cpp/include/cudf/io/detail/batched_memset.hpp rename to cpp/include/cudf/detail/utilities/batched_memset.hpp index 1c74be4a9fe..75f738f7529 100644 --- a/cpp/include/cudf/io/detail/batched_memset.hpp +++ b/cpp/include/cudf/detail/utilities/batched_memset.hpp @@ -28,7 +28,7 @@ #include namespace CUDF_EXPORT cudf { -namespace io::detail { +namespace detail { /** * @brief A helper function that takes in a vector of device spans and memsets them to the @@ -78,5 +78,5 @@ void batched_memset(std::vector> const& bufs, d_temp_storage.data(), temp_storage_bytes, iter_in, iter_out, sizes, num_bufs, stream); } -} // namespace io::detail +} // namespace detail } // namespace CUDF_EXPORT cudf diff --git a/cpp/src/io/orc/stripe_enc.cu b/cpp/src/io/orc/stripe_enc.cu index 5c70e35fd2e..ed0b6969154 100644 --- a/cpp/src/io/orc/stripe_enc.cu +++ b/cpp/src/io/orc/stripe_enc.cu @@ -20,6 +20,8 @@ #include "orc_gpu.hpp" #include +#include +#include #include #include #include @@ -1087,37 +1089,42 @@ CUDF_KERNEL void __launch_bounds__(block_size) /** * @brief Merge chunked column data into a single contiguous stream * - * @param[in,out] strm_desc StripeStream device array [stripe][stream] - * @param[in,out] streams List of encoder chunk streams [column][rowgroup] + * @param[in] strm_desc StripeStream device array [stripe][stream] + * @param[in] streams List of encoder chunk streams [column][rowgroup] + * @param[out] srcs List of source encoder chunk stream data addresses + * @param[out] dsts List of destination StripeStream data addresses + * @param[out] sizes List of stream sizes in bytes */ // blockDim {compact_streams_block_size,1,1} CUDF_KERNEL void __launch_bounds__(compact_streams_block_size) - gpuCompactOrcDataStreams(device_2dspan strm_desc, - device_2dspan streams) + gpuInitBatchedMemcpy(device_2dspan strm_desc, + device_2dspan streams, + device_span srcs, + device_span dsts, + device_span sizes) { - __shared__ __align__(16) StripeStream ss; - - auto const stripe_id = blockIdx.x; + auto const stripe_id = cudf::detail::grid_1d::global_thread_id(); auto const stream_id = blockIdx.y; - auto const t = threadIdx.x; + if (stripe_id >= strm_desc.size().first) { return; } - if (t == 0) { ss = strm_desc[stripe_id][stream_id]; } - __syncthreads(); + auto const out_id = stream_id * strm_desc.size().first + stripe_id; + StripeStream ss = strm_desc[stripe_id][stream_id]; if (ss.data_ptr == nullptr) { return; } auto const cid = ss.stream_type; auto dst_ptr = ss.data_ptr; for (auto group = ss.first_chunk_id; group < ss.first_chunk_id + ss.num_chunks; ++group) { + auto const out_id = stream_id * streams.size().second + group; + srcs[out_id] = streams[ss.column_id][group].data_ptrs[cid]; + dsts[out_id] = dst_ptr; + + // Also update the stream here, data will be copied in a separate kernel + streams[ss.column_id][group].data_ptrs[cid] = dst_ptr; + auto const len = streams[ss.column_id][group].lengths[cid]; - if (len > 0) { - auto const src_ptr = streams[ss.column_id][group].data_ptrs[cid]; - for (uint32_t i = t; i < len; i += blockDim.x) { - dst_ptr[i] = src_ptr[i]; - } - __syncthreads(); - } - if (t == 0) { streams[ss.column_id][group].data_ptrs[cid] = dst_ptr; } + // len is the size (in bytes) of the current stream. + sizes[out_id] = len; dst_ptr += len; } } @@ -1325,9 +1332,26 @@ void CompactOrcDataStreams(device_2dspan strm_desc, device_2dspan enc_streams, rmm::cuda_stream_view stream) { + auto const num_rowgroups = enc_streams.size().second; + auto const num_streams = strm_desc.size().second; + auto const num_stripes = strm_desc.size().first; + auto const num_chunks = num_rowgroups * num_streams; + auto srcs = cudf::detail::make_zeroed_device_uvector_async( + num_chunks, stream, rmm::mr::get_current_device_resource()); + auto dsts = cudf::detail::make_zeroed_device_uvector_async( + num_chunks, stream, rmm::mr::get_current_device_resource()); + auto lengths = cudf::detail::make_zeroed_device_uvector_async( + num_chunks, stream, rmm::mr::get_current_device_resource()); + dim3 dim_block(compact_streams_block_size, 1); - dim3 dim_grid(strm_desc.size().first, strm_desc.size().second); - gpuCompactOrcDataStreams<<>>(strm_desc, enc_streams); + dim3 dim_grid(cudf::util::div_rounding_up_unsafe(num_stripes, compact_streams_block_size), + strm_desc.size().second); + gpuInitBatchedMemcpy<<>>( + strm_desc, enc_streams, srcs, dsts, lengths); + + // Copy streams in a batched manner. + cudf::detail::batched_memcpy_async( + srcs.begin(), dsts.begin(), lengths.begin(), lengths.size(), stream); } std::optional CompressOrcDataStreams( diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index e0d50d7ccf9..b3276c81c1f 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -17,6 +17,8 @@ #include "page_data.cuh" #include "page_decode.cuh" +#include + #include #include @@ -466,4 +468,28 @@ void __host__ DecodeSplitPageData(cudf::detail::hostdevice_span pages, } } +void WriteFinalOffsets(host_span offsets, + host_span buff_addrs, + rmm::cuda_stream_view stream) +{ + // Copy offsets to device and create an iterator + auto d_src_data = cudf::detail::make_device_uvector_async( + offsets, stream, cudf::get_current_device_resource_ref()); + // Iterator for the source (scalar) data + auto src_iter = cudf::detail::make_counting_transform_iterator( + static_cast(0), + cuda::proclaim_return_type( + [src = d_src_data.begin()] __device__(std::size_t i) { return src + i; })); + + // Copy buffer addresses to device and create an iterator + auto d_dst_addrs = cudf::detail::make_device_uvector_async( + buff_addrs, stream, cudf::get_current_device_resource_ref()); + // size_iter is simply a constant iterator of sizeof(size_type) bytes. + auto size_iter = thrust::make_constant_iterator(sizeof(size_type)); + + // Copy offsets to buffers in batched manner. + cudf::detail::batched_memcpy_async( + src_iter, d_dst_addrs.begin(), size_iter, offsets.size(), stream); +} + } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index e631e12119d..a8ba3a969ce 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -797,6 +797,18 @@ void DecodeSplitPageData(cudf::detail::hostdevice_span pages, kernel_error::pointer error_code, rmm::cuda_stream_view stream); +/** + * @brief Writes the final offsets to the corresponding list and string buffer end addresses in a + * batched manner. + * + * @param offsets Host span of final offsets + * @param buff_addrs Host span of corresponding output col buffer end addresses + * @param stream CUDA stream to use + */ +void WriteFinalOffsets(host_span offsets, + host_span buff_addrs, + rmm::cuda_stream_view stream); + /** * @brief Launches kernel for reading the string column data stored in the pages * diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 7d817bde7af..1b69ccb7742 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -371,13 +371,15 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num CUDF_FAIL("Parquet data decode failed with code(s) " + kernel_error::to_string(error)); } - // for list columns, add the final offset to every offset buffer. - // TODO : make this happen in more efficiently. Maybe use thrust::for_each - // on each buffer. + // For list and string columns, add the final offset to every offset buffer. // Note : the reason we are doing this here instead of in the decode kernel is // that it is difficult/impossible for a given page to know that it is writing the very // last value that should then be followed by a terminator (because rows can span // page boundaries). + std::vector out_buffers; + std::vector final_offsets; + out_buffers.reserve(_input_columns.size()); + final_offsets.reserve(_input_columns.size()); for (size_t idx = 0; idx < _input_columns.size(); idx++) { input_column_info const& input_col = _input_columns[idx]; @@ -393,25 +395,21 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num // the final offset for a list at level N is the size of it's child size_type const offset = child.type.id() == type_id::LIST ? child.size - 1 : child.size; - CUDF_CUDA_TRY(cudaMemcpyAsync(static_cast(out_buf.data()) + (out_buf.size - 1), - &offset, - sizeof(size_type), - cudaMemcpyDefault, - _stream.value())); + out_buffers.emplace_back(static_cast(out_buf.data()) + (out_buf.size - 1)); + final_offsets.emplace_back(offset); out_buf.user_data |= PARQUET_COLUMN_BUFFER_FLAG_LIST_TERMINATED; } else if (out_buf.type.id() == type_id::STRING) { // need to cap off the string offsets column auto const sz = static_cast(col_string_sizes[idx]); if (sz <= strings::detail::get_offset64_threshold()) { - CUDF_CUDA_TRY(cudaMemcpyAsync(static_cast(out_buf.data()) + out_buf.size, - &sz, - sizeof(size_type), - cudaMemcpyDefault, - _stream.value())); + out_buffers.emplace_back(static_cast(out_buf.data()) + out_buf.size); + final_offsets.emplace_back(sz); } } } } + // Write the final offsets for list and string columns in a batched manner + WriteFinalOffsets(final_offsets, out_buffers, _stream); // update null counts in the final column buffers for (size_t idx = 0; idx < subpass.pages.size(); idx++) { diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 3763c2e8e6d..8cab68ea721 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -19,9 +19,9 @@ #include #include +#include #include #include -#include #include #include @@ -1656,9 +1656,9 @@ void reader::impl::allocate_columns(read_mode mode, size_t skip_rows, size_t num } } - cudf::io::detail::batched_memset(memset_bufs, static_cast(0), _stream); + cudf::detail::batched_memset(memset_bufs, static_cast(0), _stream); // Need to set null mask bufs to all high bits - cudf::io::detail::batched_memset( + cudf::detail::batched_memset( nullmask_bufs, std::numeric_limits::max(), _stream); } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b67d922d377..4596ec65ce7 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -385,6 +385,8 @@ ConfigureTest( # * utilities tests ------------------------------------------------------------------------------- ConfigureTest( UTILITIES_TEST + utilities_tests/batched_memcpy_tests.cu + utilities_tests/batched_memset_tests.cu utilities_tests/column_debug_tests.cpp utilities_tests/column_utilities_tests.cpp utilities_tests/column_wrapper_tests.cpp @@ -395,7 +397,6 @@ ConfigureTest( utilities_tests/pinned_memory_tests.cpp utilities_tests/type_check_tests.cpp utilities_tests/type_list_tests.cpp - utilities_tests/batched_memset_tests.cu ) # ################################################################################################## diff --git a/cpp/tests/utilities_tests/batched_memcpy_tests.cu b/cpp/tests/utilities_tests/batched_memcpy_tests.cu new file mode 100644 index 00000000000..98657f8e224 --- /dev/null +++ b/cpp/tests/utilities_tests/batched_memcpy_tests.cu @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +template +struct BatchedMemcpyTest : public cudf::test::BaseFixture {}; + +TEST(BatchedMemcpyTest, BasicTest) +{ + using T1 = int64_t; + + // Device init + auto stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + + // Buffer lengths (in number of elements) + std::vector const h_lens{ + 50000, 4, 1000, 0, 250000, 1, 100, 8000, 0, 1, 100, 1000, 10000, 100000, 0, 1, 100000}; + + // Total number of buffers + auto const num_buffs = h_lens.size(); + + // Exclusive sum of buffer lengths for pointers + std::vector h_lens_excl_sum(num_buffs); + std::exclusive_scan(h_lens.begin(), h_lens.end(), h_lens_excl_sum.begin(), 0); + + // Corresponding buffer sizes (in bytes) + std::vector h_sizes_bytes; + h_sizes_bytes.reserve(num_buffs); + std::transform( + h_lens.cbegin(), h_lens.cend(), std::back_inserter(h_sizes_bytes), [&](auto& size) { + return size * sizeof(T1); + }); + + // Initialize random engine + auto constexpr seed = 0xcead; + std::mt19937 engine{seed}; + using uniform_distribution = + typename std::conditional_t, + std::bernoulli_distribution, + std::conditional_t, + std::uniform_real_distribution, + std::uniform_int_distribution>>; + uniform_distribution dist{}; + + // Generate a src vector of random data vectors + std::vector> h_sources; + h_sources.reserve(num_buffs); + std::transform(h_lens.begin(), h_lens.end(), std::back_inserter(h_sources), [&](auto size) { + std::vector data(size); + std::generate_n(data.begin(), size, [&]() { return T1{dist(engine)}; }); + return data; + }); + // Copy the vectors to device + std::vector> h_device_vecs; + h_device_vecs.reserve(h_sources.size()); + std::transform( + h_sources.begin(), h_sources.end(), std::back_inserter(h_device_vecs), [stream, mr](auto& vec) { + return cudf::detail::make_device_uvector_async(vec, stream, mr); + }); + // Pointers to the source vectors + std::vector h_src_ptrs; + h_src_ptrs.reserve(h_sources.size()); + std::transform( + h_device_vecs.begin(), h_device_vecs.end(), std::back_inserter(h_src_ptrs), [](auto& vec) { + return static_cast(vec.data()); + }); + // Copy the source data pointers to device + auto d_src_ptrs = cudf::detail::make_device_uvector_async(h_src_ptrs, stream, mr); + + // Total number of elements in all buffers + auto const total_buff_len = std::accumulate(h_lens.cbegin(), h_lens.cend(), 0); + + // Create one giant buffer for destination + auto d_dst_data = cudf::detail::make_zeroed_device_uvector_async(total_buff_len, stream, mr); + // Pointers to destination buffers within the giant destination buffer + std::vector h_dst_ptrs(num_buffs); + std::for_each(thrust::make_counting_iterator(static_cast(0)), + thrust::make_counting_iterator(num_buffs), + [&](auto i) { return h_dst_ptrs[i] = d_dst_data.data() + h_lens_excl_sum[i]; }); + // Copy destination data pointers to device + auto d_dst_ptrs = cudf::detail::make_device_uvector_async(h_dst_ptrs, stream, mr); + + // Copy buffer size iterators (in bytes) to device + auto d_sizes_bytes = cudf::detail::make_device_uvector_async(h_sizes_bytes, stream, mr); + + // Run the batched memcpy + cudf::detail::batched_memcpy_async( + d_src_ptrs.begin(), d_dst_ptrs.begin(), d_sizes_bytes.begin(), num_buffs, stream); + + // Expected giant destination buffer after the memcpy + std::vector expected_buffer; + expected_buffer.reserve(total_buff_len); + std::for_each(h_sources.cbegin(), h_sources.cend(), [&expected_buffer](auto& source) { + expected_buffer.insert(expected_buffer.end(), source.begin(), source.end()); + }); + + // Copy over the result destination buffer to host and synchronize the stream + auto result_dst_buffer = + cudf::detail::make_std_vector_sync(cudf::device_span(d_dst_data), stream); + + // Check if both vectors are equal + EXPECT_TRUE( + std::equal(expected_buffer.begin(), expected_buffer.end(), result_dst_buffer.begin())); +} diff --git a/cpp/tests/utilities_tests/batched_memset_tests.cu b/cpp/tests/utilities_tests/batched_memset_tests.cu index bed0f40d70e..0eeb7b95318 100644 --- a/cpp/tests/utilities_tests/batched_memset_tests.cu +++ b/cpp/tests/utilities_tests/batched_memset_tests.cu @@ -18,8 +18,8 @@ #include #include +#include #include -#include #include #include #include @@ -78,7 +78,7 @@ TEST(MultiBufferTestIntegral, BasicTest1) }); // Function Call - cudf::io::detail::batched_memset(memset_bufs, uint64_t{0}, stream); + cudf::detail::batched_memset(memset_bufs, uint64_t{0}, stream); // Set all buffer regions to 0 for expected comparison std::for_each( From 2ec6cb32d825d2ef255d0e56497c20be30713d32 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 2 Oct 2024 18:07:52 -1000 Subject: [PATCH 03/12] Fix astype from tz-aware type to tz-aware type (#16980) closes https://github.com/rapidsai/cudf/issues/16973 Also matches astype from tz-naive to tz-aware type like pandas Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/16980 --- python/cudf/cudf/core/column/datetime.py | 15 +++++++++++++ .../cudf/tests/series/test_datetimelike.py | 22 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index d0ea4612a1b..2c9b0baa9b6 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -480,6 +480,11 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn: if dtype == self.dtype: return self + elif isinstance(dtype, pd.DatetimeTZDtype): + raise TypeError( + "Cannot use .astype to convert from timezone-naive dtype to timezone-aware dtype. " + "Use tz_localize instead." + ) return libcudf.unary.cast(self, dtype=dtype) def as_timedelta_column(self, dtype: Dtype) -> None: # type: ignore[override] @@ -940,6 +945,16 @@ def strftime(self, format: str) -> cudf.core.column.StringColumn: def as_string_column(self) -> cudf.core.column.StringColumn: return self._local_time.as_string_column() + def as_datetime_column(self, dtype: Dtype) -> DatetimeColumn: + if isinstance(dtype, pd.DatetimeTZDtype) and dtype != self.dtype: + if dtype.unit != self.time_unit: + # TODO: Doesn't check that new unit is valid. + casted = self._with_type_metadata(dtype) + else: + casted = self + return casted.tz_convert(str(dtype.tz)) + return super().as_datetime_column(dtype) + def get_dt_field(self, field: str) -> ColumnBase: return libcudf.datetime.extract_datetime_component( self._local_time, field diff --git a/python/cudf/cudf/tests/series/test_datetimelike.py b/python/cudf/cudf/tests/series/test_datetimelike.py index cea86a5499e..691da224f44 100644 --- a/python/cudf/cudf/tests/series/test_datetimelike.py +++ b/python/cudf/cudf/tests/series/test_datetimelike.py @@ -266,3 +266,25 @@ def test_pandas_compatible_non_zoneinfo_raises(klass): with cudf.option_context("mode.pandas_compatible", True): with pytest.raises(NotImplementedError): cudf.from_pandas(pandas_obj) + + +def test_astype_naive_to_aware_raises(): + ser = cudf.Series([datetime.datetime(2020, 1, 1)]) + with pytest.raises(TypeError): + ser.astype("datetime64[ns, UTC]") + with pytest.raises(TypeError): + ser.to_pandas().astype("datetime64[ns, UTC]") + + +@pytest.mark.parametrize("unit", ["ns", "us"]) +def test_astype_aware_to_aware(unit): + ser = cudf.Series( + [datetime.datetime(2020, 1, 1, tzinfo=datetime.timezone.utc)] + ) + result = ser.astype(f"datetime64[{unit}, US/Pacific]") + expected = ser.to_pandas().astype(f"datetime64[{unit}, US/Pacific]") + zoneinfo_type = pd.DatetimeTZDtype( + expected.dtype.unit, zoneinfo.ZoneInfo(str(expected.dtype.tz)) + ) + expected = ser.astype(zoneinfo_type) + assert_eq(result, expected) From 3faa3ee8b869a8450f6352c7770fb155b321d926 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 3 Oct 2024 08:53:08 -0400 Subject: [PATCH 04/12] Add cudf::strings::find_re API (#16742) Adds the `cudf::strings::find_re` and `str.find_re` API to libcudf/pylibcudf/cudf. This function returns the character position where the pattern first matches in each row of the input column. If a match is not found, -1 is returned for that corresponding row. Closes #16729 Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Nghia Truong (https://github.com/ttnghia) - Matthew Murray (https://github.com/Matt711) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/16742 --- cpp/doxygen/regex.md | 1 + cpp/include/cudf/strings/findall.hpp | 29 ++++++++++++ cpp/src/strings/search/findall.cu | 46 +++++++++++++++++++ cpp/tests/streams/strings/find_test.cpp | 1 + cpp/tests/strings/findall_tests.cpp | 35 +++++++++++--- python/cudf/cudf/_lib/strings/__init__.py | 2 +- python/cudf/cudf/_lib/strings/findall.pyx | 16 +++++++ python/cudf/cudf/core/column/string.py | 40 ++++++++++++++++ python/cudf/cudf/tests/test_string.py | 20 ++++++++ .../pylibcudf/libcudf/strings/findall.pxd | 4 ++ .../pylibcudf/pylibcudf/strings/findall.pxd | 1 + .../pylibcudf/pylibcudf/strings/findall.pyx | 32 +++++++++++++ .../pylibcudf/tests/test_string_findall.py | 17 +++++++ 13 files changed, 237 insertions(+), 7 deletions(-) diff --git a/cpp/doxygen/regex.md b/cpp/doxygen/regex.md index 6d1c91a5752..6902b1948bd 100644 --- a/cpp/doxygen/regex.md +++ b/cpp/doxygen/regex.md @@ -8,6 +8,7 @@ This page specifies which regular expression (regex) features are currently supp - cudf::strings::extract() - cudf::strings::extract_all_record() - cudf::strings::findall() +- cudf::strings::find_re() - cudf::strings::replace_re() - cudf::strings::replace_with_backrefs() - cudf::strings::split_re() diff --git a/cpp/include/cudf/strings/findall.hpp b/cpp/include/cudf/strings/findall.hpp index c6b9bc7e58a..867764b6d9a 100644 --- a/cpp/include/cudf/strings/findall.hpp +++ b/cpp/include/cudf/strings/findall.hpp @@ -66,6 +66,35 @@ std::unique_ptr findall( rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); +/** + * @brief Returns the starting character index of the first match for the given pattern + * in each row of the input column + * + * @code{.pseudo} + * Example: + * s = ["bunny", "rabbit", "hare", "dog"] + * p = regex_program::create("[be]") + * r = find_re(s, p) + * r is now [0, 2, 3, -1] + * @endcode + * + * A null output row occurs if the corresponding input row is null. + * A -1 is returned for rows that do not contain a match. + * + * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. + * + * @param input Strings instance for this operation + * @param prog Regex program instance + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New column of integers + */ +std::unique_ptr find_re( + strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); + /** @} */ // end of doxygen group } // namespace strings } // namespace CUDF_EXPORT cudf diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index d8c1b50a94b..21708e48a25 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -126,6 +126,43 @@ std::unique_ptr findall(strings_column_view const& input, mr); } +namespace { +struct find_re_fn { + column_device_view d_strings; + + __device__ size_type operator()(size_type const idx, + reprog_device const prog, + int32_t const thread_idx) const + { + if (d_strings.is_null(idx)) { return 0; } + auto const d_str = d_strings.element(idx); + + auto const result = prog.find(thread_idx, d_str, d_str.begin()); + return result.has_value() ? result.value().first : -1; + } +}; +} // namespace + +std::unique_ptr find_re(strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto results = make_numeric_column(data_type{type_to_id()}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + if (input.is_empty()) { return results; } + + auto d_results = results->mutable_view().data(); + auto d_prog = regex_device_builder::create_prog_device(prog, stream); + auto const d_strings = column_device_view::create(input.parent(), stream); + launch_transform_kernel(find_re_fn{*d_strings}, *d_prog, d_results, input.size(), stream); + + return results; +} } // namespace detail // external API @@ -139,5 +176,14 @@ std::unique_ptr findall(strings_column_view const& input, return detail::findall(input, prog, stream, mr); } +std::unique_ptr find_re(strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + return detail::find_re(input, prog, stream, mr); +} + } // namespace strings } // namespace cudf diff --git a/cpp/tests/streams/strings/find_test.cpp b/cpp/tests/streams/strings/find_test.cpp index 52839c6fc9f..e5a1ee0988c 100644 --- a/cpp/tests/streams/strings/find_test.cpp +++ b/cpp/tests/streams/strings/find_test.cpp @@ -46,4 +46,5 @@ TEST_F(StringsFindTest, Find) auto const pattern = std::string("[a-z]"); auto const prog = cudf::strings::regex_program::create(pattern); cudf::strings::findall(view, *prog, cudf::test::get_default_stream()); + cudf::strings::find_re(view, *prog, cudf::test::get_default_stream()); } diff --git a/cpp/tests/strings/findall_tests.cpp b/cpp/tests/strings/findall_tests.cpp index 73da4d081e2..4821a7fa999 100644 --- a/cpp/tests/strings/findall_tests.cpp +++ b/cpp/tests/strings/findall_tests.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -149,6 +150,22 @@ TEST_F(StringsFindallTests, LargeRegex) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); } +TEST_F(StringsFindallTests, FindTest) +{ + auto const valids = cudf::test::iterators::null_at(5); + cudf::test::strings_column_wrapper input( + {"3A", "May4", "Jan2021", "March", "A9BC", "", "", "abcdef ghijklm 12345"}, valids); + auto sv = cudf::strings_column_view(input); + + auto pattern = std::string("\\d+"); + + auto prog = cudf::strings::regex_program::create(pattern); + auto results = cudf::strings::find_re(sv, *prog); + auto expected = + cudf::test::fixed_width_column_wrapper({0, 3, 3, -1, 1, 0, -1, 15}, valids); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); +} + TEST_F(StringsFindallTests, NoMatches) { cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"}); @@ -169,10 +186,16 @@ TEST_F(StringsFindallTests, EmptyTest) auto prog = cudf::strings::regex_program::create(pattern); cudf::test::strings_column_wrapper input; - auto sv = cudf::strings_column_view(input); - auto results = cudf::strings::findall(sv, *prog); - - using LCW = cudf::test::lists_column_wrapper; - LCW expected; - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + auto sv = cudf::strings_column_view(input); + { + auto results = cudf::strings::findall(sv, *prog); + using LCW = cudf::test::lists_column_wrapper; + LCW expected; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + } + { + auto results = cudf::strings::find_re(sv, *prog); + auto expected = cudf::test::fixed_width_column_wrapper{}; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + } } diff --git a/python/cudf/cudf/_lib/strings/__init__.py b/python/cudf/cudf/_lib/strings/__init__.py index 049dbab4851..e712937f816 100644 --- a/python/cudf/cudf/_lib/strings/__init__.py +++ b/python/cudf/cudf/_lib/strings/__init__.py @@ -71,7 +71,7 @@ startswith_multiple, ) from cudf._lib.strings.find_multiple import find_multiple -from cudf._lib.strings.findall import findall +from cudf._lib.strings.findall import find_re, findall from cudf._lib.strings.json import GetJsonObjectOptions, get_json_object from cudf._lib.strings.padding import center, ljust, pad, rjust, zfill from cudf._lib.strings.repeat import repeat_scalar, repeat_sequence diff --git a/python/cudf/cudf/_lib/strings/findall.pyx b/python/cudf/cudf/_lib/strings/findall.pyx index 0e758d5b322..3e7a504d535 100644 --- a/python/cudf/cudf/_lib/strings/findall.pyx +++ b/python/cudf/cudf/_lib/strings/findall.pyx @@ -23,3 +23,19 @@ def findall(Column source_strings, object pattern, uint32_t flags): prog, ) return Column.from_pylibcudf(plc_result) + + +@acquire_spill_lock() +def find_re(Column source_strings, object pattern, uint32_t flags): + """ + Returns character positions where the pattern first matches + the elements in source_strings. + """ + prog = plc.strings.regex_program.RegexProgram.create( + str(pattern), flags + ) + plc_result = plc.strings.findall.find_re( + source_strings.to_pylibcudf(mode="read"), + prog, + ) + return Column.from_pylibcudf(plc_result) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 88df57b1b3b..b50e23bd52e 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -3626,6 +3626,46 @@ def findall(self, pat: str, flags: int = 0) -> SeriesOrIndex: data = libstrings.findall(self._column, pat, flags) return self._return_or_inplace(data) + def find_re(self, pat: str, flags: int = 0) -> SeriesOrIndex: + """ + Find first occurrence of pattern or regular expression in the + Series/Index. + + Parameters + ---------- + pat : str + Pattern or regular expression. + flags : int, default 0 (no flags) + Flags to pass through to the regex engine (e.g. re.MULTILINE) + + Returns + ------- + Series + A Series of position values where the pattern first matches + each string. + + Examples + -------- + >>> import cudf + >>> s = cudf.Series(['Lion', 'Monkey', 'Rabbit', 'Cat']) + >>> s.str.find_re('[ti]') + 0 1 + 1 -1 + 2 4 + 3 2 + dtype: int32 + """ + if isinstance(pat, re.Pattern): + flags = pat.flags & ~re.U + pat = pat.pattern + if not _is_supported_regex_flags(flags): + raise NotImplementedError( + "Unsupported value for `flags` parameter" + ) + + data = libstrings.find_re(self._column, pat, flags) + return self._return_or_inplace(data) + def find_multiple(self, patterns: SeriesOrIndex) -> cudf.Series: """ Find all first occurrences of patterns in the Series/Index. diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index cc88cc79769..45143211a11 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -1899,6 +1899,26 @@ def test_string_findall(pat, flags): assert_eq(expected, actual) +@pytest.mark.parametrize( + "pat, flags, pos", + [ + ("Monkey", 0, [-1, 0, -1, -1]), + ("on", 0, [2, 1, -1, 1]), + ("bit", 0, [-1, -1, 3, -1]), + ("on$", 0, [2, -1, -1, -1]), + ("on$", re.MULTILINE, [2, -1, -1, 1]), + ("o.*k", re.DOTALL, [-1, 1, -1, 1]), + ], +) +def test_string_find_re(pat, flags, pos): + test_data = ["Lion", "Monkey", "Rabbit", "Don\nkey"] + gs = cudf.Series(test_data) + + expected = pd.Series(pos, dtype=np.int32) + actual = gs.str.find_re(pat, flags) + assert_eq(expected, actual) + + def test_string_replace_multi(): ps = pd.Series(["hello", "goodbye"]) gs = cudf.Series(["hello", "goodbye"]) diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd index e0a8b776465..0d286c36446 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd @@ -11,3 +11,7 @@ cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] findall( column_view input, regex_program prog) except + + + cdef unique_ptr[column] find_re( + column_view input, + regex_program prog) except + diff --git a/python/pylibcudf/pylibcudf/strings/findall.pxd b/python/pylibcudf/pylibcudf/strings/findall.pxd index 54afa088141..3c35a9c9aa9 100644 --- a/python/pylibcudf/pylibcudf/strings/findall.pxd +++ b/python/pylibcudf/pylibcudf/strings/findall.pxd @@ -4,4 +4,5 @@ from pylibcudf.column cimport Column from pylibcudf.strings.regex_program cimport RegexProgram +cpdef Column find_re(Column input, RegexProgram pattern) cpdef Column findall(Column input, RegexProgram pattern) diff --git a/python/pylibcudf/pylibcudf/strings/findall.pyx b/python/pylibcudf/pylibcudf/strings/findall.pyx index 3a6b87504b3..5212dc4594d 100644 --- a/python/pylibcudf/pylibcudf/strings/findall.pyx +++ b/python/pylibcudf/pylibcudf/strings/findall.pyx @@ -38,3 +38,35 @@ cpdef Column findall(Column input, RegexProgram pattern): ) return Column.from_libcudf(move(c_result)) + + +cpdef Column find_re(Column input, RegexProgram pattern): + """ + Returns character positions where the pattern first matches + the elements in input strings. + + For details, see :cpp:func:`cudf::strings::find_re` + + Parameters + ---------- + input : Column + Strings instance for this operation + pattern : RegexProgram + Regex pattern + + Returns + ------- + Column + New column of integers + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_findall.find_re( + input.view(), + pattern.c_obj.get()[0] + ) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_findall.py b/python/pylibcudf/pylibcudf/tests/test_string_findall.py index 994552fa276..debfad92d00 100644 --- a/python/pylibcudf/pylibcudf/tests/test_string_findall.py +++ b/python/pylibcudf/pylibcudf/tests/test_string_findall.py @@ -21,3 +21,20 @@ def test_findall(): type=pa_result.type, ) assert_column_eq(result, expected) + + +def test_find_re(): + arr = pa.array(["bunny", "rabbit", "hare", "dog"]) + pattern = "[eb]" + result = plc.strings.findall.find_re( + plc.interop.from_arrow(arr), + plc.strings.regex_program.RegexProgram.create( + pattern, plc.strings.regex_flags.RegexFlags.DEFAULT + ), + ) + pa_result = plc.interop.to_arrow(result) + expected = pa.array( + [0, 2, 3, -1], + type=pa_result.type, + ) + assert_column_eq(result, expected) From bd3b3327a6326ffea4658d682b8b9087e32da98a Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 3 Oct 2024 16:25:09 -0400 Subject: [PATCH 05/12] Restore export of nvcomp outside of wheel builds (#16988) Fixes https://github.com/rapidsai/cudf/issues/16986 Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/16988 --- cpp/CMakeLists.txt | 1 + cpp/cmake/thirdparty/get_nvcomp.cmake | 6 +++++- python/libcudf/CMakeLists.txt | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 136f43ee706..f7a5dd2f2fb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -52,6 +52,7 @@ option(JITIFY_USE_CACHE "Use a file cache for JIT compiled kernels" ON) option(CUDF_BUILD_TESTUTIL "Whether to build the test utilities contained in libcudf" ON) mark_as_advanced(CUDF_BUILD_TESTUTIL) option(CUDF_USE_PROPRIETARY_NVCOMP "Download and use NVCOMP with proprietary extensions" ON) +option(CUDF_EXPORT_NVCOMP "Export NVCOMP as a dependency" ON) option(CUDF_LARGE_STRINGS_DISABLED "Build with large string support disabled" OFF) mark_as_advanced(CUDF_LARGE_STRINGS_DISABLED) option( diff --git a/cpp/cmake/thirdparty/get_nvcomp.cmake b/cpp/cmake/thirdparty/get_nvcomp.cmake index 1b6a1730161..33b1b45fb44 100644 --- a/cpp/cmake/thirdparty/get_nvcomp.cmake +++ b/cpp/cmake/thirdparty/get_nvcomp.cmake @@ -16,7 +16,11 @@ function(find_and_configure_nvcomp) include(${rapids-cmake-dir}/cpm/nvcomp.cmake) - rapids_cpm_nvcomp(USE_PROPRIETARY_BINARY ${CUDF_USE_PROPRIETARY_NVCOMP}) + set(export_args) + if(CUDF_EXPORT_NVCOMP) + set(export_args BUILD_EXPORT_SET cudf-exports INSTALL_EXPORT_SET cudf-exports) + endif() + rapids_cpm_nvcomp(${export_args} USE_PROPRIETARY_BINARY ${CUDF_USE_PROPRIETARY_NVCOMP}) # Per-thread default stream if(TARGET nvcomp AND CUDF_USE_PER_THREAD_DEFAULT_STREAM) diff --git a/python/libcudf/CMakeLists.txt b/python/libcudf/CMakeLists.txt index 2b208e2e021..5f9a04d3cee 100644 --- a/python/libcudf/CMakeLists.txt +++ b/python/libcudf/CMakeLists.txt @@ -41,6 +41,9 @@ set(BUILD_TESTS OFF) set(BUILD_BENCHMARKS OFF) set(CUDF_BUILD_TESTUTIL OFF) set(CUDF_BUILD_STREAMS_TEST_UTIL OFF) +if(USE_NVCOMP_RUNTIME_WHEEL) + set(CUDF_EXPORT_NVCOMP OFF) +endif() set(CUDA_STATIC_RUNTIME ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) From 010839172ecb5a99609044a98031ff5b7578cd64 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Thu, 3 Oct 2024 19:44:20 -0500 Subject: [PATCH 06/12] Use `libcudf` wheel from PR rather than nightly for `polars-polars` CI test job (#16975) This PR fixes an issue where one `cudf-polars` CI job uses the `pylibcudf` wheel generated from the branch being tested, but pulls a libcudf nightly which can cause issues when introducing cython/c++ changes simultaneously. Authors: - https://github.com/brandon-b-miller Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/16975 --- ci/test_cudf_polars_polars_tests.sh | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ci/test_cudf_polars_polars_tests.sh b/ci/test_cudf_polars_polars_tests.sh index 55399d0371a..f5bcdc62604 100755 --- a/ci/test_cudf_polars_polars_tests.sh +++ b/ci/test_cudf_polars_polars_tests.sh @@ -24,14 +24,17 @@ rapids-logger "Download wheels" RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" RAPIDS_PY_WHEEL_NAME="cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}" RAPIDS_PY_WHEEL_PURE="1" rapids-download-wheels-from-s3 ./dist -# Download the pylibcudf built in the previous step -RAPIDS_PY_WHEEL_NAME="pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-pylibcudf-dep +# Download libcudf and pylibcudf built in the previous step +RAPIDS_PY_WHEEL_NAME="libcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 cpp ./local-libcudf-dep +RAPIDS_PY_WHEEL_NAME="pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 python ./local-pylibcudf-dep -rapids-logger "Install pylibcudf" -python -m pip install ./local-pylibcudf-dep/pylibcudf*.whl +rapids-logger "Install libcudf, pylibcudf and cudf_polars" +python -m pip install \ + -v \ + "$(echo ./dist/cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}*.whl)[test]" \ + "$(echo ./local-libcudf-dep/libcudf_${RAPIDS_PY_CUDA_SUFFIX}*.whl)" \ + "$(echo ./local-pylibcudf-dep/pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}*.whl)" -rapids-logger "Install cudf_polars" -python -m pip install $(echo ./dist/cudf_polars*.whl) TAG=$(python -c 'import polars; print(f"py-{polars.__version__}")') rapids-logger "Clone polars to ${TAG}" From 2fa2e6a554096181b0a625cdc50368893dbaaa1f Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Fri, 4 Oct 2024 16:08:37 +0100 Subject: [PATCH 07/12] Switched AST benchmarks from GoogleBench to NVBench (#16952) This merge request switches the Benchmarking solution for the AST benchmark from GoogleBench to NVBench. ~It also refactors the L2 cache flushing functionality of `cuda_event_timer` into a separate function `flush_L2_device_cache`, since NVBench already performs the timing, synchronization, and timer setup necessary.~ Authors: - Basit Ayantunde (https://github.com/lamarrr) Approvers: - Bradley Dice (https://github.com/bdice) - Yunsong Wang (https://github.com/PointKernel) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/16952 --- cpp/benchmarks/CMakeLists.txt | 2 +- cpp/benchmarks/ast/transform.cpp | 51 +++++++++++--------------------- 2 files changed, 18 insertions(+), 35 deletions(-) diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 110b4557840..1e13bf176c1 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -330,7 +330,7 @@ ConfigureNVBench(CSV_WRITER_NVBENCH io/csv/csv_writer.cpp) # ################################################################################################## # * ast benchmark --------------------------------------------------------------------------------- -ConfigureBench(AST_BENCH ast/transform.cpp) +ConfigureNVBench(AST_NVBENCH ast/transform.cpp) # ################################################################################################## # * binaryop benchmark ---------------------------------------------------------------------------- diff --git a/cpp/benchmarks/ast/transform.cpp b/cpp/benchmarks/ast/transform.cpp index 65a44532cf1..f44f26e4d2c 100644 --- a/cpp/benchmarks/ast/transform.cpp +++ b/cpp/benchmarks/ast/transform.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,14 +15,16 @@ */ #include -#include -#include #include #include +#include + #include +#include + #include #include #include @@ -35,13 +37,10 @@ enum class TreeType { }; template -class AST : public cudf::benchmark {}; - -template -static void BM_ast_transform(benchmark::State& state) +static void BM_ast_transform(nvbench::state& state) { - auto const table_size{static_cast(state.range(0))}; - auto const tree_levels{static_cast(state.range(1))}; + auto const table_size = static_cast(state.get_int64("table_size")); + auto const tree_levels = static_cast(state.get_int64("tree_levels")); // Create table data auto const n_cols = reuse_columns ? 1 : tree_levels + 1; @@ -86,38 +85,22 @@ static void BM_ast_transform(benchmark::State& state) auto const& expression_tree_root = expressions.back(); - // Execute benchmark - for (auto _ : state) { - cuda_event_timer raii(state, true); // flush_l2_cache = true, stream = 0 - cudf::compute_column(table, expression_tree_root); - } - // Use the number of bytes read from global memory - state.SetBytesProcessed(static_cast(state.iterations()) * state.range(0) * - (tree_levels + 1) * sizeof(key_type)); -} + state.add_global_memory_reads(table_size * (tree_levels + 1)); -static void CustomRanges(benchmark::internal::Benchmark* b) -{ - auto row_counts = std::vector{100'000, 1'000'000, 10'000'000, 100'000'000}; - auto operation_counts = std::vector{1, 5, 10}; - for (auto const& row_count : row_counts) { - for (auto const& operation_count : operation_counts) { - b->Args({row_count, operation_count}); - } - } + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch&) { cudf::compute_column(table, expression_tree_root); }); } #define AST_TRANSFORM_BENCHMARK_DEFINE(name, key_type, tree_type, reuse_columns, nullable) \ - BENCHMARK_TEMPLATE_DEFINE_F(AST, name, key_type, tree_type, reuse_columns, nullable) \ - (::benchmark::State & st) \ + static void name(::nvbench::state& st) \ { \ - BM_ast_transform(st); \ + ::BM_ast_transform(st); \ } \ - BENCHMARK_REGISTER_F(AST, name) \ - ->Apply(CustomRanges) \ - ->Unit(benchmark::kMillisecond) \ - ->UseManualTime(); + NVBENCH_BENCH(name) \ + .set_name(#name) \ + .add_int64_axis("tree_levels", {1, 5, 10}) \ + .add_int64_axis("table_size", {100'000, 1'000'000, 10'000'000, 100'000'000}) AST_TRANSFORM_BENCHMARK_DEFINE( ast_int32_imbalanced_unique, int32_t, TreeType::IMBALANCED_LEFT, false, false); From a78432184f20f7acf493eaa8d1928cfee29d1771 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Fri, 4 Oct 2024 16:19:37 +0100 Subject: [PATCH 08/12] Switched BINARY_OP Benchmarks from GoogleBench to NVBench (#16963) This merge request switches the Benchmarking solution for the BINARY_OP benchmarks from GoogleBench to NVBench Authors: - Basit Ayantunde (https://github.com/lamarrr) Approvers: - Nghia Truong (https://github.com/ttnghia) - Tianyu Liu (https://github.com/kingcrimsontianyu) URL: https://github.com/rapidsai/cudf/pull/16963 --- cpp/benchmarks/CMakeLists.txt | 2 +- cpp/benchmarks/binaryop/binaryop.cpp | 65 ++++++------------- cpp/benchmarks/binaryop/compiled_binaryop.cpp | 47 ++++++-------- 3 files changed, 40 insertions(+), 74 deletions(-) diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 1e13bf176c1..b8a53cd8bd9 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -334,7 +334,7 @@ ConfigureNVBench(AST_NVBENCH ast/transform.cpp) # ################################################################################################## # * binaryop benchmark ---------------------------------------------------------------------------- -ConfigureBench(BINARYOP_BENCH binaryop/binaryop.cpp binaryop/compiled_binaryop.cpp) +ConfigureNVBench(BINARYOP_NVBENCH binaryop/binaryop.cpp binaryop/compiled_binaryop.cpp) # ################################################################################################## # * nvtext benchmark ------------------------------------------------------------------- diff --git a/cpp/benchmarks/binaryop/binaryop.cpp b/cpp/benchmarks/binaryop/binaryop.cpp index fa98d9e601a..7d267a88764 100644 --- a/cpp/benchmarks/binaryop/binaryop.cpp +++ b/cpp/benchmarks/binaryop/binaryop.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,15 +15,14 @@ */ #include -#include -#include #include #include #include +#include + #include -#include // This set of benchmarks is designed to be a comparison for the AST benchmarks @@ -33,13 +32,10 @@ enum class TreeType { }; template -class BINARYOP : public cudf::benchmark {}; - -template -static void BM_binaryop_transform(benchmark::State& state) +static void BM_binaryop_transform(nvbench::state& state) { - auto const table_size{static_cast(state.range(0))}; - auto const tree_levels{static_cast(state.range(1))}; + auto const table_size{static_cast(state.get_int64("table_size"))}; + auto const tree_levels{static_cast(state.get_int64("tree_levels"))}; // Create table data auto const n_cols = reuse_columns ? 1 : tree_levels + 1; @@ -47,9 +43,10 @@ static void BM_binaryop_transform(benchmark::State& state) cycle_dtypes({cudf::type_to_id()}, n_cols), row_count{table_size}); cudf::table_view table{*source_table}; - // Execute benchmark - for (auto _ : state) { - cuda_event_timer raii(state, true); // flush_l2_cache = true, stream = 0 + // Use the number of bytes read from global memory + state.add_global_memory_reads(table_size * (tree_levels + 1)); + + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { // Execute tree that chains additions like (((a + b) + c) + d) auto const op = cudf::binary_operator::ADD; auto const result_data_type = cudf::data_type(cudf::type_to_id()); @@ -64,16 +61,18 @@ static void BM_binaryop_transform(benchmark::State& state) result = cudf::binary_operation(result->view(), col, op, result_data_type); }); } - } - - // Use the number of bytes read from global memory - state.SetBytesProcessed(static_cast(state.iterations()) * state.range(0) * - (tree_levels + 1) * sizeof(key_type)); + }); } #define BINARYOP_TRANSFORM_BENCHMARK_DEFINE(name, key_type, tree_type, reuse_columns) \ - BENCHMARK_TEMPLATE_DEFINE_F(BINARYOP, name, key_type, tree_type, reuse_columns) \ - (::benchmark::State & st) { BM_binaryop_transform(st); } + \ + static void name(::nvbench::state& st) \ + { \ + BM_binaryop_transform(st); \ + } \ + NVBENCH_BENCH(name) \ + .add_int64_axis("tree_levels", {1, 2, 5, 10}) \ + .add_int64_axis("table_size", {100'000, 1'000'000, 10'000'000, 100'000'000}) BINARYOP_TRANSFORM_BENCHMARK_DEFINE(binaryop_int32_imbalanced_unique, int32_t, @@ -87,29 +86,3 @@ BINARYOP_TRANSFORM_BENCHMARK_DEFINE(binaryop_double_imbalanced_unique, double, TreeType::IMBALANCED_LEFT, false); - -static void CustomRanges(benchmark::internal::Benchmark* b) -{ - auto row_counts = std::vector{100'000, 1'000'000, 10'000'000, 100'000'000}; - auto operation_counts = std::vector{1, 2, 5, 10}; - for (auto const& row_count : row_counts) { - for (auto const& operation_count : operation_counts) { - b->Args({row_count, operation_count}); - } - } -} - -BENCHMARK_REGISTER_F(BINARYOP, binaryop_int32_imbalanced_unique) - ->Apply(CustomRanges) - ->Unit(benchmark::kMillisecond) - ->UseManualTime(); - -BENCHMARK_REGISTER_F(BINARYOP, binaryop_int32_imbalanced_reuse) - ->Apply(CustomRanges) - ->Unit(benchmark::kMillisecond) - ->UseManualTime(); - -BENCHMARK_REGISTER_F(BINARYOP, binaryop_double_imbalanced_unique) - ->Apply(CustomRanges) - ->Unit(benchmark::kMillisecond) - ->UseManualTime(); diff --git a/cpp/benchmarks/binaryop/compiled_binaryop.cpp b/cpp/benchmarks/binaryop/compiled_binaryop.cpp index 7086a61c7c5..bc0ff69bce9 100644 --- a/cpp/benchmarks/binaryop/compiled_binaryop.cpp +++ b/cpp/benchmarks/binaryop/compiled_binaryop.cpp @@ -15,20 +15,18 @@ */ #include -#include -#include #include -class COMPILED_BINARYOP : public cudf::benchmark {}; +#include template -void BM_compiled_binaryop(benchmark::State& state, cudf::binary_operator binop) +void BM_compiled_binaryop(nvbench::state& state, cudf::binary_operator binop) { - auto const column_size{static_cast(state.range(0))}; + auto const table_size = static_cast(state.get_int64("table_size")); auto const source_table = create_random_table( - {cudf::type_to_id(), cudf::type_to_id()}, row_count{column_size}); + {cudf::type_to_id(), cudf::type_to_id()}, row_count{table_size}); auto lhs = cudf::column_view(source_table->get_column(0)); auto rhs = cudf::column_view(source_table->get_column(1)); @@ -38,31 +36,26 @@ void BM_compiled_binaryop(benchmark::State& state, cudf::binary_operator binop) // Call once for hot cache. cudf::binary_operation(lhs, rhs, binop, output_dtype); - for (auto _ : state) { - cuda_event_timer timer(state, true); - cudf::binary_operation(lhs, rhs, binop, output_dtype); - } - // use number of bytes read and written to global memory - state.SetBytesProcessed(static_cast(state.iterations()) * column_size * - (sizeof(TypeLhs) + sizeof(TypeRhs) + sizeof(TypeOut))); + state.add_global_memory_reads(table_size); + state.add_global_memory_reads(table_size); + state.add_global_memory_reads(table_size); + + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch&) { cudf::binary_operation(lhs, rhs, binop, output_dtype); }); } +#define BM_STRINGIFY(a) #a + // TODO tparam boolean for null. -#define BM_BINARYOP_BENCHMARK_DEFINE(name, lhs, rhs, bop, tout) \ - BENCHMARK_DEFINE_F(COMPILED_BINARYOP, name) \ - (::benchmark::State & st) \ - { \ - BM_compiled_binaryop(st, cudf::binary_operator::bop); \ - } \ - BENCHMARK_REGISTER_F(COMPILED_BINARYOP, name) \ - ->Unit(benchmark::kMicrosecond) \ - ->UseManualTime() \ - ->Arg(10000) /* 10k */ \ - ->Arg(100000) /* 100k */ \ - ->Arg(1000000) /* 1M */ \ - ->Arg(10000000) /* 10M */ \ - ->Arg(100000000); /* 100M */ +#define BM_BINARYOP_BENCHMARK_DEFINE(name, lhs, rhs, bop, tout) \ + static void name(::nvbench::state& st) \ + { \ + ::BM_compiled_binaryop(st, ::cudf::binary_operator::bop); \ + } \ + NVBENCH_BENCH(name) \ + .set_name("compiled_binary_op_" BM_STRINGIFY(name)) \ + .add_int64_axis("table_size", {10'000, 100'000, 1'000'000, 10'000'000, 100'000'000}) #define build_name(a, b, c, d) a##_##b##_##c##_##d From 39342b8762c734aa2a94b94815bef75869a4e59c Mon Sep 17 00:00:00 2001 From: Vukasin Milovanovic Date: Fri, 4 Oct 2024 09:39:20 -0700 Subject: [PATCH 09/12] Properly handle the mapped and registered regions in `memory_mapped_source` (#16865) Depends on https://github.com/rapidsai/cudf/pull/16826 Set of fixes that improve robustness on the non-GDS file input: 1. Avoid registering beyond the byte range - addresses problems when reading adjacent byte ranges from multiple threads (GH only). 2. Allow reading data outside of the memory mapped region. This prevents issues with very long rows in CSV or JSON input. 3. Copy host data when the range being read is only partially registered. This avoids errors when trying to copy the host data range to the device (GH only). Modifies the datasource class hierarchy to avoid reuse of direct file `host_read`s Authors: - Vukasin Milovanovic (https://github.com/vuule) Approvers: - Basit Ayantunde (https://github.com/lamarrr) - Mads R. B. Kristensen (https://github.com/madsbk) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/16865 --- cpp/include/cudf/io/datasource.hpp | 22 +++- cpp/src/io/functions.cpp | 14 ++- cpp/src/io/utilities/datasource.cpp | 157 +++++++++++++++++----------- cpp/tests/io/csv_test.cpp | 35 +++++++ 4 files changed, 158 insertions(+), 70 deletions(-) diff --git a/cpp/include/cudf/io/datasource.hpp b/cpp/include/cudf/io/datasource.hpp index b12fbe39a57..dc14802adc1 100644 --- a/cpp/include/cudf/io/datasource.hpp +++ b/cpp/include/cudf/io/datasource.hpp @@ -86,14 +86,28 @@ class datasource { /** * @brief Creates a source from a file path. * + * @note Parameters `offset`, `max_size_estimate` and `min_size_estimate` are hints to the + * `datasource` implementation about the expected range of the data that will be read. The + * implementation may use these hints to optimize the read operation. These parameters are usually + * based on the byte range option. In this case, `min_size_estimate` should be no greater than the + * byte range to avoid potential issues when reading adjacent ranges. `max_size_estimate` can + * include padding after the byte range, to include additional data that may be needed for + * processing. + * + @throws cudf::logic_error if the minimum size estimate is greater than the maximum size estimate + * * @param[in] filepath Path to the file to use - * @param[in] offset Bytes from the start of the file (the default is zero) - * @param[in] size Bytes from the offset; use zero for entire file (the default is zero) + * @param[in] offset Starting byte offset from which data will be read (the default is zero) + * @param[in] max_size_estimate Upper estimate of the data range that will be read (the default is + * zero, which means the whole file after `offset`) + * @param[in] min_size_estimate Lower estimate of the data range that will be read (the default is + * zero, which means the whole file after `offset`) * @return Constructed datasource object */ static std::unique_ptr create(std::string const& filepath, - size_t offset = 0, - size_t size = 0); + size_t offset = 0, + size_t max_size_estimate = 0, + size_t min_size_estimate = 0); /** * @brief Creates a source from a host memory buffer. diff --git a/cpp/src/io/functions.cpp b/cpp/src/io/functions.cpp index de8eea9e99b..5a060902eb2 100644 --- a/cpp/src/io/functions.cpp +++ b/cpp/src/io/functions.cpp @@ -122,14 +122,16 @@ chunked_parquet_writer_options_builder chunked_parquet_writer_options::builder( namespace { std::vector> make_datasources(source_info const& info, - size_t range_offset = 0, - size_t range_size = 0) + size_t offset = 0, + size_t max_size_estimate = 0, + size_t min_size_estimate = 0) { switch (info.type()) { case io_type::FILEPATH: { auto sources = std::vector>(); for (auto const& filepath : info.filepaths()) { - sources.emplace_back(cudf::io::datasource::create(filepath, range_offset, range_size)); + sources.emplace_back( + cudf::io::datasource::create(filepath, offset, max_size_estimate, min_size_estimate)); } return sources; } @@ -211,7 +213,8 @@ table_with_metadata read_json(json_reader_options options, auto datasources = make_datasources(options.get_source(), options.get_byte_range_offset(), - options.get_byte_range_size_with_padding()); + options.get_byte_range_size_with_padding(), + options.get_byte_range_size()); return json::detail::read_json(datasources, options, stream, mr); } @@ -238,7 +241,8 @@ table_with_metadata read_csv(csv_reader_options options, auto datasources = make_datasources(options.get_source(), options.get_byte_range_offset(), - options.get_byte_range_size_with_padding()); + options.get_byte_range_size_with_padding(), + options.get_byte_range_size()); CUDF_EXPECTS(datasources.size() == 1, "Only a single source is currently supported."); diff --git a/cpp/src/io/utilities/datasource.cpp b/cpp/src/io/utilities/datasource.cpp index e4313eba454..0be976b6144 100644 --- a/cpp/src/io/utilities/datasource.cpp +++ b/cpp/src/io/utilities/datasource.cpp @@ -32,6 +32,7 @@ #include #include +#include namespace cudf { namespace io { @@ -54,6 +55,30 @@ class file_source : public datasource { } } + std::unique_ptr host_read(size_t offset, size_t size) override + { + lseek(_file.desc(), offset, SEEK_SET); + + // Clamp length to available data + ssize_t const read_size = std::min(size, _file.size() - offset); + + std::vector v(read_size); + CUDF_EXPECTS(read(_file.desc(), v.data(), read_size) == read_size, "read failed"); + return buffer::create(std::move(v)); + } + + size_t host_read(size_t offset, size_t size, uint8_t* dst) override + { + lseek(_file.desc(), offset, SEEK_SET); + + // Clamp length to available data + auto const read_size = std::min(size, _file.size() - offset); + + CUDF_EXPECTS(read(_file.desc(), dst, read_size) == static_cast(read_size), + "read failed"); + return read_size; + } + ~file_source() override = default; [[nodiscard]] bool supports_device_read() const override @@ -138,40 +163,63 @@ class file_source : public datasource { */ class memory_mapped_source : public file_source { public: - explicit memory_mapped_source(char const* filepath, size_t offset, size_t size) + explicit memory_mapped_source(char const* filepath, + size_t offset, + size_t max_size_estimate, + size_t min_size_estimate) : file_source(filepath) { if (_file.size() != 0) { - map(_file.desc(), offset, size); - register_mmap_buffer(); + // Memory mapping is not exclusive, so we can include the whole region we expect to read + map(_file.desc(), offset, max_size_estimate); + // Buffer registration is exclusive (can't overlap with other registered buffers) so we + // register the lower estimate; this avoids issues when reading adjacent ranges from the same + // file from multiple threads + register_mmap_buffer(offset, min_size_estimate); } } ~memory_mapped_source() override { if (_map_addr != nullptr) { - munmap(_map_addr, _map_size); + unmap(); unregister_mmap_buffer(); } } std::unique_ptr host_read(size_t offset, size_t size) override { - CUDF_EXPECTS(offset >= _map_offset, "Requested offset is outside mapping"); + // Clamp length to available data + auto const read_size = std::min(size, +_file.size() - offset); + + // If the requested range is outside of the mapped region, read from the file + if (offset < _map_offset or offset + read_size > (_map_offset + _map_size)) { + return file_source::host_read(offset, read_size); + } - // Clamp length to available data in the mapped region - auto const read_size = std::min(size, _map_size - (offset - _map_offset)); + // If the requested range is only partially within the registered region, copy to a new + // host buffer to make the data safe to copy to the device + if (_reg_addr != nullptr and + (offset < _reg_offset or offset + read_size > (_reg_offset + _reg_size))) { + auto const src = static_cast(_map_addr) + (offset - _map_offset); + + return std::make_unique>>( + std::vector(src, src + read_size)); + } return std::make_unique( - static_cast(_map_addr) + (offset - _map_offset), read_size); + static_cast(_map_addr) + offset - _map_offset, read_size); } size_t host_read(size_t offset, size_t size, uint8_t* dst) override { - CUDF_EXPECTS(offset >= _map_offset, "Requested offset is outside mapping"); + // Clamp length to available data + auto const read_size = std::min(size, +_file.size() - offset); - // Clamp length to available data in the mapped region - auto const read_size = std::min(size, _map_size - (offset - _map_offset)); + // If the requested range is outside of the mapped region, read from the file + if (offset < _map_offset or offset + read_size > (_map_offset + _map_size)) { + return file_source::host_read(offset, read_size, dst); + } auto const src = static_cast(_map_addr) + (offset - _map_offset); std::memcpy(dst, src, read_size); @@ -184,16 +232,18 @@ class memory_mapped_source : public file_source { * * Fixes nvbugs/4215160 */ - void register_mmap_buffer() + void register_mmap_buffer(size_t offset, size_t size) { - if (_map_addr == nullptr or _map_size == 0 or not pageableMemoryAccessUsesHostPageTables()) { - return; - } + if (_map_addr == nullptr or not pageableMemoryAccessUsesHostPageTables()) { return; } - auto const result = cudaHostRegister(_map_addr, _map_size, cudaHostRegisterDefault); - if (result == cudaSuccess) { - _is_map_registered = true; - } else { + // Registered region must be within the mapped region + _reg_offset = std::max(offset, _map_offset); + _reg_size = std::min(size != 0 ? size : _map_size, (_map_offset + _map_size) - _reg_offset); + + _reg_addr = static_cast(_map_addr) - _map_offset + _reg_offset; + auto const result = cudaHostRegister(_reg_addr, _reg_size, cudaHostRegisterReadOnly); + if (result != cudaSuccess) { + _reg_addr = nullptr; CUDF_LOG_WARN("cudaHostRegister failed with {} ({})", static_cast(result), cudaGetErrorString(result)); @@ -205,10 +255,12 @@ class memory_mapped_source : public file_source { */ void unregister_mmap_buffer() { - if (not _is_map_registered) { return; } + if (_reg_addr == nullptr) { return; } - auto const result = cudaHostUnregister(_map_addr); - if (result != cudaSuccess) { + auto const result = cudaHostUnregister(_reg_addr); + if (result == cudaSuccess) { + _reg_addr = nullptr; + } else { CUDF_LOG_WARN("cudaHostUnregister failed with {} ({})", static_cast(result), cudaGetErrorString(result)); @@ -226,52 +278,30 @@ class memory_mapped_source : public file_source { // Size for `mmap()` needs to include the page padding _map_size = size + (offset - _map_offset); + if (_map_size == 0) { return; } // Check if accessing a region within already mapped area _map_addr = mmap(nullptr, _map_size, PROT_READ, MAP_PRIVATE, fd, _map_offset); CUDF_EXPECTS(_map_addr != MAP_FAILED, "Cannot create memory mapping"); } - private: - size_t _map_size = 0; - size_t _map_offset = 0; - void* _map_addr = nullptr; - bool _is_map_registered = false; -}; - -/** - * @brief Implementation class for reading from a file using `read` calls - * - * Potentially faster than `memory_mapped_source` when only a small portion of the file is read - * through the host. - */ -class direct_read_source : public file_source { - public: - explicit direct_read_source(char const* filepath) : file_source(filepath) {} - - std::unique_ptr host_read(size_t offset, size_t size) override + void unmap() { - lseek(_file.desc(), offset, SEEK_SET); - - // Clamp length to available data - ssize_t const read_size = std::min(size, _file.size() - offset); - - std::vector v(read_size); - CUDF_EXPECTS(read(_file.desc(), v.data(), read_size) == read_size, "read failed"); - return buffer::create(std::move(v)); + if (_map_addr != nullptr) { + auto const result = munmap(_map_addr, _map_size); + if (result != 0) { CUDF_LOG_WARN("munmap failed with {}", result); } + _map_addr = nullptr; + } } - size_t host_read(size_t offset, size_t size, uint8_t* dst) override - { - lseek(_file.desc(), offset, SEEK_SET); - - // Clamp length to available data - auto const read_size = std::min(size, _file.size() - offset); + private: + size_t _map_offset = 0; + size_t _map_size = 0; + void* _map_addr = nullptr; - CUDF_EXPECTS(read(_file.desc(), dst, read_size) == static_cast(read_size), - "read failed"); - return read_size; - } + size_t _reg_offset = 0; + size_t _reg_size = 0; + void* _reg_addr = nullptr; }; /** @@ -431,16 +461,21 @@ class user_datasource_wrapper : public datasource { std::unique_ptr datasource::create(std::string const& filepath, size_t offset, - size_t size) + size_t max_size_estimate, + size_t min_size_estimate) { + CUDF_EXPECTS(max_size_estimate == 0 or min_size_estimate <= max_size_estimate, + "Invalid min/max size estimates for datasource creation"); + #ifdef CUFILE_FOUND if (cufile_integration::is_always_enabled()) { // avoid mmap as GDS is expected to be used for most reads - return std::make_unique(filepath.c_str()); + return std::make_unique(filepath.c_str()); } #endif // Use our own memory mapping implementation for direct file reads - return std::make_unique(filepath.c_str(), offset, size); + return std::make_unique( + filepath.c_str(), offset, max_size_estimate, min_size_estimate); } std::unique_ptr datasource::create(host_buffer const& buffer) diff --git a/cpp/tests/io/csv_test.cpp b/cpp/tests/io/csv_test.cpp index dc14824d834..0028dd946e3 100644 --- a/cpp/tests/io/csv_test.cpp +++ b/cpp/tests/io/csv_test.cpp @@ -2516,4 +2516,39 @@ TEST_F(CsvReaderTest, UTF8BOM) CUDF_TEST_EXPECT_TABLES_EQUIVALENT(result_view, expected); } +void expect_buffers_equal(cudf::io::datasource::buffer* lhs, cudf::io::datasource::buffer* rhs) +{ + ASSERT_EQ(lhs->size(), rhs->size()); + EXPECT_EQ(0, std::memcmp(lhs->data(), rhs->data(), lhs->size())); +} + +TEST_F(CsvReaderTest, OutOfMapBoundsReads) +{ + // write a lot of data into a file + auto filepath = temp_env->get_temp_dir() + "OutOfMapBoundsReads.csv"; + auto const num_rows = 1 << 20; + auto const row = std::string{"0,1,2,3,4,5,6,7,8,9\n"}; + auto const file_size = num_rows * row.size(); + { + std::ofstream outfile(filepath, std::ofstream::out); + for (size_t i = 0; i < num_rows; ++i) { + outfile << row; + } + } + + // Only memory map the middle of the file + auto source = cudf::io::datasource::create(filepath, file_size / 2, file_size / 4); + auto full_source = cudf::io::datasource::create(filepath); + auto const all_data = source->host_read(0, file_size); + auto ref_data = full_source->host_read(0, file_size); + expect_buffers_equal(ref_data.get(), all_data.get()); + + auto const start_data = source->host_read(file_size / 2, file_size / 2); + expect_buffers_equal(full_source->host_read(file_size / 2, file_size / 2).get(), + start_data.get()); + + auto const end_data = source->host_read(0, file_size / 2 + 512); + expect_buffers_equal(full_source->host_read(0, file_size / 2 + 512).get(), end_data.get()); +} + CUDF_TEST_PROGRAM_MAIN() From d15bbfdded7181fdc23d33fa5efae181b4af2e2b Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 4 Oct 2024 07:45:54 -1000 Subject: [PATCH 10/12] Allow melt(var_name=) to be a falsy label (#16981) closes https://github.com/rapidsai/cudf/issues/16972 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/16981 --- python/cudf/cudf/core/reshape.py | 2 +- python/cudf/cudf/tests/test_reshape.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/reshape.py b/python/cudf/cudf/core/reshape.py index 6e5abb2b82b..3d132c92d54 100644 --- a/python/cudf/cudf/core/reshape.py +++ b/python/cudf/cudf/core/reshape.py @@ -681,7 +681,7 @@ def _tile(A, reps): nval = len(value_vars) dtype = min_unsigned_type(nval) - if not var_name: + if var_name is None: var_name = "variable" if not value_vars: diff --git a/python/cudf/cudf/tests/test_reshape.py b/python/cudf/cudf/tests/test_reshape.py index 4235affd4d1..3adbe1d2a74 100644 --- a/python/cudf/cudf/tests/test_reshape.py +++ b/python/cudf/cudf/tests/test_reshape.py @@ -119,6 +119,15 @@ def test_melt_str_scalar_id_var(): assert_eq(result, expected) +def test_melt_falsy_var_name(): + df = cudf.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5], "C": [2, 4, 6]}) + result = cudf.melt(df, id_vars=["A"], value_vars=["B"], var_name="") + expected = pd.melt( + df.to_pandas(), id_vars=["A"], value_vars=["B"], var_name="" + ) + assert_eq(result, expected) + + @pytest.mark.parametrize("num_cols", [1, 2, 10]) @pytest.mark.parametrize("num_rows", [1, 2, 1000]) @pytest.mark.parametrize( From 04c17ded6563f4caaeeb51319672c10587401e33 Mon Sep 17 00:00:00 2001 From: Matthew Murray <41342305+Matt711@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:06:23 -0400 Subject: [PATCH 11/12] [FEA] Migrate nvtext/edit_distance APIs to pylibcudf (#16957) Apart of #15162. This PR migrates `edit_distance.pxd` to pylibcudf Authors: - Matthew Murray (https://github.com/Matt711) Approvers: - Matthew Roeschke (https://github.com/mroeschke) - Yunsong Wang (https://github.com/PointKernel) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/16957 --- cpp/include/nvtext/edit_distance.hpp | 2 +- .../user_guide/api_docs/pylibcudf/index.rst | 1 + .../pylibcudf/nvtext/edit_distance.rst | 6 ++ .../api_docs/pylibcudf/nvtext/index.rst | 7 +++ .../cudf/cudf/_lib/nvtext/edit_distance.pyx | 34 +++------- python/pylibcudf/pylibcudf/CMakeLists.txt | 1 + python/pylibcudf/pylibcudf/__init__.pxd | 2 + python/pylibcudf/pylibcudf/__init__.py | 2 + .../pylibcudf/pylibcudf/nvtext/CMakeLists.txt | 22 +++++++ .../pylibcudf/pylibcudf/nvtext/__init__.pxd | 7 +++ python/pylibcudf/pylibcudf/nvtext/__init__.py | 7 +++ .../pylibcudf/nvtext/edit_distance.pxd | 8 +++ .../pylibcudf/nvtext/edit_distance.pyx | 63 +++++++++++++++++++ .../tests/test_nvtext_edit_distance.py | 34 ++++++++++ 14 files changed, 171 insertions(+), 25 deletions(-) create mode 100644 docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/edit_distance.rst create mode 100644 docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/index.rst create mode 100644 python/pylibcudf/pylibcudf/nvtext/CMakeLists.txt create mode 100644 python/pylibcudf/pylibcudf/nvtext/__init__.pxd create mode 100644 python/pylibcudf/pylibcudf/nvtext/__init__.py create mode 100644 python/pylibcudf/pylibcudf/nvtext/edit_distance.pxd create mode 100644 python/pylibcudf/pylibcudf/nvtext/edit_distance.pyx create mode 100644 python/pylibcudf/pylibcudf/tests/test_nvtext_edit_distance.py diff --git a/cpp/include/nvtext/edit_distance.hpp b/cpp/include/nvtext/edit_distance.hpp index 723ba310a1e..dca590baebf 100644 --- a/cpp/include/nvtext/edit_distance.hpp +++ b/cpp/include/nvtext/edit_distance.hpp @@ -57,7 +57,7 @@ namespace CUDF_EXPORT nvtext { * @param targets Strings to compute edit distance against `input` * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the returned column's device memory - * @return New strings columns of with replaced strings + * @return New lists column of edit distance values */ std::unique_ptr edit_distance( cudf::strings_column_view const& input, diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst index e21536e2e97..052479d6720 100644 --- a/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst @@ -49,3 +49,4 @@ This page provides API documentation for pylibcudf. io/index.rst strings/index.rst + nvtext/index.rst diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/edit_distance.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/edit_distance.rst new file mode 100644 index 00000000000..abb45e426a8 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/edit_distance.rst @@ -0,0 +1,6 @@ +============= +edit_distance +============= + +.. automodule:: pylibcudf.nvtext.edit_distance + :members: diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/index.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/index.rst new file mode 100644 index 00000000000..b5cd5ee42c3 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/nvtext/index.rst @@ -0,0 +1,7 @@ +nvtext +====== + +.. toctree:: + :maxdepth: 1 + + edit_distance diff --git a/python/cudf/cudf/_lib/nvtext/edit_distance.pyx b/python/cudf/cudf/_lib/nvtext/edit_distance.pyx index e3c2273345a..3dd99c42d76 100644 --- a/python/cudf/cudf/_lib/nvtext/edit_distance.pyx +++ b/python/cudf/cudf/_lib/nvtext/edit_distance.pyx @@ -2,37 +2,23 @@ from cudf.core.buffer import acquire_spill_lock -from libcpp.memory cimport unique_ptr -from libcpp.utility cimport move - -from pylibcudf.libcudf.column.column cimport column -from pylibcudf.libcudf.column.column_view cimport column_view -from pylibcudf.libcudf.nvtext.edit_distance cimport ( - edit_distance as cpp_edit_distance, - edit_distance_matrix as cpp_edit_distance_matrix, -) +from pylibcudf cimport nvtext from cudf._lib.column cimport Column @acquire_spill_lock() def edit_distance(Column strings, Column targets): - cdef column_view c_strings = strings.view() - cdef column_view c_targets = targets.view() - cdef unique_ptr[column] c_result - - with nogil: - c_result = move(cpp_edit_distance(c_strings, c_targets)) - - return Column.from_unique_ptr(move(c_result)) + result = nvtext.edit_distance.edit_distance( + strings.to_pylibcudf(mode="read"), + targets.to_pylibcudf(mode="read") + ) + return Column.from_pylibcudf(result) @acquire_spill_lock() def edit_distance_matrix(Column strings): - cdef column_view c_strings = strings.view() - cdef unique_ptr[column] c_result - - with nogil: - c_result = move(cpp_edit_distance_matrix(c_strings)) - - return Column.from_unique_ptr(move(c_result)) + result = nvtext.edit_distance.edit_distance_matrix( + strings.to_pylibcudf(mode="read") + ) + return Column.from_pylibcudf(result) diff --git a/python/pylibcudf/pylibcudf/CMakeLists.txt b/python/pylibcudf/pylibcudf/CMakeLists.txt index a7cb66d7b16..1d72eacac12 100644 --- a/python/pylibcudf/pylibcudf/CMakeLists.txt +++ b/python/pylibcudf/pylibcudf/CMakeLists.txt @@ -66,3 +66,4 @@ target_link_libraries(pylibcudf_interop PUBLIC nanoarrow) add_subdirectory(libcudf) add_subdirectory(strings) add_subdirectory(io) +add_subdirectory(nvtext) diff --git a/python/pylibcudf/pylibcudf/__init__.pxd b/python/pylibcudf/pylibcudf/__init__.pxd index a384edd456d..b98b37fe0fd 100644 --- a/python/pylibcudf/pylibcudf/__init__.pxd +++ b/python/pylibcudf/pylibcudf/__init__.pxd @@ -17,6 +17,7 @@ from . cimport ( lists, merge, null_mask, + nvtext, partitioning, quantiles, reduce, @@ -78,4 +79,5 @@ __all__ = [ "transpose", "types", "unary", + "nvtext", ] diff --git a/python/pylibcudf/pylibcudf/__init__.py b/python/pylibcudf/pylibcudf/__init__.py index 2a5365e8fad..304f27be340 100644 --- a/python/pylibcudf/pylibcudf/__init__.py +++ b/python/pylibcudf/pylibcudf/__init__.py @@ -28,6 +28,7 @@ lists, merge, null_mask, + nvtext, partitioning, quantiles, reduce, @@ -92,4 +93,5 @@ "transpose", "types", "unary", + "nvtext", ] diff --git a/python/pylibcudf/pylibcudf/nvtext/CMakeLists.txt b/python/pylibcudf/pylibcudf/nvtext/CMakeLists.txt new file mode 100644 index 00000000000..ebe1fda1f12 --- /dev/null +++ b/python/pylibcudf/pylibcudf/nvtext/CMakeLists.txt @@ -0,0 +1,22 @@ +# ============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +set(cython_sources edit_distance.pyx) + +set(linked_libraries cudf::cudf) +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX pylibcudf_nvtext_ ASSOCIATED_TARGETS cudf +) diff --git a/python/pylibcudf/pylibcudf/nvtext/__init__.pxd b/python/pylibcudf/pylibcudf/nvtext/__init__.pxd new file mode 100644 index 00000000000..82f7c425b1d --- /dev/null +++ b/python/pylibcudf/pylibcudf/nvtext/__init__.pxd @@ -0,0 +1,7 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from . cimport edit_distance + +__all__ = [ + "edit_distance", +] diff --git a/python/pylibcudf/pylibcudf/nvtext/__init__.py b/python/pylibcudf/pylibcudf/nvtext/__init__.py new file mode 100644 index 00000000000..986652a241f --- /dev/null +++ b/python/pylibcudf/pylibcudf/nvtext/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from . import edit_distance + +__all__ = [ + "edit_distance", +] diff --git a/python/pylibcudf/pylibcudf/nvtext/edit_distance.pxd b/python/pylibcudf/pylibcudf/nvtext/edit_distance.pxd new file mode 100644 index 00000000000..446b95afabb --- /dev/null +++ b/python/pylibcudf/pylibcudf/nvtext/edit_distance.pxd @@ -0,0 +1,8 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from pylibcudf.column cimport Column + + +cpdef Column edit_distance(Column input, Column targets) + +cpdef Column edit_distance_matrix(Column input) diff --git a/python/pylibcudf/pylibcudf/nvtext/edit_distance.pyx b/python/pylibcudf/pylibcudf/nvtext/edit_distance.pyx new file mode 100644 index 00000000000..fc98ccbc50c --- /dev/null +++ b/python/pylibcudf/pylibcudf/nvtext/edit_distance.pyx @@ -0,0 +1,63 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.column.column_view cimport column_view +from pylibcudf.libcudf.nvtext.edit_distance cimport ( + edit_distance as cpp_edit_distance, + edit_distance_matrix as cpp_edit_distance_matrix, +) + + +cpdef Column edit_distance(Column input, Column targets): + """ + Returns the edit distance between individual strings in two strings columns + + For details, see :cpp:func:`edit_distance` + + Parameters + ---------- + input : Column + Input strings + targets : Column + Strings to compute edit distance against + + Returns + ------- + Column + New column of edit distance values + """ + cdef column_view c_strings = input.view() + cdef column_view c_targets = targets.view() + cdef unique_ptr[column] c_result + + with nogil: + c_result = move(cpp_edit_distance(c_strings, c_targets)) + + return Column.from_libcudf(move(c_result)) + + +cpdef Column edit_distance_matrix(Column input): + """ + Returns the edit distance between all strings in the input strings column + + For details, see :cpp:func:`edit_distance_matrix` + + Parameters + ---------- + input : Column + Input strings + + Returns + ------- + Column + New column of edit distance values + """ + cdef column_view c_strings = input.view() + cdef unique_ptr[column] c_result + + with nogil: + c_result = move(cpp_edit_distance_matrix(c_strings)) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/tests/test_nvtext_edit_distance.py b/python/pylibcudf/pylibcudf/tests/test_nvtext_edit_distance.py new file mode 100644 index 00000000000..7d93c471cc4 --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_nvtext_edit_distance.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import pyarrow as pa +import pylibcudf as plc +import pytest +from utils import assert_column_eq + + +@pytest.fixture(scope="module") +def edit_distance_data(): + arr1 = ["hallo", "goodbye", "world"] + arr2 = ["hello", "", "world"] + return pa.array(arr1), pa.array(arr2) + + +def test_edit_distance(edit_distance_data): + input_col, targets = edit_distance_data + result = plc.nvtext.edit_distance.edit_distance( + plc.interop.from_arrow(input_col), + plc.interop.from_arrow(targets), + ) + expected = pa.array([1, 7, 0], type=pa.int32()) + assert_column_eq(result, expected) + + +def test_edit_distance_matrix(edit_distance_data): + input_col, _ = edit_distance_data + result = plc.nvtext.edit_distance.edit_distance_matrix( + plc.interop.from_arrow(input_col) + ) + expected = pa.array( + [[0, 7, 4], [7, 0, 6], [4, 6, 0]], type=pa.list_(pa.int32()) + ) + assert_column_eq(expected, result) From efaa0b50c6ffd15c6506847987cb531e5f6ba955 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 4 Oct 2024 08:20:34 -1000 Subject: [PATCH 12/12] Add string.convert.convert_datetime/convert_booleans APIs to pylibcudf (#16971) Contributes to https://github.com/rapidsai/cudf/issues/15162 Also address a review in https://github.com/rapidsai/cudf/pull/16935#discussion_r1783726477 This also modifies some `format` arguments in `convert_datetime.pyx` to accept `str` instead of `bytes` (`const string&`) to align more with Python. Let me know if you prefer to change this back Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/16971 --- python/cudf/cudf/_lib/string_casting.pyx | 110 +++--------------- python/cudf_polars/cudf_polars/dsl/expr.py | 4 +- .../strings/convert/convert_booleans.pxd | 4 +- .../strings/convert/convert_datetime.pxd | 6 +- .../pylibcudf/strings/convert/CMakeLists.txt | 2 +- .../pylibcudf/strings/convert/__init__.pxd | 2 +- .../pylibcudf/strings/convert/__init__.py | 2 +- .../strings/convert/convert_booleans.pxd | 9 ++ .../strings/convert/convert_booleans.pyx | 91 +++++++++++++++ .../strings/convert/convert_datetime.pxd | 11 +- .../strings/convert/convert_datetime.pyx | 82 +++++++++++-- .../pylibcudf/tests/test_string_convert.py | 2 +- .../tests/test_string_convert_booleans.py | 26 +++++ .../tests/test_string_convert_datetime.py | 46 ++++++++ .../pylibcudf/tests/test_string_wrap.py | 5 +- 15 files changed, 286 insertions(+), 116 deletions(-) create mode 100644 python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pxd create mode 100644 python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pyx create mode 100644 python/pylibcudf/pylibcudf/tests/test_string_convert_booleans.py create mode 100644 python/pylibcudf/pylibcudf/tests/test_string_convert_datetime.py diff --git a/python/cudf/cudf/_lib/string_casting.pyx b/python/cudf/cudf/_lib/string_casting.pyx index 60a6795a402..55ff38f472d 100644 --- a/python/cudf/cudf/_lib/string_casting.pyx +++ b/python/cudf/cudf/_lib/string_casting.pyx @@ -3,9 +3,6 @@ from cudf._lib.column cimport Column from cudf._lib.scalar import as_device_scalar - -from cudf._lib.scalar cimport DeviceScalar - from cudf._lib.types import SUPPORTED_NUMPY_TO_LIBCUDF_TYPES from libcpp.memory cimport unique_ptr @@ -14,14 +11,6 @@ from libcpp.utility cimport move from pylibcudf.libcudf.column.column cimport column from pylibcudf.libcudf.column.column_view cimport column_view -from pylibcudf.libcudf.scalar.scalar cimport string_scalar -from pylibcudf.libcudf.strings.convert.convert_booleans cimport ( - from_booleans as cpp_from_booleans, - to_booleans as cpp_to_booleans, -) -from pylibcudf.libcudf.strings.convert.convert_datetime cimport ( - is_timestamp as cpp_is_timestamp, -) from pylibcudf.libcudf.strings.convert.convert_floats cimport ( from_floats as cpp_from_floats, to_floats as cpp_to_floats, @@ -427,77 +416,21 @@ def stoul(Column input_col): return string_to_integer(input_col, cudf.dtype("uint64")) -def _to_booleans(Column input_col, object string_true="True"): - """ - Converting/Casting input column of type string to boolean column - - Parameters - ---------- - input_col : input column of type string - string_true : string that represents True - - Returns - ------- - A Column with string values cast to boolean - """ - - cdef DeviceScalar str_true = as_device_scalar(string_true) - cdef column_view input_column_view = input_col.view() - cdef const string_scalar* string_scalar_true = ( - str_true.get_raw_ptr()) - cdef unique_ptr[column] c_result - with nogil: - c_result = move( - cpp_to_booleans( - input_column_view, - string_scalar_true[0])) - - return Column.from_unique_ptr(move(c_result)) - - def to_booleans(Column input_col): - - return _to_booleans(input_col) - - -def _from_booleans( - Column input_col, - object string_true="True", - object string_false="False"): - """ - Converting/Casting input column of type boolean to string column - - Parameters - ---------- - input_col : input column of type boolean - string_true : string that represents True - string_false : string that represents False - - Returns - ------- - A Column with boolean values cast to string - """ - - cdef DeviceScalar str_true = as_device_scalar(string_true) - cdef DeviceScalar str_false = as_device_scalar(string_false) - cdef column_view input_column_view = input_col.view() - cdef const string_scalar* string_scalar_true = ( - str_true.get_raw_ptr()) - cdef const string_scalar* string_scalar_false = ( - str_false.get_raw_ptr()) - cdef unique_ptr[column] c_result - with nogil: - c_result = move( - cpp_from_booleans( - input_column_view, - string_scalar_true[0], - string_scalar_false[0])) - - return Column.from_unique_ptr(move(c_result)) + plc_column = plc.strings.convert.convert_booleans.to_booleans( + input_col.to_pylibcudf(mode="read"), + as_device_scalar("True").c_value, + ) + return Column.from_pylibcudf(plc_column) def from_booleans(Column input_col): - return _from_booleans(input_col) + plc_column = plc.strings.convert.convert_booleans.from_booleans( + input_col.to_pylibcudf(mode="read"), + as_device_scalar("True").c_value, + as_device_scalar("False").c_value, + ) + return Column.from_pylibcudf(plc_column) def int2timestamp( @@ -520,11 +453,10 @@ def int2timestamp( A Column with date-time represented in string format """ - cdef string c_timestamp_format = format.encode("UTF-8") return Column.from_pylibcudf( plc.strings.convert.convert_datetime.from_timestamps( input_col.to_pylibcudf(mode="read"), - c_timestamp_format, + format, names.to_pylibcudf(mode="read") ) ) @@ -545,12 +477,11 @@ def timestamp2int(Column input_col, dtype, format): """ dtype = dtype_to_pylibcudf_type(dtype) - cdef string c_timestamp_format = format.encode('UTF-8') return Column.from_pylibcudf( plc.strings.convert.convert_datetime.to_timestamps( input_col.to_pylibcudf(mode="read"), dtype, - c_timestamp_format + format ) ) @@ -572,16 +503,11 @@ def istimestamp(Column input_col, str format): """ if input_col.size == 0: return cudf.core.column.column_empty(0, dtype=cudf.dtype("bool")) - cdef column_view input_column_view = input_col.view() - cdef string c_timestamp_format = str(format).encode('UTF-8') - cdef unique_ptr[column] c_result - with nogil: - c_result = move( - cpp_is_timestamp( - input_column_view, - c_timestamp_format)) - - return Column.from_unique_ptr(move(c_result)) + plc_column = plc.strings.convert.convert_datetime.is_timestamp( + input_col.to_pylibcudf(mode="read"), + format + ) + return Column.from_pylibcudf(plc_column) def timedelta2int(Column input_col, dtype, format): diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index c401e5a2f17..54476b7fedc 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -914,7 +914,7 @@ def do_evaluate( col = self.children[0].evaluate(df, context=context, mapping=mapping) is_timestamps = plc.strings.convert.convert_datetime.is_timestamp( - col.obj, format.encode() + col.obj, format ) if strict: @@ -937,7 +937,7 @@ def do_evaluate( ) return Column( plc.strings.convert.convert_datetime.to_timestamps( - res.columns()[0], self.dtype, format.encode() + res.columns()[0], self.dtype, format ) ) elif self.name == pl_expr.StringFunction.Replace: diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_booleans.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_booleans.pxd index 83a9573baad..e6688cfff81 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_booleans.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_booleans.pxd @@ -8,10 +8,10 @@ from pylibcudf.libcudf.scalar.scalar cimport string_scalar cdef extern from "cudf/strings/convert/convert_booleans.hpp" namespace \ "cudf::strings" nogil: cdef unique_ptr[column] to_booleans( - column_view input_col, + column_view input, string_scalar true_string) except + cdef unique_ptr[column] from_booleans( - column_view input_col, + column_view booleans, string_scalar true_string, string_scalar false_string) except + diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_datetime.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_datetime.pxd index fa8975c4df9..fceddd58df0 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_datetime.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/convert/convert_datetime.pxd @@ -10,14 +10,14 @@ from pylibcudf.libcudf.types cimport data_type cdef extern from "cudf/strings/convert/convert_datetime.hpp" namespace \ "cudf::strings" nogil: cdef unique_ptr[column] to_timestamps( - column_view input_col, + column_view input, data_type timestamp_type, string format) except + cdef unique_ptr[column] from_timestamps( - column_view input_col, + column_view timestamps, string format, - column_view input_strings_names) except + + column_view names) except + cdef unique_ptr[column] is_timestamp( column_view input_col, diff --git a/python/pylibcudf/pylibcudf/strings/convert/CMakeLists.txt b/python/pylibcudf/pylibcudf/strings/convert/CMakeLists.txt index 175c9b3738e..3febc78dfd2 100644 --- a/python/pylibcudf/pylibcudf/strings/convert/CMakeLists.txt +++ b/python/pylibcudf/pylibcudf/strings/convert/CMakeLists.txt @@ -12,7 +12,7 @@ # the License. # ============================================================================= -set(cython_sources convert_durations.pyx convert_datetime.pyx) +set(cython_sources convert_booleans.pyx convert_durations.pyx convert_datetime.pyx) set(linked_libraries cudf::cudf) rapids_cython_create_modules( diff --git a/python/pylibcudf/pylibcudf/strings/convert/__init__.pxd b/python/pylibcudf/pylibcudf/strings/convert/__init__.pxd index 05324cb49df..5525bca46d6 100644 --- a/python/pylibcudf/pylibcudf/strings/convert/__init__.pxd +++ b/python/pylibcudf/pylibcudf/strings/convert/__init__.pxd @@ -1,2 +1,2 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from . cimport convert_datetime, convert_durations +from . cimport convert_booleans, convert_datetime, convert_durations diff --git a/python/pylibcudf/pylibcudf/strings/convert/__init__.py b/python/pylibcudf/pylibcudf/strings/convert/__init__.py index d803399d53c..2340ebe9a26 100644 --- a/python/pylibcudf/pylibcudf/strings/convert/__init__.py +++ b/python/pylibcudf/pylibcudf/strings/convert/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2024, NVIDIA CORPORATION. -from . import convert_datetime, convert_durations +from . import convert_booleans, convert_datetime, convert_durations diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pxd b/python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pxd new file mode 100644 index 00000000000..312ac3c0ca0 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pxd @@ -0,0 +1,9 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from pylibcudf.column cimport Column +from pylibcudf.scalar cimport Scalar + + +cpdef Column to_booleans(Column input, Scalar true_string) + +cpdef Column from_booleans(Column booleans, Scalar true_string, Scalar false_string) diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pyx b/python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pyx new file mode 100644 index 00000000000..0c10f821ab6 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_booleans.pyx @@ -0,0 +1,91 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move +from pylibcudf.column cimport Column +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.scalar.scalar cimport string_scalar +from pylibcudf.libcudf.strings.convert cimport ( + convert_booleans as cpp_convert_booleans, +) +from pylibcudf.scalar cimport Scalar + +from cython.operator import dereference + + +cpdef Column to_booleans(Column input, Scalar true_string): + """ + Returns a new bool column by parsing boolean values from the strings + in the provided strings column. + + For details, see :cpp:func:`cudf::strings::to_booleans`. + + Parameters + ---------- + input : Column + Strings instance for this operation + + true_string : Scalar + String to expect for true. Non-matching strings are false + + Returns + ------- + Column + New bool column converted from strings. + """ + cdef unique_ptr[column] c_result + cdef const string_scalar* c_true_string = ( + true_string.c_obj.get() + ) + + with nogil: + c_result = move( + cpp_convert_booleans.to_booleans( + input.view(), + dereference(c_true_string) + ) + ) + + return Column.from_libcudf(move(c_result)) + +cpdef Column from_booleans(Column booleans, Scalar true_string, Scalar false_string): + """ + Returns a new strings column converting the boolean values from the + provided column into strings. + + For details, see :cpp:func:`cudf::strings::from_booleans`. + + Parameters + ---------- + booleans : Column + Boolean column to convert. + + true_string : Scalar + String to use for true in the output column. + + false_string : Scalar + String to use for false in the output column. + + Returns + ------- + Column + New strings column. + """ + cdef unique_ptr[column] c_result + cdef const string_scalar* c_true_string = ( + true_string.c_obj.get() + ) + cdef const string_scalar* c_false_string = ( + false_string.c_obj.get() + ) + + with nogil: + c_result = move( + cpp_convert_booleans.from_booleans( + booleans.view(), + dereference(c_true_string), + dereference(c_false_string), + ) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pxd b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pxd index 07c84d263d6..80ec168644b 100644 --- a/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pxd +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pxd @@ -8,11 +8,16 @@ from pylibcudf.types cimport DataType cpdef Column to_timestamps( Column input, DataType timestamp_type, - const string& format + str format ) cpdef Column from_timestamps( - Column input, - const string& format, + Column timestamps, + str format, Column input_strings_names ) + +cpdef Column is_timestamp( + Column input, + str format, +) diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pyx b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pyx index fcacb096f87..0ee60812e00 100644 --- a/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pyx +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pyx @@ -15,28 +15,74 @@ from pylibcudf.types import DataType cpdef Column to_timestamps( Column input, DataType timestamp_type, - const string& format + str format ): + """ + Returns a new timestamp column converting a strings column into + timestamps using the provided format pattern. + + For details, see cpp:`cudf::strings::to_timestamps`. + + Parameters + ---------- + input : Column + Strings instance for this operation. + + timestamp_type : DataType + The timestamp type used for creating the output column. + + format : str + String specifying the timestamp format in strings. + + Returns + ------- + Column + New datetime column + """ cdef unique_ptr[column] c_result + cdef string c_format = format.encode() with nogil: c_result = cpp_convert_datetime.to_timestamps( input.view(), timestamp_type.c_obj, - format + c_format ) return Column.from_libcudf(move(c_result)) cpdef Column from_timestamps( - Column input, - const string& format, + Column timestamps, + str format, Column input_strings_names ): + """ + Returns a new strings column converting a timestamp column into + strings using the provided format pattern. + + For details, see cpp:`cudf::strings::from_timestamps`. + + Parameters + ---------- + timestamps : Column + Timestamp values to convert + + format : str + The string specifying output format. + + input_strings_names : Column + The string names to use for weekdays ("%a", "%A") and months ("%b", "%B"). + + Returns + ------- + Column + New strings column with formatted timestamps. + """ cdef unique_ptr[column] c_result + cdef string c_format = format.encode() with nogil: c_result = cpp_convert_datetime.from_timestamps( - input.view(), - format, + timestamps.view(), + c_format, input_strings_names.view() ) @@ -44,13 +90,33 @@ cpdef Column from_timestamps( cpdef Column is_timestamp( Column input, - const string& format + str format ): + """ + Verifies the given strings column can be parsed to timestamps + using the provided format pattern. + + For details, see cpp:`cudf::strings::is_timestamp`. + + Parameters + ---------- + input : Column + Strings instance for this operation. + + format : str + String specifying the timestamp format in strings. + + Returns + ------- + Column + New bool column. + """ cdef unique_ptr[column] c_result + cdef string c_format = format.encode() with nogil: c_result = cpp_convert_datetime.is_timestamp( input.view(), - format + c_format ) return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_convert.py b/python/pylibcudf/pylibcudf/tests/test_string_convert.py index e9e95459d0e..22bb4971cb1 100644 --- a/python/pylibcudf/pylibcudf/tests/test_string_convert.py +++ b/python/pylibcudf/pylibcudf/tests/test_string_convert.py @@ -62,7 +62,7 @@ def test_to_datetime( got = plc.strings.convert.convert_datetime.to_timestamps( plc_timestamp_col, plc.interop.from_arrow(timestamp_type), - format.encode(), + format, ) assert_column_eq(expect, got) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_convert_booleans.py b/python/pylibcudf/pylibcudf/tests/test_string_convert_booleans.py new file mode 100644 index 00000000000..117c59ff1b8 --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_string_convert_booleans.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import pyarrow as pa +import pylibcudf as plc +from utils import assert_column_eq + + +def test_to_booleans(): + pa_array = pa.array(["true", None, "True"]) + result = plc.strings.convert.convert_booleans.to_booleans( + plc.interop.from_arrow(pa_array), + plc.interop.from_arrow(pa.scalar("True")), + ) + expected = pa.array([False, None, True]) + assert_column_eq(result, expected) + + +def test_from_booleans(): + pa_array = pa.array([True, None, False]) + result = plc.strings.convert.convert_booleans.from_booleans( + plc.interop.from_arrow(pa_array), + plc.interop.from_arrow(pa.scalar("A")), + plc.interop.from_arrow(pa.scalar("B")), + ) + expected = pa.array(["A", None, "B"]) + assert_column_eq(result, expected) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_convert_datetime.py b/python/pylibcudf/pylibcudf/tests/test_string_convert_datetime.py new file mode 100644 index 00000000000..f3e84286a36 --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_string_convert_datetime.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import datetime + +import pyarrow as pa +import pyarrow.compute as pc +import pylibcudf as plc +import pytest +from utils import assert_column_eq + + +@pytest.fixture +def fmt(): + return "%Y-%m-%dT%H:%M:%S" + + +def test_to_timestamp(fmt): + arr = pa.array(["2020-01-01T01:01:01", None]) + result = plc.strings.convert.convert_datetime.to_timestamps( + plc.interop.from_arrow(arr), + plc.DataType(plc.TypeId.TIMESTAMP_SECONDS), + fmt, + ) + expected = pc.strptime(arr, fmt, "s") + assert_column_eq(result, expected) + + +def test_from_timestamp(fmt): + arr = pa.array([datetime.datetime(2020, 1, 1, 1, 1, 1), None]) + result = plc.strings.convert.convert_datetime.from_timestamps( + plc.interop.from_arrow(arr), + fmt, + plc.interop.from_arrow(pa.array([], type=pa.string())), + ) + # pc.strftime will add the extra %f + expected = pa.array(["2020-01-01T01:01:01", None]) + assert_column_eq(result, expected) + + +def test_is_timestamp(fmt): + arr = pa.array(["2020-01-01T01:01:01", None, "2020-01-01"]) + result = plc.strings.convert.convert_datetime.is_timestamp( + plc.interop.from_arrow(arr), + fmt, + ) + expected = pa.array([True, None, False]) + assert_column_eq(result, expected) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_wrap.py b/python/pylibcudf/pylibcudf/tests/test_string_wrap.py index 85abd3a2bae..a1c820cd586 100644 --- a/python/pylibcudf/pylibcudf/tests/test_string_wrap.py +++ b/python/pylibcudf/pylibcudf/tests/test_string_wrap.py @@ -7,6 +7,7 @@ def test_wrap(): + width = 12 pa_array = pa.array( [ "the quick brown fox jumped over the lazy brown dog", @@ -14,10 +15,10 @@ def test_wrap(): None, ] ) - result = plc.strings.wrap.wrap(plc.interop.from_arrow(pa_array), 12) + result = plc.strings.wrap.wrap(plc.interop.from_arrow(pa_array), width) expected = pa.array( [ - textwrap.fill(val, 12) if isinstance(val, str) else val + textwrap.fill(val, width) if isinstance(val, str) else val for val in pa_array.to_pylist() ] )