Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into cuda-120-arm
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice authored Sep 26, 2023
2 parents 0ed1cf3 + 317525d commit 5303bd0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
7 changes: 4 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ struct search : public search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILT

void set_params(raft::resources const& res, const search_params& params)
{
this->itopk_size = 32;
search_width = 1;
num_cta_per_query = max(params.search_width, params.itopk_size / 32);
constexpr unsigned muti_cta_itopk_size = 32;
this->itopk_size = muti_cta_itopk_size;
search_width = 1;
num_cta_per_query = max(params.search_width, params.itopk_size / muti_cta_itopk_size);
result_buffer_size = itopk_size + search_width * graph_degree;
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
Expand Down
7 changes: 4 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ struct search_plan_impl_base : public search_params {
{
set_max_dim_team(dim);
if (algo == search_algo::AUTO) {
if (itopk_size <= 512) {
const size_t num_sm = raft::getMultiProcessorCount();
if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) {
algo = search_algo::SINGLE_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting single-cta");
} else {
algo = search_algo::MULTI_KERNEL;
RAFT_LOG_DEBUG("Auto strategy: selecting multi-kernel");
algo = search_algo::MULTI_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/test/test_cagra.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_cagra_index_params(params):
"search_width": 4,
"min_iterations": 0,
"thread_block_size": 0,
"hashmap_mode": "small",
"hashmap_mode": "auto",
"hashmap_min_bitlen": 0,
"hashmap_max_fill_rate": 0.5,
"num_random_samplings": 1,
Expand Down

0 comments on commit 5303bd0

Please sign in to comment.