From 1e7ba4ffa05a28a12fbbfd2442995cc628ee8f8a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 12 Sep 2023 17:04:12 -0400 Subject: [PATCH] Cleaning up includes --- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 25 ++++++++----------- .../ann/src/raft/raft_ivf_flat_wrapper.h | 1 - cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 4 +-- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 02aa2ea28b..727a6ed830 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -23,13 +23,11 @@ #include #include #include -#include #include #include #include #include #include -#include #include #include #include @@ -107,19 +105,16 @@ class RaftCagra : public ANN { template void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t) { - switch (raft::spatial::knn::detail::utils::check_pointer_residency(dataset)) { - case raft::spatial::knn::detail::utils::pointer_residency::host_only: { - auto dataset_view = - raft::make_host_matrix_view(dataset, IdxT(nrow), dimension_); - index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); - return; - } - default: { - auto dataset_view = - raft::make_device_matrix_view(dataset, IdxT(nrow), dimension_); - index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); - return; - } + if (raft::get_device_for_address(dataset) == -1) { + auto dataset_view = + raft::make_host_matrix_view(dataset, IdxT(nrow), dimension_); + index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); + return; + } else { + auto dataset_view = + raft::make_device_matrix_view(dataset, IdxT(nrow), dimension_); + index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); + return; } } diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index da457e32f1..b6df7de068 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 8f1e43a706..1b74dcf975 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -25,7 +25,6 @@ #include #include #include -#include #include #include #include @@ -174,8 +173,7 @@ void RaftIvfPQ::search(const T* queries, raft::runtime::neighbors::ivf_pq::search( handle_, search_params_, *index_, queries_v, candidates.view(), distances_tmp.view()); - if (raft::spatial::knn::detail::utils::check_pointer_residency(dataset_.data_handle()) == - raft::spatial::knn::detail::utils::pointer_residency::device_only) { + if (raft::get_device_for_address(dataset_.data_handle()) >= 0) { auto queries_v = raft::make_device_matrix_view(queries, batch_size, index_->dim()); auto neighbors_v = raft::make_device_matrix_view((IdxT*)neighbors, batch_size, k);