diff --git a/cpp/doxygen/regex.md b/cpp/doxygen/regex.md index 6d1c91a5752..6902b1948bd 100644 --- a/cpp/doxygen/regex.md +++ b/cpp/doxygen/regex.md @@ -8,6 +8,7 @@ This page specifies which regular expression (regex) features are currently supp - cudf::strings::extract() - cudf::strings::extract_all_record() - cudf::strings::findall() +- cudf::strings::find_re() - cudf::strings::replace_re() - cudf::strings::replace_with_backrefs() - cudf::strings::split_re() diff --git a/cpp/include/cudf/strings/findall.hpp b/cpp/include/cudf/strings/findall.hpp index c6b9bc7e58a..867764b6d9a 100644 --- a/cpp/include/cudf/strings/findall.hpp +++ b/cpp/include/cudf/strings/findall.hpp @@ -66,6 +66,35 @@ std::unique_ptr findall( rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); +/** + * @brief Returns the starting character index of the first match for the given pattern + * in each row of the input column + * + * @code{.pseudo} + * Example: + * s = ["bunny", "rabbit", "hare", "dog"] + * p = regex_program::create("[be]") + * r = find_re(s, p) + * r is now [0, 2, 3, -1] + * @endcode + * + * A null output row occurs if the corresponding input row is null. + * A -1 is returned for rows that do not contain a match. + * + * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. + * + * @param input Strings instance for this operation + * @param prog Regex program instance + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return New column of integers + */ +std::unique_ptr find_re( + strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); + /** @} */ // end of doxygen group } // namespace strings } // namespace CUDF_EXPORT cudf diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index d8c1b50a94b..21708e48a25 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -126,6 +126,43 @@ std::unique_ptr findall(strings_column_view const& input, mr); } +namespace { +struct find_re_fn { + column_device_view d_strings; + + __device__ size_type operator()(size_type const idx, + reprog_device const prog, + int32_t const thread_idx) const + { + if (d_strings.is_null(idx)) { return 0; } + auto const d_str = d_strings.element(idx); + + auto const result = prog.find(thread_idx, d_str, d_str.begin()); + return result.has_value() ? result.value().first : -1; + } +}; +} // namespace + +std::unique_ptr find_re(strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto results = make_numeric_column(data_type{type_to_id()}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + if (input.is_empty()) { return results; } + + auto d_results = results->mutable_view().data(); + auto d_prog = regex_device_builder::create_prog_device(prog, stream); + auto const d_strings = column_device_view::create(input.parent(), stream); + launch_transform_kernel(find_re_fn{*d_strings}, *d_prog, d_results, input.size(), stream); + + return results; +} } // namespace detail // external API @@ -139,5 +176,14 @@ std::unique_ptr findall(strings_column_view const& input, return detail::findall(input, prog, stream, mr); } +std::unique_ptr find_re(strings_column_view const& input, + regex_program const& prog, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + return detail::find_re(input, prog, stream, mr); +} + } // namespace strings } // namespace cudf diff --git a/cpp/tests/streams/strings/find_test.cpp b/cpp/tests/streams/strings/find_test.cpp index 52839c6fc9f..e5a1ee0988c 100644 --- a/cpp/tests/streams/strings/find_test.cpp +++ b/cpp/tests/streams/strings/find_test.cpp @@ -46,4 +46,5 @@ TEST_F(StringsFindTest, Find) auto const pattern = std::string("[a-z]"); auto const prog = cudf::strings::regex_program::create(pattern); cudf::strings::findall(view, *prog, cudf::test::get_default_stream()); + cudf::strings::find_re(view, *prog, cudf::test::get_default_stream()); } diff --git a/cpp/tests/strings/findall_tests.cpp b/cpp/tests/strings/findall_tests.cpp index 73da4d081e2..4821a7fa999 100644 --- a/cpp/tests/strings/findall_tests.cpp +++ b/cpp/tests/strings/findall_tests.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -149,6 +150,22 @@ TEST_F(StringsFindallTests, LargeRegex) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); } +TEST_F(StringsFindallTests, FindTest) +{ + auto const valids = cudf::test::iterators::null_at(5); + cudf::test::strings_column_wrapper input( + {"3A", "May4", "Jan2021", "March", "A9BC", "", "", "abcdef ghijklm 12345"}, valids); + auto sv = cudf::strings_column_view(input); + + auto pattern = std::string("\\d+"); + + auto prog = cudf::strings::regex_program::create(pattern); + auto results = cudf::strings::find_re(sv, *prog); + auto expected = + cudf::test::fixed_width_column_wrapper({0, 3, 3, -1, 1, 0, -1, 15}, valids); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); +} + TEST_F(StringsFindallTests, NoMatches) { cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"}); @@ -169,10 +186,16 @@ TEST_F(StringsFindallTests, EmptyTest) auto prog = cudf::strings::regex_program::create(pattern); cudf::test::strings_column_wrapper input; - auto sv = cudf::strings_column_view(input); - auto results = cudf::strings::findall(sv, *prog); - - using LCW = cudf::test::lists_column_wrapper; - LCW expected; - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + auto sv = cudf::strings_column_view(input); + { + auto results = cudf::strings::findall(sv, *prog); + using LCW = cudf::test::lists_column_wrapper; + LCW expected; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + } + { + auto results = cudf::strings::find_re(sv, *prog); + auto expected = cudf::test::fixed_width_column_wrapper{}; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + } } diff --git a/python/cudf/cudf/_lib/strings/__init__.py b/python/cudf/cudf/_lib/strings/__init__.py index 4bf8a9b1a8f..fc712a6b577 100644 --- a/python/cudf/cudf/_lib/strings/__init__.py +++ b/python/cudf/cudf/_lib/strings/__init__.py @@ -71,7 +71,7 @@ startswith_multiple, ) from cudf._lib.strings.find_multiple import find_multiple -from cudf._lib.strings.findall import findall +from cudf._lib.strings.findall import find_re, findall from cudf._lib.strings.json import GetJsonObjectOptions, get_json_object from cudf._lib.strings.padding import ( SideType, diff --git a/python/cudf/cudf/_lib/strings/findall.pyx b/python/cudf/cudf/_lib/strings/findall.pyx index 0e758d5b322..3e7a504d535 100644 --- a/python/cudf/cudf/_lib/strings/findall.pyx +++ b/python/cudf/cudf/_lib/strings/findall.pyx @@ -23,3 +23,19 @@ def findall(Column source_strings, object pattern, uint32_t flags): prog, ) return Column.from_pylibcudf(plc_result) + + +@acquire_spill_lock() +def find_re(Column source_strings, object pattern, uint32_t flags): + """ + Returns character positions where the pattern first matches + the elements in source_strings. + """ + prog = plc.strings.regex_program.RegexProgram.create( + str(pattern), flags + ) + plc_result = plc.strings.findall.find_re( + source_strings.to_pylibcudf(mode="read"), + prog, + ) + return Column.from_pylibcudf(plc_result) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 4463e3280df..69e42e58cd0 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -3624,6 +3624,46 @@ def findall(self, pat: str, flags: int = 0) -> SeriesOrIndex: data = libstrings.findall(self._column, pat, flags) return self._return_or_inplace(data) + def find_re(self, pat: str, flags: int = 0) -> SeriesOrIndex: + """ + Find first occurrence of pattern or regular expression in the + Series/Index. + + Parameters + ---------- + pat : str + Pattern or regular expression. + flags : int, default 0 (no flags) + Flags to pass through to the regex engine (e.g. re.MULTILINE) + + Returns + ------- + Series + A Series of position values where the pattern first matches + each string. + + Examples + -------- + >>> import cudf + >>> s = cudf.Series(['Lion', 'Monkey', 'Rabbit', 'Cat']) + >>> s.str.find_re('[ti]') + 0 1 + 1 -1 + 2 4 + 3 2 + dtype: int32 + """ + if isinstance(pat, re.Pattern): + flags = pat.flags & ~re.U + pat = pat.pattern + if not _is_supported_regex_flags(flags): + raise NotImplementedError( + "Unsupported value for `flags` parameter" + ) + + data = libstrings.find_re(self._column, pat, flags) + return self._return_or_inplace(data) + def find_multiple(self, patterns: SeriesOrIndex) -> cudf.Series: """ Find all first occurrences of patterns in the Series/Index. diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index cc88cc79769..45143211a11 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -1899,6 +1899,26 @@ def test_string_findall(pat, flags): assert_eq(expected, actual) +@pytest.mark.parametrize( + "pat, flags, pos", + [ + ("Monkey", 0, [-1, 0, -1, -1]), + ("on", 0, [2, 1, -1, 1]), + ("bit", 0, [-1, -1, 3, -1]), + ("on$", 0, [2, -1, -1, -1]), + ("on$", re.MULTILINE, [2, -1, -1, 1]), + ("o.*k", re.DOTALL, [-1, 1, -1, 1]), + ], +) +def test_string_find_re(pat, flags, pos): + test_data = ["Lion", "Monkey", "Rabbit", "Don\nkey"] + gs = cudf.Series(test_data) + + expected = pd.Series(pos, dtype=np.int32) + actual = gs.str.find_re(pat, flags) + assert_eq(expected, actual) + + def test_string_replace_multi(): ps = pd.Series(["hello", "goodbye"]) gs = cudf.Series(["hello", "goodbye"]) diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd index e0a8b776465..0d286c36446 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/findall.pxd @@ -11,3 +11,7 @@ cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] findall( column_view input, regex_program prog) except + + + cdef unique_ptr[column] find_re( + column_view input, + regex_program prog) except + diff --git a/python/pylibcudf/pylibcudf/strings/findall.pxd b/python/pylibcudf/pylibcudf/strings/findall.pxd index 54afa088141..3c35a9c9aa9 100644 --- a/python/pylibcudf/pylibcudf/strings/findall.pxd +++ b/python/pylibcudf/pylibcudf/strings/findall.pxd @@ -4,4 +4,5 @@ from pylibcudf.column cimport Column from pylibcudf.strings.regex_program cimport RegexProgram +cpdef Column find_re(Column input, RegexProgram pattern) cpdef Column findall(Column input, RegexProgram pattern) diff --git a/python/pylibcudf/pylibcudf/strings/findall.pyx b/python/pylibcudf/pylibcudf/strings/findall.pyx index 3a6b87504b3..5212dc4594d 100644 --- a/python/pylibcudf/pylibcudf/strings/findall.pyx +++ b/python/pylibcudf/pylibcudf/strings/findall.pyx @@ -38,3 +38,35 @@ cpdef Column findall(Column input, RegexProgram pattern): ) return Column.from_libcudf(move(c_result)) + + +cpdef Column find_re(Column input, RegexProgram pattern): + """ + Returns character positions where the pattern first matches + the elements in input strings. + + For details, see :cpp:func:`cudf::strings::find_re` + + Parameters + ---------- + input : Column + Strings instance for this operation + pattern : RegexProgram + Regex pattern + + Returns + ------- + Column + New column of integers + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_findall.find_re( + input.view(), + pattern.c_obj.get()[0] + ) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_findall.py b/python/pylibcudf/pylibcudf/tests/test_string_findall.py index 994552fa276..debfad92d00 100644 --- a/python/pylibcudf/pylibcudf/tests/test_string_findall.py +++ b/python/pylibcudf/pylibcudf/tests/test_string_findall.py @@ -21,3 +21,20 @@ def test_findall(): type=pa_result.type, ) assert_column_eq(result, expected) + + +def test_find_re(): + arr = pa.array(["bunny", "rabbit", "hare", "dog"]) + pattern = "[eb]" + result = plc.strings.findall.find_re( + plc.interop.from_arrow(arr), + plc.strings.regex_program.RegexProgram.create( + pattern, plc.strings.regex_flags.RegexFlags.DEFAULT + ), + ) + pa_result = plc.interop.to_arrow(result) + expected = pa.array( + [0, 2, 3, -1], + type=pa_result.type, + ) + assert_column_eq(result, expected)