From 37e26c1bd2bca4b41ce0510e842a07ea3ead0816 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 5 Dec 2024 07:51:06 -0800 Subject: [PATCH] fix style --- cpp/src/neighbors/detail/cagra/factory.cuh | 6 ++-- cpp/src/neighbors/detail/cagra/hashmap.hpp | 32 +++++++++---------- .../neighbors/detail/cagra/search_plan.cuh | 19 +++++++---- .../cagra/search_single_cta_kernel-inl.cuh | 4 +-- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 064f880ad..d2ae5c55b 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -57,15 +57,15 @@ class factory { if (plan.algo == search_algo::SINGLE_CTA) { return std::make_unique< single_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else if (plan.algo == search_algo::MULTI_CTA) { return std::make_unique< multi_cta_search::search>( - res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } else { return std::make_unique< multi_kernel_search::search>( - res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); + res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk); } } }; diff --git a/cpp/src/neighbors/detail/cagra/hashmap.hpp b/cpp/src/neighbors/detail/cagra/hashmap.hpp index 6dbdd5a8a..da736ef5e 100644 --- a/cpp/src/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/src/neighbors/detail/cagra/hashmap.hpp @@ -62,7 +62,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, const uint32_t stride = (key >> bitlen) * 2 + 1; #endif constexpr IdxT hashval_empty = ~static_cast(0); - const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { const IdxT old = atomicCAS(&table[index], hashval_empty, key); if (old == hashval_empty) { @@ -86,19 +86,19 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table, template RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key) { - const uint32_t size = get_size(bitlen); - const uint32_t bit_mask = size - 1; + const uint32_t size = get_size(bitlen); + const uint32_t bit_mask = size - 1; #ifdef HASHMAP_LINEAR_PROBING - // Linear probing - IdxT index = (key ^ (key >> bitlen)) & bit_mask; - constexpr uint32_t stride = 1; + // Linear probing + IdxT index = (key ^ (key >> bitlen)) & bit_mask; + constexpr uint32_t stride = 1; #else - // Double hashing - IdxT index = key & bit_mask; - const uint32_t stride = (key >> bitlen) * 2 + 1; + // Double hashing + IdxT index = key & bit_mask; + const uint32_t stride = (key >> bitlen) * 2 + 1; #endif constexpr IdxT hashval_empty = ~static_cast(0); - const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { const IdxT val = table[index]; if (val == key) { @@ -107,9 +107,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, return 0; } else if (SUPPORT_REMOVE) { // Check if this key has been removed. - if (val == removed_key) { - return 0; - } + if (val == removed_key) { return 0; } } index = (index + stride) & bit_mask; } @@ -119,19 +117,19 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, template RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key) { - const uint32_t size = get_size(bitlen); + const uint32_t size = get_size(bitlen); const uint32_t bit_mask = size - 1; #ifdef HASHMAP_LINEAR_PROBING // Linear probing - IdxT index = (key ^ (key >> bitlen)) & bit_mask; + IdxT index = (key ^ (key >> bitlen)) & bit_mask; constexpr uint32_t stride = 1; #else // Double hashing - IdxT index = key & bit_mask; + IdxT index = key & bit_mask; const uint32_t stride = (key >> bitlen) * 2 + 1; #endif constexpr IdxT hashval_empty = ~static_cast(0); - const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; + const IdxT removed_key = key | utils::gen_index_msb_1_mask::value; for (unsigned i = 0; i < size; i++) { // To remove a key, set the MSB to 1. const uint32_t old = atomicCAS(&table[index], key, removed_key); diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 2bbf3d56a..5b6b58a13 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -112,9 +112,13 @@ struct search_plan_impl_base : public search_params { int64_t dim; int64_t graph_degree; uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t dataset_size, - int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), dataset_size(dataset_size), graph_degree(graph_degree), topk(topk) + search_plan_impl_base( + search_params params, int64_t dim, int64_t dataset_size, int64_t graph_degree, uint32_t topk) + : search_params(params), + dim(dim), + dataset_size(dataset_size), + graph_degree(graph_degree), + topk(topk) { if (algo == search_algo::AUTO) { const size_t num_sm = raft::getMultiProcessorCount(); @@ -195,9 +199,9 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t _max_iterations = max_iterations; if (max_iterations == 0) { if (algo == search_algo::MULTI_CTA) { - constexpr uint32_t mc_itopk_size = 32; + constexpr uint32_t mc_itopk_size = 32; constexpr uint32_t mc_search_width = 1; - _max_iterations = mc_itopk_size / mc_search_width; + _max_iterations = mc_itopk_size / mc_search_width; } else { _max_iterations = itopk_size / search_width; } @@ -246,7 +250,7 @@ struct search_plan_impl : public search_plan_impl_base { // shared among CTAs. // const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations); - small_hash_bitlen = 11; // 2K + small_hash_bitlen = 11; // 2K while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { small_hash_bitlen += 1; } @@ -257,7 +261,8 @@ struct search_plan_impl : public search_plan_impl_base { // in each iteration is managed in a separate hash table, which is shared // among the CTAs. // - const auto max_traversed_nodes = mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); + const auto max_traversed_nodes = + mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); unsigned min_bitlen = 11; // 2K if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } hash_bitlen = min_bitlen; diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 0eedb8d09..94c97ed16 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -623,7 +623,7 @@ __device__ void search_core( num_seeds, local_visited_hashmap_ptr, hash_bitlen, - (INDEX_T*) nullptr, + (INDEX_T*)nullptr, 0); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -751,7 +751,7 @@ __device__ void search_core( graph_degree, local_visited_hashmap_ptr, hash_bitlen, - (INDEX_T*) nullptr, + (INDEX_T*)nullptr, 0, parent_list_buffer, result_indices_buffer,