diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index a4684ce26..289c13362 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -229,6 +229,13 @@ struct search_params : cuvs::neighbors::search_params { * impact on the throughput. */ float persistent_device_usage = 1.0; + + /** + * A parameter indicating the rate of nodes to be filtered-out, when filtering is used. + * The value must be equal to or greater than 0.0 and less than 1.0. Default value is + * negative, in which case the filtering rate is automatically set. + */ + float filtering_rate = -1.0; }; /** diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index dacfd6f63..f294c9b44 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -336,11 +336,13 @@ void search(raft::resources const& res, const cuvs::neighbors::filtering::base_filter& sample_filter_ref) { try { - using none_filter_type = cuvs::neighbors::filtering::none_sample_filter; - auto& sample_filter = dynamic_cast(sample_filter_ref); + using none_filter_type = cuvs::neighbors::filtering::none_sample_filter; + auto& sample_filter = dynamic_cast(sample_filter_ref); + search_params params_copy = params; + if (params.filtering_rate < 0.0) { params_copy.filtering_rate = 0.0; } auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy); + res, params_copy, idx, queries, neighbors, distances, sample_filter_copy); return; } catch (const std::bad_cast&) { } @@ -349,9 +351,18 @@ void search(raft::resources const& res, auto& sample_filter = dynamic_cast&>( sample_filter_ref); + search_params params_copy = params; + if (params.filtering_rate < 0.0) { + const auto num_set_bits = sample_filter.bitset_view_.count(res); + auto filtering_rate = (float)(idx.data().n_rows() - num_set_bits) / idx.data().n_rows(); + const float min_filtering_rate = 0.0; + const float max_filtering_rate = 0.999; + params_copy.filtering_rate = + std::min(std::max(filtering_rate, min_filtering_rate), max_filtering_rate); + } auto sample_filter_copy = sample_filter; return search_with_filtering( - res, params, idx, queries, neighbors, distances, sample_filter_copy); + res, params_copy, idx, queries, neighbors, distances, sample_filter_copy); } catch (const std::bad_cast&) { RAFT_FAIL("Unsupported sample filter type"); } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 8d425ca67..a50047e74 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -115,11 +115,12 @@ struct search : public search_plan_implitopk_size = multi_cta_itopk_size; search_width = 1; + itopk_size = multi_cta_itopk_size; num_cta_per_query = - max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size)); + max(params.search_width, raft::ceildiv(global_itopk_size, (size_t)multi_cta_itopk_size)); result_buffer_size = itopk_size + search_width * graph_degree; typedef raft::Pow2<32> AlignBytes; unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 2bbf3d56a..e3bb6723b 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -112,21 +112,22 @@ struct search_plan_impl_base : public search_params { int64_t dim; int64_t graph_degree; uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t dataset_size, - int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), dataset_size(dataset_size), graph_degree(graph_degree), topk(topk) + search_plan_impl_base( + search_params params, int64_t dim, int64_t dataset_size, int64_t graph_degree, uint32_t topk) + : search_params(params), + dim(dim), + dataset_size(dataset_size), + graph_degree(graph_degree), + topk(topk) { if (algo == search_algo::AUTO) { const size_t num_sm = raft::getMultiProcessorCount(); if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { algo = search_algo::SINGLE_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); - } else if (topk <= 1024) { + } else { algo = search_algo::MULTI_CTA; RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); - } else { - algo = search_algo::MULTI_KERNEL; - RAFT_LOG_DEBUG("Auto strategy: selecting multi kernel"); } } } @@ -146,7 +147,6 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t result_buffer_size; uint32_t smem_size; - uint32_t topk; uint32_t num_seeds; lightweight_uvector hashmap; @@ -195,9 +195,9 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t _max_iterations = max_iterations; if (max_iterations == 0) { if (algo == search_algo::MULTI_CTA) { - constexpr uint32_t mc_itopk_size = 32; + constexpr uint32_t mc_itopk_size = 32; constexpr uint32_t mc_search_width = 1; - _max_iterations = mc_itopk_size / mc_search_width; + _max_iterations = mc_itopk_size / mc_search_width; } else { _max_iterations = itopk_size / search_width; } @@ -213,6 +213,20 @@ struct search_plan_impl : public search_plan_impl_base { "# max_iterations is increased from %lu to %u.", max_iterations, _max_iterations); max_iterations = _max_iterations; } + if (algo == search_algo::MULTI_CTA && (0.0 < filtering_rate && filtering_rate < 1.0)) { + size_t adjusted_itopk_size = + (size_t)((float)topk / (1.0 - filtering_rate) + + (float)(itopk_size - topk) / std::sqrt(1.0 - filtering_rate)); + if (adjusted_itopk_size % 32) { adjusted_itopk_size += 32 - (adjusted_itopk_size % 32); } + if (itopk_size < adjusted_itopk_size) { + RAFT_LOG_DEBUG( + "# internal_topk is increased from %lu to %lu, considering fintering rate %f.", + itopk_size, + adjusted_itopk_size, + filtering_rate); + itopk_size = adjusted_itopk_size; + } + } if (itopk_size % 32) { uint32_t itopk32 = itopk_size; itopk32 += 32 - (itopk_size % 32); @@ -246,18 +260,20 @@ struct search_plan_impl : public search_plan_impl_base { // shared among CTAs. // const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations); - small_hash_bitlen = 11; // 2K + small_hash_bitlen = 11; // 2K while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) { small_hash_bitlen += 1; } RAFT_EXPECTS(small_hash_bitlen <= 14, "small_hash_bitlen cannot be largen than 14 (16K)"); + small_hash_reset_interval = 1024 * 1024; // This is not used. // // [traversed_hash_table] // Whether a node has ever been used as the starting point for a traversal // in each iteration is managed in a separate hash table, which is shared // among the CTAs. // - const auto max_traversed_nodes = mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); + const auto max_traversed_nodes = + mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations); unsigned min_bitlen = 11; // 2K if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } hash_bitlen = min_bitlen; @@ -322,9 +338,7 @@ struct search_plan_impl : public search_plan_impl_base { while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { hash_bitlen += 1; } - RAFT_EXPECTS(hash_bitlen <= 20, - "hash_bitlen cannot be largen than 20 (1M). You can decrease itopk_size, " - "search_width or max_iterations to reduce the required hashmap size."); + RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)"); } } RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size);