Skip to content

Commit

Permalink
updating MG tests
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Nov 26, 2024
1 parent 657bf9e commit 1fdccd4
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions cpp/test/neighbors/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ template <typename T, typename DataT>
class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
public:
AnnMGTest()
: stream_(resource::get_cuda_stream(clique_)),
: clique_(),
ps(::testing::TestWithParam<AnnMGInputs>::GetParam()),
d_index_dataset(0, stream_),
d_queries(0, stream_),
d_index_dataset(0, resource::get_cuda_stream(clique_)),
d_queries(0, resource::get_cuda_stream(clique_)),
h_index_dataset(0),
h_queries(0)
{
Expand All @@ -67,8 +67,9 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
std::vector<uint32_t> neighbors_snmg_ann_32bits(queries_size);

{
rmm::device_uvector<T> distances_ref_dev(queries_size, stream_);
rmm::device_uvector<int64_t> neighbors_ref_dev(queries_size, stream_);
rmm::device_uvector<T> distances_ref_dev(queries_size, resource::get_cuda_stream(clique_));
rmm::device_uvector<int64_t> neighbors_ref_dev(queries_size,
resource::get_cuda_stream(clique_));
cuvs::neighbors::naive_knn<T, DataT, int64_t>(clique_,
distances_ref_dev.data(),
neighbors_ref_dev.data(),
Expand All @@ -79,8 +80,14 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
ps.dim,
ps.k,
ps.metric);
update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_);
update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_);
update_host(distances_ref.data(),
distances_ref_dev.data(),
queries_size,
resource::get_cuda_stream(clique_));
update_host(neighbors_ref.data(),
neighbors_ref_dev.data(),
queries_size,
resource::get_cuda_stream(clique_));
resource::sync_stream(clique_);
}

Expand Down Expand Up @@ -602,8 +609,8 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {

void SetUp() override
{
d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_);
d_queries.resize(ps.num_queries * ps.dim, stream_);
d_index_dataset.resize(ps.num_db_vecs * ps.dim, resource::get_cuda_stream(clique_));
d_queries.resize(ps.num_queries * ps.dim, resource::get_cuda_stream(clique_));
h_index_dataset.resize(ps.num_db_vecs * ps.dim);
h_queries.resize(ps.num_queries * ps.dim);

Expand All @@ -630,7 +637,6 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
void TearDown() override {}

private:
rmm::cuda_stream_view stream_;
raft::device_resources_snmg clique_;
AnnMGInputs ps;
std::vector<DataT> h_index_dataset;
Expand Down

0 comments on commit 1fdccd4

Please sign in to comment.