From 1fdccd4e5c46625e8f1f467052d7b2ab4b5556e7 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 26 Nov 2024 14:31:46 +0100 Subject: [PATCH] updating MG tests --- cpp/test/neighbors/mg.cuh | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/cpp/test/neighbors/mg.cuh b/cpp/test/neighbors/mg.cuh index 853dc8c0e..b4131acdb 100644 --- a/cpp/test/neighbors/mg.cuh +++ b/cpp/test/neighbors/mg.cuh @@ -46,10 +46,10 @@ template class AnnMGTest : public ::testing::TestWithParam { public: AnnMGTest() - : stream_(resource::get_cuda_stream(clique_)), + : clique_(), ps(::testing::TestWithParam::GetParam()), - d_index_dataset(0, stream_), - d_queries(0, stream_), + d_index_dataset(0, resource::get_cuda_stream(clique_)), + d_queries(0, resource::get_cuda_stream(clique_)), h_index_dataset(0), h_queries(0) { @@ -67,8 +67,9 @@ class AnnMGTest : public ::testing::TestWithParam { std::vector neighbors_snmg_ann_32bits(queries_size); { - rmm::device_uvector distances_ref_dev(queries_size, stream_); - rmm::device_uvector neighbors_ref_dev(queries_size, stream_); + rmm::device_uvector distances_ref_dev(queries_size, resource::get_cuda_stream(clique_)); + rmm::device_uvector neighbors_ref_dev(queries_size, + resource::get_cuda_stream(clique_)); cuvs::neighbors::naive_knn(clique_, distances_ref_dev.data(), neighbors_ref_dev.data(), @@ -79,8 +80,14 @@ class AnnMGTest : public ::testing::TestWithParam { ps.dim, ps.k, ps.metric); - update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_); - update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_); + update_host(distances_ref.data(), + distances_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); + update_host(neighbors_ref.data(), + neighbors_ref_dev.data(), + queries_size, + resource::get_cuda_stream(clique_)); resource::sync_stream(clique_); } @@ -602,8 +609,8 @@ class AnnMGTest : public ::testing::TestWithParam { void SetUp() override { - d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_); - d_queries.resize(ps.num_queries * ps.dim, stream_); + d_index_dataset.resize(ps.num_db_vecs * ps.dim, resource::get_cuda_stream(clique_)); + d_queries.resize(ps.num_queries * ps.dim, resource::get_cuda_stream(clique_)); h_index_dataset.resize(ps.num_db_vecs * ps.dim); h_queries.resize(ps.num_queries * ps.dim); @@ -630,7 +637,6 @@ class AnnMGTest : public ::testing::TestWithParam { void TearDown() override {} private: - rmm::cuda_stream_view stream_; raft::device_resources_snmg clique_; AnnMGInputs ps; std::vector h_index_dataset;