Skip to content

Commit

Permalink
Cleaning up includes
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Sep 12, 2023
1 parent 28bee2b commit 1e7ba4f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 19 deletions.
25 changes: 10 additions & 15 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/operators.hpp>
#include <raft/distance/detail/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdexcept>
Expand Down Expand Up @@ -107,19 +105,16 @@ class RaftCagra : public ANN<T> {
template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
{
switch (raft::spatial::knn::detail::utils::check_pointer_residency(dataset)) {
case raft::spatial::knn::detail::utils::pointer_residency::host_only: {
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
}
default: {
auto dataset_view =
raft::make_device_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
}
if (raft::get_device_for_address(dataset) == -1) {
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
} else {
auto dataset_view =
raft::make_device_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
}
}

Expand Down
1 change: 0 additions & 1 deletion cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/detail/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/neighbors/ivf_flat.cuh>
Expand Down
4 changes: 1 addition & 3 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft_runtime/neighbors/ivf_pq.hpp>
#include <raft_runtime/neighbors/refine.hpp>
Expand Down Expand Up @@ -174,8 +173,7 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
raft::runtime::neighbors::ivf_pq::search(
handle_, search_params_, *index_, queries_v, candidates.view(), distances_tmp.view());

if (raft::spatial::knn::detail::utils::check_pointer_residency(dataset_.data_handle()) ==
raft::spatial::knn::detail::utils::pointer_residency::device_only) {
if (raft::get_device_for_address(dataset_.data_handle()) >= 0) {
auto queries_v =
raft::make_device_matrix_view<const T, IdxT>(queries, batch_size, index_->dim());
auto neighbors_v = raft::make_device_matrix_view<IdxT, IdxT>((IdxT*)neighbors, batch_size, k);
Expand Down

0 comments on commit 1e7ba4f

Please sign in to comment.