From 9aa1bb10566428077fb37737341eb6ae67ac45bc Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 19 Nov 2024 13:10:15 -0800 Subject: [PATCH] move the threshold-to-bf into search_params --- cpp/include/cuvs/neighbors/cagra.hpp | 24 ++++++----------- cpp/src/neighbors/cagra.cuh | 23 +++++----------- cpp/src/neighbors/cagra_search_float.cu | 23 ++++++++-------- cpp/src/neighbors/cagra_search_half.cu | 23 ++++++++-------- cpp/src/neighbors/cagra_search_int8.cu | 23 ++++++++-------- cpp/src/neighbors/cagra_search_uint8.cu | 23 ++++++++-------- .../neighbors/detail/cagra/cagra_search.cuh | 26 +++++-------------- cpp/test/neighbors/ann_cagra.cuh | 6 ++--- 8 files changed, 67 insertions(+), 104 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 8ff9664f6..2150f4214 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -229,6 +229,10 @@ struct search_params : cuvs::neighbors::search_params { * impact on the throughput. */ float persistent_device_usage = 1.0; + + /** A sparsity threshold; brute force is used when sparsity exceeds this threshold, in the range + * [0, 1] */ + double threshold_to_bf = 0.9; }; /** @@ -1057,8 +1061,6 @@ void extend( * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, @@ -1068,8 +1070,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1086,8 +1087,6 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1096,8 +1095,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1114,8 +1112,6 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1124,8 +1120,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. @@ -1142,8 +1137,6 @@ void search(raft::resources const& res, * k] * @param[in] sample_filter an optional device filter function object that greenlights samples * for a given query. (none_sample_filter for no filtering) - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] */ void search(raft::resources const& res, cuvs::neighbors::cagra::search_params const& params, @@ -1152,8 +1145,7 @@ void search(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances, const cuvs::neighbors::filtering::base_filter& sample_filter = - cuvs::neighbors::filtering::none_sample_filter{}, - double threshold_to_bf = 0.9f); + cuvs::neighbors::filtering::none_sample_filter{}); /** * @} diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index f0bad1431..dacfd6f63 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -293,9 +293,6 @@ index build( * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] - * */ template void search_with_filtering(raft::resources const& res, @@ -304,8 +301,7 @@ void search_with_filtering(raft::resources const& res, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT(), - double threshold_to_bf = 0.9) + CagraSampleFilterT sample_filter = CagraSampleFilterT()) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -326,14 +322,8 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - return cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - sample_filter, - threshold_to_bf); + return cagra::detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); } template @@ -343,15 +333,14 @@ void search(raft::resources const& res, raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - const cuvs::neighbors::filtering::base_filter& sample_filter_ref, - double threshold_to_bf = 0.9) + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) { try { using none_filter_type = cuvs::neighbors::filtering::none_sample_filter; auto& sample_filter = dynamic_cast(sample_filter_ref); auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf); + res, params, idx, queries, neighbors, distances, sample_filter_copy); return; } catch (const std::bad_cast&) { } @@ -362,7 +351,7 @@ void search(raft::resources const& res, sample_filter_ref); auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf); + res, params, idx, queries, neighbors, distances, sample_filter_copy); } catch (const std::bad_cast&) { RAFT_FAIL("Unsupported sample filter type"); } diff --git a/cpp/src/neighbors/cagra_search_float.cu b/cpp/src/neighbors/cagra_search_float.cu index d1d790121..3aca84f74 100644 --- a/cpp/src/neighbors/cagra_search_float.cu +++ b/cpp/src/neighbors/cagra_search_float.cu @@ -19,18 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(float, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_half.cu b/cpp/src/neighbors/cagra_search_half.cu index 5112e25dd..02be12731 100644 --- a/cpp/src/neighbors/cagra_search_half.cu +++ b/cpp/src/neighbors/cagra_search_half.cu @@ -19,18 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(half, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_int8.cu b/cpp/src/neighbors/cagra_search_int8.cu index a8bfaa7a7..3442ef55f 100644 --- a/cpp/src/neighbors/cagra_search_int8.cu +++ b/cpp/src/neighbors/cagra_search_int8.cu @@ -18,18 +18,17 @@ #include namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(int8_t, uint32_t); diff --git a/cpp/src/neighbors/cagra_search_uint8.cu b/cpp/src/neighbors/cagra_search_uint8.cu index 411ff9c79..08fe1861b 100644 --- a/cpp/src/neighbors/cagra_search_uint8.cu +++ b/cpp/src/neighbors/cagra_search_uint8.cu @@ -19,18 +19,17 @@ namespace cuvs::neighbors::cagra { -#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ - void search(raft::resources const& handle, \ - cuvs::neighbors::cagra::search_params const& params, \ - const cuvs::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const cuvs::neighbors::filtering::base_filter& sample_filter, \ - double threshold_to_bf) \ - { \ - cuvs::neighbors::cagra::search( \ - handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \ +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ } CUVS_INST_CAGRA_SEARCH(uint8_t, uint32_t); diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 555f91e75..ab8ee12cc 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -116,7 +116,7 @@ void search_main_core(raft::resources const& res, * * This function switches to a brute force search approach to improve recall rate when the * `sample_filter` function filters out a high proportion of samples, resulting in a sparsity level - * (proportion of unfiltered samples) exceeding the specified `threshold_to_bf`. + * (proportion of unfiltered samples) exceeding the specified `params.threshold_to_bf`. * * @tparam T data element type * @tparam IdxT type of database vector indices @@ -133,8 +133,6 @@ void search_main_core(raft::resources const& res, * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] * * @return true If the brute force search was applied successfully. * @return false If the brute force search was not applied. @@ -152,8 +150,7 @@ bool search_using_brute_force( raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT& sample_filter, - double threshold_to_bf = 0.9) + CagraSampleFilterT& sample_filter) { auto n_queries = queries.extent(0); auto n_dataset = strided_dataset.n_rows(); @@ -161,7 +158,7 @@ bool search_using_brute_force( auto bitset_filter_view = sample_filter.bitset_view_; auto sparsity = bitset_filter_view.sparsity(res); - if (sparsity < threshold_to_bf) { return false; } + if (sparsity < params.threshold_to_bf) { return false; } // TODO: Support host dataset in `brute_force::build` RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%f", sparsity); @@ -239,9 +236,6 @@ bool search_using_brute_force( * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] * @param[in] sample_filter a device filter function that greenlights samples for a given query - * @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this - * threshold, in the range [0, 1] - * */ template queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT(), - double threshold_to_bf = 0.9) + CagraSampleFilterT sample_filter = CagraSampleFilterT()) { auto stream = raft::resource::get_cuda_stream(res); const auto& graph = index.graph(); @@ -270,15 +263,8 @@ void search_main(raft::resources const& res, if constexpr (!std::is_same_v && (std::is_same_v || std::is_same_v)) { - bool bf_search_done = search_using_brute_force(res, - params, - *strided_dset, - index.metric(), - queries, - neighbors, - distances, - sample_filter, - threshold_to_bf); + bool bf_search_done = search_using_brute_force( + res, params, *strided_dset, index.metric(), queries, neighbors, distances, sample_filter); if (bf_search_done) return; } diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 5f56f24a6..fa3e8e855 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -792,6 +792,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.team_size = ps.threshold_to_bf; // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for // k>1024 skip these tests until fixed @@ -848,8 +849,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_queries_view, indices_out_view, dists_out_view, - bitset_filter_obj, - ps.threshold_to_bf); + bitset_filter_obj); raft::update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); raft::update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); @@ -1096,7 +1096,7 @@ inline std::vector generate_bf_inputs() {false}, {true}, {1.0}, - {0.1, 0.4, 0.91}); + {0.1, 0.4, 0.8}); for (auto input : inputs_original) { input.filter_offset = 0.5 * input.n_rows; input.min_recall = input.threshold_to_bf <= 0.5 ? 1.0 : 0.6;