diff --git a/.gitignore b/.gitignore index 11b7bc3eba..ca1fc6b922 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,6 @@ _text # clang tooling compile_commands.json .clangd/ + +datasets/ +index/ diff --git a/build.sh b/build.sh index a77dd188f4..2867bbcc24 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_BRUTE_FORCE_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_BRUTE_FORCE_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;NEIGHBORS_ANN_MG_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" @@ -326,6 +326,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"NEIGHBORS_ANN_BRUTE_FORCE_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_MG_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 35df378438..2d4bcde02f 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -31,6 +31,9 @@ option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchm option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT "Include raft's MG ANN IVF-FLAT in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ "Include raft's MG ANN IVF-PQ in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA "Include raft's MG ANN CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) option(RAFT_ANN_BENCH_SINGLE_EXE @@ -57,6 +60,9 @@ if(BUILD_CPU_ONLY) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF) set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF) + set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT OFF) + set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ OFF) + set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA OFF) set(RAFT_ANN_BENCH_USE_GGNN OFF) endif() @@ -66,6 +72,9 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_CAGRA OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB + OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT + OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -252,6 +261,50 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ) endif() +if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT) + ConfigureAnnBench( + NAME + RAFT_ANN_MG_IVF_FLAT + PATH + src/raft/raft_benchmark.cu + src/raft/raft_ann_mg_ivf_flat.cu + LINKS + raft::compiled + ucp ucs ucxx nccl + ) +endif() + +if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) + ConfigureAnnBench( + NAME + RAFT_ANN_MG_IVF_PQ + PATH + src/raft/raft_benchmark.cu + src/raft/raft_ann_mg_ivf_pq.cu + LINKS + raft::compiled + ucp ucs ucxx nccl + ) +endif() + +if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + ConfigureAnnBench( + NAME + RAFT_ANN_MG_CAGRA + PATH + src/raft/raft_benchmark.cu + src/raft/raft_ann_mg_cagra.cu + LINKS + raft::compiled + ucp ucs ucxx nccl + ) +endif() + +set(RAFT_FAISS_TARGETS faiss::faiss) +if(TARGET faiss::faiss_avx2) + set(RAFT_FAISS_TARGETS faiss::faiss_avx2) +endif() + message("RAFT_FAISS_TARGETS: ${RAFT_FAISS_TARGETS}") message("CUDAToolkit_LIBRARY_DIR: ${CUDAToolkit_LIBRARY_DIR}") if(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT) diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 185d54a0a3..d1091d42dc 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -36,6 +36,7 @@ #include #include #include +#include #include namespace raft::bench::ann { diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 48bf1d70d8..da8eab76b5 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ extern template class raft::bench::ann::RaftIvfFlatGpu; extern template class raft::bench::ann::RaftIvfFlatGpu; #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) #include "raft_ivf_pq_wrapper.h" #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ @@ -38,7 +38,8 @@ extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; #endif -#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) #include "raft_cagra_wrapper.h" #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA @@ -47,11 +48,33 @@ extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT +#include "raft_ann_mg_ivf_flat_wrapper.hpp" +extern template class raft::bench::ann::RaftAnnMG_IvfFlat; +extern template class raft::bench::ann::RaftAnnMG_IvfFlat; +extern template class raft::bench::ann::RaftAnnMG_IvfFlat; +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +#include "raft_ann_mg_ivf_pq_wrapper.hpp" +extern template class raft::bench::ann::RaftAnnMG_IvfPq; +extern template class raft::bench::ann::RaftAnnMG_IvfPq; +extern template class raft::bench::ann::RaftAnnMG_IvfPq; +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA +#include "raft_ann_mg_cagra_wrapper.hpp" +extern template class raft::bench::ann::RaftAnnMG_Cagra; +extern template class raft::bench::ann::RaftAnnMG_Cagra; +extern template class raft::bench::ann::RaftAnnMG_Cagra; +#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT +#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT) template void parse_build_param(const nlohmann::json& conf, +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) +#else + typename raft::bench::ann::RaftAnnMG_IvfFlat::BuildParam& param) +#endif { param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -60,17 +83,26 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) +#else + typename raft::bench::ann::RaftAnnMG_IvfFlat::SearchParam& param) +#endif { param.ivf_flat_params.n_probes = conf.at("nprobe"); } #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam& param) +#else typename raft::bench::ann::RaftIvfPQ::BuildParam& param) +#endif { if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -92,7 +124,11 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + typename raft::bench::ann::RaftAnnMG_IvfPq::SearchParam& param) +#else typename raft::bench::ann::RaftIvfPQ::SearchParam& param) +#endif { if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } if (conf.contains("internalDistanceDtype")) { @@ -133,7 +169,8 @@ void parse_search_param(const nlohmann::json& conf, } #endif -#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, raft::neighbors::experimental::nn_descent::index_params& param) @@ -180,7 +217,11 @@ nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, template void parse_build_param(const nlohmann::json& conf, +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam& param) +#else typename raft::bench::ann::RaftCagra::BuildParam& param) +#endif { if (conf.contains("graph_degree")) { param.cagra_params.graph_degree = conf.at("graph_degree"); @@ -244,7 +285,11 @@ raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) template void parse_search_param(const nlohmann::json& conf, +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + typename raft::bench::ann::RaftAnnMG_Cagra::SearchParam& param) +#else typename raft::bench::ann::RaftCagra::SearchParam& param) +#endif { if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } @@ -263,12 +308,15 @@ void parse_search_param(const nlohmann::json& conf, THROW("Invalid value for algo: %s", tmp.c_str()); } } + +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) if (conf.contains("graph_memory_type")) { param.graph_mem = parse_allocator(conf.at("graph_memory_type")); } if (conf.contains("internal_dataset_memory_type")) { param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); } +#endif // Same ratio as in IVF-PQ param.refine_ratio = conf.value("refine_ratio", 1.0f); } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu new file mode 100644 index 0000000000..0243529a67 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "raft_ann_mg_cagra_wrapper.hpp" + +#include +#include + +namespace raft::bench::ann { + +template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp new file mode 100644 index 0000000000..ef3ece2839 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "raft_ann_mg_wrapper.hpp" + +#include +#include +#include + +namespace raft::bench::ann { + +enum class AllocatorType_t { HostPinned, HostHugePage, Device }; +template +class RaftAnnMG_Cagra : public RaftAnnMG { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + raft::neighbors::cagra::search_params p; + float refine_ratio; + AllocatorType_t graph_mem = AllocatorType_t::Device; + AllocatorType_t dataset_mem = AllocatorType_t::Device; + auto needs_dataset() const -> bool override { return true; } + }; + + struct BuildParam { + raft::neighbors::cagra::mg_index_params cagra_params; + std::optional nn_descent_params = + std::nullopt; + std::optional ivf_pq_refine_rate = std::nullopt; + std::optional ivf_pq_build_params = std::nullopt; + std::optional ivf_pq_search_params = std::nullopt; + }; + + RaftAnnMG_Cagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) + : RaftAnnMG(metric, dim), index_params_(param), dimension_(dim) + { + index_params_.cagra_params.add_data_on_build = true; + index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED; + index_params_.cagra_params.metric = parse_metric_type(metric); + index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + } + + void build(const T* dataset, size_t nrow) final; + void set_search_param(const AnnSearchParam& param) override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void save(const std::string& file) const override; + void load(const std::string&) override; + std::unique_ptr> copy() override; + + private: + BuildParam index_params_; + raft::neighbors::cagra::search_params search_params_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; + float refine_ratio_ = 1.0; + int dimension_; +}; + +template +void RaftAnnMG_Cagra::build(const T* dataset, size_t nrow) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = raft::neighbors::mg::build( + handle, *this->clique_, index_params_.cagra_params, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); + return; +} + +template +void RaftAnnMG_Cagra::set_search_param(const AnnSearchParam& param) +{ + auto search_param = dynamic_cast(param); + search_params_ = search_param.p; + refine_ratio_ = search_param.refine_ratio; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void RaftAnnMG_Cagra::save(const std::string& file) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + raft::neighbors::mg::serialize(handle, *this->clique_, *index_, file); + return; +} + +template +void RaftAnnMG_Cagra::load(const std::string& file) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto idx = raft::neighbors::mg::deserialize_cagra(handle, *this->clique_, file); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); +} + +template +std::unique_ptr> RaftAnnMG_Cagra::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + +template +void RaftAnnMG_Cagra::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); + resource::sync_stream(handle); + return; +} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu new file mode 100644 index 0000000000..a11460dea0 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "raft_ann_mg_ivf_flat_wrapper.hpp" + +#include +#include + +namespace raft::bench::ann { + +template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp new file mode 100644 index 0000000000..02e55423be --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "raft_ann_mg_wrapper.hpp" + +#include +#include + +namespace raft::bench::ann { + +template +class RaftAnnMG_IvfFlat : public RaftAnnMG { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + raft::neighbors::ivf_flat::search_params ivf_flat_params; + }; + + using BuildParam = raft::neighbors::ivf_flat::mg_index_params; + + RaftAnnMG_IvfFlat(Metric metric, int dim, const BuildParam& param) + : RaftAnnMG(metric, dim), index_params_(param) + { + index_params_.metric = parse_metric_type(metric); + index_params_.conservative_memory_allocation = true; + index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDED; + } + + void build(const T* dataset, size_t nrow) final; + void set_search_param(const AnnSearchParam& param) override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void save(const std::string& file) const override; + void load(const std::string&) override; + std::unique_ptr> copy() override; + + private: + BuildParam index_params_; + raft::neighbors::ivf_flat::search_params search_params_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; +}; + +template +void RaftAnnMG_IvfFlat::build(const T* dataset, size_t nrow) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = + raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); + return; +} + +template +void RaftAnnMG_IvfFlat::set_search_param(const AnnSearchParam& param) +{ + auto search_param = dynamic_cast(param); + search_params_ = search_param.ivf_flat_params; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void RaftAnnMG_IvfFlat::save(const std::string& file) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + raft::neighbors::mg::serialize(handle, *this->clique_, *index_, file); + return; +} + +template +void RaftAnnMG_IvfFlat::load(const std::string& file) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(raft::neighbors::mg::deserialize_flat(handle, *this->clique_, file))); +} + +template +std::unique_ptr> RaftAnnMG_IvfFlat::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + +template +void RaftAnnMG_IvfFlat::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + + const auto& handle = this->clique_->set_current_device_to_root_rank(); + + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); + resource::sync_stream(handle); + return; +} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu new file mode 100644 index 0000000000..b0bf8a9d84 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "raft_ann_mg_ivf_pq_wrapper.hpp" + +#include +#include + +namespace raft::bench::ann { + +template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp new file mode 100644 index 0000000000..8c0867462d --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "raft_ann_mg_wrapper.hpp" + +#include +#include + +namespace raft::bench::ann { + +template +class RaftAnnMG_IvfPq : public RaftAnnMG { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + raft::neighbors::ivf_pq::search_params pq_param; + float refine_ratio = 1.0f; + auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; } + }; + + using BuildParam = raft::neighbors::ivf_pq::mg_index_params; + + RaftAnnMG_IvfPq(Metric metric, int dim, const BuildParam& param) + : RaftAnnMG(metric, dim), index_params_(param) + { + index_params_.metric = parse_metric_type(metric); + index_params_.conservative_memory_allocation = true; + index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDED; + } + + void build(const T* dataset, size_t nrow) final; + void set_search_param(const AnnSearchParam& param) override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void save(const std::string& file) const override; + void load(const std::string&) override; + std::unique_ptr> copy() override; + + private: + BuildParam index_params_; + raft::neighbors::ivf_pq::search_params search_params_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; + float refine_ratio_ = 1.0; +}; + +template +void RaftAnnMG_IvfPq::build(const T* dataset, size_t nrow) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = + raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); + return; +} + +template +void RaftAnnMG_IvfPq::set_search_param(const AnnSearchParam& param) +{ + auto search_param = dynamic_cast(param); + search_params_ = search_param.pq_param; + refine_ratio_ = search_param.refine_ratio; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void RaftAnnMG_IvfPq::save(const std::string& file) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + raft::neighbors::mg::serialize(handle, *this->clique_, *index_, file); + return; +} + +template +void RaftAnnMG_IvfPq::load(const std::string& file) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(raft::neighbors::mg::deserialize_pq(handle, *this->clique_, file))); +} + +template +std::unique_ptr> RaftAnnMG_IvfPq::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + +template +void RaftAnnMG_IvfPq::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + + const auto& handle = this->clique_->set_current_device_to_root_rank(); + + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); + resource::sync_stream(handle); + return; +} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp new file mode 100644 index 0000000000..3e642fe0c8 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../common/ann_types.hpp" +#include "raft_ann_bench_utils.h" + +#include + +namespace raft::bench::ann { + +template +class RaftAnnMG : public ANN, public AnnGPU { + public: + RaftAnnMG(Metric metric, int dim) : ANN(metric, dim), dimension_(dim) + { + this->init_nccl_clique(); + } + + AlgoProperty get_preference() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::HostMmap; + property.query_memory_type = MemoryType::HostMmap; + return property; + } + + private: + void init_nccl_clique() + { + int n_devices; + cudaGetDeviceCount(&n_devices); + std::cout << n_devices << " GPUs detected" << std::endl; + + std::vector device_ids(n_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); + clique_ = std::make_shared(device_ids); + } + + [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override + { + const auto& handle = clique_->set_current_device_to_root_rank(); + return resource::get_cuda_stream(handle); + } + + protected: + std::shared_ptr clique_; + int dimension_; +}; + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 8bb4d9423c..495cf639f8 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -81,7 +81,36 @@ std::unique_ptr> create_algo(const std::string& algo, ann = std::make_unique>(metric, dim, param); } #endif - +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_flat") { + typename raft::bench::ann::RaftAnnMG_IvfFlat::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } + } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_pq") { + typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } + } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_cagra") { + typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } + } +#endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } return ann; @@ -122,6 +151,39 @@ std::unique_ptr::AnnSearchParam> create_search return param; } #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_flat") { + auto param = + std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_pq") { + auto param = + std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_cagra") { + auto param = + std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + } +#endif // else throw std::runtime_error("invalid algo: '" + algo + "'"); diff --git a/cpp/include/raft/neighbors/ann_mg_helpers.cuh b/cpp/include/raft/neighbors/ann_mg_helpers.cuh new file mode 100644 index 0000000000..4b1a5405a6 --- /dev/null +++ b/cpp/include/raft/neighbors/ann_mg_helpers.cuh @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include + +#include + +namespace raft::comms { +void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank); +} + +namespace raft::neighbors::mg { + +using pool_mr = rmm::mr::pool_memory_resource; + +struct nccl_clique { + nccl_clique(const std::vector& device_ids) + : root_rank_(0), + num_ranks_(device_ids.size()), + device_ids_(device_ids), + nccl_comms_(device_ids.size()), + per_device_pools_(0), + device_resources_(0) + { + RAFT_LOG_INFO("Starting NCCL initialization..."); + RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data())); + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(device_ids[rank])); + + // create a pool memory resource for each device + auto old_mr = rmm::mr::get_current_device_resource(); + per_device_pools_.push_back( + std::make_unique(old_mr, rmm::percent_of_free_device_memory(80))); + rmm::cuda_device_id id(device_ids[rank]); + rmm::mr::set_per_device_resource(id, per_device_pools_.back().get()); + + // create a device resource handle for each device + device_resources_.emplace_back(); + + // add NCCL communications to the device resource handle + raft::comms::build_comms_nccl_only( + &device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); + } + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(device_ids[rank])); + resource::sync_stream(device_resources_[rank]); + } + + RAFT_LOG_INFO("NCCL initialization completed"); + } + + const raft::device_resources& set_current_device_to_root_rank() const + { + int root_device_id = device_ids_[root_rank_]; + RAFT_CUDA_TRY(cudaSetDevice(root_device_id)); + return device_resources_[root_rank_]; + } + + ~nccl_clique() + { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(device_ids_[rank]); + ncclCommDestroy(nccl_comms_[rank]); + rmm::cuda_device_id id(device_ids_[rank]); + rmm::mr::set_per_device_resource(id, nullptr); + } + } + + int root_rank_; + int num_ranks_; + std::vector device_ids_; + std::vector nccl_comms_; + std::vector> per_device_pools_; + std::vector device_resources_; +}; + +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ann_mg_types.hpp b/cpp/include/raft/neighbors/ann_mg_types.hpp new file mode 100644 index 0000000000..bdc653bbdb --- /dev/null +++ b/cpp/include/raft/neighbors/ann_mg_types.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::neighbors::mg { +enum parallel_mode { REPLICATED, SHARDED }; +} + +namespace raft::neighbors::ivf_flat { +struct mg_index_params : raft::neighbors::ivf_flat::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::ivf_flat + +namespace raft::neighbors::ivf_pq { +struct mg_index_params : raft::neighbors::ivf_pq::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::ivf_pq + +namespace raft::neighbors::cagra { +struct mg_index_params : raft::neighbors::cagra::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index f955cc8518..36941aa6cd 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,7 +92,7 @@ inline void knn_merge_parts( RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), "in_keys and in_values must have the same shape."); RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples, + out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == idx_t(n_samples), "Number of rows in output keys and val matrices must equal number of rows in search matrix."); RAFT_EXPECTS( out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1), diff --git a/cpp/include/raft/neighbors/cagra_mg.cuh b/cpp/include/raft/neighbors/cagra_mg.cuh new file mode 100644 index 0000000000..6cf51f580b --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_mg.cuh @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +auto build(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const cagra::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(handle, clique, index_params, index_dataset); +} + +template +void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const detail::ann_mg_index, T, IdxT>& index, + const cagra::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + uint64_t n_rows_per_batch = 1 << 20) // 2^20 +{ + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); +} +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh new file mode 100644 index 0000000000..afb640f8f5 --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, clique, index, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + return mg::detail::deserialize_cagra(handle, clique, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + return mg::detail::distribute_cagra(handle, clique, filename); +} + +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh new file mode 100644 index 0000000000..070821e2a6 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -0,0 +1,761 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include +#include +#include +#include +#define RAFT_EXPLICIT_INSTANTIATE_ONLY + +#include +#include +#include + +#include +#include +#include + +namespace raft::neighbors::mg::detail { +using namespace raft::neighbors; + +template +class ann_interface { + public: + template + void build(raft::resources const& handle, + const ann::index_params* index_params, + raft::mdspan, row_major, Accessor> index_dataset) + { + if constexpr (std::is_same>::value) { + auto idx = + raft::neighbors::ivf_flat::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); + index_.emplace(std::move(idx)); + } else if constexpr (std::is_same>::value) { + auto idx = + raft::neighbors::ivf_pq::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); + index_.emplace(std::move(idx)); + } else if constexpr (std::is_same>::value) { + auto extents = raft::make_extents(index_dataset.extent(0), index_dataset.extent(1)); + const bool host_acc = decltype(index_dataset)::accessor_type::is_host_type::value; + const bool device_acc = decltype(index_dataset)::accessor_type::is_device_type::value; + auto dataset = raft::make_mdspan( + index_dataset.data_handle(), extents); + cagra::index idx(handle); + idx = raft::neighbors::cagra::build( + handle, *static_cast(index_params), dataset); + index_.emplace(std::move(idx)); + } + resource::sync_stream(handle); + } + + template + void extend( + raft::resources const& handle, + raft::mdspan, row_major, Accessor1> new_vectors, + std::optional, layout_c_contiguous, Accessor2>> + new_indices) + { + if constexpr (std::is_same>::value) { + auto idx = raft::neighbors::ivf_flat::extend( + handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); + index_.emplace(std::move(idx)); + } else if constexpr (std::is_same>::value) { + auto idx = raft::neighbors::ivf_pq::extend( + handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); + index_.emplace(std::move(idx)); + } else if constexpr (std::is_same>::value) { + RAFT_FAIL("CAGRA does not implement the extend method"); + } + resource::sync_stream(handle); + } + + void search(raft::resources const& handle, + const ann::search_params* search_params, + raft::host_matrix_view h_query_dataset, + raft::device_matrix_view d_neighbors, + raft::device_matrix_view d_distances) const + { + IdxT n_rows = h_query_dataset.extent(0); + IdxT n_dims = h_query_dataset.extent(1); + auto d_query_dataset = raft::make_device_matrix(handle, n_rows, n_dims); + raft::copy(d_query_dataset.data_handle(), + h_query_dataset.data_handle(), + n_rows * n_dims, + resource::get_cuda_stream(handle)); + + if constexpr (std::is_same>::value) { + raft::runtime::neighbors::ivf_flat::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); + } else if constexpr (std::is_same>::value) { + raft::runtime::neighbors::ivf_pq::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); + } else if constexpr (std::is_same>::value) { + raft::runtime::neighbors::cagra::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); + } + resource::sync_stream(handle); + } + + void serialize(raft::resources const& handle, std::ostream& os) const + { + if constexpr (std::is_same>::value) { + ivf_flat::serialize(handle, os, index_.value()); + } else if constexpr (std::is_same>::value) { + ivf_pq::serialize(handle, os, index_.value()); + } else if constexpr (std::is_same>::value) { + cagra::serialize(handle, os, index_.value(), true); + } + } + + void deserialize(raft::resources const& handle, std::istream& is) + { + if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_flat::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_pq::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(cagra::deserialize(handle, is))); + } + } + + void deserialize(raft::resources const& handle, const std::string& filename) + { + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_flat::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_pq::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(cagra::deserialize(handle, is))); + } + + is.close(); + } + + const IdxT size() const + { + if constexpr (std::is_same>::value) { + return index_.value().size(); + } else if constexpr (std::is_same>::value) { + return index_.value().size(); + } else if constexpr (std::is_same>::value) { + return index_.value().size(); + } + } + + private: + std::optional index_; +}; + +template +class ann_mg_index { + public: + ann_mg_index(parallel_mode mode, int num_ranks_) : mode_(mode), num_ranks_(num_ranks_) {} + + ann_mg_index(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) + { + deserialize_mg_index(handle, clique, filename); + } + + ann_mg_index(const ann_mg_index&) = delete; + ann_mg_index(ann_mg_index&&) = default; + auto operator=(const ann_mg_index&) -> ann_mg_index& = delete; + auto operator=(ann_mg_index&&) -> ann_mg_index& = default; + + // local index deserialization and distribution + void deserialize_and_distribute(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) + { + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.deserialize(dev_res, filename); + } + } + + // MG index deserialization + void deserialize_mg_index(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) + { + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); + num_ranks_ = deserialize_scalar(handle, is); + + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.deserialize(dev_res, is); + } + + is.close(); + } + + void build(const raft::neighbors::mg::nccl_clique& clique, + const ann::index_params* index_params, + raft::host_matrix_view index_dataset) + { + if (mode_ == REPLICATED) { + IdxT n_rows = index_dataset.extent(0); + RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", num_ranks_, n_rows); + + ann_interfaces_.resize(num_ranks_); +#pragma omp parallel for + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_[rank]; + ann_if.build(dev_res, index_params, index_dataset); + resource::sync_stream(dev_res); + } +#pragma omp barrier + } else if (mode_ == SHARDED) { + IdxT n_rows = index_dataset.extent(0); + IdxT n_cols = index_dataset.extent(1); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + + RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", num_ranks_, n_rows_per_shard); + + ann_interfaces_.resize(num_ranks_); +#pragma omp parallel for + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( + partition_ptr, n_rows_of_current_shard, n_cols); + auto& ann_if = ann_interfaces_[rank]; + ann_if.build(dev_res, index_params, partition); + resource::sync_stream(dev_res); + } +#pragma omp barrier + } + } + + void extend(const raft::neighbors::mg::nccl_clique& clique, + raft::host_matrix_view new_vectors, + std::optional> new_indices) + { + IdxT n_rows = new_vectors.extent(0); + if (mode_ == REPLICATED) { + RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", num_ranks_, n_rows); + +#pragma omp parallel for + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_[rank]; + ann_if.extend(dev_res, new_vectors, new_indices); + resource::sync_stream(dev_res); + } +#pragma omp barrier + } else if (mode_ == SHARDED) { + IdxT n_cols = new_vectors.extent(1); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + + RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", num_ranks_, n_rows_per_shard); + +#pragma omp parallel for + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( + new_vectors_ptr, n_rows_of_current_shard, n_cols); + + std::optional> new_indices_part = std::nullopt; + if (new_indices.has_value()) { + const IdxT* new_indices_ptr = new_indices.value().data_handle() + offset; + new_indices_part = + raft::make_host_vector_view(new_indices_ptr, n_rows_of_current_shard); + } + auto& ann_if = ann_interfaces_[rank]; + ann_if.extend(dev_res, new_vectors_part, new_indices_part); + resource::sync_stream(dev_res); + } +#pragma omp barrier + } + } + + void search(const raft::neighbors::mg::nccl_clique& clique, + const ann::search_params* search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + IdxT n_rows_per_batch) const + { + IdxT n_rows = query_dataset.extent(0); + IdxT n_cols = query_dataset.extent(1); + IdxT n_neighbors = neighbors.extent(1); + + IdxT n_batches = raft::ceildiv(n_rows, (IdxT)n_rows_per_batch); + if (n_batches == 1) n_rows_per_batch = n_rows; + + if (mode_ == REPLICATED) { + RAFT_LOG_INFO("REPLICATED SEARCH: %d*%drows", n_batches, n_rows_per_batch); + +#pragma omp parallel for + for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { + int rank = batch_idx % num_ranks_; // alternate GPUs + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + + IdxT offset = batch_idx * n_rows_per_batch; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; + IdxT n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset); + + auto query_partition = raft::make_host_matrix_view( + query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); + auto d_neighbors = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + + auto& ann_if = ann_interfaces_[rank]; + ann_if.search( + dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); + + raft::copy(neighbors.data_handle() + output_offset, + d_neighbors.data_handle(), + n_rows_of_current_batch * n_neighbors, + resource::get_cuda_stream(dev_res)); + raft::copy(distances.data_handle() + output_offset, + d_distances.data_handle(), + n_rows_of_current_batch * n_neighbors, + resource::get_cuda_stream(dev_res)); + + resource::sync_stream(dev_res); + } +#pragma omp barrier + } else if (mode_ == SHARDED) { + RAFT_LOG_INFO("SHARDED SEARCH: %d*%drows", n_batches, n_rows_per_batch); + + const auto& root_handle = clique.set_current_device_to_root_rank(); + auto in_neighbors = raft::make_device_matrix( + root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); + auto in_distances = raft::make_device_matrix( + root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); + auto out_neighbors = + raft::make_device_matrix(root_handle, n_rows_per_batch, n_neighbors); + auto out_distances = raft::make_device_matrix( + root_handle, n_rows_per_batch, n_neighbors); + + for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { + IdxT offset = batch_idx * n_rows_per_batch; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; + IdxT n_rows_of_current_batch = std::min((IdxT)n_rows_per_batch, n_rows - offset); + auto query_partition = raft::make_host_matrix_view( + query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); + +// should use at least num_ranks_ threads to avoid NCCL hang +#pragma omp parallel for num_threads(num_ranks_) + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + auto& ann_if = ann_interfaces_[rank]; + const auto& comms = resource::get_comms(dev_res); + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + + if (rank == clique.root_rank_) { // root rank + uint64_t batch_offset = clique.root_rank_ * n_rows_of_current_batch * n_neighbors; + auto d_neighbors = raft::make_device_matrix_view( + in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix_view( + in_distances.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + ann_if.search(dev_res, + search_params, + query_partition, + d_neighbors, + d_distances); // write search results inplace + + // wait for results of other ranks + RAFT_NCCL_TRY(ncclGroupStart()); + for (int from_rank = 0; from_rank < num_ranks_; from_rank++) { + if (from_rank == clique.root_rank_) continue; + + batch_offset = from_rank * n_rows_of_current_batch * n_neighbors; + comms.device_recv(in_neighbors.data_handle() + batch_offset, + n_rows_of_current_batch * n_neighbors, + from_rank, + resource::get_cuda_stream(dev_res)); + comms.device_recv(in_distances.data_handle() + batch_offset, + n_rows_of_current_batch * n_neighbors, + from_rank, + resource::get_cuda_stream(dev_res)); + } + RAFT_NCCL_TRY(ncclGroupEnd()); + resource::sync_stream(dev_res); + } else { // non-root ranks + auto d_neighbors = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + ann_if.search( + dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); + + // send results to root rank + RAFT_NCCL_TRY(ncclGroupStart()); + comms.device_send(d_neighbors.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + comms.device_send(d_distances.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + RAFT_NCCL_TRY(ncclGroupEnd()); + resource::sync_stream(dev_res); + } + } +#pragma omp barrier + + auto in_neighbors_view = raft::make_device_matrix_view( + in_neighbors.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); + auto in_distances_view = raft::make_device_matrix_view( + in_distances.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); + auto out_neighbors_view = raft::make_device_matrix_view( + out_neighbors.data_handle(), n_rows_of_current_batch, n_neighbors); + auto out_distances_view = raft::make_device_matrix_view( + out_distances.data_handle(), n_rows_of_current_batch, n_neighbors); + + const auto& root_handle_ = clique.set_current_device_to_root_rank(); + auto h_trans = std::vector(num_ranks_); + IdxT translation_offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + h_trans[rank] = translation_offset; + translation_offset += ann_interfaces_[rank].size(); + } + auto d_trans = raft::make_device_vector(root_handle_, num_ranks_); + raft::copy(d_trans.data_handle(), + h_trans.data(), + num_ranks_, + resource::get_cuda_stream(root_handle_)); + auto translations = + std::make_optional>(d_trans.view()); + raft::neighbors::brute_force::knn_merge_parts(root_handle_, + in_distances_view, + in_neighbors_view, + out_distances_view, + out_neighbors_view, + n_rows_of_current_batch, + translations); + + raft::copy(neighbors.data_handle() + output_offset, + out_neighbors.data_handle(), + n_rows_of_current_batch * n_neighbors, + resource::get_cuda_stream(root_handle_)); + raft::copy(distances.data_handle() + output_offset, + out_distances.data_handle(), + n_rows_of_current_batch * n_neighbors, + resource::get_cuda_stream(root_handle_)); + + resource::sync_stream(root_handle_); + } + } + } + + void serialize(raft::resources const& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) const + { + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + serialize_scalar(handle, of, (int)mode_); + serialize_scalar(handle, of, num_ranks_); + for (int rank = 0; rank < num_ranks_; rank++) { + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_[rank]; + ann_if.serialize(dev_res, of); + } + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } + } + + private: + parallel_mode mode_; + int num_ranks_; + std::vector> ann_interfaces_; +}; + +template +ann_mg_index, T, IdxT> build( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ivf_flat::mg_index_params& index_params, + raft::host_matrix_view index_dataset) +{ + ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); + index.build(clique, static_cast(&index_params), index_dataset); + return index; +} + +template +ann_mg_index, T, IdxT> build( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ivf_pq::mg_index_params& index_params, + raft::host_matrix_view index_dataset) +{ + ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); + index.build(clique, static_cast(&index_params), index_dataset); + return index; +} + +template +ann_mg_index, T, IdxT> build( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const cagra::mg_index_params& index_params, + raft::host_matrix_view index_dataset) +{ + ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); + index.build(clique, static_cast(&index_params), index_dataset); + return index; +} + +template +void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + index.extend(clique, new_vectors, new_indices); +} + +template +void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + index.extend(clique, new_vectors, new_indices); +} + +template +void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + uint64_t n_rows_per_batch) +{ + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); +} + +template +void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ann_mg_index, T, IdxT>& index, + const ivf_pq::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + uint64_t n_rows_per_batch) +{ + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); +} + +template +void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ann_mg_index, T, IdxT>& index, + const cagra::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + uint64_t n_rows_per_batch) +{ + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); +} + +template +void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + index.serialize(handle, clique, filename); +} + +template +void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + index.serialize(handle, clique, filename); +} + +template +void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + index.serialize(handle, clique, filename); +} + +template +ann_mg_index, T, IdxT> deserialize_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + auto index = ann_mg_index, T, IdxT>(handle, clique, filename); + return index; +} + +template +ann_mg_index, T, IdxT> deserialize_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + auto index = ann_mg_index, T, IdxT>(handle, clique, filename); + return index; +} + +template +ann_mg_index, T, IdxT> deserialize_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + auto index = ann_mg_index, T, IdxT>(handle, clique, filename); + return index; +} + +template +ann_mg_index, T, IdxT> distribute_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); + index.deserialize_and_distribute(handle, clique, filename); + return index; +} + +template +ann_mg_index, T, IdxT> distribute_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); + index.deserialize_and_distribute(handle, clique, filename); + return index; +} + +template +ann_mg_index, T, IdxT> distribute_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); + index.deserialize_and_distribute(handle, clique, filename); + return index; +} + +} // namespace raft::neighbors::mg::detail diff --git a/cpp/include/raft/neighbors/ivf_flat_mg.cuh b/cpp/include/raft/neighbors/ivf_flat_mg.cuh new file mode 100644 index 0000000000..f5170d063a --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat_mg.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +auto build(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ivf_flat::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(handle, clique, index_params, index_dataset); +} + +template +void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(handle, clique, index, new_vectors, new_indices); +} + +template +void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const detail::ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + uint64_t n_rows_per_batch = 1 << 20) // 2^20 +{ + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); +} + +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh new file mode 100644 index 0000000000..9459243a5f --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, clique, index, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + return mg::detail::deserialize_flat(handle, clique, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + return mg::detail::distribute_flat(handle, clique, filename); +} + +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_pq_mg.cuh b/cpp/include/raft/neighbors/ivf_pq_mg.cuh new file mode 100644 index 0000000000..104745485a --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq_mg.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +auto build(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const ivf_pq::mg_index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(handle, clique, index_params, index_dataset); +} + +template +void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(handle, clique, index, new_vectors, new_indices); +} + +template +void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const detail::ann_mg_index, T, IdxT>& index, + const ivf_pq::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances, + uint64_t n_rows_per_batch = 1 << 20) // 2^20 +{ + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); +} + +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh new file mode 100644 index 0000000000..56e00576e9 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, clique, index, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + return mg::detail::deserialize_pq(handle, clique, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) +{ + return mg::detail::distribute_pq(handle, clique, filename); +} + +} // namespace raft::neighbors::mg diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 08541ad135..cbf7859768 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -23,7 +23,7 @@ function(ConfigureTest) set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY NOCUDA) set(oneValueArgs NAME GPUS PERCENT) - set(multiValueArgs PATH TARGETS CONFIGURATIONS) + set(multiValueArgs PATH ADDITIONAL_LIBS TARGETS CONFIGURATIONS) cmake_parse_arguments(_RAFT_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if(NOT DEFINED _RAFT_TEST_GPUS AND NOT DEFINED _RAFT_TEST_PERCENT) @@ -55,6 +55,7 @@ function(ConfigureTest) ${RAFT_CTK_MATH_DEPENDENCIES} $ $ + ${_RAFT_TEST_ADDITIONAL_LIBS} ) set_target_properties( ${TEST_NAME} @@ -447,6 +448,21 @@ if(BUILD_TESTS) 100 ) + ConfigureTest( + NAME + NEIGHBORS_ANN_MG_TEST + PATH + neighbors/ann_mg/test_ann_mg.cu + ADDITIONAL_LIBS + ucp ucs ucxx nccl + LIB + EXPLICIT_INSTANTIATE_ONLY + GPUS + 1 + PERCENT + 100 + ) + ConfigureTest( NAME NEIGHBORS_ANN_NN_DESCENT_TEST diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh new file mode 100644 index 0000000000..284871ef70 --- /dev/null +++ b/cpp/test/neighbors/ann_mg.cuh @@ -0,0 +1,408 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" + +#include + +#include +#include +#include + +#include +#include +#include + +#include + +#include + + +namespace raft::neighbors::mg { + +template +struct AnnMGInputs { + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + IdxT nprobe; + IdxT nlist; + raft::distance::DistanceType metric; + bool adaptive_centers; +}; + +template +class AnnMGTest : public ::testing::TestWithParam> { + public: + AnnMGTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam>::GetParam()), + d_index_dataset(0, stream_), + d_query_dataset(0, stream_), + h_index_dataset(0), + h_query_dataset(0) + { + } + + void testAnnMG() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_naive(queries_size); + std::vector distances_naive(queries_size); + std::vector indices_ann(queries_size); + std::vector distances_ann(queries_size); + std::vector indices_naive_32bits(queries_size); + std::vector indices_ann_32bits(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + raft::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + d_query_dataset.data(), + d_index_dataset.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + int n_devices; + cudaGetDeviceCount(&n_devices); + std::cout << n_devices << " GPUs detected" << std::endl; + std::vector device_ids(n_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); + + uint64_t n_rows_per_batch = 3000; // [3000, 3000, 1000] == 7000 rows + raft::neighbors::mg::nccl_clique clique(device_ids); + + // IVF-Flat + for (parallel_mode d_mode : {parallel_mode::REPLICATED, parallel_mode::SHARDED}) { + ivf_flat::mg_index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + index_params.add_data_on_build = false; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + index_params.mode = d_mode; + + ivf_flat::search_params search_params; + search_params.n_probes = ps.nprobe; + + auto index_dataset = raft::make_host_matrix_view( + h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view( + h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view( + indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view( + distances_ann.data(), ps.num_queries, ps.k); + + { + auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::extend(handle_, clique, index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_ivf_flat_index"); + } + auto new_index = raft::neighbors::mg::deserialize_flat(handle_, clique, "./cpp/build/ann_mg_ivf_flat_index"); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + // IVF-PQ + for (parallel_mode d_mode : {parallel_mode::REPLICATED, parallel_mode::SHARDED}) { + ivf_pq::mg_index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.add_data_on_build = false; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + index_params.mode = d_mode; + + ivf_pq::search_params search_params; + search_params.n_probes = ps.nprobe; + + auto index_dataset = raft::make_host_matrix_view( + h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view( + h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view( + indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view( + distances_ann.data(), ps.num_queries, ps.k); + + { + auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::extend(handle_, clique, index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_ivf_pq_index"); + } + auto new_index = raft::neighbors::mg::deserialize_pq(handle_, clique, "./cpp/build/ann_mg_ivf_pq_index"); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + // CAGRA + for (parallel_mode d_mode : {parallel_mode::REPLICATED, parallel_mode::SHARDED}) { + cagra::mg_index_params index_params; + index_params.add_data_on_build = true; + index_params.intermediate_graph_degree = 128; + index_params.graph_degree = 64; + index_params.build_algo = cagra::graph_build_algo::IVF_PQ; + index_params.nn_descent_niter = 20; + index_params.mode = d_mode; + + cagra::search_params search_params; + + auto index_dataset = raft::make_host_matrix_view( + h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view( + h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view( + indices_ann_32bits.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view( + distances_ann.data(), ps.num_queries, ps.k); + + { + auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_cagra_index"); + } + auto new_index = raft::neighbors::mg::deserialize_cagra(handle_, clique, "./cpp/build/ann_mg_cagra_index"); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive_32bits, + indices_ann_32bits, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann_32bits.begin(), indices_ann_32bits.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + { + ivf_flat::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + + ivf_flat::search_params search_params; + search_params.n_probes = ps.nprobe; + + RAFT_CUDA_TRY(cudaSetDevice(0)); + + { + auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index = raft::runtime::neighbors::ivf_flat::build(handle_, index_params, index_dataset); + ivf_flat::serialize(handle_, "./cpp/build/local_ivf_flat_index", index); + } + + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, clique, "./cpp/build/local_ivf_flat_index"); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + { + ivf_pq::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + + ivf_pq::search_params search_params; + search_params.n_probes = ps.nprobe; + + RAFT_CUDA_TRY(cudaSetDevice(0)); + + { + auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index = raft::runtime::neighbors::ivf_pq::build(handle_, index_params, index_dataset); + ivf_pq::serialize(handle_, "./cpp/build/local_ivf_pq_index", index); + } + + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, clique, "./cpp/build/local_ivf_pq_index"); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + { + cagra::index_params index_params; + index_params.intermediate_graph_degree = 128; + index_params.graph_degree = 64; + index_params.build_algo = cagra::graph_build_algo::IVF_PQ; + index_params.nn_descent_niter = 20; + + cagra::search_params search_params; + + RAFT_CUDA_TRY(cudaSetDevice(0)); + + { + auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index = raft::runtime::neighbors::cagra::build(handle_, index_params, index_dataset); + raft::neighbors::cagra::serialize(handle_, "./cpp/build/local_cagra_index", index); + } + + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann_32bits.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, clique, "./cpp/build/local_cagra_index"); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive_32bits, + indices_ann_32bits, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann_32bits.begin(), indices_ann_32bits.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + } + + void SetUp() override + { + d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_); + d_query_dataset.resize(ps.num_queries * ps.dim, stream_); + h_index_dataset.resize(ps.num_db_vecs * ps.dim); + h_query_dataset.resize(ps.num_queries * ps.dim); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, d_query_dataset.data(), d_query_dataset.size(), DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, d_query_dataset.data(), d_query_dataset.size(), DataT(1), DataT(20)); + } + + raft::copy(h_index_dataset.data(), + d_index_dataset.data(), + d_index_dataset.size(), + resource::get_cuda_stream(handle_)); + raft::copy(h_query_dataset.data(), + d_query_dataset.data(), + d_query_dataset.size(), + resource::get_cuda_stream(handle_)); + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + h_index_dataset.clear(); + h_query_dataset.clear(); + d_index_dataset.resize(0, stream_); + d_query_dataset.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnMGInputs ps; + std::vector h_index_dataset; + std::vector h_query_dataset; + rmm::device_uvector d_index_dataset; + rmm::device_uvector d_query_dataset; +}; + +const std::vector> inputs = { + {7000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, +}; +} // namespace raft::neighbors::mg diff --git a/cpp/test/neighbors/ann_mg/test_ann_mg.cu b/cpp/test/neighbors/ann_mg/test_ann_mg.cu new file mode 100644 index 0000000000..61ea54b808 --- /dev/null +++ b/cpp/test/neighbors/ann_mg/test_ann_mg.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_mg.cuh" + +#include + +namespace raft::neighbors::mg { + +typedef AnnMGTest AnnMGTestF_float; +TEST_P(AnnMGTestF_float, AnnMG) { this->testAnnMG(); } +INSTANTIATE_TEST_CASE_P(AnnMGTest, AnnMGTestF_float, ::testing::ValuesIn(inputs)); +} // namespace raft::neighbors::mg diff --git a/docs/source/ann_benchmarks_build.md b/docs/source/ann_benchmarks_build.md index 56af8e555c..3c841c6313 100644 --- a/docs/source/ann_benchmarks_build.md +++ b/docs/source/ann_benchmarks_build.md @@ -47,5 +47,6 @@ Available targets to use with `--limit-bench-ann` are: - RAFT_CAGRA_ANN_BENCH - RAFT_IVF_PQ_ANN_BENCH - RAFT_IVF_FLAT_ANN_BENCH +- RAFT_ANN_MG_ANN_BENCH By default, the `*_ANN_BENCH` executables program infer the dataset's datatype from the filename's extension. For example, an extension of `fbin` uses a `float` datatype, `f16bin` uses a `float16` datatype, extension of `i8bin` uses `int8_t` datatype, and `u8bin` uses `uint8_t` type. Currently, only `float`, `float16`, int8_t`, and `unit8_t` are supported. \ No newline at end of file diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/algos.yaml b/python/raft-ann-bench/src/raft_ann_bench/run/algos.yaml index e382bdcba6..e63b9226b5 100644 --- a/python/raft-ann-bench/src/raft_ann_bench/run/algos.yaml +++ b/python/raft-ann-bench/src/raft_ann_bench/run/algos.yaml @@ -31,6 +31,15 @@ raft_cagra: raft_brute_force: executable: RAFT_BRUTE_FORCE_ANN_BENCH requires_gpu: true +raft_ann_mg_ivf_flat: + executable: RAFT_ANN_MG_IVF_FLAT_ANN_BENCH + requires_gpu: true +raft_ann_mg_ivf_pq: + executable: RAFT_ANN_MG_IVF_PQ_ANN_BENCH + requires_gpu: true +raft_ann_mg_cagra: + executable: RAFT_ANN_MG_CAGRA_ANN_BENCH + requires_gpu: true ggnn: executable: GGNN_ANN_BENCH requires_gpu: true diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_cagra.yaml b/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_cagra.yaml new file mode 100644 index 0000000000..232010ec0e --- /dev/null +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_cagra.yaml @@ -0,0 +1,13 @@ +name: raft_ann_mg_cagra +constraints: + build: raft-ann-bench.constraints.raft_cagra_build_constraints + search: raft-ann-bench.constraints.raft_cagra_search_constraints +groups: + base: + build: + graph_degree: [32, 64, 128, 256] + intermediate_graph_degree: [32, 64, 96, 128] + graph_build_algo: ["NN_DESCENT"] + search: + itopk: [32, 64, 128, 256, 512] + search_width: [1, 2, 4, 8, 16, 32, 64] diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_ivf_flat.yaml b/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_ivf_flat.yaml new file mode 100644 index 0000000000..760bb70ed8 --- /dev/null +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_ivf_flat.yaml @@ -0,0 +1,9 @@ +name: raft_ann_mg_ivf_flat +groups: + base: + build: + nlist: [1024, 2048, 4096, 8192, 16384, 32000, 64000] + ratio: [1, 2, 4] + niter: [20, 25] + search: + nprobe: [1, 5, 10, 50, 100, 200, 500, 1000, 2000] diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_ivf_pq.yaml b/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_ivf_pq.yaml new file mode 100644 index 0000000000..9f5979912c --- /dev/null +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/algos/raft_ann_mg_ivf_pq.yaml @@ -0,0 +1,17 @@ +name: raft_ann_mg_ivf_pq +constraints: + build: raft-ann-bench.constraints.raft_ivf_pq_build_constraints + search: raft-ann-bench.constraints.raft_ivf_pq_search_constraints +groups: + base: + build: + nlist: [1024, 2048, 4096, 8192] + pq_dim: [64, 32] + pq_bits: [8, 6, 5, 4] + ratio: [10, 25] + niter: [25] + search: + nprobe: [1, 5, 10, 50, 100, 200] + internalDistanceDtype: ["float"] + smemLutDtype: ["float", "fp8", "half"] + refine_ratio: [1, 2, 4] diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json b/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json index 04e7ecb469..59f6b7c5d8 100644 --- a/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json @@ -1350,3 +1350,4 @@ } ] } + diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json b/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json index 791261251a..5803e9bf7e 100644 --- a/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json @@ -496,3 +496,4 @@ } ] } +