Skip to content

Commit

Permalink
clique as device_resource
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Nov 25, 2024
1 parent 96e69fc commit 657bf9e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 19 deletions.
3 changes: 1 addition & 2 deletions cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu {

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
auto stream = raft::resource::get_cuda_stream(handle_);
auto stream = raft::resource::get_cuda_stream(clique_);
return stream;
}

Expand All @@ -86,7 +86,6 @@ class cuvs_mg_cagra : public algo<T>, public algo_gpu {
std::unique_ptr<algo<T>> copy() override;

private:
raft::device_resources handle_;
raft::device_resources_snmg clique_;
float refine_ratio_;
build_param index_params_;
Expand Down
3 changes: 1 addition & 2 deletions cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class cuvs_mg_ivf_flat : public algo<T>, public algo_gpu {

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
auto stream = raft::resource::get_cuda_stream(handle_);
auto stream = raft::resource::get_cuda_stream(clique_);
return stream;
}

Expand All @@ -73,7 +73,6 @@ class cuvs_mg_ivf_flat : public algo<T>, public algo_gpu {
std::unique_ptr<algo<T>> copy() override;

private:
raft::device_resources handle_;
raft::device_resources_snmg clique_;
build_param index_params_;
cuvs::neighbors::mg::search_params<ivf_flat::search_params> search_params_;
Expand Down
3 changes: 1 addition & 2 deletions cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class cuvs_mg_ivf_pq : public algo<T>, public algo_gpu {

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
auto stream = raft::resource::get_cuda_stream(handle_);
auto stream = raft::resource::get_cuda_stream(clique_);
return stream;
}

Expand All @@ -73,7 +73,6 @@ class cuvs_mg_ivf_pq : public algo<T>, public algo_gpu {
std::unique_ptr<algo<T>> copy() override;

private:
raft::device_resources handle_;
raft::device_resources_snmg clique_;
build_param index_params_;
cuvs::neighbors::mg::search_params<ivf_pq::search_params> search_params_;
Expand Down
24 changes: 11 additions & 13 deletions cpp/test/neighbors/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ template <typename T, typename DataT>
class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
public:
AnnMGTest()
: stream_(resource::get_cuda_stream(handle_)),
clique_(),
: stream_(resource::get_cuda_stream(clique_)),
ps(::testing::TestWithParam<AnnMGInputs>::GetParam()),
d_index_dataset(0, stream_),
d_queries(0, stream_),
Expand Down Expand Up @@ -82,7 +81,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
ps.metric);
update_host(distances_ref.data(), distances_ref_dev.data(), queries_size, stream_);
update_host(neighbors_ref.data(), neighbors_ref_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);
resource::sync_stream(clique_);
}

int64_t n_rows_per_search_batch = 3000; // [3000, 3000, 1000] == 7000 rows
Expand Down Expand Up @@ -132,7 +131,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
search_params.merge_mode = TREE_MERGE;
cuvs::neighbors::mg::search(
clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch);
resource::sync_stream(handle_);
resource::sync_stream(clique_);

double min_recall = static_cast<double>(ps.nprobe) / static_cast<double>(ps.nlist);
ASSERT_TRUE(eval_neighbours(neighbors_ref,
Expand Down Expand Up @@ -191,7 +190,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
search_params.merge_mode = TREE_MERGE;
cuvs::neighbors::mg::search(
clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch);
resource::sync_stream(handle_);
resource::sync_stream(clique_);

double min_recall = static_cast<double>(ps.nprobe) / static_cast<double>(ps.nlist);
ASSERT_TRUE(eval_neighbours(neighbors_ref,
Expand Down Expand Up @@ -244,7 +243,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
search_params.merge_mode = TREE_MERGE;
cuvs::neighbors::mg::search(
clique_, new_index, search_params, queries, neighbors, distances, n_rows_per_search_batch);
resource::sync_stream(handle_);
resource::sync_stream(clique_);

double min_recall = static_cast<double>(ps.nprobe) / static_cast<double>(ps.nlist);
ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits,
Expand Down Expand Up @@ -297,7 +296,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances,
n_rows_per_search_batch);

resource::sync_stream(handle_);
resource::sync_stream(clique_);

double min_recall = static_cast<double>(ps.nprobe) / static_cast<double>(ps.nlist);
ASSERT_TRUE(eval_neighbours(neighbors_ref,
Expand Down Expand Up @@ -349,7 +348,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances,
n_rows_per_search_batch);

resource::sync_stream(handle_);
resource::sync_stream(clique_);

double min_recall = static_cast<double>(ps.nprobe) / static_cast<double>(ps.nlist);
ASSERT_TRUE(eval_neighbours(neighbors_ref,
Expand Down Expand Up @@ -397,7 +396,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
distances,
n_rows_per_search_batch);

resource::sync_stream(handle_);
resource::sync_stream(clique_);

double min_recall = static_cast<double>(ps.nprobe) / static_cast<double>(ps.nlist);
ASSERT_TRUE(eval_neighbours(neighbors_ref_32bits,
Expand Down Expand Up @@ -622,16 +621,15 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
raft::copy(h_index_dataset.data(),
d_index_dataset.data(),
d_index_dataset.size(),
resource::get_cuda_stream(handle_));
resource::get_cuda_stream(clique_));
raft::copy(
h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(handle_));
resource::sync_stream(handle_);
h_queries.data(), d_queries.data(), d_queries.size(), resource::get_cuda_stream(clique_));
resource::sync_stream(clique_);
}

void TearDown() override {}

private:
raft::device_resources handle_;
rmm::cuda_stream_view stream_;
raft::device_resources_snmg clique_;
AnnMGInputs ps;
Expand Down

0 comments on commit 657bf9e

Please sign in to comment.