Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InnerProduct Distance Metric for CAGRA search #2260

Merged
merged 34 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b6d9980
apply updates to 24.06
tarang-jain Apr 5, 2024
fac65b8
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 5, 2024
ad48b03
remove build errors
tarang-jain Apr 5, 2024
2b8d898
search inputs
tarang-jain Apr 5, 2024
e44ab17
inner product in compute_distance_vpq.cuh
tarang-jain Apr 9, 2024
c506216
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 9, 2024
cb7fbba
inner product in index build; debug statements
tarang-jain Apr 10, 2024
0029bba
tests passing
tarang-jain Apr 10, 2024
293bc8f
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 10, 2024
c91d895
style
tarang-jain Apr 10, 2024
7a5f876
update testing
tarang-jain Apr 10, 2024
c37aa27
rm log statements
tarang-jain Apr 10, 2024
890372b
pass CagraSort
tarang-jain Apr 10, 2024
810ddd1
tests passing
tarang-jain Apr 11, 2024
f92e68b
remove dbg statements
tarang-jain Apr 11, 2024
5bbbc70
update docs
tarang-jain Apr 11, 2024
7febd73
metric assertions
tarang-jain Apr 11, 2024
7e19937
add metric as const arg
tarang-jain Apr 12, 2024
e40b967
make metric template:
tarang-jain Apr 12, 2024
b102393
clean up metric template
tarang-jain Apr 15, 2024
6da7d55
update assertion
tarang-jain Apr 15, 2024
1e666e3
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 17, 2024
c1f4dcd
metric runtime dispatch
tarang-jain Apr 17, 2024
2e9e3fe
Merge branch 'cagra-dists' of https://github.com/tarang-jain/raft int…
tarang-jain Apr 17, 2024
d8a4b39
address all PR reviews
tarang-jain Apr 18, 2024
c5cd0e7
update docs, passing gtests
tarang-jain Apr 19, 2024
eabb031
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 19, 2024
77ff0c2
add ivf_pq::index_params helper
tarang-jain Apr 23, 2024
b61c1c3
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 23, 2024
54061d0
tracking issue; styling
tarang-jain Apr 23, 2024
dda1810
make helper static
tarang-jain Apr 24, 2024
743d5bc
update from_dataset verbiage
tarang-jain Apr 24, 2024
5791743
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 24, 2024
f5f30b7
Merge branch 'branch-24.06' into cagra-dists
cjnolet Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

Expand All @@ -48,6 +49,7 @@ namespace raft::neighbors::cagra {
*
* The following distance metrics are supported:
* - L2Expanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
Expand All @@ -57,7 +59,8 @@ namespace raft::neighbors::cagra {
* ivf_pq::search_params search_params
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
* cagra::build_knn_graph(res, dataset, knn_graph.view(),
* raft::distance::DistanceType::L2Expanded, 2, build_params, search_params);
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
Expand All @@ -71,6 +74,7 @@ namespace raft::neighbors::cagra {
* @param[in] res raft resources
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
* @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree]
* @param[in] metric distance metric (default = raft::distance::DistanceType::L2Expanded)
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] refine_rate (optional) refinement rate for ivf-pq search
* @param[in] build_params (optional) ivf_pq index building parameters for knn graph
* @param[in] search_params (optional) ivf_pq search parameters
Expand All @@ -79,7 +83,8 @@ template <typename DataT, typename IdxT, typename accessor>
void build_knn_graph(raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph,
std::optional<float> refine_rate = std::nullopt,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
std::optional<float> refine_rate = std::nullopt,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
Expand All @@ -93,7 +98,7 @@ void build_knn_graph(raft::resources const& res,
dataset.data_handle(), dataset.extent(0), dataset.extent(1));

cagra::detail::build_knn_graph(
res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params);
res, dataset_internal, knn_graph_internal, metric, refine_rate, build_params, search_params);
}

/**
Expand Down
26 changes: 22 additions & 4 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
Expand All @@ -44,12 +45,17 @@ template <typename DataT, typename IdxT, typename accessor>
void build_knn_graph(raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph,
std::optional<float> refine_rate = std::nullopt,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
std::optional<float> refine_rate = std::nullopt,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");
std::cout << "metric from build_knn_graph" << metric << std::endl;
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
RAFT_EXPECTS(!build_params || build_params->metric == metric,
"Mismatch between index metric and IVF-PQ metric");
RAFT_EXPECTS(
metric == distance::DistanceType::L2Expanded || metric == distance::DistanceType::InnerProduct,
"Currently only L2Expanded metric is supported");
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved

uint32_t node_degree = knn_graph.extent(1);
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::build_graph(%zu, %zu, %u)",
Expand All @@ -65,6 +71,7 @@ void build_knn_graph(raft::resources const& res,
build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10;
build_params->kmeans_n_iters = 25;
build_params->add_data_on_build = true;
build_params->metric = metric;
}

// Make model name
Expand Down Expand Up @@ -300,6 +307,8 @@ index<T, IdxT> build(
std::optional<ivf_pq::search_params> search_params = std::nullopt,
bool construct_index_with_dataset = true)
{
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
if (intermediate_degree >= static_cast<size_t>(dataset.extent(0))) {
Expand All @@ -321,9 +330,18 @@ index<T, IdxT> build(
raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), intermediate_degree));

if (params.build_algo == graph_build_algo::IVF_PQ) {
build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params);
RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded ||
params.metric == raft::distance::DistanceType::InnerProduct,
"L2Expanded and InnerProduct are the only distance metrics supported with IVF-PQ");
build_knn_graph(
res, dataset, knn_graph->view(), params.metric, refine_rate, pq_build_params, search_params);

} else {
RAFT_EXPECTS(
params.metric != raft::distance::DistanceType::InnerProduct,
"InnerProduct distance metric supported with nn_descent. Use IVF_PQ as the build_algo");
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded,
"L2Expanded is the only distance metrics supported with nn_descent");
// Use nn-descent to build CAGRA knn graph
if (!nn_descent_params) {
nn_descent_params = experimental::nn_descent::index_params();
Expand Down
35 changes: 26 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
Expand Down Expand Up @@ -87,7 +88,8 @@ void search_main_core(
raft::device_matrix_view<const typename DatasetDescriptorT::DATA_T, int64_t, row_major> queries,
raft::device_matrix_view<typename DatasetDescriptorT::INDEX_T, int64_t, row_major> neighbors,
raft::device_matrix_view<typename DatasetDescriptorT::DISTANCE_T, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
CagraSampleFilterT sample_filter = CagraSampleFilterT(),
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.data().n_rows()),
Expand All @@ -112,7 +114,7 @@ void search_main_core(
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DatasetDescriptorT, CagraSampleFilterT_s>> plan =
factory<DatasetDescriptorT, CagraSampleFilterT_s>::create(
res, params, dataset_desc.dim, graph.extent(1), topk);
res, params, dataset_desc.dim, graph.extent(1), topk, metric);

plan->check(topk);

Expand Down Expand Up @@ -163,7 +165,8 @@ void launch_vpq_search_main_core(
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<InternalIdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter)
CagraSampleFilterT sample_filter,
raft::distance::DistanceType metric)
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
Expand Down Expand Up @@ -192,7 +195,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else if (vpq_dset->pq_len() == 4) {
using dataset_desc_t = cagra_q_dataset_descriptor_t<T,
DatasetT,
Expand All @@ -210,7 +213,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else {
RAFT_FAIL("Subspace dimension must be 2 or 4");
}
Expand Down Expand Up @@ -268,17 +271,31 @@ void search_main(raft::resources const& res,
strided_dset->n_rows(),
strided_dset->dim(),
strided_dset->stride());

search_main_core<dataset_desc_t, CagraSampleFilterT>(
res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter);
search_main_core<dataset_desc_t, CagraSampleFilterT>(res,
params,
dataset_desc,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<float, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
// Search using a compressed dataset
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
launch_vpq_search_main_core<T, half, ds_idx_type, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, vpq_dset, params, graph_internal, queries, neighbors, distances, sample_filter);
res,
vpq_dset,
params,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* empty_dset = dynamic_cast<const empty_dataset<ds_idx_type>*>(&index.data());
empty_dset != nullptr) {
// Forgot to add a dataset.
Expand Down
28 changes: 18 additions & 10 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const uint32_t num_seeds,
INDEX_T* const visited_hash_ptr,
const uint32_t hash_bitlen,
raft::distance::DistanceType metric,
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand All @@ -79,7 +80,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
}

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, seed_index, valid_i);
query_buffer, seed_index, valid_i, metric);

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
Expand Down Expand Up @@ -121,7 +122,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
const std::uint32_t search_width,
const raft::distance::DistanceType metric)
{
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
Expand Down Expand Up @@ -154,7 +156,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
if (valid_i) { child_id = result_child_indices_ptr[i]; }

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, child_id, child_id != invalid_index);
query_buffer, child_id, child_id != invalid_index, metric);

// Store the distance
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand Down Expand Up @@ -223,7 +225,8 @@ struct standard_dataset_descriptor_t
template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
const bool valid,
raft::distance::DistanceType metric) const
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
{
const auto dataset_ptr = ptr + dataset_i * ld;
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand All @@ -232,7 +235,7 @@ struct standard_dataset_descriptor_t
constexpr unsigned reg_nelem = raft::ceildiv<unsigned>(DATASET_BLOCK_DIM, TEAM_SIZE * vlen);
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T norm2 = 0;
DISTANCE_T dist = 0;
if (valid) {
for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) {
#pragma unroll
Expand All @@ -252,17 +255,22 @@ struct standard_dataset_descriptor_t
// because:
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T diff = query_ptr[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
norm2 += diff * diff;
DISTANCE_T d = query_ptr[device::swizzling(kv)];
if (metric == raft::distance::L2Expanded) {
tfeher marked this conversation as resolved.
Show resolved Hide resolved
d -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
dist += d * d;
} else {
d *= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
dist -= d;
}
}
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
norm2 += __shfl_xor_sync(0xffffffff, norm2, offset);
dist += __shfl_xor_sync(0xffffffff, dist, offset);
}
return norm2;
return dist;
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "compute_distance.hpp"

#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>

namespace raft::neighbors::cagra::detail {
Expand Down Expand Up @@ -115,7 +116,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T node_id,
const bool valid) const
const bool valid,
raft::distance::DistanceType metric) const
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
{
float norm = 0;
if (valid) {
Expand Down Expand Up @@ -227,4 +229,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
};

} // namespace raft::neighbors::cagra::detail
} // namespace raft::neighbors::cagra::detail
11 changes: 6 additions & 5 deletions cpp/include/raft/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class factory {
search_params const& params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
uint32_t topk,
raft::distance::DistanceType metric)
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, graph_degree, topk, metric);
switch (plan.dataset_block_dim) {
case 128:
switch (plan.team_size) {
Expand Down Expand Up @@ -77,17 +78,17 @@ class factory {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new single_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_kernel_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
}
}
};
Expand Down
10 changes: 8 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible
#include "utils.hpp"

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_properties.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/map.cuh>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cuda_rt_essentials.hpp>
#include <raft/util/cudart_utils.hpp> // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp
Expand Down Expand Up @@ -96,8 +99,10 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
search_params params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
: search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>(res, params, dim, graph_degree, topk),
uint32_t topk,
raft::distance::DistanceType metric)
: search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>(
res, params, dim, graph_degree, topk, metric),
intermediate_indices(0, resource::get_cuda_stream(res)),
intermediate_distances(0, resource::get_cuda_stream(res)),
topk_workspace(0, resource::get_cuda_stream(res))
Expand Down Expand Up @@ -235,6 +240,7 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

Expand Down
Loading
Loading