Skip to content

Commit

Permalink
Performance improvement for cudf::strings::like (rapidsai#13594)
Browse files Browse the repository at this point in the history
Minimizes character counting in the kernel logic for `cudf::strings::like` to improve overall performance especially for longer strings.
The nvbench benchmark is updated to include measurements for various strings sizes.

Reference rapidsai#13048

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Karthikeyan (https://github.com/karthikeyann)

URL: rapidsai#13594
  • Loading branch information
davidwendt authored Jun 23, 2023
1 parent 4f8afef commit 6aad528
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 29 deletions.
59 changes: 39 additions & 20 deletions cpp/benchmarks/string/like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,41 @@

#include <cudf/copying.hpp>
#include <cudf/filling.hpp>
#include <cudf/strings/combine.hpp>
#include <cudf/strings/contains.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <nvbench/nvbench.cuh>

namespace {
std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows, int32_t hit_rate)
std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows,
cudf::size_type row_width,
int32_t hit_rate)
{
// build input table using the following data
auto data = cudf::test::strings_column_wrapper({
"123 abc 4567890 DEFGHI 0987 5W43", // matches always;
"012345 6789 01234 56789 0123 456", // the rest do not match
"abc 4567890 DEFGHI 0987 Wxyz 123",
"abcdefghijklmnopqrstuvwxyz 01234",
"",
"AbcéDEFGHIJKLMNOPQRSTUVWXYZ 01",
"9876543210,abcdefghijklmnopqrstU",
"9876543210,abcdefghijklmnopqrstU",
"123 édf 4567890 DéFG 0987 X5",
"1",
});
auto data_view = cudf::column_view(data);
auto raw_data = cudf::test::strings_column_wrapper(
{
"123 abc 4567890 DEFGHI 0987 5W43", // matches always;
"012345 6789 01234 56789 0123 456", // the rest do not match
"abc 4567890 DEFGHI 0987 Wxyz 123",
"abcdefghijklmnopqrstuvwxyz 01234",
"",
"AbcéDEFGHIJKLMNOPQRSTUVWXYZ 01",
"9876543210,abcdefghijklmnopqrstU",
"9876543210,abcdefghijklmnopqrstU",
"123 édf 4567890 DéFG 0987 X5",
"1",
})
.release();
if (row_width / 32 > 1) {
std::vector<cudf::column_view> columns;
for (int i = 0; i < row_width / 32; ++i) {
columns.push_back(raw_data->view());
}
raw_data = cudf::strings::concatenate(cudf::table_view(columns));
}
auto data_view = raw_data->view();

// compute number of rows in n_rows that should match
auto matches = static_cast<int32_t>(n_rows * hit_rate) / 100;
Expand Down Expand Up @@ -71,14 +83,20 @@ std::unique_ptr<cudf::column> build_input_column(cudf::size_type n_rows, int32_t

static void bench_like(nvbench::state& state)
{
auto const n_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const hit_rate = static_cast<int32_t>(state.get_int64("hit_rate"));
auto const n_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const hit_rate = static_cast<int32_t>(state.get_int64("hit_rate"));

auto col = build_input_column(n_rows, hit_rate);
if (static_cast<std::size_t>(n_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
state.skip("Skip benchmarks greater than size_type limit");
}

auto col = build_input_column(n_rows, row_width, hit_rate);
auto input = cudf::strings_column_view(col->view());

// This pattern forces reading the entire target string (when matched expected)
auto pattern = std::string("% 5W4_"); // regex equivalent: ".* 5W4."
auto pattern = std::string("% 5W4_"); // regex equivalent: ".* 5W4.$"

state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value()));
// gather some throughput statistics as well
Expand All @@ -93,5 +111,6 @@ static void bench_like(nvbench::state& state)

NVBENCH_BENCH(bench_like)
.set_name("strings_like")
.add_int64_axis("num_rows", {4096, 32768, 262144, 2097152, 16777216})
.add_int64_axis("hit_rate", {1, 5, 10, 25, 70, 100});
.add_int64_axis("row_width", {32, 64, 128, 256, 512})
.add_int64_axis("num_rows", {32768, 262144, 2097152, 16777216})
.add_int64_axis("hit_rate", {10, 25, 70, 100});
27 changes: 18 additions & 9 deletions cpp/src/strings/like.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ struct like_fn {
auto const d_str = d_strings.element<string_view>(idx);
auto const d_pattern = patterns_itr[idx];

// using only iterators to better handle UTF-8 characters
auto target_itr = d_str.begin();
// incrementing by bytes instead of character improves performance 10-20%
auto target_itr = d_str.data();
auto pattern_itr = d_pattern.begin();

auto const target_end = d_str.end();
auto const target_end = target_itr + d_str.size_bytes();
auto const pattern_end = d_pattern.end();
auto const esc_char = d_escape.empty() ? 0 : d_escape[0];

Expand All @@ -75,12 +75,20 @@ struct like_fn {
escaped && (pattern_itr + 1 < pattern_end) ? *(++pattern_itr) : *pattern_itr;

if (escaped || (pattern_char != multi_wildcard)) {
size_type char_width = 0;
// check match with the current character
result = ((target_itr != target_end) && ((!escaped && pattern_char == single_wildcard) ||
(pattern_char == *target_itr)));
result = (target_itr != target_end);
if (result) {
if (escaped || pattern_char != single_wildcard) {
char_utf8 target_char = 0;
// retrieve the target character to compare with the current pattern_char
char_width = to_char_utf8(target_itr, target_char);
result = (pattern_char == target_char);
}
}
if (!result) { break; }
++target_itr;
++pattern_itr;
target_itr += char_width ? char_width : bytes_in_utf8_byte(*target_itr);
} else {
// process wildcard '%'
result = true;
Expand All @@ -92,8 +100,8 @@ struct like_fn {
// save positions
last_pattern_itr = pattern_itr;
last_target_itr = target_itr;
}
} // next pattern character
} // next pattern character
}

if (result && (target_itr == target_end)) { break; } // success

Expand All @@ -103,7 +111,8 @@ struct like_fn {

// restore saved positions
pattern_itr = last_pattern_itr;
target_itr = ++last_target_itr;
last_target_itr += bytes_in_utf8_byte(*last_target_itr);
target_itr = last_target_itr;
}
return result;
}
Expand Down

0 comments on commit 6aad528

Please sign in to comment.