Skip to content

Commit

Permalink
[BUG] Fix a bug in the filtering operation in CAGRA multi-kernel (#1862)
Browse files Browse the repository at this point in the history
This PR fixes a bug in the filtering operations in the CAGRA multi-kernel search implementation. This bug caused the test of #1837 to fail.

Authors:
   - tsuki (https://github.com/enp1s0)

Approvers:
   - Micka (https://github.com/lowener)
   - Corey J. Nolet (https://github.com/cjnolet)
  • Loading branch information
enp1s0 authored Oct 2, 2023
1 parent e618fb0 commit 1ee423b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
23 changes: 16 additions & 7 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,15 @@ __global__ void apply_filter_kernel(INDEX_T* const result_indices_ptr,
const INDEX_T query_id_offset,
SAMPLE_FILTER_T sample_filter)
{
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= result_buffer_size * num_queries) { return; }
const auto i = tid % result_buffer_size;
const auto j = tid / result_buffer_size;
const auto index = i + j * lds;

if (!sample_filter(query_id_offset + j, result_indices_ptr[index])) {
if (result_indices_ptr[index] != ~index_msb_1_mask &&
!sample_filter(query_id_offset + j, result_indices_ptr[index])) {
result_indices_ptr[index] = utils::get_max_value<INDEX_T>();
result_distances_ptr[index] = utils::get_max_value<DISTANCE_T>();
}
Expand Down Expand Up @@ -788,12 +790,15 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size;
auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size;

// Remove parent bit in search results
remove_parent_bit(
num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream);
if constexpr (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
// Remove parent bit in search results
remove_parent_bit(num_queries,
result_buffer_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
stream);

if (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
apply_filter<INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>(
result_indices.data() + (iter & 0x1) * itopk_size,
result_distances.data() + (iter & 0x1) * itopk_size,
Expand Down Expand Up @@ -821,6 +826,10 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
true,
topk_hint.data(),
stream);
} else {
// Remove parent bit in search results
remove_parent_bit(
num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream);
}

// Copy results from working buffer to final buffer
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ struct search_plan_impl : public search_plan_impl_base {
"`hashmap_max_fill_rate` must be equal to or greater than 0.1 and smaller than 0.9. " +
std::to_string(hashmap_max_fill_rate) + " has been given.";
}
if constexpr (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
if (hashmap_mode == hash_mode::SMALL) {
error_message += "`SMALL` hash is not available when filtering";
} else {
hashmap_mode = hash_mode::HASH;
}
}
if (algo == search_algo::MULTI_CTA) {
if (hashmap_mode == hash_mode::SMALL) {
error_message += "`small_hash` is not available when 'search_mode' is \"multi-cta\"";
Expand Down
3 changes: 1 addition & 2 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1278,8 +1278,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
std::thread update_and_sample_thread(update_and_sample, it);
std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r";
std::fflush(stdout);
RAFT_LOG_DEBUG("# GNND iteraton: %lu / %lu", it + 1, build_config_.max_iterations);
// Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it
// contains some information for local_join.
Expand Down

0 comments on commit 1ee423b

Please sign in to comment.