From fa12901024fcc810fcf7f695d2f2e41f472f2306 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 26 Sep 2024 19:07:48 -0400 Subject: [PATCH] Fix cudf::strings::findall error with empty input (#16928) Fixes `cudf::strings::findall` error when passed an empty input column. Also adds a gtest for empty input and for all-rows do not match case. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Yunsong Wang (https://github.com/PointKernel) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/16928 --- cpp/src/strings/search/findall.cu | 10 +++++++--- cpp/tests/strings/findall_tests.cpp | 28 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 067a513af96..d8c1b50a94b 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -97,8 +98,11 @@ std::unique_ptr findall(strings_column_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto const strings_count = input.size(); - auto const d_strings = column_device_view::create(input.parent(), stream); + if (input.is_empty()) { + return cudf::lists::detail::make_empty_lists_column(input.parent().type(), stream, mr); + } + + auto const d_strings = column_device_view::create(input.parent(), stream); // create device object from regex_program auto d_prog = regex_device_builder::create_prog_device(prog, stream); @@ -113,7 +117,7 @@ std::unique_ptr findall(strings_column_view const& input, auto strings_output = findall_util(*d_strings, *d_prog, total_matches, d_offsets, stream, mr); // Build the lists column from the offsets and the strings - return make_lists_column(strings_count, + return make_lists_column(input.size(), std::move(offsets), std::move(strings_output), input.null_count(), diff --git a/cpp/tests/strings/findall_tests.cpp b/cpp/tests/strings/findall_tests.cpp index 47606b9b3ed..6eea1895fb1 100644 --- a/cpp/tests/strings/findall_tests.cpp +++ b/cpp/tests/strings/findall_tests.cpp @@ -148,3 +148,31 @@ TEST_F(StringsFindallTests, LargeRegex) LCW expected({LCW{large_regex.c_str()}, LCW{}, LCW{}}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); } + +TEST_F(StringsFindallTests, NoMatches) +{ + cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"}); + auto sv = cudf::strings_column_view(input); + + auto pattern = std::string("(^zzz$)"); + using LCW = cudf::test::lists_column_wrapper; + LCW expected({LCW{}, LCW{}, LCW{}, LCW{}, LCW{}}); + auto prog = cudf::strings::regex_program::create(pattern); + auto results = cudf::strings::findall(sv, *prog); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); +} + +TEST_F(StringsFindallTests, EmptyTest) +{ + std::string pattern = R"(\w+)"; + + auto prog = cudf::strings::regex_program::create(pattern); + + cudf::test::strings_column_wrapper input; + auto sv = cudf::strings_column_view(input); + auto results = cudf::strings::findall(sv, *prog); + + using LCW = cudf::test::lists_column_wrapper; + LCW expected; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); +}