Skip to content

Commit

Permalink
Fix the topk function name
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Dec 3, 2024
1 parent 3d676e2 commit 49acbab
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const termin
}

template <unsigned MAX_CANDIDATES, class IdxT = void>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_1st(
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_full(
float* candidate_distances, // [num_candidates]
IdxT* candidate_indices, // [num_candidates]
const std::uint32_t num_candidates,
Expand Down Expand Up @@ -215,7 +215,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_1st(
}

template <unsigned MAX_ITOPK, class IdxT = void>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_2nd(
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_merge(
float* itopk_distances, // [num_itopk]
IdxT* itopk_indices, // [num_itopk]
const std::uint32_t num_itopk,
Expand Down Expand Up @@ -424,7 +424,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_2nd(
template <unsigned MAX_ITOPK,
unsigned MAX_CANDIDATES,
class IdxT>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge(
float* itopk_distances, // [num_itopk]
IdxT* itopk_indices, // [num_itopk]
const std::uint32_t num_itopk,
Expand All @@ -437,20 +437,20 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(
const unsigned MULTI_WARPS_2)
{
// The results in candidate_distances/indices are sorted by bitonic sort.
topk_by_bitonic_sort_1st<MAX_CANDIDATES, IdxT>(
topk_by_bitonic_sort_and_merge_full<MAX_CANDIDATES, IdxT>(
candidate_distances, candidate_indices, num_candidates, num_itopk, MULTI_WARPS_1);

// The results sorted above are merged with the internal intermediate top-k
// results so far using bitonic merge.
topk_by_bitonic_sort_2nd<MAX_ITOPK, IdxT>(itopk_distances,
itopk_indices,
num_itopk,
candidate_distances,
candidate_indices,
num_candidates,
work_buf,
first,
MULTI_WARPS_2);
topk_by_bitonic_sort_and_merge_merge<MAX_ITOPK, IdxT>(itopk_distances,
itopk_indices,
num_itopk,
candidate_distances,
candidate_indices,
num_candidates,
work_buf,
first,
MULTI_WARPS_2);
}

// This function move the invalid index element to the end of the itopk list.
Expand Down Expand Up @@ -631,10 +631,10 @@ __device__ void search_core(
// sort
if constexpr (TOPK_BY_BITONIC_SORT) {
// [Notice]
// It is good to use multiple warps in topk_by_bitonic_sort() when
// It is good to use multiple warps in topk_by_bitonic_sort_and_merge() when
// batch size is small (short-latency), but it might not be always good
// when batch size is large (high-throughput).
// topk_by_bitonic_sort() consists of two operations:
// topk_by_bitonic_sort_and_merge() consists of two operations:
// if MAX_CANDIDATES is greater than 128, the first operation uses two warps;
// if MAX_ITOPK is greater than 256, the second operation used two warps.
const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0;
Expand All @@ -643,9 +643,9 @@ __device__ void search_core(
// reset small-hash table.
if ((iter + 1) % small_hash_reset_interval == 0) {
// Depending on the block size and the number of warps used in
// topk_by_bitonic_sort(), determine which warps are used to reset
// topk_by_bitonic_sort_and_merge(), determine which warps are used to reset
// the small hash and whether they are performed in overlap with
// topk_by_bitonic_sort().
// topk_by_bitonic_sort_and_merge().
_CLK_START();
unsigned hash_start_tid;
if (blockDim.x == 32) {
Expand Down Expand Up @@ -679,16 +679,17 @@ __device__ void search_core(

if (threadIdx.x == 0) { *terminate_flag = 0; }
}
topk_by_bitonic_sort<MAX_ITOPK, MAX_CANDIDATES>(result_distances_buffer,
result_indices_buffer,
internal_topk,
result_distances_buffer + internal_topk,
result_indices_buffer + internal_topk,
search_width * graph_degree,
topk_ws,
(iter == 0),
multi_warps_1,
multi_warps_2);
topk_by_bitonic_sort_and_merge<MAX_ITOPK, MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk,
result_distances_buffer + internal_topk,
result_indices_buffer + internal_topk,
search_width * graph_degree,
topk_ws,
(iter == 0),
multi_warps_1,
multi_warps_2);
__syncthreads();
_CLK_REC(clk_topk);
} else {
Expand Down Expand Up @@ -829,8 +830,8 @@ __device__ void search_core(
}

if (num_found_valid < top_k) {
// Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort` is usable in
// the next step
// Fill the remaining buffer with invalid values so that `topk_by_bitonic_sort_and_merge` is
// usable in the next step
for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) {
result_indices_buffer[i] = invalid_index;
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
Expand All @@ -844,16 +845,17 @@ __device__ void search_core(
__syncthreads();
const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0;
const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0;
topk_by_bitonic_sort<MAX_ITOPK, MAX_CANDIDATES>(result_distances_buffer,
result_indices_buffer,
internal_topk,
result_distances_buffer + internal_topk,
result_indices_buffer + internal_topk,
search_width * graph_degree,
topk_ws,
(iter == 0),
multi_warps_1,
multi_warps_2);
topk_by_bitonic_sort_and_merge<MAX_ITOPK, MAX_CANDIDATES>(
result_distances_buffer,
result_indices_buffer,
internal_topk,
result_distances_buffer + internal_topk,
result_indices_buffer + internal_topk,
search_width * graph_degree,
topk_ws,
(iter == 0),
multi_warps_1,
multi_warps_2);
}
__syncthreads();
}
Expand Down

0 comments on commit 49acbab

Please sign in to comment.