diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index ff66d1efab..13efa277ed 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -32,6 +33,48 @@ namespace raft::neighbors::cagra::detail { +template +struct CagraSampleFilterWithQueryIdOffset { + const std::size_t offset; + CagraSampleFilterT filter; + + CagraSampleFilterWithQueryIdOffset(const std::size_t offset, const CagraSampleFilterT filter) + : offset(offset), filter(filter) + { + } + + _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id) + { + return filter(query_id + offset, sample_id); + } +}; + +template +struct CagraSampleFilterT_Selector { + using type = CagraSampleFilterWithQueryIdOffset; +}; +template <> +struct CagraSampleFilterT_Selector { + using type = raft::neighbors::filtering::none_cagra_sample_filter; +}; + +// A helper function to set a query id offset +template +inline typename CagraSampleFilterT_Selector::type set_offset( + CagraSampleFilterT filter, const uint32_t offset) +{ + typename CagraSampleFilterT_Selector::type new_filter(offset, filter); + return new_filter; +} +template <> +inline + typename CagraSampleFilterT_Selector::type + set_offset( + raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t) +{ + return filter; +} + /** * @brief Search ANN using the constructed index. * @@ -76,8 +119,9 @@ void search_main(raft::resources const& res, if (params.max_queries == 0) { params.max_queries = queries.extent(0); } - std::unique_ptr> plan = - factory::create( + using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; + std::unique_ptr> plan = + factory::create( res, params, index.dim(), index.graph_degree(), topk); plan->check(neighbors.extent(1)); @@ -119,7 +163,7 @@ void search_main(raft::resources const& res, _seed_ptr, _num_executed_iterations, topk, - sample_filter); + set_offset(sample_filter, qid)); } static_assert(std::is_same_v,