diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index a81133181f..55b7b47508 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -164,7 +164,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in // computaiton is necessary. for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += blockDim.x) { const INDEX_T smem_parent_id = parent_indices[i / knn_k]; - INDEX_T child_id = invalid_index; + INDEX_T child_id = invalid_index; if (smem_parent_id != invalid_index) { const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)]; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index e41d9854cc..358a183971 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -273,21 +273,20 @@ __launch_bounds__(1024, 1) __global__ void search_kernel( _CLK_START(); // constexpr unsigned max_n_frags = 16; constexpr unsigned max_n_frags = 0; - device:: - compute_distance_to_child_nodes( - result_indices_buffer + itopk_size, - result_distances_buffer + itopk_size, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_indices_buffer, - result_indices_buffer, - search_width); + device::compute_distance_to_child_nodes( + result_indices_buffer + itopk_size, + result_distances_buffer + itopk_size, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_ld, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_indices_buffer, + result_indices_buffer, + search_width); _CLK_REC(clk_compute_distance); __syncthreads(); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index b84a805b5f..3a5501f545 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -627,16 +627,16 @@ __launch_bounds__(1024, 1) __global__ if (std::is_same::value || *filter_flag == 0) { - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - search_width * graph_degree, - topk_ws, - (iter == 0), - multi_warps_1, - multi_warps_2); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + search_width * graph_degree, + topk_ws, + (iter == 0), + multi_warps_1, + multi_warps_2); __syncthreads(); } else { topk_by_bitonic_sort_1st( @@ -644,7 +644,7 @@ __launch_bounds__(1024, 1) __global__ result_indices_buffer, internal_topk + search_width * graph_degree, internal_topk, - false); + false); if (threadIdx.x == 0) { *terminate_flag = 0; } } _CLK_REC(clk_topk); @@ -703,21 +703,20 @@ __launch_bounds__(1024, 1) __global__ // compute the norms between child nodes and query node _CLK_START(); constexpr unsigned max_n_frags = 16; - device:: - compute_distance_to_child_nodes( - result_indices_buffer + internal_topk, - result_distances_buffer + internal_topk, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_list_buffer, - result_indices_buffer, - search_width); + device::compute_distance_to_child_nodes( + result_indices_buffer + internal_topk, + result_distances_buffer + internal_topk, + query_buffer, + dataset_ptr, + dataset_dim, + dataset_ld, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + parent_list_buffer, + result_indices_buffer, + search_width); __syncthreads(); _CLK_REC(clk_compute_distance); @@ -768,7 +767,7 @@ __launch_bounds__(1024, 1) __global__ result_indices_buffer, internal_topk + search_width * graph_degree, top_k, - false); + false); __syncthreads(); } @@ -834,8 +833,7 @@ struct search_kernel_config { T, DistT, IdxT, - SAMPLE_FILTER_T - >; + SAMPLE_FILTER_T>; } else if (itopk_size <= 256) { return search_kernel; + SAMPLE_FILTER_T>; } else if (itopk_size <= 512) { return search_kernel; + SAMPLE_FILTER_T>; } THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); }