diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 51c6e765edd..775a2580f60 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -716,13 +716,13 @@ class regex_parser { if (item.type != COUNTED && item.type != COUNTED_LAZY) { out.push_back(item); if (item.type == LBRA || item.type == LBRA_NC) { - lbra_stack.push(index); + lbra_stack.push(out.size() - 1); repeat_start_index = -1; } else if (item.type == RBRA) { repeat_start_index = lbra_stack.top(); lbra_stack.pop(); } else if ((item.type & ITEM_MASK) != OPERATOR_MASK) { - repeat_start_index = index; + repeat_start_index = out.size() - 1; } } else { // item is of type COUNTED or COUNTED_LAZY @@ -731,26 +731,39 @@ class regex_parser { CUDF_EXPECTS(repeat_start_index >= 0, "regex: invalid counted quantifier location"); // range of affected item(s) to repeat - auto const begin = in.begin() + repeat_start_index; - auto const end = in.begin() + index; + auto const begin = out.begin() + repeat_start_index; + auto const end = out.end(); + // count range values auto const n = item.d.count.n; // minimum count auto const m = item.d.count.m; // maximum count - assert(n >= 0 && "invalid repeat count value n"); // zero-repeat edge-case: need to erase the previous items - if (n == 0) { out.erase(out.end() - (index - repeat_start_index), out.end()); } - - // minimum repeats (n) - for (int j = 1; j < n; j++) { - out.insert(out.end(), begin, end); + if (n == 0 && m == 0) { out.erase(begin, end); } + + std::vector repeat_copy(begin, end); + // special handling for quantified capture groups + if ((n > 1) && (*begin).type == LBRA) { + (*begin).type = LBRA_NC; // change first one to non-capture + // add intermediate groups as non-capture + std::vector ncg_copy(begin, end); + for (int j = 1; j < (n - 1); j++) { + out.insert(out.end(), ncg_copy.begin(), ncg_copy.end()); + } + // add the last entry as a regular capture-group + out.insert(out.end(), repeat_copy.begin(), repeat_copy.end()); + } else { + // minimum repeats (n) + for (int j = 1; j < n; j++) { + out.insert(out.end(), repeat_copy.begin(), repeat_copy.end()); + } } // optional maximum repeats (m) if (m >= 0) { for (int j = n; j < m; j++) { out.emplace_back(LBRA_NC, 0); - out.insert(out.end(), begin, end); + out.insert(out.end(), repeat_copy.begin(), repeat_copy.end()); } for (int j = n; j < m; j++) { out.emplace_back(RBRA, 0); @@ -760,8 +773,9 @@ class regex_parser { // infinite repeats if (n > 0) { // append '+' after last repetition out.emplace_back(item.type == COUNTED ? PLUS : PLUS_LAZY, 0); - } else { // copy it once then append '*' - out.insert(out.end(), begin, end); + } else { + // copy it once then append '*' + out.insert(out.end(), repeat_copy.begin(), repeat_copy.end()); out.emplace_back(item.type == COUNTED ? STAR : STAR_LAZY, 0); } } diff --git a/cpp/tests/strings/contains_tests.cpp b/cpp/tests/strings/contains_tests.cpp index bdfd38267e6..216ddfce5f1 100644 --- a/cpp/tests/strings/contains_tests.cpp +++ b/cpp/tests/strings/contains_tests.cpp @@ -474,6 +474,20 @@ TEST_F(StringsContainsTests, FixedQuantifier) } } +TEST_F(StringsContainsTests, NestedQuantifier) +{ + auto input = cudf::test::strings_column_wrapper({"TEST12 1111 2222 3333 4444 5555", + "0000 AAAA 9999 BBBB 8888", + "7777 6666 4444 3333", + "12345 3333 4444 1111 ABCD"}); + auto sv = cudf::strings_column_view(input); + auto pattern = std::string(R"((\d{4}\s){4})"); + cudf::test::fixed_width_column_wrapper expected({true, false, false, true}); + auto prog = cudf::strings::regex_program::create(pattern); + auto results = cudf::strings::contains_re(sv, *prog); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); +} + TEST_F(StringsContainsTests, QuantifierErrors) { EXPECT_THROW(cudf::strings::regex_program::create("^+"), cudf::logic_error); diff --git a/cpp/tests/strings/extract_tests.cpp b/cpp/tests/strings/extract_tests.cpp index 61246fb098d..7e0338f1bf4 100644 --- a/cpp/tests/strings/extract_tests.cpp +++ b/cpp/tests/strings/extract_tests.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include @@ -240,6 +239,21 @@ TEST_F(StringsExtractTests, SpecialNewLines) CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view().column(0), expected); } +TEST_F(StringsExtractTests, NestedQuantifier) +{ + auto input = cudf::test::strings_column_wrapper({"TEST12 1111 2222 3333 4444 5555", + "0000 AAAA 9999 BBBB 8888", + "7777 6666 4444 3333", + "12345 3333 4444 1111 ABCD"}); + auto sv = cudf::strings_column_view(input); + auto pattern = std::string(R"((\d{4}\s){4})"); + auto prog = cudf::strings::regex_program::create(pattern); + auto results = cudf::strings::extract(sv, *prog); + // fixed quantifier on capture group only honors the last group + auto expected = cudf::test::strings_column_wrapper({"4444 ", "", "", "1111 "}, {1, 0, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view().column(0), expected); +} + TEST_F(StringsExtractTests, EmptyExtractTest) { std::vector h_strings{nullptr, "AAA", "AAA_A", "AAA_AAA_", "A__", ""};