diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 0736abe637..fe9132b223 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -302,6 +302,8 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu @@ -592,8 +594,6 @@ if(RAFT_COMPILE_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - - foreach(target raft_lib raft_lib_static raft_objs) target_link_libraries( ${target} diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 1c30da41e2..da8eab76b5 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) #include "raft_cagra_wrapper.h" #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA @@ -70,11 +70,11 @@ extern template class raft::bench::ann::RaftAnnMG_Cagra; #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT) template void parse_build_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) - #else +#else typename raft::bench::ann::RaftAnnMG_IvfFlat::BuildParam& param) - #endif +#endif { param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -83,26 +83,26 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) - #else +#else typename raft::bench::ann::RaftAnnMG_IvfFlat::SearchParam& param) - #endif +#endif { param.ivf_flat_params.n_probes = conf.at("nprobe"); } #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam& param) - #else +#else typename raft::bench::ann::RaftIvfPQ::BuildParam& param) - #endif +#endif { if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -124,11 +124,11 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ typename raft::bench::ann::RaftAnnMG_IvfPq::SearchParam& param) - #else +#else typename raft::bench::ann::RaftIvfPQ::SearchParam& param) - #endif +#endif { if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } if (conf.contains("internalDistanceDtype")) { @@ -170,7 +170,7 @@ void parse_search_param(const nlohmann::json& conf, #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, raft::neighbors::experimental::nn_descent::index_params& param) @@ -217,11 +217,11 @@ nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, template void parse_build_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam& param) - #else +#else typename raft::bench::ann::RaftCagra::BuildParam& param) - #endif +#endif { if (conf.contains("graph_degree")) { param.cagra_params.graph_degree = conf.at("graph_degree"); @@ -285,11 +285,11 @@ raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) template void parse_search_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA - typename raft::bench::ann::RaftAnnMG_Cagra::SearchParam& param) - #else - typename raft::bench::ann::RaftCagra::SearchParam& param) - #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + typename raft::bench::ann::RaftAnnMG_Cagra::SearchParam& param) +#else + typename raft::bench::ann::RaftCagra::SearchParam& param) +#endif { if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } @@ -309,14 +309,14 @@ void parse_search_param(const nlohmann::json& conf, } } - #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) if (conf.contains("graph_memory_type")) { param.graph_mem = parse_allocator(conf.at("graph_memory_type")); } if (conf.contains("internal_dataset_memory_type")) { param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); } - #endif +#endif // Same ratio as in IVF-PQ param.refine_ratio = conf.value("refine_ratio", 1.0f); } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu index 84699e052b..0243529a67 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu @@ -14,15 +14,15 @@ * limitations under the License. */ - #include "raft_ann_mg_cagra_wrapper.hpp" -#include + #include +#include namespace raft::bench::ann { - template class RaftAnnMG_Cagra; - template class RaftAnnMG_Cagra; - template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp index cab2605b4d..ef3ece2839 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp @@ -17,6 +17,7 @@ #pragma once #include "raft_ann_mg_wrapper.hpp" + #include #include #include @@ -39,26 +40,26 @@ class RaftAnnMG_Cagra : public RaftAnnMG { struct BuildParam { raft::neighbors::cagra::mg_index_params cagra_params; - std::optional nn_descent_params = std::nullopt; + std::optional nn_descent_params = + std::nullopt; std::optional ivf_pq_refine_rate = std::nullopt; std::optional ivf_pq_build_params = std::nullopt; std::optional ivf_pq_search_params = std::nullopt; }; RaftAnnMG_Cagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) - : RaftAnnMG(metric, dim), - index_params_(param), - dimension_(dim) + : RaftAnnMG(metric, dim), index_params_(param), dimension_(dim) { - index_params_.cagra_params.add_data_on_build = true; - index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED; - index_params_.cagra_params.metric = parse_metric_type(metric); - index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + index_params_.cagra_params.add_data_on_build = true; + index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED; + index_params_.cagra_params.metric = parse_metric_type(metric); + index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); } void build(const T* dataset, size_t nrow) final; void set_search_param(const AnnSearchParam& param) override; - void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; void save(const std::string& file) const override; void load(const std::string&) override; std::unique_ptr> copy() override; @@ -66,7 +67,9 @@ class RaftAnnMG_Cagra : public RaftAnnMG { private: BuildParam index_params_; raft::neighbors::cagra::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; float refine_ratio_ = 1.0; int dimension_; }; @@ -74,10 +77,14 @@ class RaftAnnMG_Cagra : public RaftAnnMG { template void RaftAnnMG_Cagra::build(const T* dataset, size_t nrow) { - const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); - auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_.cagra_params, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = raft::neighbors::mg::build( + handle, *this->clique_, index_params_.cagra_params, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); return; } @@ -103,7 +110,9 @@ void RaftAnnMG_Cagra::load(const std::string& file) { const auto& handle = this->clique_->set_current_device_to_root_rank(); auto idx = raft::neighbors::mg::deserialize_cagra(handle, *this->clique_, file); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); } template @@ -113,14 +122,24 @@ std::unique_ptr> RaftAnnMG_Cagra::copy() } template -void RaftAnnMG_Cagra::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +void RaftAnnMG_Cagra::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); - auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); - auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - - raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); resource::sync_stream(handle); return; } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu index 4f9d04072c..a11460dea0 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu @@ -14,15 +14,15 @@ * limitations under the License. */ - #include "raft_ann_mg_ivf_flat_wrapper.hpp" -#include + #include +#include namespace raft::bench::ann { - template class RaftAnnMG_IvfFlat; - template class RaftAnnMG_IvfFlat; - template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp index 5308f07161..02e55423be 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp @@ -17,6 +17,7 @@ #pragma once #include "raft_ann_mg_wrapper.hpp" + #include #include @@ -43,7 +44,8 @@ class RaftAnnMG_IvfFlat : public RaftAnnMG { void build(const T* dataset, size_t nrow) final; void set_search_param(const AnnSearchParam& param) override; - void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; void save(const std::string& file) const override; void load(const std::string&) override; std::unique_ptr> copy() override; @@ -51,16 +53,22 @@ class RaftAnnMG_IvfFlat : public RaftAnnMG { private: BuildParam index_params_; raft::neighbors::ivf_flat::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; }; template void RaftAnnMG_IvfFlat::build(const T* dataset, size_t nrow) { - const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); - auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = + raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); return; } @@ -84,7 +92,9 @@ template void RaftAnnMG_IvfFlat::load(const std::string& file) { const auto& handle = this->clique_->set_current_device_to_root_rank(); - index_ = std::make_shared, T, IdxT>>(std::move(raft::neighbors::mg::deserialize_flat(handle, *this->clique_, file))); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(raft::neighbors::mg::deserialize_flat(handle, *this->clique_, file))); } template @@ -94,17 +104,27 @@ std::unique_ptr> RaftAnnMG_IvfFlat::copy() } template -void RaftAnnMG_IvfFlat::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +void RaftAnnMG_IvfFlat::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); - auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); - auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - - raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); resource::sync_stream(handle); return; } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu index 84fdc0da91..b0bf8a9d84 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu @@ -14,15 +14,15 @@ * limitations under the License. */ - #include "raft_ann_mg_ivf_pq_wrapper.hpp" -#include + #include +#include namespace raft::bench::ann { - template class RaftAnnMG_IvfPq; - template class RaftAnnMG_IvfPq; - template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp index 8e7049d971..8c0867462d 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp @@ -17,6 +17,7 @@ #pragma once #include "raft_ann_mg_wrapper.hpp" + #include #include @@ -45,7 +46,8 @@ class RaftAnnMG_IvfPq : public RaftAnnMG { void build(const T* dataset, size_t nrow) final; void set_search_param(const AnnSearchParam& param) override; - void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; void save(const std::string& file) const override; void load(const std::string&) override; std::unique_ptr> copy() override; @@ -53,17 +55,23 @@ class RaftAnnMG_IvfPq : public RaftAnnMG { private: BuildParam index_params_; raft::neighbors::ivf_pq::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; float refine_ratio_ = 1.0; }; template void RaftAnnMG_IvfPq::build(const T* dataset, size_t nrow) { - const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); - auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = + raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); return; } @@ -88,7 +96,9 @@ template void RaftAnnMG_IvfPq::load(const std::string& file) { const auto& handle = this->clique_->set_current_device_to_root_rank(); - index_ = std::make_shared, T, IdxT>>(std::move(raft::neighbors::mg::deserialize_pq(handle, *this->clique_, file))); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(raft::neighbors::mg::deserialize_pq(handle, *this->clique_, file))); } template @@ -98,17 +108,27 @@ std::unique_ptr> RaftAnnMG_IvfPq::copy() } template -void RaftAnnMG_IvfPq::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +void RaftAnnMG_IvfPq::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); - auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); - auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - - raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); resource::sync_stream(handle); return; } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp index 0cade61f03..3e642fe0c8 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp @@ -18,48 +18,48 @@ #include "../common/ann_types.hpp" #include "raft_ann_bench_utils.h" + #include namespace raft::bench::ann { template class RaftAnnMG : public ANN, public AnnGPU { + public: + RaftAnnMG(Metric metric, int dim) : ANN(metric, dim), dimension_(dim) + { + this->init_nccl_clique(); + } - public: - RaftAnnMG(Metric metric, int dim) - : ANN(metric, dim), dimension_(dim) - { - this->init_nccl_clique(); - } - - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::HostMmap; - property.query_memory_type = MemoryType::HostMmap; - return property; - } + AlgoProperty get_preference() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::HostMmap; + property.query_memory_type = MemoryType::HostMmap; + return property; + } - private: - void init_nccl_clique() { - int n_devices; - cudaGetDeviceCount(&n_devices); - std::cout << n_devices << " GPUs detected" << std::endl; + private: + void init_nccl_clique() + { + int n_devices; + cudaGetDeviceCount(&n_devices); + std::cout << n_devices << " GPUs detected" << std::endl; - std::vector device_ids(n_devices); - std::iota(device_ids.begin(), device_ids.end(), 0); - clique_ = std::make_shared(device_ids); - } + std::vector device_ids(n_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); + clique_ = std::make_shared(device_ids); + } - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - const auto& handle = clique_->set_current_device_to_root_rank(); - return resource::get_cuda_stream(handle); - } + [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override + { + const auto& handle = clique_->set_current_device_to_root_rank(); + return resource::get_cuda_stream(handle); + } - protected: - std::shared_ptr clique_; - int dimension_; + protected: + std::shared_ptr clique_; + int dimension_; }; -} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 54cb33fd59..495cf639f8 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -92,24 +92,24 @@ std::unique_ptr> create_algo(const std::string& algo, } #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ -if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { - if (algo == "raft_ann_mg_ivf_pq") { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_pq") { typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam param; parse_build_param(conf, param); ann = std::make_unique>(metric, dim, param); } -} + } #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA -if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { if (algo == "raft_ann_mg_cagra") { typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam param; parse_build_param(conf, param); ann = std::make_unique>(metric, dim, param); } -} + } #endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } @@ -166,7 +166,8 @@ std::unique_ptr::AnnSearchParam> create_search if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { if (algo == "raft_ann_mg_ivf_pq") { - auto param = std::make_unique::SearchParam>(); + auto param = + std::make_unique::SearchParam>(); parse_search_param(conf, *param); return param; } @@ -176,7 +177,8 @@ std::unique_ptr::AnnSearchParam> create_search if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { if (algo == "raft_ann_mg_cagra") { - auto param = std::make_unique::SearchParam>(); + auto param = + std::make_unique::SearchParam>(); parse_search_param(conf, *param); return param; } diff --git a/cpp/include/raft/neighbors/ann_mg_helpers.cuh b/cpp/include/raft/neighbors/ann_mg_helpers.cuh index ebf006050b..4b1a5405a6 100644 --- a/cpp/include/raft/neighbors/ann_mg_helpers.cuh +++ b/cpp/include/raft/neighbors/ann_mg_helpers.cuh @@ -16,14 +16,17 @@ #pragma once -#include -#include -#include -#include #include +#include + +#include + +#include + +#include namespace raft::comms { - void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank); +void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank); } namespace raft::neighbors::mg { @@ -31,7 +34,6 @@ namespace raft::neighbors::mg { using pool_mr = rmm::mr::pool_memory_resource; struct nccl_clique { - nccl_clique(const std::vector& device_ids) : root_rank_(0), num_ranks_(device_ids.size()), @@ -48,7 +50,8 @@ struct nccl_clique { // create a pool memory resource for each device auto old_mr = rmm::mr::get_current_device_resource(); - per_device_pools_.push_back(std::make_unique(old_mr, rmm::percent_of_free_device_memory(80))); + per_device_pools_.push_back( + std::make_unique(old_mr, rmm::percent_of_free_device_memory(80))); rmm::cuda_device_id id(device_ids[rank]); rmm::mr::set_per_device_resource(id, per_device_pools_.back().get()); @@ -56,7 +59,8 @@ struct nccl_clique { device_resources_.emplace_back(); // add NCCL communications to the device resource handle - raft::comms::build_comms_nccl_only(&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); + raft::comms::build_comms_nccl_only( + &device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); } for (int rank = 0; rank < num_ranks_; rank++) { @@ -92,4 +96,4 @@ struct nccl_clique { std::vector device_resources_; }; -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ann_mg_types.hpp b/cpp/include/raft/neighbors/ann_mg_types.hpp index 4242f3fb9f..bdc653bbdb 100644 --- a/cpp/include/raft/neighbors/ann_mg_types.hpp +++ b/cpp/include/raft/neighbors/ann_mg_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,28 +16,28 @@ #pragma once +#include #include #include -#include namespace raft::neighbors::mg { - enum parallel_mode { REPLICATED, SHARDED }; +enum parallel_mode { REPLICATED, SHARDED }; } namespace raft::neighbors::ivf_flat { - struct mg_index_params : raft::neighbors::ivf_flat::index_params { - raft::neighbors::mg::parallel_mode mode; - }; -} +struct mg_index_params : raft::neighbors::ivf_flat::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::ivf_flat namespace raft::neighbors::ivf_pq { - struct mg_index_params : raft::neighbors::ivf_pq::index_params { - raft::neighbors::mg::parallel_mode mode; - }; -} +struct mg_index_params : raft::neighbors::ivf_pq::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::ivf_pq namespace raft::neighbors::cagra { - struct mg_index_params : raft::neighbors::cagra::index_params { - raft::neighbors::mg::parallel_mode mode; - }; -} +struct mg_index_params : raft::neighbors::cagra::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index 60bb2b133f..36941aa6cd 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/cagra_mg.cuh b/cpp/include/raft/neighbors/cagra_mg.cuh index 62be51ca7b..6cf51f580b 100644 --- a/cpp/include/raft/neighbors/cagra_mg.cuh +++ b/cpp/include/raft/neighbors/cagra_mg.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,9 @@ void search(const raft::resources& handle, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances, - uint64_t n_rows_per_batch = 1 << 20) // 2^20 + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } - // 2^20 -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh index fd8c4b7667..afb640f8f5 100644 --- a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,19 +31,21 @@ void serialize(const raft::resources& handle, } template -detail::ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> deserialize_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::deserialize_cagra(handle, clique, filename); } template -detail::ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> distribute_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::distribute_cagra(handle, clique, filename); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index fe85a70901..070821e2a6 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,30 +16,28 @@ #pragma once -#include #include #include #include -#include -#include #include +#include +#include +#include #undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include +#include #include #include -#include #define RAFT_EXPLICIT_INSTANTIATE_ONLY -#include +#include #include - -#include #include #include -#include - +#include +#include namespace raft::neighbors::mg::detail { using namespace raft::neighbors; @@ -47,58 +45,63 @@ using namespace raft::neighbors; template class ann_interface { public: - template void build(raft::resources const& handle, const ann::index_params* index_params, raft::mdspan, row_major, Accessor> index_dataset) { if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_flat::build(handle, - *static_cast(index_params), - index_dataset.data_handle(), - index_dataset.extent(0), - index_dataset.extent(1)); + auto idx = + raft::neighbors::ivf_flat::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_pq::build(handle, - *static_cast(index_params), - index_dataset.data_handle(), - index_dataset.extent(0), - index_dataset.extent(1)); + auto idx = + raft::neighbors::ivf_pq::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - auto extents = raft::make_extents(index_dataset.extent(0), index_dataset.extent(1)); + auto extents = raft::make_extents(index_dataset.extent(0), index_dataset.extent(1)); const bool host_acc = decltype(index_dataset)::accessor_type::is_host_type::value; const bool device_acc = decltype(index_dataset)::accessor_type::is_device_type::value; - auto dataset = raft::make_mdspan(index_dataset.data_handle(), extents); + auto dataset = raft::make_mdspan( + index_dataset.data_handle(), extents); cagra::index idx(handle); - idx = raft::neighbors::cagra::build(handle, - *static_cast(index_params), - dataset); + idx = raft::neighbors::cagra::build( + handle, *static_cast(index_params), dataset); index_.emplace(std::move(idx)); } resource::sync_stream(handle); } template - void extend(raft::resources const& handle, - raft::mdspan, row_major, Accessor1> new_vectors, - std::optional, layout_c_contiguous, Accessor2>> new_indices) + void extend( + raft::resources const& handle, + raft::mdspan, row_major, Accessor1> new_vectors, + std::optional, layout_c_contiguous, Accessor2>> + new_indices) { if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_flat::extend(handle, - index_.value(), - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); + auto idx = raft::neighbors::ivf_flat::extend( + handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_pq::extend(handle, - index_.value(), - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); + auto idx = raft::neighbors::ivf_pq::extend( + handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { RAFT_FAIL("CAGRA does not implement the extend method"); @@ -112,39 +115,44 @@ class ann_interface { raft::device_matrix_view d_neighbors, raft::device_matrix_view d_distances) const { - IdxT n_rows = h_query_dataset.extent(0); - IdxT n_dims = h_query_dataset.extent(1); + IdxT n_rows = h_query_dataset.extent(0); + IdxT n_dims = h_query_dataset.extent(1); auto d_query_dataset = raft::make_device_matrix(handle, n_rows, n_dims); - raft::copy(d_query_dataset.data_handle(), h_query_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); + raft::copy(d_query_dataset.data_handle(), + h_query_dataset.data_handle(), + n_rows * n_dims, + resource::get_cuda_stream(handle)); if constexpr (std::is_same>::value) { - raft::runtime::neighbors::ivf_flat::search(handle, - *reinterpret_cast(search_params), - index_.value(), - d_query_dataset.view(), - d_neighbors, - d_distances); + raft::runtime::neighbors::ivf_flat::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); } else if constexpr (std::is_same>::value) { - raft::runtime::neighbors::ivf_pq::search(handle, - *reinterpret_cast(search_params), - index_.value(), - d_query_dataset.view(), - d_neighbors, - d_distances); + raft::runtime::neighbors::ivf_pq::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); } else if constexpr (std::is_same>::value) { - raft::runtime::neighbors::cagra::search(handle, - *reinterpret_cast(search_params), - index_.value(), - d_query_dataset.view(), - d_neighbors, - d_distances); + raft::runtime::neighbors::cagra::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); } resource::sync_stream(handle); } - void serialize(raft::resources const& handle, - std::ostream& os) const - { + void serialize(raft::resources const& handle, std::ostream& os) const + { if constexpr (std::is_same>::value) { ivf_flat::serialize(handle, os, index_.value()); } else if constexpr (std::is_same>::value) { @@ -154,8 +162,7 @@ class ann_interface { } } - void deserialize(raft::resources const& handle, - std::istream& is) + void deserialize(raft::resources const& handle, std::istream& is) { if constexpr (std::is_same>::value) { index_.emplace(std::move(ivf_flat::deserialize(handle, is))); @@ -166,8 +173,7 @@ class ann_interface { } } - void deserialize(raft::resources const& handle, - const std::string& filename) + void deserialize(raft::resources const& handle, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } @@ -201,14 +207,12 @@ class ann_interface { template class ann_mg_index { public: - ann_mg_index(parallel_mode mode, int num_ranks_) - : mode_(mode), - num_ranks_(num_ranks_) - {} + ann_mg_index(parallel_mode mode, int num_ranks_) : mode_(mode), num_ranks_(num_ranks_) {} ann_mg_index(const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) { + const std::string& filename) + { deserialize_mg_index(handle, clique, filename); } @@ -223,7 +227,7 @@ class ann_mg_index { const std::string& filename) { for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_.emplace_back(); @@ -239,11 +243,11 @@ class ann_mg_index { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); + mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); num_ranks_ = deserialize_scalar(handle, is); for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_.emplace_back(); @@ -262,38 +266,39 @@ class ann_mg_index { RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", num_ranks_, n_rows); ann_interfaces_.resize(num_ranks_); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_[rank]; ann_if.build(dev_res, index_params, index_dataset); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } else if (mode_ == SHARDED) { - IdxT n_rows = index_dataset.extent(0); - IdxT n_cols = index_dataset.extent(1); - IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + IdxT n_rows = index_dataset.extent(0); + IdxT n_cols = index_dataset.extent(1); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", num_ranks_, n_rows_per_shard); ann_interfaces_.resize(num_ranks_); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT offset = rank * n_rows_per_shard; - IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view(partition_ptr, n_rows_of_current_shard, n_cols); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( + partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = ann_interfaces_[rank]; ann_if.build(dev_res, index_params, partition); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } } @@ -305,42 +310,44 @@ class ann_mg_index { if (mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", num_ranks_, n_rows); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_res, new_vectors, new_indices); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } else if (mode_ == SHARDED) { IdxT n_cols = new_vectors.extent(1); - IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", num_ranks_, n_rows_per_shard); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT offset = rank * n_rows_per_shard; - IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view(new_vectors_ptr, n_rows_of_current_shard, n_cols); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( + new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; if (new_indices.has_value()) { const IdxT* new_indices_ptr = new_indices.value().data_handle() + offset; - new_indices_part = raft::make_host_vector_view(new_indices_ptr, n_rows_of_current_shard); + new_indices_part = + raft::make_host_vector_view(new_indices_ptr, n_rows_of_current_shard); } auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_res, new_vectors_part, new_indices_part); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } } @@ -355,35 +362,34 @@ class ann_mg_index { IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); - IdxT n_batches = raft::ceildiv(n_rows, (IdxT)n_rows_per_batch); - if (n_batches == 1) - n_rows_per_batch = n_rows; + IdxT n_batches = raft::ceildiv(n_rows, (IdxT)n_rows_per_batch); + if (n_batches == 1) n_rows_per_batch = n_rows; if (mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED SEARCH: %d*%drows", n_batches, n_rows_per_batch); - #pragma omp parallel for +#pragma omp parallel for for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - int rank = batch_idx % num_ranks_; // alternate GPUs - int dev_id = clique.device_ids_[rank]; + int rank = batch_idx % num_ranks_; // alternate GPUs + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT offset = batch_idx * n_rows_per_batch; - IdxT query_offset = offset * n_cols; - IdxT output_offset = offset * n_neighbors; - IdxT n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset); + IdxT offset = batch_idx * n_rows_per_batch; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; + IdxT n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset); - auto query_partition = raft::make_host_matrix_view(query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); - auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); - auto d_distances = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); + auto query_partition = raft::make_host_matrix_view( + query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); + auto d_neighbors = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_res, - search_params, - query_partition, - d_neighbors.view(), - d_distances.view()); + ann_if.search( + dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); raft::copy(neighbors.data_handle() + output_offset, d_neighbors.data_handle(), @@ -396,7 +402,7 @@ class ann_mg_index { resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } else if (mode_ == SHARDED) { RAFT_LOG_INFO("SHARDED SEARCH: %d*%drows", n_batches, n_rows_per_batch); @@ -405,39 +411,44 @@ class ann_mg_index { root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); auto in_distances = raft::make_device_matrix( root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); - auto out_neighbors = raft::make_device_matrix( - root_handle, n_rows_per_batch, n_neighbors); + auto out_neighbors = + raft::make_device_matrix(root_handle, n_rows_per_batch, n_neighbors); auto out_distances = raft::make_device_matrix( root_handle, n_rows_per_batch, n_neighbors); for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - IdxT offset = batch_idx * n_rows_per_batch; - IdxT query_offset = offset * n_cols; - IdxT output_offset = offset * n_neighbors; + IdxT offset = batch_idx * n_rows_per_batch; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; IdxT n_rows_of_current_batch = std::min((IdxT)n_rows_per_batch, n_rows - offset); - auto query_partition = raft::make_host_matrix_view( + auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); - // should use at least num_ranks_ threads to avoid NCCL hang - #pragma omp parallel for num_threads(num_ranks_) +// should use at least num_ranks_ threads to avoid NCCL hang +#pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; - auto& ann_if = ann_interfaces_[rank]; - const auto& comms = resource::get_comms(dev_res); + auto& ann_if = ann_interfaces_[rank]; + const auto& comms = resource::get_comms(dev_res); RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - if (rank == clique.root_rank_) { // root rank + if (rank == clique.root_rank_) { // root rank uint64_t batch_offset = clique.root_rank_ * n_rows_of_current_batch * n_neighbors; - auto d_neighbors = raft::make_device_matrix_view(in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); - auto d_distances = raft::make_device_matrix_view(in_distances.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); - ann_if.search(dev_res, search_params, query_partition, d_neighbors, d_distances); // write search results inplace + auto d_neighbors = raft::make_device_matrix_view( + in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix_view( + in_distances.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + ann_if.search(dev_res, + search_params, + query_partition, + d_neighbors, + d_distances); // write search results inplace // wait for results of other ranks RAFT_NCCL_TRY(ncclGroupStart()); for (int from_rank = 0; from_rank < num_ranks_; from_rank++) { - if (from_rank == clique.root_rank_) - continue; + if (from_rank == clique.root_rank_) continue; batch_offset = from_rank * n_rows_of_current_batch * n_neighbors; comms.device_recv(in_neighbors.data_handle() + batch_offset, @@ -451,28 +462,31 @@ class ann_mg_index { } RAFT_NCCL_TRY(ncclGroupEnd()); resource::sync_stream(dev_res); - } else { // non-root ranks - auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); - auto d_distances = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); - ann_if.search(dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); - - // send results to root rank - RAFT_NCCL_TRY(ncclGroupStart()); - comms.device_send(d_neighbors.data_handle(), - n_rows_of_current_batch * n_neighbors, - clique.root_rank_, - resource::get_cuda_stream(dev_res)); - comms.device_send(d_distances.data_handle(), - n_rows_of_current_batch * n_neighbors, - clique.root_rank_, - resource::get_cuda_stream(dev_res)); - RAFT_NCCL_TRY(ncclGroupEnd()); - resource::sync_stream(dev_res); - } + } else { // non-root ranks + auto d_neighbors = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + ann_if.search( + dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); + + // send results to root rank + RAFT_NCCL_TRY(ncclGroupStart()); + comms.device_send(d_neighbors.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + comms.device_send(d_distances.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + RAFT_NCCL_TRY(ncclGroupEnd()); + resource::sync_stream(dev_res); + } } - #pragma omp barrier +#pragma omp barrier - auto in_neighbors_view = raft::make_device_matrix_view( + auto in_neighbors_view = raft::make_device_matrix_view( in_neighbors.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); auto in_distances_view = raft::make_device_matrix_view( in_distances.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); @@ -482,15 +496,19 @@ class ann_mg_index { out_distances.data_handle(), n_rows_of_current_batch, n_neighbors); const auto& root_handle_ = clique.set_current_device_to_root_rank(); - auto h_trans = std::vector(num_ranks_); - IdxT translation_offset = 0; + auto h_trans = std::vector(num_ranks_); + IdxT translation_offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { h_trans[rank] = translation_offset; translation_offset += ann_interfaces_[rank].size(); } auto d_trans = raft::make_device_vector(root_handle_, num_ranks_); - raft::copy(d_trans.data_handle(), h_trans.data(), num_ranks_, resource::get_cuda_stream(root_handle_)); - auto translations = std::make_optional>(d_trans.view()); + raft::copy(d_trans.data_handle(), + h_trans.data(), + num_ranks_, + resource::get_cuda_stream(root_handle_)); + auto translations = + std::make_optional>(d_trans.view()); raft::neighbors::brute_force::knn_merge_parts(root_handle_, in_distances_view, in_neighbors_view, @@ -523,11 +541,11 @@ class ann_mg_index { serialize_scalar(handle, of, (int)mode_); serialize_scalar(handle, of, num_ranks_); for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = ann_interfaces_[rank]; - ann_if.serialize(dev_res, of); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_[rank]; + ann_if.serialize(dev_res, of); } of.close(); @@ -606,7 +624,12 @@ void search(const raft::resources& handle, raft::host_matrix_view distances, uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); } template @@ -619,7 +642,12 @@ void search(const raft::resources& handle, raft::host_matrix_view distances, uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); } template @@ -632,7 +660,12 @@ void search(const raft::resources& handle, raft::host_matrix_view distances, uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); } template @@ -663,36 +696,40 @@ void serialize(const raft::resources& handle, } template -ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> deserialize_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(handle, clique, filename); return index; } template -ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> deserialize_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(handle, clique, filename); return index; } template -ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> deserialize_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(handle, clique, filename); return index; } template -ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> distribute_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); @@ -700,9 +737,10 @@ ann_mg_index, T, IdxT> distribute_flat(const raft::reso } template -ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> distribute_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); @@ -710,13 +748,14 @@ ann_mg_index, T, IdxT> distribute_pq(const raft::resources& } template -ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> distribute_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); return index; } -} // namespace raft::neighbors::mg::detail \ No newline at end of file +} // namespace raft::neighbors::mg::detail diff --git a/cpp/include/raft/neighbors/ivf_flat_mg.cuh b/cpp/include/raft/neighbors/ivf_flat_mg.cuh index 12f76babe7..f5170d063a 100644 --- a/cpp/include/raft/neighbors/ivf_flat_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_mg.cuh @@ -49,9 +49,10 @@ void search(const raft::resources& handle, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances, - uint64_t n_rows_per_batch = 1 << 20) // 2^20 + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh index 6f4c9d820f..9459243a5f 100644 --- a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh @@ -31,19 +31,21 @@ void serialize(const raft::resources& handle, } template -detail::ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> deserialize_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::deserialize_flat(handle, clique, filename); } template -detail::ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> distribute_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::distribute_flat(handle, clique, filename); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_pq_mg.cuh b/cpp/include/raft/neighbors/ivf_pq_mg.cuh index 7cad3449d0..104745485a 100644 --- a/cpp/include/raft/neighbors/ivf_pq_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_mg.cuh @@ -49,9 +49,10 @@ void search(const raft::resources& handle, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances, - uint64_t n_rows_per_batch = 1 << 20) // 2^20 + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh index ca6f91e763..56e00576e9 100644 --- a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh @@ -31,19 +31,21 @@ void serialize(const raft::resources& handle, } template -detail::ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> deserialize_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::deserialize_pq(handle, clique, filename); } template -detail::ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> distribute_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::distribute_pq(handle, clique, filename); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/test/neighbors/ann_mg/test_ann_mg.cu b/cpp/test/neighbors/ann_mg/test_ann_mg.cu index b392f629bc..61ea54b808 100644 --- a/cpp/test/neighbors/ann_mg/test_ann_mg.cu +++ b/cpp/test/neighbors/ann_mg/test_ann_mg.cu @@ -1,10 +1,26 @@ -#include +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "../ann_mg.cuh" +#include + namespace raft::neighbors::mg { typedef AnnMGTest AnnMGTestF_float; - TEST_P(AnnMGTestF_float, AnnMG) { this->testAnnMG(); } - INSTANTIATE_TEST_CASE_P(AnnMGTest, AnnMGTestF_float, ::testing::ValuesIn(inputs)); -} +TEST_P(AnnMGTestF_float, AnnMG) { this->testAnnMG(); } +INSTANTIATE_TEST_CASE_P(AnnMGTest, AnnMGTestF_float, ::testing::ValuesIn(inputs)); +} // namespace raft::neighbors::mg diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json b/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json index e7aecfd239..59f6b7c5d8 100644 --- a/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json @@ -1349,4 +1349,5 @@ "search_result_file" : "result/mnist-784-euclidean/raft_cagra/dim64" } ] -} \ No newline at end of file +} + diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json b/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json index 3ca47a2566..5803e9bf7e 100644 --- a/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json @@ -495,4 +495,5 @@ ] } ] -} \ No newline at end of file +} +