Skip to content

Commit

Permalink
Merge pull request rapidsai#11 from res-life/multi-string-contains-fo…
Browse files Browse the repository at this point in the history
…r-merge-2

Fix comments; warp optimization
  • Loading branch information
wjxiz1992 authored Sep 6, 2024
2 parents 069ea47 + c0b30fe commit adea127
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 97 deletions.
11 changes: 0 additions & 11 deletions cpp/benchmarks/string/find.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ static void bench_find_string(nvbench::state& state)
constexpr bool combine = false; // test true/false
bool has_same_target_first_char = false; // test true/false
constexpr int iters = 10; // test 4/10
bool check_result = false;

std::vector<std::string> match_targets({" abc",
"W43",
Expand Down Expand Up @@ -102,16 +101,6 @@ static void bench_find_string(nvbench::state& state)
cudf::strings::contains(input, cudf::string_scalar(multi_targets[i])));
contains_cvs.emplace_back(contains_results.back()->view());
}

if (check_result) {
cudf::test::strings_column_wrapper multi_targets_column(multi_targets.begin(),
multi_targets.end());
auto tab =
cudf::strings::multi_contains(input, cudf::strings_column_view(multi_targets_column));
for (int i = 0; i < tab->num_columns(); i++) {
cudf::test::detail::expect_columns_equal(contains_cvs[i], tab->get_column(i).view());
}
}
});
} else { // combine
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
Expand Down
3 changes: 3 additions & 0 deletions cpp/include/cudf/strings/find.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,14 @@ std::unique_ptr<column> contains(
*
* Any null string entries return corresponding null entries in the output columns.
* e.g.:
* @code
* input: "a", "b", "c"
* targets: "a", "c"
* output is a table with two boolean columns:
* column_0: true, false, false
* column_1: false, false, true
* @endcode
*
* @param input Strings instance for this operation
* @param targets UTF-8 encoded strings to search for in each string in `input`
* @param stream CUDA stream used for device memory operations and kernel launches
Expand Down
183 changes: 97 additions & 86 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -418,45 +418,94 @@ std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,
return results;
}

CUDF_KERNEL void multi_contains_warp_parallel_multi_scalars_fn(column_device_view const d_strings,
column_device_view const d_targets,
cudf::device_span<bool*> d_results)
/**
* Each string uses a warp(32 threads) to handle all the targets.
* Each thread uses num_targets bools shared memory to store temp result for each lane.
*/
CUDF_KERNEL void multi_contains_warp_parallel_multi_scalars_fn(
column_device_view const d_strings,
column_device_view const d_targets,
cudf::device_span<char> const d_target_first_bytes,
column_device_view const d_target_indexes_for_first_bytes,
cudf::device_span<bool*> d_results)
{
auto const num_targets = d_targets.size();
auto const num_rows = d_strings.size();

auto const idx = static_cast<size_type>(threadIdx.x + blockIdx.x * blockDim.x);
using warp_reduce = cub::WarpReduce<bool>;
__shared__ typename warp_reduce::TempStorage temp_storage;

auto const idx = static_cast<size_type>(threadIdx.x + blockIdx.x * blockDim.x);
if (idx >= (num_rows * cudf::detail::warp_size)) { return; }

auto const lane_idx = idx % cudf::detail::warp_size;
auto const str_idx = idx / cudf::detail::warp_size;
if (d_strings.is_null(str_idx)) { return; } // bitmask will set result to null.

// get the string for this warp
auto const d_str = d_strings.element<string_view>(str_idx);

for (size_t target_idx = 0; target_idx < num_targets; target_idx++) {
// Identify the target.
/**
* size of shared_bools = Min(targets_size * block_size, target_group * block_size)
* each thread uses targets_size bools
*/
extern __shared__ bool shared_bools[];

// initialize temp result:
// set true if target is empty, set false otherwise
for (int target_idx = 0; target_idx < num_targets; target_idx++) {
auto const d_target = d_targets.element<string_view>(target_idx);
shared_bools[threadIdx.x * num_targets + target_idx] = d_target.size_bytes() == 0;
}

// each thread of the warp will check just part of the string
auto found = false;
if (d_target.empty()) {
found = true;
} else {
for (auto i = static_cast<size_type>(lane_idx);
!found && ((i + d_target.size_bytes()) <= d_str.size_bytes());
i += cudf::detail::warp_size) {
// check the target matches this part of the d_str data
if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; }
for (size_type str_byte_idx = lane_idx; str_byte_idx < d_str.size_bytes();
str_byte_idx += cudf::detail::warp_size) {
// 1. check the first chars using binary search on first char set
char c = *(d_str.data() + str_byte_idx);
auto first_byte_ptr =
thrust::lower_bound(thrust::seq, d_target_first_bytes.begin(), d_target_first_bytes.end(), c);
if (not(first_byte_ptr != d_target_first_bytes.end() && *first_byte_ptr == c)) {
// first char is not matched for all targets, already set result as found
continue;
}

// 2. check the 2nd chars
int first_char_index_in_list = first_byte_ptr - d_target_first_bytes.begin();
// get possible targets
auto const possible_targets_list =
cudf::list_device_view{d_target_indexes_for_first_bytes, first_char_index_in_list};
for (auto list_idx = 0; list_idx < possible_targets_list.size();
++list_idx) { // iterate possible targets
auto target_idx = possible_targets_list.element<size_type>(list_idx);
int temp_result_idx = threadIdx.x * num_targets + target_idx;
if (!shared_bools[temp_result_idx]) { // not found before
auto const d_target = d_targets.element<string_view>(target_idx);
if (d_str.size_bytes() - str_byte_idx >= d_target.size_bytes()) {
// first char already checked, only need to check the [2nd, end) chars if has.
bool found = true;
for (auto i = 1; i < d_target.size_bytes(); i++) {
if (*(d_str.data() + str_byte_idx + i) != *(d_target.data() + i)) {
found = false;
break;
}
}
if (found) { shared_bools[temp_result_idx] = true; }
}
}
}
}

// wait all lanes are done in a warp
__syncwarp();

if (lane_idx == 0) {
for (int target_idx = 0; target_idx < num_targets; target_idx++) {
bool found = false;
for (int lane_idx = 0; lane_idx < cudf::detail::warp_size; lane_idx++) {
bool temp_idx = (str_idx * cudf::detail::warp_size + lane_idx) * num_targets + target_idx;
if (shared_bools[temp_idx]) {
found = true;
break;
}
}
d_results[target_idx][str_idx] = found;
}
__syncwarp();
auto const result = warp_reduce(temp_storage).Reduce(found, cub::Max());
if (lane_idx == 0) { d_results[target_idx][str_idx] = result; }
}
}

Expand All @@ -483,7 +532,7 @@ CUDF_KERNEL void multi_contains_using_indexes_fn(
for (auto str_byte_idx = 0; str_byte_idx < d_str.size_bytes();
++str_byte_idx) { // iterate the start index in the string

// binary search in the target first char set.
// 1. check the first chars using binary search on first char set
char c = *(d_str.data() + str_byte_idx);
auto first_byte_ptr =
thrust::lower_bound(thrust::seq, d_target_first_bytes.begin(), d_target_first_bytes.end(), c);
Expand All @@ -499,8 +548,9 @@ CUDF_KERNEL void multi_contains_using_indexes_fn(
auto const possible_targets_list =
cudf::list_device_view{d_target_indexes_for_first_bytes, first_char_index_in_list};

for (auto i = 0; i < possible_targets_list.size(); ++i) { // iterate possible targets
auto target_idx = possible_targets_list.element<size_type>(i);
for (auto list_idx = 0; list_idx < possible_targets_list.size();
++list_idx) { // iterate possible targets
auto target_idx = possible_targets_list.element<size_type>(list_idx);
if (!d_results[target_idx][str_idx]) { // not found before
auto const d_target = d_targets.element<string_view>(target_idx);
if (d_str.size_bytes() - str_byte_idx >= d_target.size_bytes()) {
Expand Down Expand Up @@ -540,11 +590,11 @@ CUDF_KERNEL void multi_contains_using_indexes_fn(
* if char in string is 'a', then only need to try ["ac", "ad", "af"] targets.
*
*/
std::vector<std::unique_ptr<column>> multi_contains_using_indexes(
strings_column_view const& input,
strings_column_view const& targets,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
std::vector<std::unique_ptr<column>> multi_contains(bool warp_parallel,
strings_column_view const& input,
strings_column_view const& targets,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const num_targets = static_cast<size_type>(targets.size());
CUDF_EXPECTS(not targets.is_empty(), "Must specify at least one target string.");
Expand Down Expand Up @@ -644,59 +694,20 @@ std::vector<std::unique_ptr<column>> multi_contains_using_indexes(
constexpr int block_size = 256;
cudf::detail::grid_1d grid{input.size(), block_size};

multi_contains_using_indexes_fn<<<grid.num_blocks,
grid.num_threads_per_block,
0,
stream.value()>>>(
*d_strings, *d_targets, d_first_bytes, *d_list_column, device_results_list);
return results_list;
}

/**
* Execute multi contains for long strings
*/
std::vector<std::unique_ptr<column>> multi_contains_using_warp_parallel(
strings_column_view const& input,
strings_column_view const& targets,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const num_targets = static_cast<size_type>(targets.size());
CUDF_EXPECTS(not targets.is_empty(), "Must specify at least one target string.");

// Create output columns.
auto const results_iter =
thrust::make_transform_iterator(thrust::counting_iterator<cudf::size_type>(0), [&](int i) {
return make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
});
auto results_list =
std::vector<std::unique_ptr<column>>(results_iter, results_iter + targets.size());
auto device_results_list = [&] {
auto host_results_pointer_iter =
thrust::make_transform_iterator(results_list.begin(), [](auto const& results_column) {
return results_column->mutable_view().template data<bool>();
});
auto host_results_pointers = std::vector<bool*>(
host_results_pointer_iter, host_results_pointer_iter + results_list.size());
return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr);
}();

constexpr int block_size = 256;
auto const d_strings = column_device_view::create(input.parent(), stream);
auto const d_targets = column_device_view::create(targets.parent(), stream);

// launch warp per string; one warp handles multi-targets for the same string.
cudf::detail::grid_1d grid{input.size() * cudf::detail::warp_size, block_size};
multi_contains_warp_parallel_multi_scalars_fn<<<grid.num_blocks,
grid.num_threads_per_block,
0,
stream.value()>>>(
*d_strings, *d_targets, device_results_list);
if (warp_parallel) {
int shared_mem_size = block_size * targets.size();
multi_contains_warp_parallel_multi_scalars_fn<<<grid.num_blocks,
grid.num_threads_per_block,
shared_mem_size,
stream.value()>>>(
*d_strings, *d_targets, d_first_bytes, *d_list_column, device_results_list);
} else {
multi_contains_using_indexes_fn<<<grid.num_blocks,
grid.num_threads_per_block,
0,
stream.value()>>>(
*d_strings, *d_targets, d_first_bytes, *d_list_column, device_results_list);
}

return results_list;
}
Expand Down Expand Up @@ -859,10 +870,10 @@ std::unique_ptr<table> multi_contains(strings_column_view const& input,
((input.chars_size(stream) / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
// Large strings.
// use warp parallel when the average string width is greater than the threshold
return multi_contains_using_warp_parallel(input, targets, stream, mr);
return multi_contains(/**warp parallel**/ true, input, targets, stream, mr);
} else {
// Small strings. Searching for multiple targets in one thread seems to work fastest.
return multi_contains_using_indexes(input, targets, stream, mr);
return multi_contains(/**warp parallel**/ false, input, targets, stream, mr);
}
}();
return std::make_unique<table>(std::move(result_columns));
Expand Down

0 comments on commit adea127

Please sign in to comment.