diff --git a/cpp/benchmarks/string/contains.cpp b/cpp/benchmarks/string/contains.cpp index 6d839c1de64..ae6c8b844c8 100644 --- a/cpp/benchmarks/string/contains.cpp +++ b/cpp/benchmarks/string/contains.cpp @@ -80,7 +80,7 @@ std::unique_ptr build_input_column(cudf::size_type n_rows, } // longer pattern lengths demand more working memory per string -std::string patterns[] = {"^\\d+ [a-z]+", "[A-Z ]+\\d+ +\\d+[A-Z]+\\d+$"}; +std::string patterns[] = {"^\\d+ [a-z]+", "[A-Z ]+\\d+ +\\d+[A-Z]+\\d+$", "5W43"}; static void bench_contains(nvbench::state& state) { @@ -114,4 +114,4 @@ NVBENCH_BENCH(bench_contains) .add_int64_axis("row_width", {32, 64, 128, 256, 512}) .add_int64_axis("num_rows", {32768, 262144, 2097152, 16777216}) .add_int64_axis("hit_rate", {50, 100}) // percentage - .add_int64_axis("pattern", {0, 1}); + .add_int64_axis("pattern", {0, 1, 2}); diff --git a/cpp/benchmarks/string/count.cpp b/cpp/benchmarks/string/count.cpp index a656010dca5..f964bc5d224 100644 --- a/cpp/benchmarks/string/count.cpp +++ b/cpp/benchmarks/string/count.cpp @@ -25,10 +25,13 @@ #include +static std::string patterns[] = {"\\d+", "a"}; + static void bench_count(nvbench::state& state) { - auto const num_rows = static_cast(state.get_int64("num_rows")); - auto const row_width = static_cast(state.get_int64("row_width")); + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const row_width = static_cast(state.get_int64("row_width")); + auto const pattern_index = static_cast(state.get_int64("pattern")); if (static_cast(num_rows) * static_cast(row_width) >= static_cast(std::numeric_limits::max())) { @@ -41,7 +44,7 @@ static void bench_count(nvbench::state& state) create_random_table({cudf::type_id::STRING}, row_count{num_rows}, table_profile); cudf::strings_column_view input(table->view().column(0)); - std::string pattern = "\\d+"; + auto const pattern = patterns[pattern_index]; auto prog = cudf::strings::regex_program::create(pattern); @@ -59,4 +62,5 @@ static void bench_count(nvbench::state& state) NVBENCH_BENCH(bench_count) .set_name("count") .add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024, 2048}) - .add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216}); + .add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216}) + .add_int64_axis("pattern", {0, 1}); diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index ce12dc17aa4..10e06505094 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -217,6 +217,15 @@ __device__ __forceinline__ reprog_device reprog_device::load(reprog_device const : reinterpret_cast(buffer)[0]; } +__device__ __forceinline__ static string_view::const_iterator find_char( + cudf::char_utf8 chr, string_view const d_str, string_view::const_iterator itr) +{ + while (itr.byte_offset() < d_str.size_bytes() && *itr != chr) { + ++itr; + } + return itr; +} + /** * @brief Evaluate a specific string against regex pattern compiled to this instance. * @@ -253,16 +262,16 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const case BOL: if (pos == 0) break; if (jnk.startchar != '^') { return thrust::nullopt; } - --pos; + --itr; startchar = static_cast('\n'); case CHAR: { - auto const fidx = dstr.find(startchar, pos); - if (fidx == string_view::npos) { return thrust::nullopt; } - pos = fidx + (jnk.starttype == BOL); + auto const find_itr = find_char(startchar, dstr, itr); + if (find_itr.byte_offset() >= dstr.size_bytes()) { return thrust::nullopt; } + itr = find_itr + (jnk.starttype == BOL); + pos = itr.position(); break; } } - itr += (pos - itr.position()); // faster to increment position } if (((eos < 0) || (pos < eos)) && match == 0) {