Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve strings contains/find performance for smaller strings #17330

Merged
merged 26 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3e15aac
Use a KMP search algorithm for strings contains()
davidwendt Nov 14, 2024
453fdb3
Merge branch 'branch-24.12' into kmp-contains
davidwendt Nov 15, 2024
b7d88f7
add stream to kernel launch
davidwendt Nov 15, 2024
cf8f782
replace kmp with compare()
davidwendt Nov 15, 2024
1d3c292
Merge branch 'branch-24.12' into kmp-contains
davidwendt Nov 15, 2024
4b83025
remove commented out code line
davidwendt Nov 15, 2024
0d12680
Merge branch 'branch-24.12' into kmp-contains
davidwendt Nov 18, 2024
1ae61d2
Merge branch 'branch-24.12' into kmp-contains
davidwendt Nov 18, 2024
884633d
Merge branch 'branch-24.12' into kmp-contains
davidwendt Nov 20, 2024
e50e37a
Merge branch 'branch-24.12' into kmp-contains
davidwendt Nov 20, 2024
c62f99f
Merge branch 'branch-24.12' into kmp-contains
davidwendt Nov 20, 2024
0b27a4d
Merge branch 'branch-25.02' into kmp-contains
davidwendt Nov 21, 2024
076fc06
Merge branch 'kmp-contains' of github.com:davidwendt/cudf into kmp-co…
davidwendt Nov 22, 2024
ca5c5e2
Merge branch 'branch-25.02' into kmp-contains
davidwendt Nov 22, 2024
d77b2a4
move fast-path to string_view::find()
davidwendt Nov 22, 2024
83285ec
Merge branch 'branch-25.02' into kmp-contains
davidwendt Nov 22, 2024
71811e6
refactor to improve rfind as well
davidwendt Nov 23, 2024
03aae8c
Merge branch 'branch-25.02' into kmp-contains
davidwendt Nov 23, 2024
6dbb2f4
Merge branch 'branch-25.02' into kmp-contains
davidwendt Nov 25, 2024
287f1f5
Merge branch 'branch-25.02' into kmp-contains
davidwendt Nov 25, 2024
6716389
fix merge conflict
davidwendt Nov 26, 2024
e9afaa4
Merge branch 'kmp-contains' of github.com:davidwendt/cudf into kmp-co…
davidwendt Dec 2, 2024
770463a
Merge branch 'branch-25.02' into kmp-contains
davidwendt Dec 2, 2024
45e428b
Merge branch 'branch-25.02' into kmp-contains
davidwendt Dec 5, 2024
2a55490
Merge branch 'branch-25.02' into kmp-contains
davidwendt Dec 5, 2024
1cdaf31
refactor detail::contains lambdas
davidwendt Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions cpp/benchmarks/string/find.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,19 @@

static void bench_find_string(nvbench::state& state)
{
auto const n_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const max_width = static_cast<cudf::size_type>(state.get_int64("max_width"));
auto const hit_rate = static_cast<cudf::size_type>(state.get_int64("hit_rate"));
auto const api = state.get_string("api");

if (static_cast<std::size_t>(n_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
state.skip("Skip benchmarks greater than size_type limit");
}
auto const tgt_type = state.get_string("target");

auto const stream = cudf::get_default_stream();
auto const col = create_string_column(n_rows, row_width, hit_rate);
auto const col = create_string_column(num_rows, max_width, hit_rate);
auto const input = cudf::strings_column_view(col->view());

cudf::string_scalar target("0987 5W43");
auto target = cudf::string_scalar("0987 5W43");
auto targets_col = cudf::make_column_from_scalar(target, num_rows);
auto const targets = cudf::strings_column_view(targets_col->view());

state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value()));
auto const chars_size = input.chars_size(stream);
Expand All @@ -55,23 +53,44 @@ static void bench_find_string(nvbench::state& state)
}

if (api == "find") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::find(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::find(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::find(input, targets); });
}
} else if (api == "contains") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::contains(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::contains(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::contains(input, targets); });
}
} else if (api == "starts_with") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::starts_with(input, targets); });
}
} else if (api == "ends_with") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::ends_with(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::ends_with(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::ends_with(input, targets); });
}
}
}

NVBENCH_BENCH(bench_find_string)
.set_name("find_string")
.add_int64_axis("max_width", {32, 64, 128, 256})
.add_int64_axis("num_rows", {32768, 262144, 2097152})
.add_int64_axis("hit_rate", {20, 80}) // percentage
.add_string_axis("api", {"find", "contains", "starts_with", "ends_with"})
.add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024})
.add_int64_axis("num_rows", {260'000, 1'953'000, 16'777'216})
.add_int64_axis("hit_rate", {20, 80}); // percentage
.add_string_axis("target", {"scalar", "column"});
17 changes: 8 additions & 9 deletions cpp/include/cudf/strings/string_view.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -373,24 +373,23 @@ __device__ inline size_type string_view::find_impl(char const* str,
size_type pos,
size_type count) const
{
auto const nchars = length();
if (!str || pos < 0 || pos > nchars) return npos;
if (count < 0) count = nchars;
if (!str || pos < 0) { return npos; }
lamarrr marked this conversation as resolved.
Show resolved Hide resolved
if (pos > 0 && pos > length()) { return npos; }
lamarrr marked this conversation as resolved.
Show resolved Hide resolved

// use iterator to help reduce character/byte counting
auto itr = begin() + pos;
auto const itr = begin() + pos;
auto const spos = itr.byte_offset();
auto const epos = ((pos + count) < nchars) ? (itr + count).byte_offset() : size_bytes();
auto const epos =
(count >= 0) && ((pos + count) < length()) ? (itr + count).byte_offset() : size_bytes();

auto const find_length = (epos - spos) - bytes + 1;
auto const d_target = string_view{str, bytes};

auto ptr = data() + (forward ? spos : (epos - bytes));
for (size_type idx = 0; idx < find_length; ++idx) {
bool match = true;
for (size_type jdx = 0; match && (jdx < bytes); ++jdx) {
match = (ptr[jdx] == str[jdx]);
if (d_target.compare(ptr, bytes) == 0) {
lamarrr marked this conversation as resolved.
Show resolved Hide resolved
return forward ? pos : character_offset(epos - bytes - idx);
}
if (match) { return forward ? pos : character_offset(epos - bytes - idx); }
// use pos to record the current find position
pos += strings::detail::is_begin_utf8_char(*ptr);
forward ? ++ptr : --ptr;
Expand Down
24 changes: 14 additions & 10 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@ struct finder_fn {
if (d_strings.is_null(idx)) { return -1; }
auto const d_str = d_strings.element<string_view>(idx);
if (d_str.empty() && (start > 0)) { return -1; }
if (stop >= 0 && start > stop) { return -1; }
auto const d_target = d_targets[idx];

auto const length = d_str.length();
auto const begin = (start > length) ? length : start;
auto const end = (stop < 0) || (stop > length) ? length : stop;
return forward ? d_str.find(d_target, begin, end - begin)
: d_str.rfind(d_target, begin, end - begin);
auto const count = (stop < 0) ? stop : (stop - start);
return forward ? d_str.find(d_target, start, count) : d_str.rfind(d_target, start, count);
}
};

Expand Down Expand Up @@ -367,7 +365,7 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings,
i += cudf::detail::warp_size * bytes_per_warp) {
// check the target matches this part of the d_str data
// this is definitely faster for very long strings > 128B
for (auto j = 0; j < bytes_per_warp; j++) {
for (auto j = 0; !found && (j < bytes_per_warp); j++) {
if (((i + j + d_target.size_bytes()) <= d_str.size_bytes()) &&
d_target.compare(d_str.data() + i + j, d_target.size_bytes()) == 0) {
found = true;
Expand Down Expand Up @@ -531,7 +529,6 @@ std::unique_ptr<column> contains_fn(strings_column_view const& strings,
results->set_null_count(strings.null_count());
return results;
}

} // namespace

std::unique_ptr<column> contains(strings_column_view const& input,
Expand All @@ -541,13 +538,17 @@ std::unique_ptr<column> contains(strings_column_view const& input,
{
// use warp parallel when the average string width is greater than the threshold
if ((input.null_count() < input.size()) &&
((input.chars_size(stream) / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
((input.chars_size(stream) / (input.size() - input.null_count())) >
AVG_CHAR_BYTES_THRESHOLD)) {
return contains_warp_parallel(input, target, stream, mr);
}

// benchmark measurements showed this to be faster for smaller strings
auto pfn = [] __device__(string_view d_string, string_view d_target) {
return d_string.find(d_target) != string_view::npos;
for (size_type i = 0; i <= (d_string.size_bytes() - d_target.size_bytes()); ++i) {
if (d_target.compare(d_string.data() + i, d_target.size_bytes()) == 0) { return true; }
}
return false;
};
return contains_fn(input, target, pfn, stream, mr);
}
Expand All @@ -558,7 +559,10 @@ std::unique_ptr<column> contains(strings_column_view const& strings,
rmm::device_async_resource_ref mr)
{
auto pfn = [] __device__(string_view d_string, string_view d_target) {
return d_string.find(d_target) != string_view::npos;
for (size_type i = 0; i <= (d_string.size_bytes() - d_target.size_bytes()); ++i) {
if (d_target.compare(d_string.data() + i, d_target.size_bytes()) == 0) { return true; }
}
return false;
};
return contains_fn(strings, targets, pfn, stream, mr);
}
Expand Down
Loading