Skip to content

Commit

Permalink
Improve performance for cudf::strings::count_re (#15578)
Browse files Browse the repository at this point in the history
Improves performance of `cudf::strings::count_re` when pattern starts with a literal character.
Although this is a specific use case, the regex code has special logic to help speed up the search in this case.

Since the pattern indicates the target must contain this character as the start of the matching sequence, it first does a normal find for the character before continuing matching the remaining pattern. The `find()` function can be inefficient for long strings since it is character based and must resolve the character's byte-position by counting from the beginning of the string. For a function like `count_re()` all occurrences are matched within a target meaning longer target strings can incur expensive counting.

The solution included here is to introduce a more efficient `find_char()` utility that accepts a `string_view::const_iterator()` which automatically keeps track of its byte and character positions. This helps minimize byte/character counting in between calls from `count_re()` and other similar functions that make repeated calls for all matches (e.g. `replace_re()` and `split_re()`).

Close #15567

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Yunsong Wang (https://github.com/PointKernel)
  - Nghia Truong (https://github.com/ttnghia)

URL: #15578
  • Loading branch information
davidwendt authored Apr 25, 2024
1 parent 70a5b2b commit 4dc9ebb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
4 changes: 2 additions & 2 deletions cpp/benchmarks/string/contains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows,
}

// longer pattern lengths demand more working memory per string
std::string patterns[] = {"^\\d+ [a-z]+", "[A-Z ]+\\d+ +\\d+[A-Z]+\\d+$"};
std::string patterns[] = {"^\\d+ [a-z]+", "[A-Z ]+\\d+ +\\d+[A-Z]+\\d+$", "5W43"};

static void bench_contains(nvbench::state& state)
{
Expand Down Expand Up @@ -114,4 +114,4 @@ NVBENCH_BENCH(bench_contains)
.add_int64_axis("row_width", {32, 64, 128, 256, 512})
.add_int64_axis("num_rows", {32768, 262144, 2097152, 16777216})
.add_int64_axis("hit_rate", {50, 100}) // percentage
.add_int64_axis("pattern", {0, 1});
.add_int64_axis("pattern", {0, 1, 2});
12 changes: 8 additions & 4 deletions cpp/benchmarks/string/count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@

#include <nvbench/nvbench.cuh>

static std::string patterns[] = {"\\d+", "a"};

static void bench_count(nvbench::state& state)
{
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const pattern_index = static_cast<cudf::size_type>(state.get_int64("pattern"));

if (static_cast<std::size_t>(num_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
Expand All @@ -41,7 +44,7 @@ static void bench_count(nvbench::state& state)
create_random_table({cudf::type_id::STRING}, row_count{num_rows}, table_profile);
cudf::strings_column_view input(table->view().column(0));

std::string pattern = "\\d+";
auto const pattern = patterns[pattern_index];

auto prog = cudf::strings::regex_program::create(pattern);

Expand All @@ -59,4 +62,5 @@ static void bench_count(nvbench::state& state)
NVBENCH_BENCH(bench_count)
.set_name("count")
.add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024, 2048})
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216});
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216})
.add_int64_axis("pattern", {0, 1});
19 changes: 14 additions & 5 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ __device__ __forceinline__ reprog_device reprog_device::load(reprog_device const
: reinterpret_cast<reprog_device*>(buffer)[0];
}

__device__ __forceinline__ static string_view::const_iterator find_char(
cudf::char_utf8 chr, string_view const d_str, string_view::const_iterator itr)
{
while (itr.byte_offset() < d_str.size_bytes() && *itr != chr) {
++itr;
}
return itr;
}

/**
* @brief Evaluate a specific string against regex pattern compiled to this instance.
*
Expand Down Expand Up @@ -253,16 +262,16 @@ __device__ __forceinline__ match_result reprog_device::regexec(string_view const
case BOL:
if (pos == 0) break;
if (jnk.startchar != '^') { return thrust::nullopt; }
--pos;
--itr;
startchar = static_cast<char_utf8>('\n');
case CHAR: {
auto const fidx = dstr.find(startchar, pos);
if (fidx == string_view::npos) { return thrust::nullopt; }
pos = fidx + (jnk.starttype == BOL);
auto const find_itr = find_char(startchar, dstr, itr);
if (find_itr.byte_offset() >= dstr.size_bytes()) { return thrust::nullopt; }
itr = find_itr + (jnk.starttype == BOL);
pos = itr.position();
break;
}
}
itr += (pos - itr.position()); // faster to increment position
}

if (((eos < 0) || (pos < eos)) && match == 0) {
Expand Down

0 comments on commit 4dc9ebb

Please sign in to comment.