Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA: reduce argument count in select_and_run() kernel wrappers #227

Merged
merged 6 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,15 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
thread_block_size,
result_buffer_size,
smem_size,
hash_bitlen,
hashmap.data(),
num_cta_per_query,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,15 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
19 changes: 7 additions & 12 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ void select_and_run(
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
// multi_cta_search (params struct)
uint32_t block_size, //
Expand All @@ -466,13 +467,7 @@ void select_and_run(
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream)
Expand Down Expand Up @@ -507,16 +502,16 @@ void select_and_run(
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
hash_bitlen,
itopk_size,
search_width,
min_iterations,
max_iterations,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
sample_filter,
metric);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
num_itopk_candidates,
static_cast<uint32_t>(thread_block_size),
Expand All @@ -241,13 +242,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
hashmap.data(),
small_hash_bitlen,
small_hash_reset_interval,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -40,13 +41,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
21 changes: 8 additions & 13 deletions cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,7 @@ void select_and_run(
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
uint32_t num_itopk_candidates,
uint32_t block_size, //
Expand All @@ -927,20 +928,14 @@ void select_and_run(
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
size_t small_hash_bitlen,
size_t small_hash_reset_interval,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream)
{
auto kernel =
search_kernel_config<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>::
choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size);
choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size);
RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte));
Expand All @@ -955,15 +950,15 @@ void select_and_run(
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
itopk_size,
search_width,
min_iterations,
max_iterations,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
hash_bitlen,
small_hash_bitlen,
Expand Down
Loading