diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index a6a20ca49..efbf9b56d 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -230,6 +230,7 @@ struct search : public search_plan_impl { num_queries, dev_seed_ptr, num_executed_iterations, + *this, topk, thread_block_size, result_buffer_size, @@ -237,13 +238,7 @@ struct search : public search_plan_impl { hash_bitlen, hashmap.data(), num_cta_per_query, - num_random_samplings, - rand_xor_mask, num_seeds, - itopk_size, - search_width, - min_iterations, - max_iterations, sample_filter, this->metric, stream); diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh index e28389f38..b1cfaf870 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh @@ -32,6 +32,7 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search { const uint32_t num_queries, \ const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \ uint32_t* const num_executed_iterations, \ + const search_params& ps, \ uint32_t topk, \ uint32_t block_size, \ uint32_t result_buffer_size, \ @@ -39,13 +40,7 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search { int64_t hash_bitlen, \ typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \ uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ cuvs::distance::DistanceType metric, \ cudaStream_t stream); diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index 495ec6a4d..b00d6617c 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -27,33 +27,29 @@ namespace multi_cta_search { #ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY template + unsigned DATASET_BLOCK_DIM, + typename DATASET_DESCRIPTOR_T, + typename SAMPLE_FILTER_T> void select_and_run( DATASET_DESCRIPTOR_T dataset_desc, raft::device_matrix_view graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, + typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, - uint32_t* const num_executed_iterations, + const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] + const search_params& ps, uint32_t topk, - uint32_t block_size, + // multi_cta_search (params struct) + uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, int64_t hash_bitlen, typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, SAMPLE_FILTER_T sample_filter, cuvs::distance::DistanceType metric, cudaStream_t stream) RAFT_EXPLICIT; @@ -75,6 +71,7 @@ void select_and_run( const uint32_t num_queries, \ const INDEX_T* dev_seed_ptr, \ uint32_t* const num_executed_iterations, \ + const search_params& ps, \ uint32_t topk, \ uint32_t block_size, \ uint32_t result_buffer_size, \ @@ -82,13 +79,7 @@ void select_and_run( int64_t hash_bitlen, \ INDEX_T* hashmap_ptr, \ uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ cuvs::distance::DistanceType metric, \ cudaStream_t stream); @@ -160,6 +151,7 @@ instantiate_kernel_selection( const uint32_t num_queries, \ const INDEX_T* dev_seed_ptr, \ uint32_t* const num_executed_iterations, \ + const search_params& ps, \ uint32_t topk, \ uint32_t block_size, \ uint32_t result_buffer_size, \ @@ -167,13 +159,7 @@ instantiate_kernel_selection( int64_t hash_bitlen, \ INDEX_T* hashmap_ptr, \ uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ cuvs::distance::DistanceType metric, \ cudaStream_t stream); diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 90e699f48..4d2030c6c 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -458,6 +458,7 @@ void select_and_run( const uint32_t num_queries, const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] uint32_t* const num_executed_iterations, // [num_queries,] + const search_params& ps, uint32_t topk, // multi_cta_search (params struct) uint32_t block_size, // @@ -466,13 +467,7 @@ void select_and_run( int64_t hash_bitlen, typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, SAMPLE_FILTER_T sample_filter, cuvs::distance::DistanceType metric, cudaStream_t stream) @@ -507,16 +502,16 @@ void select_and_run( queries_ptr, graph.data_handle(), graph.extent(1), - num_random_samplings, - rand_xor_mask, + ps.num_random_samplings, + ps.rand_xor_mask, dev_seed_ptr, num_seeds, hashmap_ptr, hash_bitlen, - itopk_size, - search_width, - min_iterations, - max_iterations, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, num_executed_iterations, sample_filter, metric); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index b8e1726e7..0a101cbfe 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -233,6 +233,7 @@ struct search : search_plan_impl { num_queries, dev_seed_ptr, num_executed_iterations, + *this, topk, num_itopk_candidates, static_cast(thread_block_size), @@ -241,13 +242,7 @@ struct search : search_plan_impl { hashmap.data(), small_hash_bitlen, small_hash_reset_interval, - num_random_samplings, - rand_xor_mask, num_seeds, - itopk_size, - search_width, - min_iterations, - max_iterations, sample_filter, this->metric, stream); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh index b3d75e923..a4581d15e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh @@ -32,6 +32,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search { const uint32_t num_queries, \ const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \ uint32_t* const num_executed_iterations, \ + const search_params& ps, \ uint32_t topk, \ uint32_t num_itopk_candidates, \ uint32_t block_size, \ @@ -40,13 +41,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search { typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \ size_t small_hash_bitlen, \ size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ cuvs::distance::DistanceType metric, \ cudaStream_t stream); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index dfcdec28f..79f6e153c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -26,10 +26,10 @@ namespace single_cta_search { #ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY template -void select_and_run( // raft::resources const& res, +void select_and_run( DATASET_DESCRIPTOR_T dataset_desc, raft::device_matrix_view graph, @@ -39,21 +39,16 @@ void select_and_run( // raft::resources const& res, const uint32_t num_queries, const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] uint32_t* const num_executed_iterations, // [num_queries,] + const search_params& ps, uint32_t topk, uint32_t num_itopk_candidates, - uint32_t block_size, + uint32_t block_size, // uint32_t smem_size, int64_t hash_bitlen, typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, size_t small_hash_bitlen, size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, SAMPLE_FILTER_T sample_filter, cuvs::distance::DistanceType metric, cudaStream_t stream) RAFT_EXPLICIT; @@ -76,6 +71,7 @@ void select_and_run( // raft::resources const& res, const uint32_t num_queries, \ const INDEX_T* dev_seed_ptr, \ uint32_t* const num_executed_iterations, \ + const search_params& ps, \ uint32_t topk, \ uint32_t num_itopk_candidates, \ uint32_t block_size, \ @@ -84,13 +80,7 @@ void select_and_run( // raft::resources const& res, INDEX_T* hashmap_ptr, \ size_t small_hash_bitlen, \ size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ cuvs::distance::DistanceType metric, \ cudaStream_t stream); @@ -162,6 +152,7 @@ instantiate_single_cta_select_and_run( const uint32_t num_queries, \ const INDEX_T* dev_seed_ptr, \ uint32_t* const num_executed_iterations, \ + const search_params& ps, \ uint32_t topk, \ uint32_t num_itopk_candidates, \ uint32_t block_size, \ @@ -170,13 +161,7 @@ instantiate_single_cta_select_and_run( INDEX_T* hashmap_ptr, \ size_t small_hash_bitlen, \ size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ cuvs::distance::DistanceType metric, \ cudaStream_t stream); diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index e58167432..a101cdc1f 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -919,6 +919,7 @@ void select_and_run( const uint32_t num_queries, const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] uint32_t* const num_executed_iterations, // [num_queries,] + const search_params& ps, uint32_t topk, uint32_t num_itopk_candidates, uint32_t block_size, // @@ -927,20 +928,14 @@ void select_and_run( typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, size_t small_hash_bitlen, size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, SAMPLE_FILTER_T sample_filter, cuvs::distance::DistanceType metric, cudaStream_t stream) { auto kernel = search_kernel_config:: - choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); + choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size); RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte)); @@ -955,15 +950,15 @@ void select_and_run( queries_ptr, graph.data_handle(), graph.extent(1), - num_random_samplings, - rand_xor_mask, + ps.num_random_samplings, + ps.rand_xor_mask, dev_seed_ptr, num_seeds, hashmap_ptr, - itopk_size, - search_width, - min_iterations, - max_iterations, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, num_executed_iterations, hash_bitlen, small_hash_bitlen,