From cb24d998771a72ea6bad12a65cfb4aaf6ab0122e Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Tue, 26 Sep 2023 04:09:43 +0800 Subject: [PATCH] [FEA] Add pre-filtering to CAGRA (#1811) This PR adds the pre-filtering feature to the CAGRA search implementations. Rel: taken over from https://github.com/rapidsai/raft/pull/1765 ## Algorithm The pre-filtering algorithm removes a node that should not be in the final result after it has behaved as a parent node. This way, the nodes that should not be in the final result are also used in the graph traversal, avoiding potential performance degradation. ## Changes - Add filtering operation on a parent node after internal top-M buffer candidate calculation. - Add filtering operation to result buffer before storing them in the device memory. Authors: - tsuki (https://github.com/enp1s0) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1811 --- cpp/include/raft/neighbors/cagra.cuh | 75 +++++- .../neighbors/detail/cagra/cagra_search.cuh | 60 ++++- .../detail/cagra/compute_distance.hpp | 13 +- .../raft/neighbors/detail/cagra/factory.cuh | 42 ++-- .../detail/cagra/search_multi_cta.cuh | 80 +++---- .../cagra/search_multi_cta_kernel-ext.cuh | 94 ++++---- .../cagra/search_multi_cta_kernel-inl.cuh | 89 ++++++-- .../detail/cagra/search_multi_kernel.cuh | 215 +++++++++++++----- .../neighbors/detail/cagra/search_plan.cuh | 9 +- .../detail/cagra/search_single_cta.cuh | 70 +++--- .../cagra/search_single_cta_kernel-ext.cuh | 96 ++++---- .../cagra/search_single_cta_kernel-inl.cuh | 214 ++++++++++++++--- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 12 +- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 12 +- .../raft/neighbors/sample_filter_types.hpp | 12 + .../cagra/search_multi_cta_00_generate.py | 58 ++--- ...arch_multi_cta_float_uint32_dim1024_t32.cu | 55 +++-- ...search_multi_cta_float_uint32_dim128_t8.cu | 55 +++-- ...earch_multi_cta_float_uint32_dim256_t16.cu | 55 +++-- ...earch_multi_cta_float_uint32_dim512_t32.cu | 55 +++-- ...arch_multi_cta_float_uint64_dim1024_t32.cu | 55 +++-- ...search_multi_cta_float_uint64_dim128_t8.cu | 55 +++-- ...earch_multi_cta_float_uint64_dim256_t16.cu | 55 +++-- ...earch_multi_cta_float_uint64_dim512_t32.cu | 55 +++-- ...earch_multi_cta_int8_uint32_dim1024_t32.cu | 55 +++-- .../search_multi_cta_int8_uint32_dim128_t8.cu | 55 +++-- ...search_multi_cta_int8_uint32_dim256_t16.cu | 55 +++-- ...search_multi_cta_int8_uint32_dim512_t32.cu | 55 +++-- ...arch_multi_cta_uint8_uint32_dim1024_t32.cu | 55 +++-- ...search_multi_cta_uint8_uint32_dim128_t8.cu | 55 +++-- ...earch_multi_cta_uint8_uint32_dim256_t16.cu | 55 +++-- ...earch_multi_cta_uint8_uint32_dim512_t32.cu | 55 +++-- .../cagra/search_single_cta_00_generate.py | 59 ++--- ...rch_single_cta_float_uint32_dim1024_t32.cu | 58 ++--- ...earch_single_cta_float_uint32_dim128_t8.cu | 58 ++--- ...arch_single_cta_float_uint32_dim256_t16.cu | 58 ++--- ...arch_single_cta_float_uint32_dim512_t32.cu | 58 ++--- ...rch_single_cta_float_uint64_dim1024_t32.cu | 58 ++--- ...earch_single_cta_float_uint64_dim128_t8.cu | 58 ++--- ...arch_single_cta_float_uint64_dim256_t16.cu | 58 ++--- ...arch_single_cta_float_uint64_dim512_t32.cu | 58 ++--- ...arch_single_cta_int8_uint32_dim1024_t32.cu | 58 ++--- ...search_single_cta_int8_uint32_dim128_t8.cu | 58 ++--- ...earch_single_cta_int8_uint32_dim256_t16.cu | 58 ++--- ...earch_single_cta_int8_uint32_dim512_t32.cu | 58 ++--- ...rch_single_cta_uint8_uint32_dim1024_t32.cu | 58 ++--- ...earch_single_cta_uint8_uint32_dim128_t8.cu | 58 ++--- ...arch_single_cta_uint8_uint32_dim256_t16.cu | 58 ++--- ...arch_single_cta_uint8_uint32_dim512_t32.cu | 58 ++--- cpp/test/neighbors/ann_cagra.cuh | 177 +++++++++++++- .../ann_cagra/search_kernel_uint64_t.cuh | 200 ++++++++-------- .../neighbors/ann_cagra/test_float_int64_t.cu | 4 +- .../ann_cagra/test_float_uint32_t.cu | 8 +- .../ann_cagra/test_int8_t_uint32_t.cu | 7 +- .../ann_cagra/test_uint8_t_uint32_t.cu | 8 +- 55 files changed, 2142 insertions(+), 1280 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 903d0571dc..1bd7010c83 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -54,14 +54,14 @@ namespace raft::neighbors::cagra { * // use default index parameters * cagra::index_params build_params; * cagra::search_params search_params - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); - * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); * // Construct an index from dataset and optimized knn_graph * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); + * optimized_graph.view()); * @endcode * * @tparam DataT data element type @@ -106,7 +106,7 @@ void build_knn_graph(raft::resources const& res, * @code{.cpp} * using namespace raft::neighbors; * cagra::index_params build_params; - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // build KNN graph not using `cagra::build_knn_graph` * // build(knn_graph, dataset, ...); * // sort graph index @@ -115,7 +115,7 @@ void build_knn_graph(raft::resources const& res, * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); * // Construct an index from dataset and optimized knn_graph * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); + * optimized_graph.view()); * @endcode * * @tparam DataT type of the data in the source dataset @@ -316,9 +316,70 @@ void search(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal); + cagra::detail::search_main(res, + params, + idx, + queries_internal, + neighbors_internal, + distances_internal, + raft::neighbors::filtering::none_cagra_sample_filter()); } + +/** + * @brief Search ANN using the constructed index with the given sample filter. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * @tparam CagraSampleFilterT Device filter function, with the signature + * `(uint32_t query ix, uint32_t sample_ix) -> bool` + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * @param[in] sample_filter a device filter function that greenlights samples for a given query + */ +template +void search_with_filtering(raft::resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT()) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + using internal_IdxT = typename std::make_unsigned::type; + auto queries_internal = raft::make_device_matrix_view( + queries.data_handle(), queries.extent(0), queries.extent(1)); + auto neighbors_internal = raft::make_device_matrix_view( + reinterpret_cast(neighbors.data_handle()), + neighbors.extent(0), + neighbors.extent(1)); + auto distances_internal = raft::make_device_matrix_view( + distances.data_handle(), distances.extent(0), distances.extent(1)); + + cagra::detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); +} + /** @} */ // end group cagra } // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index b484fa55f9..81e714dc4e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -34,6 +35,48 @@ namespace raft::neighbors::cagra::detail { +template +struct CagraSampleFilterWithQueryIdOffset { + const uint32_t offset; + CagraSampleFilterT filter; + + CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter) + : offset(offset), filter(filter) + { + } + + _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id) + { + return filter(query_id + offset, sample_id); + } +}; + +template +struct CagraSampleFilterT_Selector { + using type = CagraSampleFilterWithQueryIdOffset; +}; +template <> +struct CagraSampleFilterT_Selector { + using type = raft::neighbors::filtering::none_cagra_sample_filter; +}; + +// A helper function to set a query id offset +template +inline typename CagraSampleFilterT_Selector::type set_offset( + CagraSampleFilterT filter, const uint32_t offset) +{ + typename CagraSampleFilterT_Selector::type new_filter(offset, filter); + return new_filter; +} +template <> +inline + typename CagraSampleFilterT_Selector::type + set_offset( + raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t) +{ + return filter; +} + /** * @brief Search ANN using the constructed index. * @@ -54,13 +97,18 @@ namespace raft::neighbors::cagra::detail { * k] */ -template +template void search_main(raft::resources const& res, search_params params, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT()) { resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search"); RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", @@ -77,8 +125,9 @@ void search_main(raft::resources const& res, common::nvtx::range fun_scope( "cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim()); - std::unique_ptr> plan = - factory::create( + using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; + std::unique_ptr> plan = + factory::create( res, params, index.dim(), index.graph_degree(), topk); plan->check(neighbors.extent(1)); @@ -119,7 +168,8 @@ void search_main(raft::resources const& res, n_queries, _seed_ptr, _num_executed_iterations, - topk); + topk, + set_offset(sample_filter, qid)); } static_assert(std::is_same_v, diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 47e976e252..624c1a35d6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -155,17 +155,20 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in INDEX_T* const visited_hashmap_ptr, const std::uint32_t hash_bitlen, const INDEX_T* const parent_indices, + const INDEX_T* const internal_topk_list, const std::uint32_t search_width) { - const INDEX_T invalid_index = utils::get_max_value(); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); // Read child indices of parents from knn graph and check if the distance // computaiton is necessary. for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += BLOCK_SIZE) { - const INDEX_T parent_id = parent_indices[i / knn_k]; - INDEX_T child_id = invalid_index; - if (parent_id != invalid_index) { - child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)]; + const INDEX_T smem_parent_id = parent_indices[i / knn_k]; + INDEX_T child_id = invalid_index; + if (smem_parent_id != invalid_index) { + const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; + child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)]; } if (child_id != invalid_index) { if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh index 625040194b..78111a9310 100644 --- a/cpp/include/raft/neighbors/detail/cagra/factory.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/factory.cuh @@ -20,20 +20,25 @@ #include "search_multi_kernel.cuh" #include "search_plan.cuh" #include "search_single_cta.cuh" +#include namespace raft::neighbors::cagra::detail { -template +template class factory { public: /** * Create a search structure for dataset with dim features. */ - static std::unique_ptr> create(raft::resources const& res, - search_params const& params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) + static std::unique_ptr> create( + raft::resources const& res, + search_params const& params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) { search_plan_impl_base plan(params, dim, graph_degree, topk); switch (plan.max_dim) { @@ -63,26 +68,29 @@ class factory { break; default: RAFT_LOG_DEBUG("Incorrect max_dim (%lu)\n", plan.max_dim); } - return std::unique_ptr>(); + return std::unique_ptr>(); } private: template - static std::unique_ptr> dispatch_kernel( + static std::unique_ptr> dispatch_kernel( raft::resources const& res, search_plan_impl_base& plan) { if (plan.algo == search_algo::SINGLE_CTA) { - return std::unique_ptr>( - new single_cta_search::search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + return std::unique_ptr>( + new single_cta_search:: + search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); } else if (plan.algo == search_algo::MULTI_CTA) { - return std::unique_ptr>( - new multi_cta_search::search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + return std::unique_ptr>( + new multi_cta_search:: + search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); } else { - return std::unique_ptr>( - new multi_kernel_search::search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + return std::unique_ptr>( + new multi_kernel_search:: + search( + res, plan, plan.dim, plan.graph_degree, plan.topk)); } } }; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 6ea1e34032..9a722a6dfe 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -48,42 +48,43 @@ template - -struct search : public search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; + typename DISTANCE_T, + typename SAMPLE_FILTER_T> + +struct search : public search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; uint32_t num_cta_per_query; rmm::device_uvector intermediate_indices; @@ -96,7 +97,8 @@ struct search : public search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk), + : search_plan_impl( + res, params, dim, graph_degree, topk), intermediate_indices(0, resource::get_cuda_stream(res)), intermediate_distances(0, resource::get_cuda_stream(res)), topk_workspace(0, resource::get_cuda_stream(res)) @@ -196,7 +198,8 @@ struct search : public search_plan_impl { const uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk) + uint32_t topk, + SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); @@ -223,6 +226,7 @@ struct search : public search_plan_impl { search_width, min_iterations, max_iterations, + sample_filter, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index de83acbb64..ee525587d7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include // RAFT_EXPLICIT +#include // none_cagra_sample_filter +#include // RAFT_EXPLICIT namespace raft::neighbors::cagra::detail { namespace multi_cta_search { @@ -26,7 +27,8 @@ template + class DISTANCE_T, + class SAMPLE_FILTER_T> void select_and_run(raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const topk_indices_ptr, @@ -49,47 +51,63 @@ void select_and_run(raft::device_matrix_view( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, float, uint32_t, float); -instantiate_kernel_selection(8, 128, float, uint32_t, float); -instantiate_kernel_selection(16, 256, float, uint32_t, float); -instantiate_kernel_selection(32, 512, float, uint32_t, float); -instantiate_kernel_selection(32, 1024, int8_t, uint32_t, float); -instantiate_kernel_selection(8, 128, int8_t, uint32_t, float); -instantiate_kernel_selection(16, 256, int8_t, uint32_t, float); -instantiate_kernel_selection(32, 512, int8_t, uint32_t, float); -instantiate_kernel_selection(32, 1024, uint8_t, uint32_t, float); -instantiate_kernel_selection(8, 128, uint8_t, uint32_t, float); -instantiate_kernel_selection(16, 256, uint8_t, uint32_t, float); -instantiate_kernel_selection(32, 512, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection } // namespace multi_cta_search diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 4fc051ac09..8bfbc48898 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include @@ -75,7 +76,7 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [sea if (new_parent) { const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; if (i < search_width) { - next_parent_indices[i] = index; + next_parent_indices[i] = j; itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node } } @@ -131,7 +132,8 @@ template + class LOAD_T, + class SAMPLE_FILTER_T> __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] @@ -152,8 +154,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const uint32_t search_width, const uint32_t min_iteration, const uint32_t max_iteration, - uint32_t* const num_executed_iterations /* stats */ -) + uint32_t* const num_executed_iterations, /* stats */ + SAMPLE_FILTER_T sample_filter) { assert(blockDim.x == BLOCK_SIZE); assert(dataset_dim <= MAX_DATASET_DIM); @@ -287,13 +289,57 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( local_visited_hashmap_ptr, hash_bitlen, parent_indices_buffer, + result_indices_buffer, search_width); _CLK_REC(clk_compute_distance); __syncthreads(); + // Filtering + if constexpr (!std::is_same::value) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + if (parent_indices_buffer[p] != invalid_index) { + const auto parent_id = + result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; + if (!sample_filter(query_id, parent_id)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_indices_buffer[p]] = invalid_index; + } + } + } + __syncthreads(); + } + iter++; } + // Post process for filtering + if constexpr (!std::is_same::value) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned i = threadIdx.x; i < itopk_size + search_width * graph_degree; i += blockDim.x) { + const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; + if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[i] = utils::get_max_value(); + result_indices_buffer[i] = invalid_index; + } + } + + __syncthreads(); + topk_by_bitonic_sort(result_distances_buffer, + result_indices_buffer, + itopk_size + (search_width * graph_degree), + itopk_size); + __syncthreads(); + } + for (uint32_t i = threadIdx.x; i < itopk_size; i += BLOCK_SIZE) { uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } @@ -361,7 +407,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> struct search_kernel_config { // Search kernel function type. Note that the actual values for the template value // parameters do not matter, because they are not part of the function signature. The @@ -374,7 +421,8 @@ struct search_kernel_config { DATA_T, DISTANCE_T, INDEX_T, - device::LOAD_128BIT_T>); + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>); static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t { @@ -401,7 +449,8 @@ struct search_kernel_config { DATA_T, DISTANCE_T, INDEX_T, - device::LOAD_128BIT_T>; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else if (block_size == 128) { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else if (block_size == 256) { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else if (block_size == 512) { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } else { return search_kernel; + device::LOAD_128BIT_T, + SAMPLE_FILTER_T>; } } }; @@ -450,7 +503,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> void select_and_run( // raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, @@ -475,10 +529,12 @@ void select_and_run( // raft::resources const& res, size_t search_width, size_t min_iterations, size_t max_iterations, + SAMPLE_FILTER_T sample_filter, cudaStream_t stream) { - auto kernel = search_kernel_config:: - choose_buffer_size(result_buffer_size, block_size); + auto kernel = + search_kernel_config:: + choose_buffer_size(result_buffer_size, block_size); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -489,7 +545,7 @@ void select_and_run( // raft::resources const& res, dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); - RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %lu smem", + RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %u smem", block_size, num_cta_per_query, num_queries, @@ -513,7 +569,8 @@ void select_and_run( // raft::resources const& res, search_width, min_iterations, max_iterations, - num_executed_iterations); + num_executed_iterations, + sample_filter); } } // namespace multi_cta_search diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index f312226f42..ff1bb969e7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -242,7 +243,7 @@ __global__ void pickup_next_parents_kernel( if (new_parent) { const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; if (i < parent_list_size) { - parent_list_ptr[i + (ldd * query_id)] = index; + parent_list_ptr[i + (ldd * query_id)] = j; parent_candidates_ptr[j + (lds * query_id)] |= index_msb_1_mask; // set most significant bit as used node } @@ -306,9 +307,13 @@ template + class DISTANCE_T, + class SAMPLE_FILTER_T> __global__ void compute_distance_to_child_nodes_kernel( const INDEX_T* const parent_node_list, // [num_queries, search_width] + INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] + DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] + const std::size_t lds, const std::uint32_t search_width, const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, @@ -321,16 +326,25 @@ __global__ void compute_distance_to_child_nodes_kernel( const std::uint32_t hash_bitlen, INDEX_T* const result_indices_ptr, // [num_queries, ldd] DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd // (*) ldd >= search_width * graph_degree -) + const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter) { const uint32_t ldb = hashmap::get_size(hash_bitlen); const auto tid = threadIdx.x + blockDim.x * blockIdx.x; const auto global_team_id = tid / TEAM_SIZE; + const auto query_id = blockIdx.y; + if (global_team_id >= search_width * graph_degree) { return; } - const std::size_t parent_index = + const std::size_t parent_list_index = parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; + + if (parent_list_index == utils::get_max_value()) { return; } + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto parent_index = + parent_candidates_ptr[parent_list_index + (lds * query_id)] & ~index_msb_1_mask; + if (parent_index == utils::get_max_value()) { result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); return; @@ -361,15 +375,28 @@ __global__ void compute_distance_to_child_nodes_kernel( result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); } } + + if constexpr (!std::is_same::value) { + if (!sample_filter(query_id, parent_index)) { + parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); + parent_distance_ptr[parent_list_index + (lds * query_id)] = + utils::get_max_value(); + } + } } template + class DISTANCE_T, + class SAMPLE_FILTER_T> void compute_distance_to_child_nodes( const INDEX_T* const parent_node_list, // [num_queries, search_width] + INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] + DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] + const std::size_t lds, const uint32_t search_width, const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, @@ -384,6 +411,7 @@ void compute_distance_to_child_nodes( INDEX_T* const result_indices_ptr, // [num_queries, ldd] DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter, cudaStream_t cuda_stream = 0) { const auto block_size = 128; @@ -392,6 +420,9 @@ void compute_distance_to_child_nodes( num_queries); compute_distance_to_child_nodes_kernel <<>>(parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, search_width, dataset_ptr, data_dim, @@ -404,7 +435,8 @@ void compute_distance_to_child_nodes( hash_bitlen, result_indices_ptr, result_distances_ptr, - ldd); + ldd, + sample_filter); } template @@ -436,6 +468,50 @@ void remove_parent_bit(const std::uint32_t num_queries, num_queries, num_topk, topk_indices_ptr, ld); } +// This function called after the `remove_parent_bit` function +template +__global__ void apply_filter_kernel(INDEX_T* const result_indices_ptr, + DISTANCE_T* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const INDEX_T query_id_offset, + SAMPLE_FILTER_T sample_filter) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= result_buffer_size * num_queries) { return; } + const auto i = tid % result_buffer_size; + const auto j = tid / result_buffer_size; + const auto index = i + j * lds; + + if (!sample_filter(query_id_offset + j, result_indices_ptr[index])) { + result_indices_ptr[index] = utils::get_max_value(); + result_distances_ptr[index] = utils::get_max_value(); + } +} + +template +void apply_filter(INDEX_T* const result_indices_ptr, + DISTANCE_T* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const INDEX_T query_id_offset, + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream) +{ + const std::uint32_t block_size = 256; + const std::uint32_t grid_size = ceildiv(num_queries * result_buffer_size, block_size); + + apply_filter_kernel<<>>(result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + sample_filter); +} + template __global__ void batched_memcpy_kernel(T* const dst, // [batch_size, ld_dst] const uint64_t ld_dst, @@ -508,41 +584,42 @@ template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; + typename DISTANCE_T, + typename SAMPLE_FILTER_T> +struct search : search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; size_t result_buffer_allocation_size; rmm::device_uvector result_indices; // results_indices_buffer @@ -557,7 +634,8 @@ struct search : search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk), + : search_plan_impl( + res, params, dim, graph_degree, topk), result_indices(0, resource::get_cuda_stream(res)), result_distances(0, resource::get_cuda_stream(res)), parent_node_list(0, resource::get_cuda_stream(res)), @@ -602,7 +680,8 @@ struct search : search_plan_impl { const uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk) + uint32_t topk, + SAMPLE_FILTER_T sample_filter) { // Init hashmap cudaStream_t stream = resource::get_cuda_stream(res); @@ -684,6 +763,9 @@ struct search : search_plan_impl { // Compute distance to child nodes that are adjacent to the parent node compute_distance_to_child_nodes( parent_node_list.data(), + result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, + result_buffer_allocation_size, search_width, dataset.data_handle(), dataset.extent(1), @@ -698,22 +780,53 @@ struct search : search_plan_impl { result_indices.data() + itopk_size, result_distances.data() + itopk_size, result_buffer_allocation_size, + sample_filter, stream); iter++; } // while ( 1 ) + auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size; + auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size; // Remove parent bit in search results - remove_parent_bit(num_queries, - itopk_size, - result_indices.data() + (iter & 0x1) * result_buffer_size, - result_buffer_allocation_size, - stream); + remove_parent_bit( + num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream); + + if (!std::is_same::value) { + apply_filter( + result_indices.data() + (iter & 0x1) * itopk_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_buffer_size, + num_queries, + 0, + sample_filter, + stream); + + result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size; + result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size; + _cuann_find_topk(itopk_size, + num_queries, + result_buffer_size, + result_distances.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_indices.data() + (iter & 0x1) * itopk_size, + result_buffer_allocation_size, + result_distances_ptr, + result_buffer_allocation_size, + result_indices_ptr, + result_buffer_allocation_size, + topk_workspace.data(), + true, + topk_hint.data(), + stream); + } // Copy results from working buffer to final buffer batched_memcpy(topk_indices_ptr, topk, - result_indices.data() + (iter & 0x1) * result_buffer_size, + result_indices_ptr, result_buffer_allocation_size, topk, num_queries, @@ -721,7 +834,7 @@ struct search : search_plan_impl { if (topk_distances_ptr) { batched_memcpy(topk_distances_ptr, topk, - result_distances.data() + (iter & 0x1) * result_buffer_size, + result_distances_ptr, result_buffer_allocation_size, topk, num_queries, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 33c77db61e..9419385836 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -65,7 +65,7 @@ struct search_plan_impl_base : public search_params { } }; -template +template struct search_plan_impl : public search_plan_impl_base { int64_t hash_bitlen; @@ -113,7 +113,8 @@ struct search_plan_impl : public search_plan_impl_base { const std::uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk){}; + uint32_t topk, + SAMPLE_FILTER_T sample_filter){}; void adjust_search_params() { @@ -129,13 +130,13 @@ struct search_plan_impl : public search_plan_impl_base { if (max_iterations < min_iterations) { _max_iterations = min_iterations; } if (max_iterations < _max_iterations) { RAFT_LOG_DEBUG( - "# max_iterations is increased from %u to %u.", max_iterations, _max_iterations); + "# max_iterations is increased from %lu to %u.", max_iterations, _max_iterations); max_iterations = _max_iterations; } if (itopk_size % 32) { uint32_t itopk32 = itopk_size; itopk32 += 32 - (itopk_size % 32); - RAFT_LOG_DEBUG("# internal_topk is increased from %u to %u, as it must be multiple of 32.", + RAFT_LOG_DEBUG("# internal_topk is increased from %lu to %u, as it must be multiple of 32.", itopk_size, itopk32); itopk_size = itopk32; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 45dd535e1d..27d54f72cb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -49,41 +49,42 @@ template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; + typename DISTANCE_T, + typename SAMPLE_FILTER_T> +struct search : search_plan_impl { + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; - using search_plan_impl::max_dim; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; + using search_plan_impl::max_dim; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; - using search_plan_impl::hash_bitlen; + using search_plan_impl::hash_bitlen; - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; - using search_plan_impl::smem_size; + using search_plan_impl::smem_size; - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; uint32_t num_itopk_candidates; @@ -92,7 +93,8 @@ struct search : search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk) + : search_plan_impl( + res, params, dim, graph_degree, topk) { set_params(res); } @@ -111,7 +113,7 @@ struct search : search_plan_impl { RAFT_EXPECTS(itopk_size <= max_itopk, "itopk_size cannot be larger than %u", max_itopk); RAFT_LOG_DEBUG("# num_itopk_candidates: %u", num_itopk_candidates); - RAFT_LOG_DEBUG("# num_itopk: %u", itopk_size); + RAFT_LOG_DEBUG("# num_itopk: %lu", itopk_size); // // Determine the thread block size // @@ -234,7 +236,8 @@ struct search : search_plan_impl { const std::uint32_t num_queries, const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk) + uint32_t topk, + SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); select_and_run( @@ -261,6 +264,7 @@ struct search : search_plan_impl { search_width, min_iterations, max_iterations, + sample_filter, stream); } }; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index 5f5df1a818..35d239563a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -15,7 +15,9 @@ */ #pragma once +#include #include // RAFT_EXPLICIT + namespace raft::neighbors::cagra::detail { namespace single_cta_search { @@ -25,7 +27,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> void select_and_run( // raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, @@ -50,50 +53,65 @@ void select_and_run( // raft::resources const& res, size_t search_width, size_t min_iterations, size_t max_iterations, + SAMPLE_FILTER_T sample_filter, cudaStream_t stream) RAFT_EXPLICIT; #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - extern template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, float, uint32_t, float); -instantiate_single_cta_select_and_run(8, 128, float, uint32_t, float); -instantiate_single_cta_select_and_run(16, 256, float, uint32_t, float); -instantiate_single_cta_select_and_run(32, 512, float, uint32_t, float); -instantiate_single_cta_select_and_run(32, 1024, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(8, 128, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(16, 256, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(32, 512, int8_t, uint32_t, float); -instantiate_single_cta_select_and_run(32, 1024, uint8_t, uint32_t, float); -instantiate_single_cta_select_and_run(8, 128, uint8_t, uint32_t, float); -instantiate_single_cta_select_and_run(16, 256, uint8_t, uint32_t, float); -instantiate_single_cta_select_and_run(32, 512, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_select_and_run diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 81325fd5da..128dc8d116 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -78,7 +79,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag, if (new_parent) { const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; if (i < search_width) { - next_parent_indices[i] = index; + next_parent_indices[i] = jj; // set most significant bit as used node internal_topk_indices[jj] |= index_msb_1_mask; } @@ -458,7 +459,8 @@ template + class INDEX_T, + class SAMPLE_FILTER_T> __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] @@ -482,7 +484,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ std::uint32_t* const num_executed_iterations, // [num_queries] const std::uint32_t hash_bitlen, const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval) + const std::uint32_t small_hash_reset_interval, + SAMPLE_FILTER_T sample_filter) { using LOAD_T = device::LOAD_128BIT_T; const auto query_id = blockIdx.y; @@ -527,6 +530,9 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ auto terminate_flag = reinterpret_cast(topk_ws + 3); auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); + // A flag for filtering. + auto filter_flag = terminate_flag; + const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { unsigned j = device::swizzling(i); @@ -576,7 +582,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ std::uint32_t iter = 0; while (1) { // sort - if (TOPK_BY_BITONIC_SORT) { + if constexpr (TOPK_BY_BITONIC_SORT) { // [Notice] // It is good to use multiple warps in topk_by_bitonic_sort() when // batch size is small (short-latency), but it might not be always good @@ -614,17 +620,28 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ // topk with bitonic sort _CLK_START(); - topk_by_bitonic_sort( - result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - search_width * graph_degree, - topk_ws, - (iter == 0)); + if (std::is_same::value || + *filter_flag == 0) { + topk_by_bitonic_sort( + result_distances_buffer, + result_indices_buffer, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + search_width * graph_degree, + topk_ws, + (iter == 0)); + __syncthreads(); + } else { + topk_by_bitonic_sort_1st( + result_distances_buffer, + result_indices_buffer, + internal_topk + search_width * graph_degree, + internal_topk); + if (threadIdx.x == 0) { *terminate_flag = 0; } + } _CLK_REC(clk_topk); - } else { _CLK_START(); // topk with radix block sort @@ -693,12 +710,61 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ local_visited_hashmap_ptr, hash_bitlen, parent_list_buffer, + result_indices_buffer, search_width); __syncthreads(); _CLK_REC(clk_compute_distance); + // Filtering + if constexpr (!std::is_same::value) { + if (threadIdx.x == 0) { *filter_flag = 0; } + __syncthreads(); + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + if (parent_list_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; + if (!sample_filter(query_id, parent_id)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_list_buffer[p]] = invalid_index; + *filter_flag = 1; + } + } + } + __syncthreads(); + } + iter++; } + + // Post process for filtering + if constexpr (!std::is_same::value) { + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const INDEX_T invalid_index = utils::get_max_value(); + + for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; + i += blockDim.x) { + const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; + if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { + result_distances_buffer[i] = utils::get_max_value(); + result_indices_buffer[i] = invalid_index; + } + } + + __syncthreads(); + topk_by_bitonic_sort_1st( + result_distances_buffer, + result_indices_buffer, + internal_topk + search_width * graph_degree, + top_k); + __syncthreads(); + } + for (std::uint32_t i = threadIdx.x; i < top_k; i += BLOCK_SIZE) { unsigned j = i + (top_k * query_id); unsigned ii = i; @@ -737,9 +803,15 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ #endif } -template +template struct search_kernel_config { - using kernel_t = decltype(&search_kernel); + using kernel_t = + decltype(&search_kernel); template static auto choose_block_size(unsigned block_size) -> kernel_t @@ -747,24 +819,104 @@ struct search_kernel_config { constexpr unsigned BS = USE_BITONIC_SORT; if constexpr (BS) { if (block_size == 64) { - return search_kernel; + return search_kernel; } else if (block_size == 128) { - return search_kernel; + return search_kernel; } else if (block_size == 256) { - return search_kernel; + return search_kernel; } else if (block_size == 512) { - return search_kernel; + return search_kernel; } else { - return search_kernel; + return search_kernel; } } else { if (block_size == 256) { - return search_kernel; + return search_kernel; } else if (block_size == 512) { - return search_kernel; + return search_kernel; } else { - return search_kernel; + return search_kernel; } } } @@ -826,7 +978,8 @@ template + typename DISTANCE_T, + typename SAMPLE_FILTER_T> void select_and_run( // raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, @@ -851,16 +1004,18 @@ void select_and_run( // raft::resources const& res, size_t search_width, size_t min_iterations, size_t max_iterations, + SAMPLE_FILTER_T sample_filter, cudaStream_t stream) { - auto kernel = search_kernel_config:: - choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); + auto kernel = + search_kernel_config:: + choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); dim3 thread_dims(block_size, 1, 1); dim3 block_dims(1, num_queries, 1); RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %lu smem", block_size, num_queries, smem_size); + "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); kernel<<>>(topk_indices_ptr, topk_distances_ptr, topk, @@ -883,7 +1038,8 @@ void select_and_run( // raft::resources const& res, num_executed_iterations, hash_bitlen, small_hash_bitlen, - small_hash_reset_interval); + small_hash_reset_interval, + sample_filter); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace single_cta_search diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index a18ee065bf..6641346a67 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -342,7 +342,7 @@ void extend(raft::resources const& handle, /** @} */ /** - * @brief Search ANN using the constructed index. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * @@ -374,6 +374,8 @@ void extend(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -386,7 +388,7 @@ void extend(raft::resources const& handle, * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] * @param[in] mr an optional memory resource to use across the searches (you can provide a large * enough memory pool here to avoid memory allocations within search). - * @param[in] sample_filter a filter the greenlights samples for a given query + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template void search_with_filtering(raft::resources const& handle, @@ -475,7 +477,7 @@ void search(raft::resources const& handle, */ /** - * @brief Search ANN using the constructed index using the given filter. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * @@ -501,6 +503,8 @@ void search(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -509,7 +513,7 @@ void search(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a filter the greenlights samples for a given query + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template void search_with_filtering(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index ccf8717486..9f203d92fb 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -134,7 +134,7 @@ void extend(raft::resources const& handle, } /** - * @brief Search ANN using the constructed index using the given filter. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. * @@ -148,6 +148,8 @@ void extend(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -157,7 +159,7 @@ void extend(raft::resources const& handle, * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] - * @param[in] sample_filter a filter the greenlights samples for a given query. + * @param[in] sample_filter a device filter function that greenlights samples for a given query. */ template void search_with_filtering(raft::resources const& handle, @@ -343,7 +345,7 @@ void extend(raft::resources const& handle, } /** - * @brief Search ANN using the constructed index using the given filter. + * @brief Search ANN using the constructed index with the given filter. * * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. * @@ -372,6 +374,8 @@ void extend(raft::resources const& handle, * * @tparam T data element type * @tparam IdxT type of the indices + * @tparam IvfSampleFilterT Device filter function, with the signature + * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` * * @param[in] handle * @param[in] params configure the search @@ -382,7 +386,7 @@ void extend(raft::resources const& handle, * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a filter the greenlights samples for a given query + * @param[in] sample_filter a device filter function that greenlights samples for a given query */ template void search_with_filtering(raft::resources const& handle, diff --git a/cpp/include/raft/neighbors/sample_filter_types.hpp b/cpp/include/raft/neighbors/sample_filter_types.hpp index 5a301e9d2f..10c5e99372 100644 --- a/cpp/include/raft/neighbors/sample_filter_types.hpp +++ b/cpp/include/raft/neighbors/sample_filter_types.hpp @@ -37,6 +37,18 @@ struct none_ivf_sample_filter { } }; +/* A filter that filters nothing. This is the default behavior. */ +struct none_cagra_sample_filter { + inline _RAFT_HOST_DEVICE bool operator()( + // query index + const uint32_t query_ix, + // the index of the current sample + const uint32_t sample_ix) const + { + return true; + } +}; + /** * If the filtering depends on the index of a sample, then the following * filter template can be used: diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py index 784d116503..15eb0a9e65 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -39,41 +39,45 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \\ - template void select_and_run( \\ - raft::device_matrix_view dataset, \\ - raft::device_matrix_view graph, \\ - INDEX_T* const topk_indices_ptr, \\ - DISTANCE_T* const topk_distances_ptr, \\ - const DATA_T* const queries_ptr, \\ - const uint32_t num_queries, \\ - const INDEX_T* dev_seed_ptr, \\ - uint32_t* const num_executed_iterations, \\ - uint32_t topk, \\ - uint32_t block_size, \\ - uint32_t result_buffer_size, \\ - uint32_t smem_size, \\ - int64_t hash_bitlen, \\ - INDEX_T* hashmap_ptr, \\ - uint32_t num_cta_per_query, \\ - uint32_t num_random_samplings, \\ - uint64_t rand_xor_mask, \\ - uint32_t num_seeds, \\ - size_t itopk_size, \\ - size_t search_width, \\ - size_t min_iterations, \\ - size_t max_iterations, \\ - cudaStream_t stream); +#define instantiate_kernel_selection( \\ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \\ + template void \\ + select_and_run( \\ + raft::device_matrix_view dataset, \\ + raft::device_matrix_view graph, \\ + INDEX_T* const topk_indices_ptr, \\ + DISTANCE_T* const topk_distances_ptr, \\ + const DATA_T* const queries_ptr, \\ + const uint32_t num_queries, \\ + const INDEX_T* dev_seed_ptr, \\ + uint32_t* const num_executed_iterations, \\ + uint32_t topk, \\ + uint32_t block_size, \\ + uint32_t result_buffer_size, \\ + uint32_t smem_size, \\ + int64_t hash_bitlen, \\ + INDEX_T* hashmap_ptr, \\ + uint32_t num_cta_per_query, \\ + uint32_t num_random_samplings, \\ + uint64_t rand_xor_mask, \\ + uint32_t num_seeds, \\ + size_t itopk_size, \\ + size_t search_width, \\ + size_t min_iterations, \\ + size_t max_iterations, \\ + SAMPLE_FILTER_T sample_filter, \\ + cudaStream_t stream); """ trailer = """ #undef instantiate_kernel_selection -} // namespace raft::neighbors::cagra::detail::namespace multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search """ mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] @@ -97,7 +101,7 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_kernel_selection({team}, {mxdim}, {data_t}, {idx_t}, {distance_t});\n" + f"instantiate_kernel_selection(\n {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}, raft::neighbors::filtering::none_cagra_sample_filter);\n" ) f.write(trailer) # For pasting into CMakeLists.txt diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu index 2a4e7ac607..1a3b2284bd 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_widthhhhhhhhh, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, float, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu index 115ce3b48b..36e86d9ed6 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, float, uint32_t, float); +instantiate_kernel_selection( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu index c5e704a85f..6f1af2d93f 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, float, uint32_t, float); +instantiate_kernel_selection( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu index 3469facf39..1279f8e415 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, float, uint32_t, float); +instantiate_kernel_selection( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu index 327bfc73b4..0dabff0df5 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, float, uint64_t, float); +instantiate_kernel_selection( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu index 1abe0cd8af..72bb74cdb8 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, float, uint64_t, float); +instantiate_kernel_selection( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu index dd61810d06..dceea10b5d 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, float, uint64_t, float); +instantiate_kernel_selection( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu index 8e12bab514..acb8bd6a12 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, float, uint64_t, float); +instantiate_kernel_selection( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu index d946ac9c79..0254f09ff0 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, int8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu index e4d7b44d1e..2b67e7e968 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, int8_t, uint32_t, float); +instantiate_kernel_selection( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu index b8dc3b38a8..17d6722e58 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, int8_t, uint32_t, float); +instantiate_kernel_selection( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu index 749b35bad6..38f02812e2 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, int8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu index 428d460ba8..fa111196c6 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_widthh, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 1024, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu index 28a20b865e..1ef3c28aa3 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(8, 128, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu index e85a84ae8e..d26cb44843 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(16, 256, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu index 232b62ebcd..4d4322f261 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu @@ -25,36 +25,41 @@ */ #include +#include namespace raft::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection(32, 512, uint8_t, uint32_t, float); +instantiate_kernel_selection( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index cf61a45b4a..249555082e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -39,35 +39,38 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \\ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \\ - template void select_and_run( \\ - raft::device_matrix_view dataset, \\ - raft::device_matrix_view graph, \\ - INDEX_T* const topk_indices_ptr, \\ - DISTANCE_T* const topk_distances_ptr, \\ - const DATA_T* const queries_ptr, \\ - const uint32_t num_queries, \\ - const INDEX_T* dev_seed_ptr, \\ - uint32_t* const num_executed_iterations, \\ - uint32_t topk, \\ - uint32_t num_itopk_candidates, \\ - uint32_t block_size, \\ - uint32_t smem_size, \\ - int64_t hash_bitlen, \\ - INDEX_T* hashmap_ptr, \\ - size_t small_hash_bitlen, \\ - size_t small_hash_reset_interval, \\ - uint32_t num_random_samplings, \\ - uint64_t rand_xor_mask, \\ - uint32_t num_seeds, \\ - size_t itopk_size, \\ - size_t search_width, \\ - size_t min_iterations, \\ - size_t max_iterations, \\ +#define instantiate_single_cta_select_and_run( \\ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \\ + template void \\ + select_and_run( \\ + raft::device_matrix_view dataset, \\ + raft::device_matrix_view graph, \\ + INDEX_T* const topk_indices_ptr, \\ + DISTANCE_T* const topk_distances_ptr, \\ + const DATA_T* const queries_ptr, \\ + const uint32_t num_queries, \\ + const INDEX_T* dev_seed_ptr, \\ + uint32_t* const num_executed_iterations, \\ + uint32_t topk, \\ + uint32_t num_itopk_candidates, \\ + uint32_t block_size, \\ + uint32_t smem_size, \\ + int64_t hash_bitlen, \\ + INDEX_T* hashmap_ptr, \\ + size_t small_hash_bitlen, \\ + size_t small_hash_reset_interval, \\ + uint32_t num_random_samplings, \\ + uint64_t rand_xor_mask, \\ + uint32_t num_seeds, \\ + size_t itopk_size, \\ + size_t search_width, \\ + size_t min_iterations, \\ + size_t max_iterations, \\ + SAMPLE_FILTER_T sample_filter, \\ cudaStream_t stream); """ @@ -75,7 +78,7 @@ trailer = """ #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search """ mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] @@ -102,7 +105,7 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_single_cta_select_and_run({team}, {mxdim},{data_t}, {idx_t}, {distance_t});\n" + f"instantiate_single_cta_select_and_run(\n {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}, raft::neighbors::filtering::none_cagra_sample_filter);\n" ) f.write(trailer) diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu index eb45d4ff08..b8c23103ba 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu index 049715aa20..8ab1897119 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu index 6028c283db..9fd36b4cb9 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu index 2566e9cbd9..a9ee2c864b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, float, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu index 4cd96ad9c0..dadc574b65 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu index 822a2efb2f..30e043f47e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu index 80d1f76b9b..089e4c930f 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu index 06c3eaf10b..3e8ffb8bf8 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, float, uint64_t, float); +instantiate_single_cta_select_and_run( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu index b4c30ac943..279587738e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu index c8d0df3ac4..ef127d3f7d 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu index 19ecee91af..7fcfdcc28e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu index 52c4eb7d6b..a6c606d99b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, int8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu index 4675e17084..0b8be56614 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 1024, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu index e73e1071ee..4c193b9408 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(8, 128, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu index 01e26b5f29..bdf16d2f03 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(16, 256, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu index b0534b555f..93624df4aa 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu @@ -25,38 +25,42 @@ */ #include +#include namespace raft::neighbors::cagra::detail::single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run(32, 512, uint8_t, uint32_t, float); +instantiate_single_cta_select_and_run( + 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_single_cta_search_kernel diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index eadc88085f..90f271e3ee 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -15,6 +15,8 @@ */ #pragma once +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation + #include "../test_utils.cuh" #include "ann_utils.cuh" #include @@ -25,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -41,8 +44,22 @@ #include #include -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { namespace { + +/* A filter that excludes all indices below `offset`. */ +struct test_cagra_sample_filter { + static constexpr unsigned offset = 400; + inline _RAFT_HOST_DEVICE auto operator()( + // query index + const uint32_t query_ix, + // the index of the current sample inside the current inverted list + const uint32_t sample_ix) const + { + return sample_ix >= offset; + } +}; + // For sort_knn_graph test template void RandomSuffle(raft::host_matrix_view index) @@ -365,6 +382,162 @@ class AnnCagraSortTest : public ::testing::TestWithParam { rmm::device_uvector database; }; +template +class AnnCagraFilterTest : public ::testing::TestWithParam { + public: + AnnCagraFilterTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + protected: + void testCagraFilter() + { + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_Cagra(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_Cagra(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - test_cagra_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_cagra_sample_filter::offset), + queries_size, + stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + rmm::device_uvector distances_dev(queries_size, stream_); + rmm::device_uvector indices_dev(queries_size, stream_); + + { + cagra::index_params index_params; + index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is + // not used for knn_graph building. + cagra::search_params search_params; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + search_params.hashmap_mode = cagra::hash_mode::HASH; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + cagra::index index(handle_); + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + index = cagra::build(handle_, index_params, database_host_view); + } else { + index = cagra::build(handle_, index_params, database_view); + } + + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + + cagra::search_with_filtering(handle_, + search_params, + index, + search_queries_view, + indices_out_view, + dists_out_view, + test_cagra_sample_filter()); + update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); + update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + // Test filter + bool unacceptable_node = false; + for (int q = 0; q < ps.n_queries; q++) { + for (int i = 0; i < ps.k; i++) { + const auto n = indices_Cagra[q * ps.k + i]; + unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n); + } + } + EXPECT_FALSE(unacceptable_node); + + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_Cagra, + distances_naive, + distances_Cagra, + ps.n_queries, + ps.k, + 0.001, + min_recall)); + EXPECT_TRUE(eval_distances(handle_, + database.data(), + search_queries.data(), + indices_dev.data(), + distances_dev.data(), + ps.n_rows, + ps.dim, + ps.n_queries, + ps.k, + ps.metric, + 1.0e-4)); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + search_queries.resize(ps.n_queries * ps.dim, stream_); + raft::random::Rng r(1234ULL); + if constexpr (std::is_same{}) { + r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + } else { + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnCagraInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + inline std::vector generate_inputs() { // TODO(tfeher): test MULTI_CTA kernel with search_width > 1 to allow multiple CTA per queries @@ -467,4 +640,4 @@ inline std::vector generate_inputs() const std::vector inputs = generate_inputs(); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh index f61e476652..175e4ef483 100644 --- a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh +++ b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh @@ -1,93 +1,107 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // RAFT_EXPLICIT - -namespace raft::neighbors::cagra::detail { - -namespace multi_cta_search { -#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - extern template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - cudaStream_t stream); - -instantiate_kernel_selection(32, 1024, float, uint64_t, float); -instantiate_kernel_selection(8, 128, float, uint64_t, float); -instantiate_kernel_selection(16, 256, float, uint64_t, float); -instantiate_kernel_selection(32, 512, float, uint64_t, float); - -#undef instantiate_kernel_selection -} // namespace multi_cta_search - -namespace single_cta_search { - -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ - extern template void select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - cudaStream_t stream); - -instantiate_single_cta_select_and_run(32, 1024, float, uint64_t, float); -instantiate_single_cta_select_and_run(8, 128, float, uint64_t, float); -instantiate_single_cta_select_and_run(16, 256, float, uint64_t, float); -instantiate_single_cta_select_and_run(32, 512, float, uint64_t, float); - -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // none_cagra_sample_filter +#include // RAFT_EXPLICIT + +namespace raft::neighbors::cagra::detail { + +namespace multi_cta_search { +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +instantiate_kernel_selection( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); + +#undef instantiate_kernel_selection +} // namespace multi_cta_search + +namespace single_cta_search { + +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run( \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +instantiate_single_cta_select_and_run( + 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run( + 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace single_cta_search +} // namespace raft::neighbors::cagra::detail \ No newline at end of file diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu index fa3d76d066..6f9e8dbd43 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu @@ -19,11 +19,11 @@ #include "../ann_cagra.cuh" #include "search_kernel_uint64_t.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestF_I64; TEST_P(AnnCagraTestF_I64, AnnCagra) { this->testCagra(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_I64, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index dbaf4dedd9..01d7e1e1ea 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -18,7 +18,7 @@ #include "../ann_cagra.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestF_U32; TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } @@ -26,7 +26,11 @@ TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestF_U32; TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraFilterTest AnnCagraFilterTestF_U32; +TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) { this->testCagraFilter(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu index ba60131677..ee06d369fa 100644 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -18,14 +18,17 @@ #include "../ann_cagra.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestI8_U32; TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestI8_U32; TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraFilterTest AnnCagraFilterTestI8_U32; +TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) { this->testCagraFilter(); } INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestI8_U32, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu index cc172e4833..3243e73ccd 100644 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -18,7 +18,7 @@ #include "../ann_cagra.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { typedef AnnCagraTest AnnCagraTestU8_U32; TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } @@ -26,7 +26,11 @@ TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } typedef AnnCagraSortTest AnnCagraSortTestU8_U32; TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); } +typedef AnnCagraFilterTest AnnCagraFilterTestU8_U32; +TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) { this->testCagraFilter(); } + INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestU8_U32, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra