Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Dec 5, 2024
1 parent 8ff6991 commit 37e26c1
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 29 deletions.
6 changes: 3 additions & 3 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ class factory {
if (plan.algo == search_algo::SINGLE_CTA) {
return std::make_unique<
single_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::make_unique<
multi_cta_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
} else {
return std::make_unique<
multi_kernel_search::search<DataT, IndexT, DistanceT, CagraSampleFilterT>>(
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
res, plan, dataset_desc, plan.dim, plan.dataset_size, plan.graph_degree, plan.topk);
}
}
};
Expand Down
32 changes: 15 additions & 17 deletions cpp/src/neighbors/detail/cagra/hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT old = atomicCAS(&table[index], hashval_empty, key);
if (old == hashval_empty) {
Expand All @@ -86,19 +86,19 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t insert(IdxT* const table,
template <class IdxT, unsigned SUPPORT_REMOVE = 0>
RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen, const IdxT key)
{
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
// Double hashing
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
const IdxT val = table[index];
if (val == key) {
Expand All @@ -107,9 +107,7 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen,
return 0;
} else if (SUPPORT_REMOVE) {
// Check if this key has been removed.
if (val == removed_key) {
return 0;
}
if (val == removed_key) { return 0; }
}
index = (index + stride) & bit_mask;
}
Expand All @@ -119,19 +117,19 @@ RAFT_DEVICE_INLINE_FUNCTION uint32_t search(IdxT* table, const uint32_t bitlen,
template <class IdxT>
RAFT_DEVICE_INLINE_FUNCTION uint32_t remove(IdxT* table, const uint32_t bitlen, const IdxT key)
{
const uint32_t size = get_size(bitlen);
const uint32_t size = get_size(bitlen);
const uint32_t bit_mask = size - 1;
#ifdef HASHMAP_LINEAR_PROBING
// Linear probing
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
IdxT index = (key ^ (key >> bitlen)) & bit_mask;
constexpr uint32_t stride = 1;
#else
// Double hashing
IdxT index = key & bit_mask;
IdxT index = key & bit_mask;
const uint32_t stride = (key >> bitlen) * 2 + 1;
#endif
constexpr IdxT hashval_empty = ~static_cast<IdxT>(0);
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
const IdxT removed_key = key | utils::gen_index_msb_1_mask<IdxT>::value;
for (unsigned i = 0; i < size; i++) {
// To remove a key, set the MSB to 1.
const uint32_t old = atomicCAS(&table[index], key, removed_key);
Expand Down
19 changes: 12 additions & 7 deletions cpp/src/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,13 @@ struct search_plan_impl_base : public search_params {
int64_t dim;
int64_t graph_degree;
uint32_t topk;
search_plan_impl_base(search_params params, int64_t dim, int64_t dataset_size,
int64_t graph_degree, uint32_t topk)
: search_params(params), dim(dim), dataset_size(dataset_size), graph_degree(graph_degree), topk(topk)
search_plan_impl_base(
search_params params, int64_t dim, int64_t dataset_size, int64_t graph_degree, uint32_t topk)
: search_params(params),
dim(dim),
dataset_size(dataset_size),
graph_degree(graph_degree),
topk(topk)
{
if (algo == search_algo::AUTO) {
const size_t num_sm = raft::getMultiProcessorCount();
Expand Down Expand Up @@ -195,9 +199,9 @@ struct search_plan_impl : public search_plan_impl_base {
uint32_t _max_iterations = max_iterations;
if (max_iterations == 0) {
if (algo == search_algo::MULTI_CTA) {
constexpr uint32_t mc_itopk_size = 32;
constexpr uint32_t mc_itopk_size = 32;
constexpr uint32_t mc_search_width = 1;
_max_iterations = mc_itopk_size / mc_search_width;
_max_iterations = mc_itopk_size / mc_search_width;
} else {
_max_iterations = itopk_size / search_width;
}
Expand Down Expand Up @@ -246,7 +250,7 @@ struct search_plan_impl : public search_plan_impl_base {
// shared among CTAs.
//
const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations);
small_hash_bitlen = 11; // 2K
small_hash_bitlen = 11; // 2K
while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) {
small_hash_bitlen += 1;
}
Expand All @@ -257,7 +261,8 @@ struct search_plan_impl : public search_plan_impl_base {
// in each iteration is managed in a separate hash table, which is shared
// among the CTAs.
//
const auto max_traversed_nodes = mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations);
const auto max_traversed_nodes =
mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations);
unsigned min_bitlen = 11; // 2K
if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; }
hash_bitlen = min_bitlen;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ __device__ void search_core(
num_seeds,
local_visited_hashmap_ptr,
hash_bitlen,
(INDEX_T*) nullptr,
(INDEX_T*)nullptr,
0);
__syncthreads();
_CLK_REC(clk_compute_1st_distance);
Expand Down Expand Up @@ -751,7 +751,7 @@ __device__ void search_core(
graph_degree,
local_visited_hashmap_ptr,
hash_bitlen,
(INDEX_T*) nullptr,
(INDEX_T*)nullptr,
0,
parent_list_buffer,
result_indices_buffer,
Expand Down

0 comments on commit 37e26c1

Please sign in to comment.