Skip to content

Commit

Permalink
Adjust itopk size according to filtering rate
Browse files Browse the repository at this point in the history
  • Loading branch information
anaruse committed Dec 5, 2024
1 parent 8ff6991 commit aa7cfde
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 21 deletions.
7 changes: 7 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ struct search_params : cuvs::neighbors::search_params {
* impact on the throughput.
*/
float persistent_device_usage = 1.0;

/**
* A parameter indicating the rate of nodes to be filtered-out, when filtering is used.
* The value must be equal to or greater than 0.0 and less than 1.0. Default value is
* negative, in which case the filtering rate is automatically set.
*/
float filtering_rate = -1.0;
};

/**
Expand Down
19 changes: 15 additions & 4 deletions cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,13 @@ void search(raft::resources const& res,
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
{
try {
using none_filter_type = cuvs::neighbors::filtering::none_sample_filter;
auto& sample_filter = dynamic_cast<const none_filter_type&>(sample_filter_ref);
using none_filter_type = cuvs::neighbors::filtering::none_sample_filter;
auto& sample_filter = dynamic_cast<const none_filter_type&>(sample_filter_ref);
search_params params_copy = params;
if (params.filtering_rate < 0.0) { params_copy.filtering_rate = 0.0; }
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, none_filter_type>(
res, params, idx, queries, neighbors, distances, sample_filter_copy);
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
return;
} catch (const std::bad_cast&) {
}
Expand All @@ -349,9 +351,18 @@ void search(raft::resources const& res,
auto& sample_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>&>(
sample_filter_ref);
search_params params_copy = params;
if (params.filtering_rate < 0.0) {
const auto num_set_bits = sample_filter.bitset_view_.count(res);
auto filtering_rate = (float)(idx.data().n_rows() - num_set_bits) / idx.data().n_rows();
const float min_filtering_rate = 0.0;
const float max_filtering_rate = 0.999;
params_copy.filtering_rate =
std::min(std::max(filtering_rate, min_filtering_rate), max_filtering_rate);
}
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, decltype(sample_filter_copy)>(
res, params, idx, queries, neighbors, distances, sample_filter_copy);
res, params_copy, idx, queries, neighbors, distances, sample_filter_copy);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type");
}
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,12 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_

void set_params(raft::resources const& res, const search_params& params)
{
const size_t global_itopk_size = itopk_size;
constexpr unsigned multi_cta_itopk_size = 32;
this->itopk_size = multi_cta_itopk_size;
search_width = 1;
itopk_size = multi_cta_itopk_size;
num_cta_per_query =
max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)multi_cta_itopk_size));
max(params.search_width, raft::ceildiv(global_itopk_size, (size_t)multi_cta_itopk_size));
result_buffer_size = itopk_size + search_width * graph_degree;
typedef raft::Pow2<32> AlignBytes;
unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size);
Expand Down
44 changes: 29 additions & 15 deletions cpp/src/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,22 @@ struct search_plan_impl_base : public search_params {
int64_t dim;
int64_t graph_degree;
uint32_t topk;
search_plan_impl_base(search_params params, int64_t dim, int64_t dataset_size,
int64_t graph_degree, uint32_t topk)
: search_params(params), dim(dim), dataset_size(dataset_size), graph_degree(graph_degree), topk(topk)
search_plan_impl_base(
search_params params, int64_t dim, int64_t dataset_size, int64_t graph_degree, uint32_t topk)
: search_params(params),
dim(dim),
dataset_size(dataset_size),
graph_degree(graph_degree),
topk(topk)
{
if (algo == search_algo::AUTO) {
const size_t num_sm = raft::getMultiProcessorCount();
if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) {
algo = search_algo::SINGLE_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting single-cta");
} else if (topk <= 1024) {
} else {
algo = search_algo::MULTI_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta");
} else {
algo = search_algo::MULTI_KERNEL;
RAFT_LOG_DEBUG("Auto strategy: selecting multi kernel");
}
}
}
Expand All @@ -146,7 +147,6 @@ struct search_plan_impl : public search_plan_impl_base {
uint32_t result_buffer_size;

uint32_t smem_size;
uint32_t topk;
uint32_t num_seeds;

lightweight_uvector<INDEX_T> hashmap;
Expand Down Expand Up @@ -195,9 +195,9 @@ struct search_plan_impl : public search_plan_impl_base {
uint32_t _max_iterations = max_iterations;
if (max_iterations == 0) {
if (algo == search_algo::MULTI_CTA) {
constexpr uint32_t mc_itopk_size = 32;
constexpr uint32_t mc_itopk_size = 32;
constexpr uint32_t mc_search_width = 1;
_max_iterations = mc_itopk_size / mc_search_width;
_max_iterations = mc_itopk_size / mc_search_width;
} else {
_max_iterations = itopk_size / search_width;
}
Expand All @@ -213,6 +213,20 @@ struct search_plan_impl : public search_plan_impl_base {
"# max_iterations is increased from %lu to %u.", max_iterations, _max_iterations);
max_iterations = _max_iterations;
}
if (algo == search_algo::MULTI_CTA && (0.0 < filtering_rate && filtering_rate < 1.0)) {
size_t adjusted_itopk_size =
(size_t)((float)topk / (1.0 - filtering_rate) +
(float)(itopk_size - topk) / std::sqrt(1.0 - filtering_rate));
if (adjusted_itopk_size % 32) { adjusted_itopk_size += 32 - (adjusted_itopk_size % 32); }
if (itopk_size < adjusted_itopk_size) {
RAFT_LOG_DEBUG(
"# internal_topk is increased from %lu to %lu, considering fintering rate %f.",
itopk_size,
adjusted_itopk_size,
filtering_rate);
itopk_size = adjusted_itopk_size;
}
}
if (itopk_size % 32) {
uint32_t itopk32 = itopk_size;
itopk32 += 32 - (itopk_size % 32);
Expand Down Expand Up @@ -246,18 +260,20 @@ struct search_plan_impl : public search_plan_impl_base {
// shared among CTAs.
//
const uint32_t max_visited_nodes = mc_itopk_size + (graph_degree * max_iterations);
small_hash_bitlen = 11; // 2K
small_hash_bitlen = 11; // 2K
while (max_visited_nodes > hashmap::get_size(small_hash_bitlen) * max_fill_rate) {
small_hash_bitlen += 1;
}
RAFT_EXPECTS(small_hash_bitlen <= 14, "small_hash_bitlen cannot be largen than 14 (16K)");
small_hash_reset_interval = 1024 * 1024; // This is not used.
//
// [traversed_hash_table]
// Whether a node has ever been used as the starting point for a traversal
// in each iteration is managed in a separate hash table, which is shared
// among the CTAs.
//
const auto max_traversed_nodes = mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations);
const auto max_traversed_nodes =
mc_num_cta_per_query * max((size_t)mc_itopk_size, max_iterations);
unsigned min_bitlen = 11; // 2K
if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; }
hash_bitlen = min_bitlen;
Expand Down Expand Up @@ -322,9 +338,7 @@ struct search_plan_impl : public search_plan_impl_base {
while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) {
hash_bitlen += 1;
}
RAFT_EXPECTS(hash_bitlen <= 20,
"hash_bitlen cannot be largen than 20 (1M). You can decrease itopk_size, "
"search_width or max_iterations to reduce the required hashmap size.");
RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)");
}
}
RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size);
Expand Down

0 comments on commit aa7cfde

Please sign in to comment.