Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of hash_character_ngrams using warp-per-string kernel #16212

Merged
merged 26 commits into from
Aug 1, 2024
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
00de3e5
Improve performance of hash_character_ngrams
davidwendt Jul 8, 2024
2a00e2e
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 9, 2024
9bb5600
create char-ngram counting kernel
davidwendt Jul 9, 2024
b4cbafd
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 9, 2024
20c206e
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 10, 2024
3761822
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 11, 2024
84f949b
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 11, 2024
06c399b
update some var types
davidwendt Jul 11, 2024
f38ded5
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 11, 2024
2efb048
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 12, 2024
0b08a91
fix type of launch parameters
davidwendt Jul 12, 2024
2496beb
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 12, 2024
f7ac0ea
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 17, 2024
fbbbe90
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 17, 2024
7bffad4
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 18, 2024
f7a2689
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 19, 2024
5e6b54c
minor variable changes
davidwendt Jul 19, 2024
cae4bc0
Merge branch 'branch-24.08' into ngram-hash-wide-chars
davidwendt Jul 23, 2024
f0a9956
Merge branch 'branch-24.10' into ngram-hash-wide-chars
davidwendt Jul 25, 2024
06925ed
Merge branch 'branch-24.10' into ngram-hash-wide-chars
davidwendt Jul 25, 2024
6b7cfd9
use updated grid_1d class
davidwendt Jul 25, 2024
4fbbf07
Merge branch 'branch-24.10' into ngram-hash-wide-chars
davidwendt Jul 30, 2024
88476a8
replace cub-warp-reduce with cg reduce
davidwendt Jul 31, 2024
9c6ce7f
add cast to grid_1d ctor call to prevent overflow
davidwendt Jul 31, 2024
b25403f
use cuda::std::max
davidwendt Jul 31, 2024
899ad1f
Merge branch 'branch-24.10' into ngram-hash-wide-chars
davidwendt Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 113 additions & 48 deletions cpp/src/text/generate_ngrams.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
#include <rmm/exec_policy.hpp>
#include <rmm/resource_ref.hpp>

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda/functional>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform_scan.h>

#include <stdexcept>

Expand Down Expand Up @@ -165,6 +167,47 @@ std::unique_ptr<cudf::column> generate_ngrams(cudf::strings_column_view const& s
namespace detail {
namespace {

constexpr cudf::thread_index_type block_size = 256;
constexpr cudf::thread_index_type bytes_per_thread = 4;

/**
* @brief Counts the number of ngrams in each row of the given strings column
*
* Each warp processes a single string.
* Formula is `count = max(0,str.length() - ngrams + 1)`
* If a string has less than ngrams characters, its count is 0.
*/
CUDF_KERNEL void count_char_ngrams_kernel(cudf::column_device_view const d_strings,
cudf::size_type ngrams,
cudf::size_type* d_counts)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();

auto const str_idx = idx / cudf::detail::warp_size;
if (str_idx >= d_strings.size()) { return; }
if (d_strings.is_null(str_idx)) {
d_counts[str_idx] = 0;
return;
}

namespace cg = cooperative_groups;
auto const warp = cg::tiled_partition<cudf::detail::warp_size>(cg::this_thread_block());

auto const d_str = d_strings.element<cudf::string_view>(str_idx);
auto const end = d_str.data() + d_str.size_bytes();

auto const lane_idx = warp.thread_rank();
cudf::size_type count = 0;
for (auto itr = d_str.data() + (lane_idx * bytes_per_thread); itr < end;
itr += cudf::detail::warp_size * bytes_per_thread) {
for (auto s = itr; (s < (itr + bytes_per_thread)) && (s < end); ++s) {
count += static_cast<cudf::size_type>(cudf::strings::detail::is_begin_utf8_char(*s));
}
}
auto const char_count = cg::reduce(warp, count, cg::plus<int>());
if (lane_idx == 0) { d_counts[str_idx] = cuda::std::max(0, char_count - ngrams + 1); }
}

/**
* @brief Generate character ngrams for each string
*
Expand Down Expand Up @@ -220,17 +263,16 @@ std::unique_ptr<cudf::column> generate_character_ngrams(cudf::strings_column_vie

auto const d_strings = cudf::column_device_view::create(input.parent(), stream);

auto sizes_itr = cudf::detail::make_counting_transform_iterator(
0,
cuda::proclaim_return_type<cudf::size_type>(
[d_strings = *d_strings, ngrams] __device__(auto idx) {
if (d_strings.is_null(idx)) { return 0; }
auto const length = d_strings.element<cudf::string_view>(idx).length();
return std::max(0, static_cast<cudf::size_type>(length + 1 - ngrams));
}));
auto [offsets, total_ngrams] =
cudf::detail::make_offsets_child_column(sizes_itr, sizes_itr + input.size(), stream, mr);
auto [offsets, total_ngrams] = [&] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat!

auto counts = rmm::device_uvector<cudf::size_type>(input.size(), stream);
auto const num_blocks = cudf::util::div_rounding_up_safe(
static_cast<cudf::thread_index_type>(input.size()) * cudf::detail::warp_size, block_size);
count_char_ngrams_kernel<<<num_blocks, block_size, 0, stream.value()>>>(
*d_strings, ngrams, counts.data());
return cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr);
}();
auto d_offsets = offsets->view().data<cudf::size_type>();

CUDF_EXPECTS(total_ngrams > 0,
"Insufficient number of characters in each string to generate ngrams");

Expand All @@ -246,36 +288,64 @@ std::unique_ptr<cudf::column> generate_character_ngrams(cudf::strings_column_vie
}

namespace {

/**
* @brief Computes the hash of each character ngram
*
* Each thread processes a single string. Substrings are resolved for every character
* Each warp processes a single string. Substrings are resolved for every character
* of the string and hashed.
*/
struct character_ngram_hash_fn {
cudf::column_device_view const d_strings;
cudf::size_type ngrams;
cudf::size_type const* d_ngram_offsets;
cudf::hash_value_type* d_results;
CUDF_KERNEL void character_ngram_hash_kernel(cudf::column_device_view const d_strings,
cudf::size_type ngrams,
cudf::size_type const* d_ngram_offsets,
cudf::hash_value_type* d_results)
{
auto const idx = cudf::detail::grid_1d::global_thread_id();
if (idx >= (static_cast<cudf::thread_index_type>(d_strings.size()) * cudf::detail::warp_size)) {
return;
}

__device__ void operator()(cudf::size_type idx) const
{
if (d_strings.is_null(idx)) return;
auto const d_str = d_strings.element<cudf::string_view>(idx);
if (d_str.empty()) return;
auto itr = d_str.begin();
auto const ngram_offset = d_ngram_offsets[idx];
auto const ngram_count = d_ngram_offsets[idx + 1] - ngram_offset;
auto const hasher = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>{0};
auto d_hashes = d_results + ngram_offset;
for (cudf::size_type n = 0; n < ngram_count; ++n, ++itr) {
auto const begin = itr.byte_offset();
auto const end = (itr + ngrams).byte_offset();
auto const ngram = cudf::string_view(d_str.data() + begin, end - begin);
*d_hashes++ = hasher(ngram);
auto const str_idx = idx / cudf::detail::warp_size;

if (d_strings.is_null(str_idx)) { return; }
auto const d_str = d_strings.element<cudf::string_view>(str_idx);
if (d_str.empty()) { return; }

__shared__ cudf::hash_value_type hvs[block_size]; // temp store for hash values

auto const ngram_offset = d_ngram_offsets[str_idx];
auto const hasher = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>{0};

auto const end = d_str.data() + d_str.size_bytes();
auto const warp_count = (d_str.size_bytes() / cudf::detail::warp_size) + 1;
auto const lane_idx = idx % cudf::detail::warp_size;

auto d_hashes = d_results + ngram_offset;
auto itr = d_str.data() + lane_idx;
for (auto i = 0; i < warp_count; ++i) {
cudf::hash_value_type hash = 0;
if (itr < end && cudf::strings::detail::is_begin_utf8_char(*itr)) {
// resolve ngram substring
auto const sub_str =
cudf::string_view(itr, static_cast<cudf::size_type>(thrust::distance(itr, end)));
auto const [bytes, left] =
cudf::strings::detail::bytes_to_character_position(sub_str, ngrams);
if (left == 0) { hash = hasher(cudf::string_view(itr, bytes)); }
}
hvs[threadIdx.x] = hash; // store hash into shared memory
__syncwarp();
if (lane_idx == 0) {
// copy valid hash values into d_hashes
auto const hashes = &hvs[threadIdx.x];
d_hashes = thrust::copy_if(
thrust::seq, hashes, hashes + cudf::detail::warp_size, d_hashes, [](auto h) {
return h != 0;
});
}
__syncwarp();
itr += cudf::detail::warp_size;
}
};
}
} // namespace

std::unique_ptr<cudf::column> hash_character_ngrams(cudf::strings_column_view const& input,
Expand All @@ -291,18 +361,16 @@ std::unique_ptr<cudf::column> hash_character_ngrams(cudf::strings_column_view co
if (input.is_empty()) { return cudf::make_empty_column(output_type); }

auto const d_strings = cudf::column_device_view::create(input.parent(), stream);
auto const grid = cudf::detail::grid_1d(
static_cast<cudf::thread_index_type>(input.size()) * cudf::detail::warp_size, block_size);

// build offsets column by computing the number of ngrams per string
auto sizes_itr = cudf::detail::make_counting_transform_iterator(
0,
cuda::proclaim_return_type<cudf::size_type>(
[d_strings = *d_strings, ngrams] __device__(auto idx) {
if (d_strings.is_null(idx)) { return 0; }
auto const length = d_strings.element<cudf::string_view>(idx).length();
return std::max(0, static_cast<cudf::size_type>(length + 1 - ngrams));
}));
auto [offsets, total_ngrams] =
cudf::detail::make_offsets_child_column(sizes_itr, sizes_itr + input.size(), stream, mr);
auto [offsets, total_ngrams] = [&] {
auto counts = rmm::device_uvector<cudf::size_type>(input.size(), stream);
count_char_ngrams_kernel<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
*d_strings, ngrams, counts.data());
return cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr);
}();
auto d_offsets = offsets->view().data<cudf::size_type>();

CUDF_EXPECTS(total_ngrams > 0,
Expand All @@ -313,11 +381,8 @@ std::unique_ptr<cudf::column> hash_character_ngrams(cudf::strings_column_view co
cudf::make_numeric_column(output_type, total_ngrams, cudf::mask_state::UNALLOCATED, stream, mr);
auto d_hashes = hashes->mutable_view().data<cudf::hash_value_type>();

character_ngram_hash_fn generator{*d_strings, ngrams, d_offsets, d_hashes};
thrust::for_each_n(rmm::exec_policy(stream),
thrust::counting_iterator<cudf::size_type>(0),
input.size(),
generator);
character_ngram_hash_kernel<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
*d_strings, ngrams, d_offsets, d_hashes);

return make_lists_column(
input.size(), std::move(offsets), std::move(hashes), 0, rmm::device_buffer{}, stream, mr);
Expand Down
Loading