Skip to content

Commit

Permalink
move the threshold-to-bf into search_params
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Nov 19, 2024
1 parent d190b9d commit 9aa1bb1
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 104 deletions.
24 changes: 8 additions & 16 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ struct search_params : cuvs::neighbors::search_params {
* impact on the throughput.
*/
float persistent_device_usage = 1.0;

/** A sparsity threshold; brute force is used when sparsity exceeds this threshold, in the range
* [0, 1] */
double threshold_to_bf = 0.9;
};

/**
Expand Down Expand Up @@ -1057,8 +1061,6 @@ void extend(
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/

void search(raft::resources const& res,
Expand All @@ -1068,8 +1070,7 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
Expand All @@ -1086,8 +1087,6 @@ void search(raft::resources const& res,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
Expand All @@ -1096,8 +1095,7 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
Expand All @@ -1114,8 +1112,6 @@ void search(raft::resources const& res,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
Expand All @@ -1124,8 +1120,7 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @brief Search ANN using the constructed index.
Expand All @@ -1142,8 +1137,6 @@ void search(raft::resources const& res,
* k]
* @param[in] sample_filter an optional device filter function object that greenlights samples
* for a given query. (none_sample_filter for no filtering)
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
Expand All @@ -1152,8 +1145,7 @@ void search(raft::resources const& res,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{},
double threshold_to_bf = 0.9f);
cuvs::neighbors::filtering::none_sample_filter{});

/**
* @}
Expand Down
23 changes: 6 additions & 17 deletions cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,6 @@ index<T, IdxT> build(
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
* @param[in] threshold_to_bf A sparsity threshold; brute force is used when sparsity exceeds this
* threshold, in the range [0, 1]
*
*/
template <typename T, typename IdxT, typename CagraSampleFilterT>
void search_with_filtering(raft::resources const& res,
Expand All @@ -304,8 +301,7 @@ void search_with_filtering(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT(),
double threshold_to_bf = 0.9)
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
Expand All @@ -326,14 +322,8 @@ void search_with_filtering(raft::resources const& res,
auto distances_internal = raft::make_device_matrix_view<float, int64_t, raft::row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

return cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(res,
params,
idx,
queries_internal,
neighbors_internal,
distances_internal,
sample_filter,
threshold_to_bf);
return cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
}

template <typename T, typename IdxT>
Expand All @@ -343,15 +333,14 @@ void search(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter_ref,
double threshold_to_bf = 0.9)
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
{
try {
using none_filter_type = cuvs::neighbors::filtering::none_sample_filter;
auto& sample_filter = dynamic_cast<const none_filter_type&>(sample_filter_ref);
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, none_filter_type>(
res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf);
res, params, idx, queries, neighbors, distances, sample_filter_copy);
return;
} catch (const std::bad_cast&) {
}
Expand All @@ -362,7 +351,7 @@ void search(raft::resources const& res,
sample_filter_ref);
auto sample_filter_copy = sample_filter;
return search_with_filtering<T, IdxT, decltype(sample_filter_copy)>(
res, params, idx, queries, neighbors, distances, sample_filter_copy, threshold_to_bf);
res, params, idx, queries, neighbors, distances, sample_filter_copy);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type");
}
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/neighbors/cagra_search_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@

namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter, \
double threshold_to_bf) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \
#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter); \
}

CUVS_INST_CAGRA_SEARCH(float, uint32_t);
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/neighbors/cagra_search_half.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@

namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter, \
double threshold_to_bf) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \
#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter); \
}

CUVS_INST_CAGRA_SEARCH(half, uint32_t);
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/neighbors/cagra_search_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
#include <cuvs/neighbors/cagra.hpp>
namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter, \
double threshold_to_bf) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \
#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter); \
}

CUVS_INST_CAGRA_SEARCH(int8_t, uint32_t);
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/neighbors/cagra_search_uint8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@

namespace cuvs::neighbors::cagra {

#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter, \
double threshold_to_bf) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter, threshold_to_bf); \
#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
cuvs::neighbors::cagra::search_params const& params, \
const cuvs::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<float, int64_t, raft::row_major> distances, \
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
cuvs::neighbors::cagra::search<T, IdxT>( \
handle, params, index, queries, neighbors, distances, sample_filter); \
}

CUVS_INST_CAGRA_SEARCH(uint8_t, uint32_t);
Expand Down
Loading

0 comments on commit 9aa1bb1

Please sign in to comment.