Skip to content

Commit

Permalink
Fix ANN bench ground truth generation for k>1024
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Feb 13, 2024
1 parent c9574d7 commit dbfa7dd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions cpp/include/raft/neighbors/detail/knn_merge_parts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/error.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -168,5 +169,7 @@ inline void knn_merge_parts(const value_t* inK,
else if (k <= 1024)
knn_merge_parts_impl<value_idx, value_t, 1024, 8>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else
THROW("Unimplemented for k=%d, knn_merge_parts works for k<=1024", k);
}
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):
queries,
k,
metric=metric,
handle=handle,
global_id_offset=i, # shift neighbor index by offset i
handle=handle
)
handle.sync()

D, Ind = cp.asarray(D), cp.asarray(Ind)
Ind += i # shift neighbor index by offset i

if distances is None:
distances = D
indices = Ind
Expand Down

0 comments on commit dbfa7dd

Please sign in to comment.