From 3b74685d29e741527c7b9ffa96ba9f14354a7e73 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 14 Nov 2023 19:23:03 +0100 Subject: [PATCH 01/22] SNMG ANN --- build.sh | 3 +- cpp/CMakeLists.txt | 3 + cpp/include/raft/neighbors/ann_mg-ext.cuh | 114 +++++++ cpp/include/raft/neighbors/ann_mg-inl.cuh | 80 +++++ cpp/include/raft/neighbors/ann_mg.cuh | 24 ++ cpp/include/raft/neighbors/detail/ann_mg.cuh | 328 +++++++++++++++++++ cpp/src/neighbors/ann_mg.cu | 63 ++++ cpp/test/CMakeLists.txt | 13 + cpp/test/neighbors/ann_mg.cuh | 197 +++++++++++ cpp/test/neighbors/ann_mg/test_ann_mg.cu | 10 + 10 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 cpp/include/raft/neighbors/ann_mg-ext.cuh create mode 100644 cpp/include/raft/neighbors/ann_mg-inl.cuh create mode 100644 cpp/include/raft/neighbors/ann_mg.cuh create mode 100644 cpp/include/raft/neighbors/detail/ann_mg.cuh create mode 100644 cpp/src/neighbors/ann_mg.cu create mode 100644 cpp/test/neighbors/ann_mg.cuh create mode 100644 cpp/test/neighbors/ann_mg/test_ann_mg.cu diff --git a/build.sh b/build.sh index 51e59cc259..d715c33e78 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;NEIGHBORS_ANN_MG_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" @@ -325,6 +325,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then $CMAKE_TARGET == *"MATRIX_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \ + $CMAKE_TARGET == *"NEIGHBORS_ANN_MG_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \ $CMAKE_TARGET == *"NEIGHBORS_TEST"* || \ $CMAKE_TARGET == *"SPARSE_DIST_TEST" || \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5d2864e2e0..367ac7bb05 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -403,6 +403,7 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/refine_float_float.cu src/neighbors/refine_int8_t_float.cu src/neighbors/refine_uint8_t_float.cu + src/neighbors/ann_mg.cu src/raft_runtime/cluster/cluster_cost.cuh src/raft_runtime/cluster/cluster_cost_double.cu src/raft_runtime/cluster/cluster_cost_float.cu @@ -489,6 +490,8 @@ if(RAFT_COMPILE_LIBRARY) ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this # will just be cublas $ + nccl + ucp ) # So consumers know when using libraft.so/libraft.a diff --git a/cpp/include/raft/neighbors/ann_mg-ext.cuh b/cpp/include/raft/neighbors/ann_mg-ext.cuh new file mode 100644 index 0000000000..29750f9cb6 --- /dev/null +++ b/cpp/include/raft/neighbors/ann_mg-ext.cuh @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::mg { + using namespace raft::neighbors::mg; + + template + auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_flat::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> RAFT_EXPLICIT; + + template + auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, uint32_t> RAFT_EXPLICIT; + + template + void extend(detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + raft::host_matrix_view new_indices) RAFT_EXPLICIT; + + template + void extend(detail::ann_mg_index, T, uint32_t>& index, + raft::host_matrix_view new_vectors, + raft::host_matrix_view new_indices) RAFT_EXPLICIT; + + template + void search(detail::ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + IdxT n_neighbors, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) RAFT_EXPLICIT; + + template + void search(detail::ann_mg_index, T, uint32_t>& index, + const ivf_pq::search_params& search_params, + uint32_t n_neighbors, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) RAFT_EXPLICIT; + +} + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ann_mg_build(T, IdxT) \ + extern template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_flat::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + -> detail::ann_mg_index, T, IdxT>; \ + \ + extern template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_pq::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + -> detail::ann_mg_index, T, uint32_t>; \ + \ + extern template void raft::neighbors::mg::extend( \ + detail::ann_mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + raft::host_matrix_view new_indices); \ + \ + extern template void raft::neighbors::mg::extend( \ + detail::ann_mg_index, T, uint32_t>& index, \ + raft::host_matrix_view new_vectors, \ + raft::host_matrix_view new_indices); \ + \ + extern template void raft::neighbors::mg::search( \ + detail::ann_mg_index, T, IdxT>& index, \ + const ivf_flat::search_params& search_params, \ + IdxT n_neighbors, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + \ + extern template void raft::neighbors::mg::search( \ + detail::ann_mg_index, T, uint32_t>& index, \ + const ivf_pq::search_params& search_params, \ + uint32_t n_neighbors, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + +instantiate_raft_neighbors_ann_mg_build(float, uint32_t); + +#undef instantiate_raft_neighbors_ann_mg_build diff --git a/cpp/include/raft/neighbors/ann_mg-inl.cuh b/cpp/include/raft/neighbors/ann_mg-inl.cuh new file mode 100644 index 0000000000..c36ea0053e --- /dev/null +++ b/cpp/include/raft/neighbors/ann_mg-inl.cuh @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::neighbors::mg { + + template + auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_flat::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> + { + return mg::detail::build(device_ids, mode, index_params, index_dataset); + } + + template + auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, uint32_t> + { + return mg::detail::build(device_ids, mode, index_params, index_dataset); + } + + template + void extend(detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + raft::host_matrix_view new_indices) + { + mg::detail::extend(index, new_vectors, new_indices); + } + + template + void extend(detail::ann_mg_index, T, uint32_t>& index, + raft::host_matrix_view new_vectors, + raft::host_matrix_view new_indices) + { + mg::detail::extend(index, new_vectors, new_indices); + } + + template + void search(detail::ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + IdxT n_neighbors, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) + { + mg::detail::search(index, search_params, n_neighbors, query_dataset, neighbors, distances); + } + + template + void search(detail::ann_mg_index, T, uint32_t>& index, + const ivf_pq::search_params& search_params, + uint32_t n_neighbors, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) + { + mg::detail::search(index, search_params, n_neighbors, query_dataset, neighbors, distances); + } +} \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ann_mg.cuh b/cpp/include/raft/neighbors/ann_mg.cuh new file mode 100644 index 0000000000..2b22be273d --- /dev/null +++ b/cpp/include/raft/neighbors/ann_mg.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "ann_mg-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ann_mg-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh new file mode 100644 index 0000000000..a3ee05b99f --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include +#include +#define RAFT_EXPLICIT_INSTANTIATE_ONLY + + +namespace raft::neighbors::mg { + enum dist_mode { SHARDING, INDEX_DUPLICATION }; +} + +namespace raft::neighbors::mg::detail { + using namespace raft::neighbors; + + template + class ann_interface { + public: + void build(raft::resources const& handle, + const ann::index_params* index_params, + raft::host_matrix_view h_index_dataset) { + auto index_dataset_view = store_to_device(handle, index_dataset_, h_index_dataset); + + if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_flat::build(handle, + *static_cast(index_params), + index_dataset_view))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_pq::build(handle, + *static_cast(index_params), + index_dataset_view))); + } + } + + void extend(raft::resources const& handle, + raft::host_matrix_view h_new_vectors, + raft::host_matrix_view h_new_indices) { + auto new_vectors_view = store_to_device(handle, new_vectors_, h_new_vectors); + auto new_indices_view = store_to_device(handle, new_indices_, h_new_indices); + auto new_indices_vector_view = \ + raft::make_device_vector_view(new_indices_view.data_handle(), new_indices_view.extent(0)); + + std::optional> new_indices_opt = + std::make_optional>(new_indices_vector_view); + + if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_flat::extend(handle, + new_vectors_view, + new_indices_opt, + index_.value()))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_pq::extend(handle, + new_vectors_view, + new_indices_opt, + index_.value()))); + } + } + + void search(raft::resources const& handle, + const ann::search_params* search_params, + IdxT n_neighbors, + raft::host_matrix_view h_query_dataset, + raft::host_matrix_view h_neighbors, + raft::host_matrix_view h_distances) { + auto query_dataset_view = store_to_device(handle, query_dataset_, h_query_dataset); + IdxT n_rows = h_query_dataset.extent(0); + auto neighbors_view = neighbors_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))); + auto distances_view = distances_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))); + + if constexpr (std::is_same>::value) { + ivf_flat::search(handle, + *reinterpret_cast(search_params), + index_.value(), + query_dataset_view, + neighbors_view, + distances_view); + } else if constexpr (std::is_same>::value) { + ivf_pq::search(handle, + *reinterpret_cast(search_params), + index_.value(), + query_dataset_view, + neighbors_view, + distances_view); + } + } + + private: + template + raft::device_matrix_view store_to_device(raft::resources const& handle, + std::optional>& dev_mat_opt, + raft::host_matrix_view host_mat_view) { + DataIdxT n_rows = host_mat_view.extent(0); + DataIdxT n_cols = host_mat_view.extent(1); + dev_mat_opt.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_cols))); + raft::copy(dev_mat_opt.value().data_handle(), // async copy + host_mat_view.data_handle(), + n_rows * n_cols, + resource::get_cuda_stream(handle)); + auto const_dev_mat_view = dev_mat_opt.value().view(); + raft::device_matrix_view dev_mat_view = \ + raft::make_device_matrix_view(const_dev_mat_view.data_handle(), + const_dev_mat_view.extent(0), + const_dev_mat_view.extent(1)); + return dev_mat_view; + } + + std::optional> index_dataset_; + std::optional> new_vectors_; + std::optional> new_indices_; + std::optional> query_dataset_; + std::optional> neighbors_; + std::optional> distances_; + std::optional index_; + }; + + template + class ann_mg_index { + public: + ann_mg_index() = delete; + ann_mg_index(const std::vector& dev_list, + dist_mode mode = SHARDING) + : mode_(mode), + num_ranks_(dev_list.size()), + dev_ids_(dev_list), + nccl_comms_(dev_list.size()) + { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + + raft::resources& handle = dev_resources_.emplace_back(); + raft::comms::build_comms_nccl_only(&handle, nccl_comms_[rank], num_ranks_, rank); + } + ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data()); + } + + ~ann_mg_index() { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + ncclCommDestroy(nccl_comms_[rank]); + } + } + + ann_mg_index(const ann_mg_index&) = delete; + ann_mg_index(ann_mg_index&&) = default; + auto operator=(const ann_mg_index&) -> ann_mg_index& = delete; + auto operator=(ann_mg_index&&) -> ann_mg_index& = default; + + void build(const ann::index_params* index_params, + raft::host_matrix_view index_dataset) { + if (mode_ == INDEX_DUPLICATION) { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.build(dev_resources_[rank], index_params, index_dataset); + } + } else if (mode_ == SHARDING) { + IdxT n_rows = index_dataset.extent(0); + IdxT n_cols = index_dataset.extent(1); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + offset; + auto partition = raft::make_host_matrix_view(partition_ptr, n_rows_per_shard, n_cols); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.build(dev_resources_[rank], index_params, partition); + offset += n_rows_per_shard * n_cols; + } + } + } + + void extend(raft::host_matrix_view new_vectors, + raft::host_matrix_view new_indices) { + if (mode_ == INDEX_DUPLICATION) { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + auto& ann_if = ann_interfaces_[rank]; + ann_if.extend(dev_resources_[rank], new_vectors, new_indices); + } + } else if (mode_ == SHARDING) { + IdxT n_rows = new_vectors.extent(0); + IdxT n_cols = new_vectors.extent(1); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + offset; + const IdxT* new_indices_ptr = new_indices.data_handle() + offset; + auto new_vectors_part = raft::make_host_matrix_view(new_vectors_ptr, n_rows_per_shard, n_cols); + auto new_indices_part = raft::make_host_matrix_view(new_indices_ptr, n_rows_per_shard, 1); + auto& ann_if = ann_interfaces_[rank]; + ann_if.extend(dev_resources_[rank], new_vectors_part, new_indices_part); + offset += n_rows_per_shard * n_cols; + } + } + } + + void search(const ann::search_params* search_params, + IdxT n_neighbors, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) + { + if (mode_ == INDEX_DUPLICATION) { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + auto& ann_if = ann_interfaces_[rank]; + ann_if.search(dev_resources_[rank], search_params, n_neighbors, query_dataset, neighbors, distances); + } + } else if (mode_ == SHARDING) { + IdxT n_rows = query_dataset.extent(0); + IdxT n_cols = query_dataset.extent(1); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* query_dataset_ptr = query_dataset.data_handle() + offset; + auto query_dataset_part = raft::make_host_matrix_view(query_dataset_ptr, n_rows_per_shard, n_cols); + auto& ann_if = ann_interfaces_[rank]; + //ann_if.search(dev_resources_[rank], search_params, n_neighbors, query_dataset_part, neighbors, distances); + offset += n_rows_per_shard * n_cols; + } + } + } + + private: + dist_mode mode_; + int num_ranks_; + std::vector dev_ids_; + std::vector dev_resources_; + std::vector> ann_interfaces_; + std::vector nccl_comms_; + }; + + template + ann_mg_index, T, IdxT> build(const std::vector device_ids, + dist_mode mode, + const ivf_flat::index_params& index_params, + raft::host_matrix_view index_dataset) + { + ann_mg_index, T, IdxT> index(device_ids, mode); + index.build(static_cast(&index_params), index_dataset); + return index; + } + + template + ann_mg_index, T, uint32_t> build(const std::vector device_ids, + dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) + { + ann_mg_index, T, uint32_t> index(device_ids, mode); + index.build(static_cast(&index_params), index_dataset); + return index; + } + + template + void extend(ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + raft::host_matrix_view new_indices) + { + index.extend(new_vectors, new_indices); + } + + template + void extend(ann_mg_index, T, uint32_t>& index, + raft::host_matrix_view new_vectors, + raft::host_matrix_view new_indices) + { + index.extend(new_vectors, new_indices); + } + + template + void search(ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + IdxT n_neighbors, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) + { + index.search(static_cast(&search_params), + n_neighbors, + query_dataset, + neighbors, + distances); + } + + template + void search(ann_mg_index, T, uint32_t>& index, + const ivf_pq::search_params& search_params, + uint32_t n_neighbors, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) + { + index.search(static_cast(&search_params), + n_neighbors, + query_dataset, + neighbors, + distances); + } + +} \ No newline at end of file diff --git a/cpp/src/neighbors/ann_mg.cu b/cpp/src/neighbors/ann_mg.cu new file mode 100644 index 0000000000..1a7ad67dc5 --- /dev/null +++ b/cpp/src/neighbors/ann_mg.cu @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include + +#define instantiate_raft_neighbors_ann_mg(T, IdxT) \ + template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_flat::index_params& index_params, \ + raft::host_matrix_view index_dataset \ + ) -> raft::neighbors::mg::detail::ann_mg_index, T, IdxT>; \ + \ + template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_pq::index_params& index_params, \ + raft::host_matrix_view index_dataset \ + ) -> raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>; \ + \ + template void raft::neighbors::mg::extend( \ + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>& index, \ + raft::host_matrix_view new_vectors, \ + raft::host_matrix_view new_indices); \ + \ + template void raft::neighbors::mg::extend( \ + raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>& index, \ + raft::host_matrix_view new_vectors, \ + raft::host_matrix_view new_indices); \ + \ + template void raft::neighbors::mg::search( \ + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>& index, \ + const ivf_flat::search_params& search_params, \ + IdxT n_neighbors, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + \ + template void raft::neighbors::mg::search( \ + raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>& index, \ + const ivf_pq::search_params& search_params, \ + uint32_t n_neighbors, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + +instantiate_raft_neighbors_ann_mg(float, uint32_t); + +#undef instantiate_raft_neighbors_ann_mg diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6c03da8d7f..3a214a657c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -397,6 +397,19 @@ if(BUILD_TESTS) 100 ) + ConfigureTest( + NAME + NEIGHBORS_ANN_MG_TEST + PATH + test/neighbors/ann_mg/test_ann_mg.cu + LIB + EXPLICIT_INSTANTIATE_ONLY + GPUS + 1 + PERCENT + 100 + ) + ConfigureTest( NAME NEIGHBORS_ANN_NN_DESCENT_TEST diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh new file mode 100644 index 0000000000..b2686c9e35 --- /dev/null +++ b/cpp/test/neighbors/ann_mg.cuh @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace raft::neighbors::mg { + +template +struct AnnMGInputs { + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + IdxT nprobe; + IdxT nlist; + raft::distance::DistanceType metric; + bool adaptive_centers; +}; + +template +class AnnMGTest : public ::testing::TestWithParam> { + public: + AnnMGTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam>::GetParam()), + d_index_dataset(0, stream_), + d_query_dataset(0, stream_), + h_index_dataset(0), + h_query_dataset(0) + { + } + + void testAnnMG() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector distances_ivfflat(queries_size); + std::vector distances_naive(queries_size); + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + d_query_dataset.data(), + d_index_dataset.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); + rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); + + std::vector device_ids{0, 1}; + raft::neighbors::mg::dist_mode mode = SHARDING; + ivf_flat::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + index_params.add_data_on_build = false; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + auto index_dataset_view = raft::make_host_matrix_view(h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index = raft::neighbors::mg::build(device_ids, mode, index_params, index_dataset_view); + + update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); + update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfflat, + distances_naive, + distances_ivfflat, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } + + void SetUp() override + { + d_index_dataset.resize(ps.num_db_vecs * ps.dim, stream_); + d_query_dataset.resize(ps.num_queries * ps.dim, stream_); + h_index_dataset.resize(ps.num_db_vecs * ps.dim); + h_query_dataset.resize(ps.num_queries * ps.dim); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, d_index_dataset.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, d_query_dataset.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, d_index_dataset.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, d_query_dataset.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + + raft::copy(h_index_dataset.data(), + d_index_dataset.data(), + d_index_dataset.size(), + resource::get_cuda_stream(handle_)); + raft::copy(h_query_dataset.data(), + d_query_dataset.data(), + d_query_dataset.size(), + resource::get_cuda_stream(handle_)); + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + h_index_dataset.clear(); + h_query_dataset.clear(); + d_index_dataset.resize(0, stream_); + d_query_dataset.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnMGInputs ps; + std::vector h_index_dataset; + std::vector h_query_dataset; + rmm::device_uvector d_index_dataset; + rmm::device_uvector d_query_dataset; +}; + +const std::vector> inputs = { + {1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, +};} \ No newline at end of file diff --git a/cpp/test/neighbors/ann_mg/test_ann_mg.cu b/cpp/test/neighbors/ann_mg/test_ann_mg.cu new file mode 100644 index 0000000000..c34cdabb36 --- /dev/null +++ b/cpp/test/neighbors/ann_mg/test_ann_mg.cu @@ -0,0 +1,10 @@ +#include + +#include "../ann_mg.cuh" + +namespace raft::neighbors::mg { + +typedef AnnMGTest AnnMGTestF_float; + TEST_P(AnnMGTestF_float, AnnMG) { this->testAnnMG(); } + INSTANTIATE_TEST_CASE_P(AnnMGTest, AnnMGTestF_float, ::testing::ValuesIn(inputs)); +} From c655d91f5db8a61d7ed5d39458b973177205ddd7 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 15 Nov 2023 18:48:46 +0100 Subject: [PATCH 02/22] Complete main parts and add tests --- cpp/include/raft/neighbors/ann_mg-ext.cuh | 4 - cpp/include/raft/neighbors/ann_mg-inl.cuh | 6 +- cpp/include/raft/neighbors/detail/ann_mg.cuh | 102 ++++++++++---- cpp/src/neighbors/ann_mg.cu | 2 - cpp/test/neighbors/ann_mg.cuh | 135 +++++++++---------- 5 files changed, 146 insertions(+), 103 deletions(-) diff --git a/cpp/include/raft/neighbors/ann_mg-ext.cuh b/cpp/include/raft/neighbors/ann_mg-ext.cuh index 29750f9cb6..2032166a55 100644 --- a/cpp/include/raft/neighbors/ann_mg-ext.cuh +++ b/cpp/include/raft/neighbors/ann_mg-ext.cuh @@ -51,7 +51,6 @@ namespace raft::neighbors::mg { template void search(detail::ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, - IdxT n_neighbors, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) RAFT_EXPLICIT; @@ -59,7 +58,6 @@ namespace raft::neighbors::mg { template void search(detail::ann_mg_index, T, uint32_t>& index, const ivf_pq::search_params& search_params, - uint32_t n_neighbors, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) RAFT_EXPLICIT; @@ -96,7 +94,6 @@ namespace raft::neighbors::mg { extern template void raft::neighbors::mg::search( \ detail::ann_mg_index, T, IdxT>& index, \ const ivf_flat::search_params& search_params, \ - IdxT n_neighbors, \ raft::host_matrix_view query_dataset, \ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances); \ @@ -104,7 +101,6 @@ namespace raft::neighbors::mg { extern template void raft::neighbors::mg::search( \ detail::ann_mg_index, T, uint32_t>& index, \ const ivf_pq::search_params& search_params, \ - uint32_t n_neighbors, \ raft::host_matrix_view query_dataset, \ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances); \ diff --git a/cpp/include/raft/neighbors/ann_mg-inl.cuh b/cpp/include/raft/neighbors/ann_mg-inl.cuh index c36ea0053e..0d8432fb00 100644 --- a/cpp/include/raft/neighbors/ann_mg-inl.cuh +++ b/cpp/include/raft/neighbors/ann_mg-inl.cuh @@ -59,22 +59,20 @@ namespace raft::neighbors::mg { template void search(detail::ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, - IdxT n_neighbors, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - mg::detail::search(index, search_params, n_neighbors, query_dataset, neighbors, distances); + mg::detail::search(index, search_params, query_dataset, neighbors, distances); } template void search(detail::ann_mg_index, T, uint32_t>& index, const ivf_pq::search_params& search_params, - uint32_t n_neighbors, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - mg::detail::search(index, search_params, n_neighbors, query_dataset, neighbors, distances); + mg::detail::search(index, search_params, query_dataset, neighbors, distances); } } \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index a3ee05b99f..9a4fe74c51 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -16,15 +16,17 @@ #pragma once +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include #include #include #include #include -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include #include + #define RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -57,11 +59,13 @@ namespace raft::neighbors::mg::detail { void extend(raft::resources const& handle, raft::host_matrix_view h_new_vectors, raft::host_matrix_view h_new_indices) { + resource::sync_stream(handle); + index_dataset_.reset(); + auto new_vectors_view = store_to_device(handle, new_vectors_, h_new_vectors); auto new_indices_view = store_to_device(handle, new_indices_, h_new_indices); auto new_indices_vector_view = \ raft::make_device_vector_view(new_indices_view.data_handle(), new_indices_view.extent(0)); - std::optional> new_indices_opt = std::make_optional>(new_indices_vector_view); @@ -80,14 +84,20 @@ namespace raft::neighbors::mg::detail { void search(raft::resources const& handle, const ann::search_params* search_params, - IdxT n_neighbors, raft::host_matrix_view h_query_dataset, raft::host_matrix_view h_neighbors, raft::host_matrix_view h_distances) { + resource::sync_stream(handle); + index_dataset_.reset(); + new_vectors_.reset(); + new_indices_.reset(); + query_dataset_.reset(); + auto query_dataset_view = store_to_device(handle, query_dataset_, h_query_dataset); IdxT n_rows = h_query_dataset.extent(0); - auto neighbors_view = neighbors_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))); - auto distances_view = distances_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))); + IdxT n_neighbors = h_neighbors.extent(1); + auto neighbors_view = neighbors_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))).view(); + auto distances_view = distances_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))).view(); if constexpr (std::is_same>::value) { ivf_flat::search(handle, @@ -100,10 +110,19 @@ namespace raft::neighbors::mg::detail { ivf_pq::search(handle, *reinterpret_cast(search_params), index_.value(), - query_dataset_view, + query_dataset_view, neighbors_view, distances_view); } + + raft::copy(h_neighbors.data_handle(), + neighbors_view.data_handle(), + n_rows * n_neighbors, + resource::get_cuda_stream(handle)); + raft::copy(h_distances.data_handle(), + distances_view.data_handle(), + n_rows * n_neighbors, + resource::get_cuda_stream(handle)); } private: @@ -220,30 +239,67 @@ namespace raft::neighbors::mg::detail { } void search(const ann::search_params* search_params, - IdxT n_neighbors, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { if (mode_ == INDEX_DUPLICATION) { + IdxT n_rows = query_dataset.extent(0); + IdxT n_cols = query_dataset.extent(1); + IdxT n_neighbors = neighbors.extent(1); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT query_offset = 0; + IdxT output_offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { cudaSetDevice(dev_ids_[rank]); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - query_offset); + auto query_partition = \ + raft::make_host_matrix_view(query_dataset.data_handle() + query_offset, + n_rows_per_shard, + n_cols); + auto neighbors_partition = \ + raft::make_host_matrix_view(neighbors.data_handle() + output_offset, + n_rows_per_shard, + n_neighbors); + auto distances_partition = \ + raft::make_host_matrix_view(distances.data_handle() + output_offset, + n_rows_per_shard, + n_neighbors); auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_resources_[rank], search_params, n_neighbors, query_dataset, neighbors, distances); + ann_if.search(dev_resources_[rank], search_params, query_partition, neighbors_partition, distances_partition); + query_offset += n_rows_per_shard * n_cols; + output_offset += n_rows_per_shard * n_neighbors; } } else if (mode_ == SHARDING) { IdxT n_rows = query_dataset.extent(0); IdxT n_cols = query_dataset.extent(1); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT offset = 0; - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* query_dataset_ptr = query_dataset.data_handle() + offset; - auto query_dataset_part = raft::make_host_matrix_view(query_dataset_ptr, n_rows_per_shard, n_cols); - auto& ann_if = ann_interfaces_[rank]; - //ann_if.search(dev_resources_[rank], search_params, n_neighbors, query_dataset_part, neighbors, distances); - offset += n_rows_per_shard * n_cols; + IdxT n_neighbors = neighbors.extent(1); + + IdxT n_rows_per_batches = 1000000; + IdxT n_batches = (n_rows + n_rows_per_batches - 1) / n_rows_per_batches; + IdxT query_offset = 0; + IdxT output_offset = 0; + for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { + n_rows_per_batches = std::min(n_rows_per_batches, n_rows - query_offset); + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + auto query_partition = \ + raft::make_host_matrix_view(query_dataset.data_handle() + query_offset, + n_rows_per_batches, + n_cols); + auto neighbors_partition = \ + raft::make_host_matrix_view(neighbors.data_handle() + output_offset, + n_rows_per_batches, + n_neighbors); + auto distances_partition = \ + raft::make_host_matrix_view(distances.data_handle() + output_offset, + n_rows_per_batches, + n_neighbors); + auto& ann_if = ann_interfaces_[rank]; + ann_if.search(dev_resources_[rank], search_params, query_partition, neighbors_partition, distances_partition); + query_offset += n_rows_per_batches * n_cols; + output_offset += n_rows_per_batches * n_neighbors; + } } } } @@ -270,9 +326,9 @@ namespace raft::neighbors::mg::detail { template ann_mg_index, T, uint32_t> build(const std::vector device_ids, - dist_mode mode, - const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) + dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) { ann_mg_index, T, uint32_t> index(device_ids, mode); index.build(static_cast(&index_params), index_dataset); @@ -298,13 +354,11 @@ namespace raft::neighbors::mg::detail { template void search(ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, - IdxT n_neighbors, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { index.search(static_cast(&search_params), - n_neighbors, query_dataset, neighbors, distances); @@ -313,13 +367,11 @@ namespace raft::neighbors::mg::detail { template void search(ann_mg_index, T, uint32_t>& index, const ivf_pq::search_params& search_params, - uint32_t n_neighbors, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { index.search(static_cast(&search_params), - n_neighbors, query_dataset, neighbors, distances); diff --git a/cpp/src/neighbors/ann_mg.cu b/cpp/src/neighbors/ann_mg.cu index 1a7ad67dc5..6eb175a191 100644 --- a/cpp/src/neighbors/ann_mg.cu +++ b/cpp/src/neighbors/ann_mg.cu @@ -45,7 +45,6 @@ template void raft::neighbors::mg::search( \ raft::neighbors::mg::detail::ann_mg_index, T, IdxT>& index, \ const ivf_flat::search_params& search_params, \ - IdxT n_neighbors, \ raft::host_matrix_view query_dataset, \ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances); \ @@ -53,7 +52,6 @@ template void raft::neighbors::mg::search( \ raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>& index, \ const ivf_pq::search_params& search_params, \ - uint32_t n_neighbors, \ raft::host_matrix_view query_dataset, \ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances); \ diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index b2686c9e35..1f5ec3794c 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -17,46 +17,9 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - #include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include - -#include - -#include -#include - -#include -#include -#include - namespace raft::neighbors::mg { template @@ -87,35 +50,33 @@ class AnnMGTest : public ::testing::TestWithParam> { void testAnnMG() { size_t queries_size = ps.num_queries * ps.k; - std::vector distances_ivfflat(queries_size); - std::vector distances_naive(queries_size); - std::vector indices_ivfflat(queries_size); std::vector indices_naive(queries_size); + std::vector distances_naive(queries_size); + std::vector indices_ann(queries_size); + std::vector distances_ann(queries_size); { rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - d_query_dataset.data(), - d_index_dataset.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.metric); + raft::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + d_query_dataset.data(), + d_index_dataset.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); resource::sync_stream(handle_); } - { - rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); - rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); + std::vector device_ids{0, 1}; - std::vector device_ids{0, 1}; - raft::neighbors::mg::dist_mode mode = SHARDING; + // IVF-Flat + for (dist_mode d_mode : { dist_mode::SHARDING, dist_mode::INDEX_DUPLICATION }) { ivf_flat::index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; @@ -123,23 +84,61 @@ class AnnMGTest : public ::testing::TestWithParam> { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - auto index_dataset_view = raft::make_host_matrix_view(h_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto index = raft::neighbors::mg::build(device_ids, mode, index_params, index_dataset_view); - update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); - update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); + ivf_flat::search_params search_params; + search_params.n_probes = ps.nprobe; + + auto index_dataset = raft::make_host_matrix_view(h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); } - double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); - ASSERT_TRUE(eval_neighbours(indices_naive, - indices_ivfflat, - distances_naive, - distances_ivfflat, - ps.num_queries, - ps.k, - 0.001, - min_recall)); + // IVF-PQ + for (dist_mode d_mode : { dist_mode::SHARDING, dist_mode::INDEX_DUPLICATION }) { + ivf_pq::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.add_data_on_build = false; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + + ivf_pq::search_params search_params; + search_params.n_probes = ps.nprobe; + + auto index_dataset = raft::make_host_matrix_view(h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } } void SetUp() override From 9aeb456094518262194cf59c88362ecd36d63081 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 20 Nov 2023 14:04:39 +0100 Subject: [PATCH 03/22] Debugging --- cpp/include/raft/neighbors/ann_mg-ext.cuh | 146 ++-- cpp/include/raft/neighbors/ann_mg-inl.cuh | 102 +-- cpp/include/raft/neighbors/detail/ann_mg.cuh | 701 ++++++++++--------- cpp/src/neighbors/ann_mg.cu | 77 +- cpp/test/neighbors/ann_mg.cuh | 58 +- 5 files changed, 562 insertions(+), 522 deletions(-) diff --git a/cpp/include/raft/neighbors/ann_mg-ext.cuh b/cpp/include/raft/neighbors/ann_mg-ext.cuh index 2032166a55..23209c95b7 100644 --- a/cpp/include/raft/neighbors/ann_mg-ext.cuh +++ b/cpp/include/raft/neighbors/ann_mg-ext.cuh @@ -22,88 +22,90 @@ #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft::neighbors::mg { - using namespace raft::neighbors::mg; +using namespace raft::neighbors::mg; - template - auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_flat::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, IdxT> RAFT_EXPLICIT; +template +auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_flat::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> RAFT_EXPLICIT; - template - auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, uint32_t> RAFT_EXPLICIT; +template +auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, uint32_t> RAFT_EXPLICIT; - template - void extend(detail::ann_mg_index, T, IdxT>& index, - raft::host_matrix_view new_vectors, - raft::host_matrix_view new_indices) RAFT_EXPLICIT; +template +void extend(detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) + RAFT_EXPLICIT; - template - void extend(detail::ann_mg_index, T, uint32_t>& index, - raft::host_matrix_view new_vectors, - raft::host_matrix_view new_indices) RAFT_EXPLICIT; +template +void extend(detail::ann_mg_index, T, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) + RAFT_EXPLICIT; - template - void search(detail::ann_mg_index, T, IdxT>& index, - const ivf_flat::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) RAFT_EXPLICIT; +template +void search(detail::ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) RAFT_EXPLICIT; - template - void search(detail::ann_mg_index, T, uint32_t>& index, - const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) RAFT_EXPLICIT; +template +void search(detail::ann_mg_index, T, uint32_t>& index, + const ivf_pq::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) RAFT_EXPLICIT; -} +} // namespace raft::neighbors::mg #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_ann_mg_build(T, IdxT) \ - extern template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_flat::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - -> detail::ann_mg_index, T, IdxT>; \ - \ - extern template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_pq::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - -> detail::ann_mg_index, T, uint32_t>; \ - \ - extern template void raft::neighbors::mg::extend( \ - detail::ann_mg_index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - raft::host_matrix_view new_indices); \ - \ - extern template void raft::neighbors::mg::extend( \ - detail::ann_mg_index, T, uint32_t>& index, \ - raft::host_matrix_view new_vectors, \ - raft::host_matrix_view new_indices); \ - \ - extern template void raft::neighbors::mg::search( \ - detail::ann_mg_index, T, IdxT>& index, \ - const ivf_flat::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ - \ - extern template void raft::neighbors::mg::search( \ - detail::ann_mg_index, T, uint32_t>& index, \ - const ivf_pq::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ +#define instantiate_raft_neighbors_ann_mg_build(T, IdxT) \ + extern template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_flat::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + ->detail::ann_mg_index, T, IdxT>; \ + \ + extern template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_pq::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + ->detail::ann_mg_index, T, uint32_t>; \ + \ + extern template void raft::neighbors::mg::extend( \ + detail::ann_mg_index, T, IdxT> & index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices); \ + \ + extern template void raft::neighbors::mg::extend( \ + detail::ann_mg_index, T, uint32_t> & index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices); \ + \ + extern template void raft::neighbors::mg::search( \ + detail::ann_mg_index, T, IdxT> & index, \ + const ivf_flat::search_params& search_params, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + \ + extern template void raft::neighbors::mg::search( \ + detail::ann_mg_index, T, uint32_t> & index, \ + const ivf_pq::search_params& search_params, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); instantiate_raft_neighbors_ann_mg_build(float, uint32_t); diff --git a/cpp/include/raft/neighbors/ann_mg-inl.cuh b/cpp/include/raft/neighbors/ann_mg-inl.cuh index 0d8432fb00..b04afd7432 100644 --- a/cpp/include/raft/neighbors/ann_mg-inl.cuh +++ b/cpp/include/raft/neighbors/ann_mg-inl.cuh @@ -20,59 +20,59 @@ namespace raft::neighbors::mg { - template - auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_flat::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, IdxT> - { - return mg::detail::build(device_ids, mode, index_params, index_dataset); - } +template +auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_flat::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(device_ids, mode, index_params, index_dataset); +} - template - auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, uint32_t> - { - return mg::detail::build(device_ids, mode, index_params, index_dataset); - } +template +auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, uint32_t> +{ + return mg::detail::build(device_ids, mode, index_params, index_dataset); +} - template - void extend(detail::ann_mg_index, T, IdxT>& index, - raft::host_matrix_view new_vectors, - raft::host_matrix_view new_indices) - { - mg::detail::extend(index, new_vectors, new_indices); - } +template +void extend(detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(index, new_vectors, new_indices); +} - template - void extend(detail::ann_mg_index, T, uint32_t>& index, - raft::host_matrix_view new_vectors, - raft::host_matrix_view new_indices) - { - mg::detail::extend(index, new_vectors, new_indices); - } +template +void extend(detail::ann_mg_index, T, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(index, new_vectors, new_indices); +} - template - void search(detail::ann_mg_index, T, IdxT>& index, - const ivf_flat::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) - { - mg::detail::search(index, search_params, query_dataset, neighbors, distances); - } +template +void search(detail::ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(index, search_params, query_dataset, neighbors, distances); +} - template - void search(detail::ann_mg_index, T, uint32_t>& index, - const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) - { - mg::detail::search(index, search_params, query_dataset, neighbors, distances); - } -} \ No newline at end of file +template +void search(detail::ann_mg_index, T, uint32_t>& index, + const ivf_pq::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(index, search_params, query_dataset, neighbors, distances); +} +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 9a4fe74c51..e20c7d736e 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -18,363 +18,388 @@ #undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include -#include +#include #include #include -#include +#include +#include #include #include #define RAFT_EXPLICIT_INSTANTIATE_ONLY - namespace raft::neighbors::mg { - enum dist_mode { SHARDING, INDEX_DUPLICATION }; +enum dist_mode { SHARDING, INDEX_DUPLICATION }; } namespace raft::neighbors::mg::detail { - using namespace raft::neighbors; - - template - class ann_interface { - public: - void build(raft::resources const& handle, - const ann::index_params* index_params, - raft::host_matrix_view h_index_dataset) { - auto index_dataset_view = store_to_device(handle, index_dataset_, h_index_dataset); - - if constexpr (std::is_same>::value) { - index_.emplace(std::move(ivf_flat::build(handle, - *static_cast(index_params), - index_dataset_view))); - } else if constexpr (std::is_same>::value) { - index_.emplace(std::move(ivf_pq::build(handle, - *static_cast(index_params), - index_dataset_view))); - } - } - - void extend(raft::resources const& handle, - raft::host_matrix_view h_new_vectors, - raft::host_matrix_view h_new_indices) { - resource::sync_stream(handle); - index_dataset_.reset(); - - auto new_vectors_view = store_to_device(handle, new_vectors_, h_new_vectors); - auto new_indices_view = store_to_device(handle, new_indices_, h_new_indices); - auto new_indices_vector_view = \ - raft::make_device_vector_view(new_indices_view.data_handle(), new_indices_view.extent(0)); - std::optional> new_indices_opt = - std::make_optional>(new_indices_vector_view); - - if constexpr (std::is_same>::value) { - index_.emplace(std::move(ivf_flat::extend(handle, - new_vectors_view, - new_indices_opt, - index_.value()))); - } else if constexpr (std::is_same>::value) { - index_.emplace(std::move(ivf_pq::extend(handle, - new_vectors_view, - new_indices_opt, - index_.value()))); - } - } - - void search(raft::resources const& handle, - const ann::search_params* search_params, - raft::host_matrix_view h_query_dataset, - raft::host_matrix_view h_neighbors, - raft::host_matrix_view h_distances) { - resource::sync_stream(handle); - index_dataset_.reset(); - new_vectors_.reset(); - new_indices_.reset(); - query_dataset_.reset(); - - auto query_dataset_view = store_to_device(handle, query_dataset_, h_query_dataset); - IdxT n_rows = h_query_dataset.extent(0); - IdxT n_neighbors = h_neighbors.extent(1); - auto neighbors_view = neighbors_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))).view(); - auto distances_view = distances_.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_neighbors))).view(); - - if constexpr (std::is_same>::value) { - ivf_flat::search(handle, - *reinterpret_cast(search_params), - index_.value(), - query_dataset_view, - neighbors_view, - distances_view); - } else if constexpr (std::is_same>::value) { - ivf_pq::search(handle, - *reinterpret_cast(search_params), - index_.value(), - query_dataset_view, - neighbors_view, - distances_view); - } - - raft::copy(h_neighbors.data_handle(), - neighbors_view.data_handle(), - n_rows * n_neighbors, - resource::get_cuda_stream(handle)); - raft::copy(h_distances.data_handle(), - distances_view.data_handle(), - n_rows * n_neighbors, - resource::get_cuda_stream(handle)); - } - - private: - template - raft::device_matrix_view store_to_device(raft::resources const& handle, - std::optional>& dev_mat_opt, - raft::host_matrix_view host_mat_view) { - DataIdxT n_rows = host_mat_view.extent(0); - DataIdxT n_cols = host_mat_view.extent(1); - dev_mat_opt.emplace(std::move(raft::make_device_matrix(handle, n_rows, n_cols))); - raft::copy(dev_mat_opt.value().data_handle(), // async copy - host_mat_view.data_handle(), - n_rows * n_cols, - resource::get_cuda_stream(handle)); - auto const_dev_mat_view = dev_mat_opt.value().view(); - raft::device_matrix_view dev_mat_view = \ - raft::make_device_matrix_view(const_dev_mat_view.data_handle(), - const_dev_mat_view.extent(0), - const_dev_mat_view.extent(1)); - return dev_mat_view; - } - - std::optional> index_dataset_; - std::optional> new_vectors_; - std::optional> new_indices_; - std::optional> query_dataset_; - std::optional> neighbors_; - std::optional> distances_; - std::optional index_; - }; - - template - class ann_mg_index { - public: - ann_mg_index() = delete; - ann_mg_index(const std::vector& dev_list, - dist_mode mode = SHARDING) - : mode_(mode), - num_ranks_(dev_list.size()), - dev_ids_(dev_list), - nccl_comms_(dev_list.size()) - { - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - - raft::resources& handle = dev_resources_.emplace_back(); - raft::comms::build_comms_nccl_only(&handle, nccl_comms_[rank], num_ranks_, rank); - } - ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data()); - } - - ~ann_mg_index() { - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - ncclCommDestroy(nccl_comms_[rank]); - } - } - - ann_mg_index(const ann_mg_index&) = delete; - ann_mg_index(ann_mg_index&&) = default; - auto operator=(const ann_mg_index&) -> ann_mg_index& = delete; - auto operator=(ann_mg_index&&) -> ann_mg_index& = default; - - void build(const ann::index_params* index_params, - raft::host_matrix_view index_dataset) { - if (mode_ == INDEX_DUPLICATION) { - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - auto& ann_if = ann_interfaces_.emplace_back(); - ann_if.build(dev_resources_[rank], index_params, index_dataset); - } - } else if (mode_ == SHARDING) { - IdxT n_rows = index_dataset.extent(0); - IdxT n_cols = index_dataset.extent(1); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT offset = 0; - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + offset; - auto partition = raft::make_host_matrix_view(partition_ptr, n_rows_per_shard, n_cols); - auto& ann_if = ann_interfaces_.emplace_back(); - ann_if.build(dev_resources_[rank], index_params, partition); - offset += n_rows_per_shard * n_cols; - } - } - } - - void extend(raft::host_matrix_view new_vectors, - raft::host_matrix_view new_indices) { - if (mode_ == INDEX_DUPLICATION) { - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - auto& ann_if = ann_interfaces_[rank]; - ann_if.extend(dev_resources_[rank], new_vectors, new_indices); - } - } else if (mode_ == SHARDING) { - IdxT n_rows = new_vectors.extent(0); - IdxT n_cols = new_vectors.extent(1); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT offset = 0; - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + offset; - const IdxT* new_indices_ptr = new_indices.data_handle() + offset; - auto new_vectors_part = raft::make_host_matrix_view(new_vectors_ptr, n_rows_per_shard, n_cols); - auto new_indices_part = raft::make_host_matrix_view(new_indices_ptr, n_rows_per_shard, 1); - auto& ann_if = ann_interfaces_[rank]; - ann_if.extend(dev_resources_[rank], new_vectors_part, new_indices_part); - offset += n_rows_per_shard * n_cols; - } - } - } - - void search(const ann::search_params* search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) - { - if (mode_ == INDEX_DUPLICATION) { - IdxT n_rows = query_dataset.extent(0); - IdxT n_cols = query_dataset.extent(1); - IdxT n_neighbors = neighbors.extent(1); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT query_offset = 0; - IdxT output_offset = 0; - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - query_offset); - auto query_partition = \ - raft::make_host_matrix_view(query_dataset.data_handle() + query_offset, - n_rows_per_shard, - n_cols); - auto neighbors_partition = \ - raft::make_host_matrix_view(neighbors.data_handle() + output_offset, - n_rows_per_shard, - n_neighbors); - auto distances_partition = \ - raft::make_host_matrix_view(distances.data_handle() + output_offset, - n_rows_per_shard, - n_neighbors); - auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_resources_[rank], search_params, query_partition, neighbors_partition, distances_partition); - query_offset += n_rows_per_shard * n_cols; - output_offset += n_rows_per_shard * n_neighbors; - } - } else if (mode_ == SHARDING) { - IdxT n_rows = query_dataset.extent(0); - IdxT n_cols = query_dataset.extent(1); - IdxT n_neighbors = neighbors.extent(1); - - IdxT n_rows_per_batches = 1000000; - IdxT n_batches = (n_rows + n_rows_per_batches - 1) / n_rows_per_batches; - IdxT query_offset = 0; - IdxT output_offset = 0; - for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - n_rows_per_batches = std::min(n_rows_per_batches, n_rows - query_offset); - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - auto query_partition = \ - raft::make_host_matrix_view(query_dataset.data_handle() + query_offset, - n_rows_per_batches, - n_cols); - auto neighbors_partition = \ - raft::make_host_matrix_view(neighbors.data_handle() + output_offset, - n_rows_per_batches, - n_neighbors); - auto distances_partition = \ - raft::make_host_matrix_view(distances.data_handle() + output_offset, - n_rows_per_batches, - n_neighbors); - auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_resources_[rank], search_params, query_partition, neighbors_partition, distances_partition); - query_offset += n_rows_per_batches * n_cols; - output_offset += n_rows_per_batches * n_neighbors; - } - } - } - } - - private: - dist_mode mode_; - int num_ranks_; - std::vector dev_ids_; - std::vector dev_resources_; - std::vector> ann_interfaces_; - std::vector nccl_comms_; - }; - - template - ann_mg_index, T, IdxT> build(const std::vector device_ids, - dist_mode mode, - const ivf_flat::index_params& index_params, - raft::host_matrix_view index_dataset) - { - ann_mg_index, T, IdxT> index(device_ids, mode); - index.build(static_cast(&index_params), index_dataset); - return index; +using namespace raft::neighbors; + +template +class ann_interface { + public: + void build(raft::resources const& handle, + const ann::index_params* index_params, + raft::host_matrix_view h_index_dataset) + { + auto index_dataset_view = store_to_device(handle, index_dataset_, h_index_dataset); + + if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_flat::build( + handle, *static_cast(index_params), index_dataset_view))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_pq::build( + handle, *static_cast(index_params), index_dataset_view))); } + } - template - ann_mg_index, T, uint32_t> build(const std::vector device_ids, - dist_mode mode, - const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) - { - ann_mg_index, T, uint32_t> index(device_ids, mode); - index.build(static_cast(&index_params), index_dataset); - return index; - } + void extend(raft::resources const& handle, + raft::host_matrix_view h_new_vectors, + std::optional> h_new_indices) + { + index_dataset_.reset(); + new_vectors_.reset(); + new_indices_.reset(); - template - void extend(ann_mg_index, T, IdxT>& index, - raft::host_matrix_view new_vectors, - raft::host_matrix_view new_indices) - { - index.extend(new_vectors, new_indices); - } + auto new_vectors_view = store_to_device(handle, new_vectors_, h_new_vectors); - template - void extend(ann_mg_index, T, uint32_t>& index, - raft::host_matrix_view new_vectors, - raft::host_matrix_view new_indices) - { - index.extend(new_vectors, new_indices); + std::optional> new_indices_opt = std::nullopt; + if (h_new_indices) { + new_indices_opt = store_to_device(handle, new_indices_, h_new_indices.value()); } - template - void search(ann_mg_index, T, IdxT>& index, - const ivf_flat::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) - { - index.search(static_cast(&search_params), - query_dataset, - neighbors, - distances); + if constexpr (std::is_same>::value) { + index_.emplace(std::move( + ivf_flat::extend(handle, new_vectors_view, new_indices_opt, index_.value()))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move( + ivf_pq::extend(handle, new_vectors_view, new_indices_opt, index_.value()))); + } + } + + void search(raft::resources const& handle, + const ann::search_params* search_params, + raft::host_matrix_view h_query_dataset, + raft::host_matrix_view h_neighbors, + raft::host_matrix_view h_distances) + { + index_dataset_.reset(); + new_vectors_.reset(); + new_indices_.reset(); + query_dataset_.reset(); + + IdxT n_rows = h_query_dataset.extent(0); + IdxT n_neighbors = h_neighbors.extent(1); + auto query_dataset_view = store_to_device(handle, query_dataset_, h_query_dataset); + auto neighbors_view = neighbors_ + .emplace(std::move(raft::make_device_matrix( + handle, n_rows, n_neighbors))) + .view(); + auto distances_view = distances_ + .emplace(std::move(raft::make_device_matrix( + handle, n_rows, n_neighbors))) + .view(); + + if constexpr (std::is_same>::value) { + ivf_flat::search(handle, + *reinterpret_cast(search_params), + index_.value(), + query_dataset_view, + neighbors_view, + distances_view); + } else if constexpr (std::is_same>::value) { + ivf_pq::search(handle, + *reinterpret_cast(search_params), + index_.value(), + query_dataset_view, + neighbors_view, + distances_view); } - template - void search(ann_mg_index, T, uint32_t>& index, - const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) - { - index.search(static_cast(&search_params), - query_dataset, - neighbors, - distances); + raft::copy(h_neighbors.data_handle(), + neighbors_view.data_handle(), + n_rows * n_neighbors, + resource::get_cuda_stream(handle)); + raft::copy(h_distances.data_handle(), + distances_view.data_handle(), + n_rows * n_neighbors, + resource::get_cuda_stream(handle)); + } + + private: + template + raft::device_matrix_view store_to_device( + raft::resources const& handle, + std::optional>& dev_mat_opt, + raft::host_matrix_view host_mat_view) + { + DataIdxT n_rows = host_mat_view.extent(0); + DataIdxT n_cols = host_mat_view.extent(1); + dev_mat_opt.emplace( + std::move(raft::make_device_matrix(handle, n_rows, n_cols))); + raft::copy(dev_mat_opt.value().data_handle(), // async copy + host_mat_view.data_handle(), + n_rows * n_cols, + resource::get_cuda_stream(handle)); + auto const_dev_mat_view = dev_mat_opt.value().view(); + raft::device_matrix_view dev_mat_view = + raft::make_device_matrix_view( + const_dev_mat_view.data_handle(), + const_dev_mat_view.extent(0), + const_dev_mat_view.extent(1)); + return dev_mat_view; + } + + template + raft::device_vector_view store_to_device( + raft::resources const& handle, + std::optional>& dev_vec_opt, + raft::host_vector_view host_vec_view) + { + DataIdxT n_rows = host_vec_view.extent(0); + dev_vec_opt.emplace(std::move(raft::make_device_vector(handle, n_rows))); + raft::copy(dev_vec_opt.value().data_handle(), // async copy + host_vec_view.data_handle(), + n_rows, + resource::get_cuda_stream(handle)); + auto const_dev_vec_view = dev_vec_opt.value().view(); + raft::device_vector_view dev_vec_view = + raft::make_device_vector_view(const_dev_vec_view.data_handle(), + const_dev_vec_view.extent(0)); + return dev_vec_view; + } + + std::optional> index_dataset_; + std::optional> new_vectors_; + std::optional> new_indices_; + std::optional> query_dataset_; + std::optional> neighbors_; + std::optional> distances_; + std::optional index_; +}; + +template +class ann_mg_index { + public: + ann_mg_index() = delete; + ann_mg_index(const std::vector& dev_list, dist_mode mode = SHARDING) + : mode_(mode), num_ranks_(dev_list.size()), dev_ids_(dev_list), nccl_comms_(dev_list.size()) + { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + + raft::device_resources& handle = dev_resources_.emplace_back(); + raft::comms::build_comms_nccl_only(&handle, nccl_comms_[rank], num_ranks_, rank); + } + ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data()); + } + + ~ann_mg_index() + { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + ncclCommDestroy(nccl_comms_[rank]); + } + } + + ann_mg_index(const ann_mg_index&) = delete; + ann_mg_index(ann_mg_index&&) = default; + auto operator=(const ann_mg_index&) -> ann_mg_index& = delete; + auto operator=(ann_mg_index&&) -> ann_mg_index& = default; + + void build(const ann::index_params* index_params, + raft::host_matrix_view index_dataset) + { + if (mode_ == INDEX_DUPLICATION) { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.build(dev_resources_[rank], index_params, index_dataset); + } + } else if (mode_ == SHARDING) { + IdxT n_rows = index_dataset.extent(0); + IdxT n_cols = index_dataset.extent(1); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + offset; + auto partition = raft::make_host_matrix_view( + partition_ptr, n_rows_per_shard, n_cols); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.build(dev_resources_[rank], index_params, partition); + offset += n_rows_per_shard * n_cols; + } + } + } + + void extend(raft::host_matrix_view new_vectors, + std::optional> new_indices) + { + if (mode_ == INDEX_DUPLICATION) { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + auto& ann_if = ann_interfaces_[rank]; + ann_if.extend(dev_resources_[rank], new_vectors, new_indices); + } + } else if (mode_ == SHARDING) { + IdxT n_rows = new_vectors.extent(0); + IdxT n_cols = new_vectors.extent(1); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + offset; + auto new_vectors_part = raft::make_host_matrix_view( + new_vectors_ptr, n_rows_per_shard, n_cols); + + std::optional> new_indices_part = std::nullopt; + if (new_indices) { + const IdxT* new_indices_ptr = new_indices.value().data_handle() + offset; + new_indices_part = + raft::make_host_vector_view(new_indices_ptr, n_rows_per_shard); + } + auto& ann_if = ann_interfaces_[rank]; + ann_if.extend(dev_resources_[rank], new_vectors_part, new_indices_part); + offset += n_rows_per_shard * n_cols; + } } + } + + void search(const ann::search_params* search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) + { + if (mode_ == INDEX_DUPLICATION) { + IdxT n_rows = query_dataset.extent(0); + IdxT n_cols = query_dataset.extent(1); + IdxT n_neighbors = neighbors.extent(1); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT query_offset = 0; + IdxT output_offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - query_offset); + auto query_partition = raft::make_host_matrix_view( + query_dataset.data_handle() + query_offset, n_rows_per_shard, n_cols); + auto neighbors_partition = raft::make_host_matrix_view( + neighbors.data_handle() + output_offset, n_rows_per_shard, n_neighbors); + auto distances_partition = raft::make_host_matrix_view( + distances.data_handle() + output_offset, n_rows_per_shard, n_neighbors); + auto& ann_if = ann_interfaces_[rank]; + ann_if.search(dev_resources_[rank], + search_params, + query_partition, + neighbors_partition, + distances_partition); + query_offset += n_rows_per_shard * n_cols; + output_offset += n_rows_per_shard * n_neighbors; + } + } else if (mode_ == SHARDING) { + IdxT n_rows = query_dataset.extent(0); + IdxT n_cols = query_dataset.extent(1); + IdxT n_neighbors = neighbors.extent(1); + + IdxT n_rows_per_batches = 1000000; + IdxT n_batches = (n_rows + n_rows_per_batches - 1) / n_rows_per_batches; + IdxT query_offset = 0; + IdxT output_offset = 0; + for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { + n_rows_per_batches = std::min(n_rows_per_batches, n_rows - query_offset); + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(dev_ids_[rank]); + auto query_partition = raft::make_host_matrix_view( + query_dataset.data_handle() + query_offset, n_rows_per_batches, n_cols); + auto neighbors_partition = raft::make_host_matrix_view( + neighbors.data_handle() + output_offset, n_rows_per_batches, n_neighbors); + auto distances_partition = raft::make_host_matrix_view( + distances.data_handle() + output_offset, n_rows_per_batches, n_neighbors); + auto& ann_if = ann_interfaces_[rank]; + ann_if.search(dev_resources_[rank], + search_params, + query_partition, + neighbors_partition, + distances_partition); + query_offset += n_rows_per_batches * n_cols; + output_offset += n_rows_per_batches * n_neighbors; + } + } + } + } + + private: + dist_mode mode_; + int num_ranks_; + std::vector dev_ids_; + std::vector dev_resources_; + std::vector> ann_interfaces_; + std::vector nccl_comms_; +}; + +template +ann_mg_index, T, IdxT> build( + const std::vector device_ids, + dist_mode mode, + const ivf_flat::index_params& index_params, + raft::host_matrix_view index_dataset) +{ + ann_mg_index, T, IdxT> index(device_ids, mode); + index.build(static_cast(&index_params), index_dataset); + return index; +} + +template +ann_mg_index, T, uint32_t> build( + const std::vector device_ids, + dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) +{ + ann_mg_index, T, uint32_t> index(device_ids, mode); + index.build(static_cast(&index_params), index_dataset); + return index; +} + +template +void extend(ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + index.extend(new_vectors, new_indices); +} + +template +void extend(ann_mg_index, T, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + index.extend(new_vectors, new_indices); +} + +template +void search(ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + index.search( + static_cast(&search_params), query_dataset, neighbors, distances); +} + +template +void search(ann_mg_index, T, uint32_t>& index, + const ivf_pq::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + index.search( + static_cast(&search_params), query_dataset, neighbors, distances); +} -} \ No newline at end of file +} // namespace raft::neighbors::mg::detail \ No newline at end of file diff --git a/cpp/src/neighbors/ann_mg.cu b/cpp/src/neighbors/ann_mg.cu index 6eb175a191..587feb53e1 100644 --- a/cpp/src/neighbors/ann_mg.cu +++ b/cpp/src/neighbors/ann_mg.cu @@ -14,47 +14,46 @@ * limitations under the License. */ - #include -#define instantiate_raft_neighbors_ann_mg(T, IdxT) \ - template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_flat::index_params& index_params, \ - raft::host_matrix_view index_dataset \ - ) -> raft::neighbors::mg::detail::ann_mg_index, T, IdxT>; \ - \ - template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_pq::index_params& index_params, \ - raft::host_matrix_view index_dataset \ - ) -> raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>; \ - \ - template void raft::neighbors::mg::extend( \ - raft::neighbors::mg::detail::ann_mg_index, T, IdxT>& index, \ - raft::host_matrix_view new_vectors, \ - raft::host_matrix_view new_indices); \ - \ - template void raft::neighbors::mg::extend( \ - raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>& index, \ - raft::host_matrix_view new_vectors, \ - raft::host_matrix_view new_indices); \ - \ - template void raft::neighbors::mg::search( \ - raft::neighbors::mg::detail::ann_mg_index, T, IdxT>& index, \ - const ivf_flat::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ - \ - template void raft::neighbors::mg::search( \ - raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>& index, \ - const ivf_pq::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ +#define instantiate_raft_neighbors_ann_mg(T, IdxT) \ + template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_flat::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + ->raft::neighbors::mg::detail::ann_mg_index, T, IdxT>; \ + \ + template auto raft::neighbors::mg::build( \ + const std::vector device_ids, \ + raft::neighbors::mg::dist_mode mode, \ + const ivf_pq::index_params& index_params, \ + raft::host_matrix_view index_dataset) \ + ->raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>; \ + \ + template void raft::neighbors::mg::extend( \ + raft::neighbors::mg::detail::ann_mg_index, T, IdxT> & index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices); \ + \ + template void raft::neighbors::mg::extend( \ + raft::neighbors::mg::detail::ann_mg_index, T, uint32_t> & index, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices); \ + \ + template void raft::neighbors::mg::search( \ + raft::neighbors::mg::detail::ann_mg_index, T, IdxT> & index, \ + const ivf_flat::search_params& search_params, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + \ + template void raft::neighbors::mg::search( \ + raft::neighbors::mg::detail::ann_mg_index, T, uint32_t> & index, \ + const ivf_pq::search_params& search_params, \ + raft::host_matrix_view query_dataset, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); instantiate_raft_neighbors_ann_mg(float, uint32_t); diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index 1f5ec3794c..4abdf1025d 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -17,8 +17,8 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" -#include #include +#include namespace raft::neighbors::mg { @@ -76,7 +76,7 @@ class AnnMGTest : public ::testing::TestWithParam> { std::vector device_ids{0, 1}; // IVF-Flat - for (dist_mode d_mode : { dist_mode::SHARDING, dist_mode::INDEX_DUPLICATION }) { + for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION}) { ivf_flat::index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; @@ -88,13 +88,20 @@ class AnnMGTest : public ::testing::TestWithParam> { ivf_flat::search_params search_params; search_params.n_probes = ps.nprobe; - auto index_dataset = raft::make_host_matrix_view(h_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); - auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); - auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - - auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); + auto index_dataset = raft::make_host_matrix_view( + h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view( + h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view( + indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view( + distances_ann.data(), ps.num_queries, ps.k); + + auto index = + raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::extend(index, index_dataset, std::nullopt); + raft::neighbors::mg::search( + index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -109,7 +116,7 @@ class AnnMGTest : public ::testing::TestWithParam> { } // IVF-PQ - for (dist_mode d_mode : { dist_mode::SHARDING, dist_mode::INDEX_DUPLICATION }) { + for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION}) { ivf_pq::index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; @@ -120,12 +127,18 @@ class AnnMGTest : public ::testing::TestWithParam> { ivf_pq::search_params search_params; search_params.n_probes = ps.nprobe; - auto index_dataset = raft::make_host_matrix_view(h_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); - auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); - auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - - auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + auto index_dataset = raft::make_host_matrix_view( + h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view( + h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view( + indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view( + distances_ann.data(), ps.num_queries, ps.k); + + auto index = + raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::extend(index, index_dataset, std::nullopt); raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); @@ -151,14 +164,14 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { raft::random::uniform( - handle_, r, d_index_dataset.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(0.1), DataT(2.0)); raft::random::uniform( - handle_, r, d_query_dataset.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + handle_, r, d_query_dataset.data(), d_query_dataset.size(), DataT(0.1), DataT(2.0)); } else { raft::random::uniformInt( - handle_, r, d_index_dataset.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + handle_, r, d_index_dataset.data(), d_index_dataset.size(), DataT(1), DataT(20)); raft::random::uniformInt( - handle_, r, d_query_dataset.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + handle_, r, d_query_dataset.data(), d_query_dataset.size(), DataT(1), DataT(20)); } raft::copy(h_index_dataset.data(), @@ -192,5 +205,6 @@ class AnnMGTest : public ::testing::TestWithParam> { }; const std::vector> inputs = { - {1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, -};} \ No newline at end of file + {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, +}; +} // namespace raft::neighbors::mg \ No newline at end of file From 224a59f1c95bae1b91b6c421cb97396e109c96bb Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 20 Nov 2023 19:13:24 +0100 Subject: [PATCH 04/22] Implement search on shards --- cpp/include/raft/neighbors/detail/ann_mg.cuh | 177 ++++++++++++++----- cpp/test/neighbors/ann_mg.cuh | 4 +- 2 files changed, 139 insertions(+), 42 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index e20c7d736e..906edc083c 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -78,6 +79,30 @@ class ann_interface { } } + void search_impl(raft::resources const& handle, + const ann::search_params* search_params, + raft::device_matrix_view query_dataset, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) + { + if constexpr (std::is_same>::value) { + ivf_flat::search(handle, + *reinterpret_cast(search_params), + index_.value(), + query_dataset, + neighbors, + distances); + } else if constexpr (std::is_same>::value) { + ivf_pq::search(handle, + *reinterpret_cast(search_params), + index_.value(), + query_dataset, + neighbors, + distances); + } + } + + // Index duplication, results stored on host memory without merge void search(raft::resources const& handle, const ann::search_params* search_params, raft::host_matrix_view h_query_dataset, @@ -101,21 +126,7 @@ class ann_interface { handle, n_rows, n_neighbors))) .view(); - if constexpr (std::is_same>::value) { - ivf_flat::search(handle, - *reinterpret_cast(search_params), - index_.value(), - query_dataset_view, - neighbors_view, - distances_view); - } else if constexpr (std::is_same>::value) { - ivf_pq::search(handle, - *reinterpret_cast(search_params), - index_.value(), - query_dataset_view, - neighbors_view, - distances_view); - } + search_impl(handle, search_params, query_dataset_view, neighbors_view, distances_view); raft::copy(h_neighbors.data_handle(), neighbors_view.data_handle(), @@ -127,6 +138,42 @@ class ann_interface { resource::get_cuda_stream(handle)); } + // Sharding, results sent to root rank, then merged by it + void search(raft::resources const& handle, + const ann::search_params* search_params, + raft::host_matrix_view h_query_dataset, + IdxT n_neighbors, + int root_rank) + { + index_dataset_.reset(); + new_vectors_.reset(); + new_indices_.reset(); + query_dataset_.reset(); + + IdxT n_rows = h_query_dataset.extent(0); + auto query_dataset_view = store_to_device(handle, query_dataset_, h_query_dataset); + auto neighbors_view = neighbors_ + .emplace(std::move(raft::make_device_matrix( + handle, n_rows, n_neighbors))) + .view(); + auto distances_view = distances_ + .emplace(std::move(raft::make_device_matrix( + handle, n_rows, n_neighbors))) + .view(); + + search_impl(handle, search_params, query_dataset_view, neighbors_view, distances_view); + + const auto& comms = resource::get_comms(handle); + comms.device_send(neighbors_view.data_handle(), + n_rows * n_neighbors * sizeof(IdxT), + root_rank, + resource::get_cuda_stream(handle)); + comms.device_send(distances_view.data_handle(), + n_rows * n_neighbors * sizeof(float), + root_rank, + resource::get_cuda_stream(handle)); + } + private: template raft::device_matrix_view store_to_device( @@ -184,15 +231,19 @@ class ann_mg_index { public: ann_mg_index() = delete; ann_mg_index(const std::vector& dev_list, dist_mode mode = SHARDING) - : mode_(mode), num_ranks_(dev_list.size()), dev_ids_(dev_list), nccl_comms_(dev_list.size()) + : mode_(mode), + root_rank_(0), + num_ranks_(dev_list.size()), + dev_ids_(dev_list), + nccl_comms_(dev_list.size()) { + RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data())); for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); raft::device_resources& handle = dev_resources_.emplace_back(); raft::comms::build_comms_nccl_only(&handle, nccl_comms_[rank], num_ranks_, rank); } - ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data()); } ~ann_mg_index() @@ -213,7 +264,7 @@ class ann_mg_index { { if (mode_ == INDEX_DUPLICATION) { for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); auto& ann_if = ann_interfaces_.emplace_back(); ann_if.build(dev_resources_[rank], index_params, index_dataset); } @@ -223,7 +274,7 @@ class ann_mg_index { IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; IdxT offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); const T* partition_ptr = index_dataset.data_handle() + offset; auto partition = raft::make_host_matrix_view( @@ -233,6 +284,7 @@ class ann_mg_index { offset += n_rows_per_shard * n_cols; } } + set_current_device_to_root_rank(); } void extend(raft::host_matrix_view new_vectors, @@ -240,7 +292,7 @@ class ann_mg_index { { if (mode_ == INDEX_DUPLICATION) { for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_resources_[rank], new_vectors, new_indices); } @@ -250,7 +302,7 @@ class ann_mg_index { IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; IdxT offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); const T* new_vectors_ptr = new_vectors.data_handle() + offset; auto new_vectors_part = raft::make_host_matrix_view( @@ -267,6 +319,7 @@ class ann_mg_index { offset += n_rows_per_shard * n_cols; } } + set_current_device_to_root_rank(); } void search(const ann::search_params* search_params, @@ -282,7 +335,7 @@ class ann_mg_index { IdxT query_offset = 0; IdxT output_offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); n_rows_per_shard = std::min(n_rows_per_shard, n_rows - query_offset); auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_per_shard, n_cols); @@ -306,33 +359,77 @@ class ann_mg_index { IdxT n_rows_per_batches = 1000000; IdxT n_batches = (n_rows + n_rows_per_batches - 1) / n_rows_per_batches; - IdxT query_offset = 0; - IdxT output_offset = 0; + + const auto& root_handle = set_current_device_to_root_rank(); + auto in_neighbors = raft::make_device_matrix( + root_handle, n_batches * n_rows_per_batches, n_neighbors); + auto in_distances = raft::make_device_matrix( + root_handle, n_batches * n_rows_per_batches, n_neighbors); + auto out_neighbors = raft::make_device_matrix( + root_handle, n_rows_per_batches, n_neighbors); + auto out_distances = raft::make_device_matrix( + root_handle, n_rows_per_batches, n_neighbors); + + IdxT query_offset = 0; + IdxT output_offset = 0; for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - n_rows_per_batches = std::min(n_rows_per_batches, n_rows - query_offset); + n_rows_per_batches = std::min(n_rows_per_batches, n_rows - query_offset); + auto query_partition = raft::make_host_matrix_view( + query_dataset.data_handle() + query_offset, n_rows_per_batches, n_cols); + for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - auto query_partition = raft::make_host_matrix_view( - query_dataset.data_handle() + query_offset, n_rows_per_batches, n_cols); - auto neighbors_partition = raft::make_host_matrix_view( - neighbors.data_handle() + output_offset, n_rows_per_batches, n_neighbors); - auto distances_partition = raft::make_host_matrix_view( - distances.data_handle() + output_offset, n_rows_per_batches, n_neighbors); + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_resources_[rank], - search_params, - query_partition, - neighbors_partition, - distances_partition); - query_offset += n_rows_per_batches * n_cols; - output_offset += n_rows_per_batches * n_neighbors; + ann_if.search( + dev_resources_[rank], search_params, query_partition, n_neighbors, root_rank_); + + const auto& root_handle = set_current_device_to_root_rank(); + const auto& comms = resource::get_comms(root_handle); + uint64_t batch_offset = rank * n_rows_per_batches * n_neighbors; + comms.device_recv(in_neighbors.data_handle() + batch_offset, + n_rows_per_batches * n_neighbors * sizeof(IdxT), + rank, + resource::get_cuda_stream(root_handle)); + comms.device_recv(in_distances.data_handle() + batch_offset, + n_rows_per_batches * n_neighbors * sizeof(float), + rank, + resource::get_cuda_stream(root_handle)); } + + query_offset += n_rows_per_batches * n_cols; + output_offset += n_rows_per_batches * n_neighbors; } + + const auto& root_handle_ = set_current_device_to_root_rank(); + raft::neighbors::brute_force::knn_merge_parts(root_handle_, + in_distances.view(), + in_neighbors.view(), + out_distances.view(), + out_neighbors.view(), + n_rows_per_batches); + + raft::copy(neighbors.data_handle() + output_offset, + out_neighbors.data_handle(), + n_rows_per_batches * n_neighbors, + resource::get_cuda_stream(root_handle_)); + raft::copy(distances.data_handle() + output_offset, + out_distances.data_handle(), + n_rows_per_batches * n_neighbors, + resource::get_cuda_stream(root_handle_)); } + + set_current_device_to_root_rank(); + } + + inline const raft::device_resources& set_current_device_to_root_rank() + { + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[root_rank_])); + return dev_resources_[root_rank_]; } private: dist_mode mode_; + int root_rank_; int num_ranks_; std::vector dev_ids_; std::vector dev_resources_; diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index 4abdf1025d..a866142761 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -76,7 +76,7 @@ class AnnMGTest : public ::testing::TestWithParam> { std::vector device_ids{0, 1}; // IVF-Flat - for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION}) { + for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION, dist_mode::SHARDING}) { ivf_flat::index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; @@ -116,7 +116,7 @@ class AnnMGTest : public ::testing::TestWithParam> { } // IVF-PQ - for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION}) { + for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION, dist_mode::SHARDING}) { ivf_pq::index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; From ebca042fa6089a83d2998cb38fd25fc49dc01dd1 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 23 Nov 2023 18:45:52 +0100 Subject: [PATCH 05/22] Debugging --- cpp/CMakeLists.txt | 1 - cpp/include/raft/neighbors/ann_mg-ext.cuh | 112 ------------------- cpp/include/raft/neighbors/ann_mg-inl.cuh | 78 ------------- cpp/include/raft/neighbors/ann_mg.cuh | 66 ++++++++++- cpp/include/raft/neighbors/detail/ann_mg.cuh | 70 +++++++----- cpp/src/neighbors/ann_mg.cu | 60 ---------- 6 files changed, 103 insertions(+), 284 deletions(-) delete mode 100644 cpp/include/raft/neighbors/ann_mg-ext.cuh delete mode 100644 cpp/include/raft/neighbors/ann_mg-inl.cuh delete mode 100644 cpp/src/neighbors/ann_mg.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 367ac7bb05..dbdd387511 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -403,7 +403,6 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/refine_float_float.cu src/neighbors/refine_int8_t_float.cu src/neighbors/refine_uint8_t_float.cu - src/neighbors/ann_mg.cu src/raft_runtime/cluster/cluster_cost.cuh src/raft_runtime/cluster/cluster_cost_double.cu src/raft_runtime/cluster/cluster_cost_float.cu diff --git a/cpp/include/raft/neighbors/ann_mg-ext.cuh b/cpp/include/raft/neighbors/ann_mg-ext.cuh deleted file mode 100644 index 23209c95b7..0000000000 --- a/cpp/include/raft/neighbors/ann_mg-ext.cuh +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::mg { -using namespace raft::neighbors::mg; - -template -auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_flat::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, IdxT> RAFT_EXPLICIT; - -template -auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, uint32_t> RAFT_EXPLICIT; - -template -void extend(detail::ann_mg_index, T, IdxT>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) - RAFT_EXPLICIT; - -template -void extend(detail::ann_mg_index, T, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) - RAFT_EXPLICIT; - -template -void search(detail::ann_mg_index, T, IdxT>& index, - const ivf_flat::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) RAFT_EXPLICIT; - -template -void search(detail::ann_mg_index, T, uint32_t>& index, - const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) RAFT_EXPLICIT; - -} // namespace raft::neighbors::mg - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ann_mg_build(T, IdxT) \ - extern template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_flat::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - ->detail::ann_mg_index, T, IdxT>; \ - \ - extern template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_pq::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - ->detail::ann_mg_index, T, uint32_t>; \ - \ - extern template void raft::neighbors::mg::extend( \ - detail::ann_mg_index, T, IdxT> & index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices); \ - \ - extern template void raft::neighbors::mg::extend( \ - detail::ann_mg_index, T, uint32_t> & index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices); \ - \ - extern template void raft::neighbors::mg::search( \ - detail::ann_mg_index, T, IdxT> & index, \ - const ivf_flat::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ - \ - extern template void raft::neighbors::mg::search( \ - detail::ann_mg_index, T, uint32_t> & index, \ - const ivf_pq::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); - -instantiate_raft_neighbors_ann_mg_build(float, uint32_t); - -#undef instantiate_raft_neighbors_ann_mg_build diff --git a/cpp/include/raft/neighbors/ann_mg-inl.cuh b/cpp/include/raft/neighbors/ann_mg-inl.cuh deleted file mode 100644 index b04afd7432..0000000000 --- a/cpp/include/raft/neighbors/ann_mg-inl.cuh +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::neighbors::mg { - -template -auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_flat::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, IdxT> -{ - return mg::detail::build(device_ids, mode, index_params, index_dataset); -} - -template -auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, uint32_t> -{ - return mg::detail::build(device_ids, mode, index_params, index_dataset); -} - -template -void extend(detail::ann_mg_index, T, IdxT>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) -{ - mg::detail::extend(index, new_vectors, new_indices); -} - -template -void extend(detail::ann_mg_index, T, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) -{ - mg::detail::extend(index, new_vectors, new_indices); -} - -template -void search(detail::ann_mg_index, T, IdxT>& index, - const ivf_flat::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) -{ - mg::detail::search(index, search_params, query_dataset, neighbors, distances); -} - -template -void search(detail::ann_mg_index, T, uint32_t>& index, - const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) -{ - mg::detail::search(index, search_params, query_dataset, neighbors, distances); -} -} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ann_mg.cuh b/cpp/include/raft/neighbors/ann_mg.cuh index 2b22be273d..b04afd7432 100644 --- a/cpp/include/raft/neighbors/ann_mg.cuh +++ b/cpp/include/raft/neighbors/ann_mg.cuh @@ -13,12 +13,66 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #pragma once -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "ann_mg-inl.cuh" -#endif +#include + +namespace raft::neighbors::mg { + +template +auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_flat::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(device_ids, mode, index_params, index_dataset); +} + +template +auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const ivf_pq::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, uint32_t> +{ + return mg::detail::build(device_ids, mode, index_params, index_dataset); +} + +template +void extend(detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(index, new_vectors, new_indices); +} + +template +void extend(detail::ann_mg_index, T, uint32_t>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(index, new_vectors, new_indices); +} + +template +void search(detail::ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(index, search_params, query_dataset, neighbors, distances); +} -#ifdef RAFT_COMPILED -#include "ann_mg-ext.cuh" -#endif +template +void search(detail::ann_mg_index, T, uint32_t>& index, + const ivf_pq::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(index, search_params, query_dataset, neighbors, distances); +} +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 906edc083c..5fdd26525a 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -16,20 +16,21 @@ #pragma once -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY - #include #include #include #include #include +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include #include #include - #define RAFT_EXPLICIT_INSTANTIATE_ONLY +// Number of rows per batch (search on shards) +#define N_ROWS_PER_BATCH 33554432 // 2**25 + namespace raft::neighbors::mg { enum dist_mode { SHARDING, INDEX_DUPLICATION }; } @@ -165,11 +166,11 @@ class ann_interface { const auto& comms = resource::get_comms(handle); comms.device_send(neighbors_view.data_handle(), - n_rows * n_neighbors * sizeof(IdxT), + n_rows * n_neighbors, root_rank, resource::get_cuda_stream(handle)); comms.device_send(distances_view.data_handle(), - n_rows * n_neighbors * sizeof(float), + n_rows * n_neighbors, root_rank, resource::get_cuda_stream(handle)); } @@ -357,14 +358,14 @@ class ann_mg_index { IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); - IdxT n_rows_per_batches = 1000000; + IdxT n_rows_per_batches = N_ROWS_PER_BATCH; IdxT n_batches = (n_rows + n_rows_per_batches - 1) / n_rows_per_batches; const auto& root_handle = set_current_device_to_root_rank(); auto in_neighbors = raft::make_device_matrix( - root_handle, n_batches * n_rows_per_batches, n_neighbors); + root_handle, num_ranks_ * n_rows_per_batches, n_neighbors); auto in_distances = raft::make_device_matrix( - root_handle, n_batches * n_rows_per_batches, n_neighbors); + root_handle, num_ranks_ * n_rows_per_batches, n_neighbors); auto out_neighbors = raft::make_device_matrix( root_handle, n_rows_per_batches, n_neighbors); auto out_distances = raft::make_device_matrix( @@ -377,6 +378,7 @@ class ann_mg_index { auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_per_batches, n_cols); + RAFT_NCCL_TRY(ncclGroupStart()); for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); auto& ann_if = ann_interfaces_[rank]; @@ -387,35 +389,49 @@ class ann_mg_index { const auto& comms = resource::get_comms(root_handle); uint64_t batch_offset = rank * n_rows_per_batches * n_neighbors; comms.device_recv(in_neighbors.data_handle() + batch_offset, - n_rows_per_batches * n_neighbors * sizeof(IdxT), + n_rows_per_batches * n_neighbors, rank, resource::get_cuda_stream(root_handle)); comms.device_recv(in_distances.data_handle() + batch_offset, - n_rows_per_batches * n_neighbors * sizeof(float), + n_rows_per_batches * n_neighbors, rank, resource::get_cuda_stream(root_handle)); } + RAFT_NCCL_TRY(ncclGroupEnd()); + + auto in_neighbors_view = raft::make_device_matrix_view( + in_neighbors.data_handle(), num_ranks_ * n_rows_per_batches, n_neighbors); + auto in_distances_view = raft::make_device_matrix_view( + in_distances.data_handle(), num_ranks_ * n_rows_per_batches, n_neighbors); + auto out_neighbors_view = raft::make_device_matrix_view( + out_neighbors.data_handle(), n_rows_per_batches, n_neighbors); + auto out_distances_view = raft::make_device_matrix_view( + out_distances.data_handle(), n_rows_per_batches, n_neighbors); + + const auto& root_handle_ = set_current_device_to_root_rank(); + auto trans = raft::make_device_vector(root_handle_, num_ranks_); + RAFT_CUDA_TRY(cudaMemsetAsync(trans.data_handle(), 0, num_ranks_ * sizeof(IdxT), resource::get_cuda_stream(root_handle_))); + auto translations = std::make_optional>(trans.view()); + raft::neighbors::brute_force::knn_merge_parts(root_handle_, + in_distances_view, + in_neighbors_view, + out_distances_view, + out_neighbors_view, + n_rows_per_batches, + translations); + + raft::copy(neighbors.data_handle() + output_offset, + out_neighbors.data_handle(), + n_rows_per_batches * n_neighbors, + resource::get_cuda_stream(root_handle_)); + raft::copy(distances.data_handle() + output_offset, + out_distances.data_handle(), + n_rows_per_batches * n_neighbors, + resource::get_cuda_stream(root_handle_)); query_offset += n_rows_per_batches * n_cols; output_offset += n_rows_per_batches * n_neighbors; } - - const auto& root_handle_ = set_current_device_to_root_rank(); - raft::neighbors::brute_force::knn_merge_parts(root_handle_, - in_distances.view(), - in_neighbors.view(), - out_distances.view(), - out_neighbors.view(), - n_rows_per_batches); - - raft::copy(neighbors.data_handle() + output_offset, - out_neighbors.data_handle(), - n_rows_per_batches * n_neighbors, - resource::get_cuda_stream(root_handle_)); - raft::copy(distances.data_handle() + output_offset, - out_distances.data_handle(), - n_rows_per_batches * n_neighbors, - resource::get_cuda_stream(root_handle_)); } set_current_device_to_root_rank(); diff --git a/cpp/src/neighbors/ann_mg.cu b/cpp/src/neighbors/ann_mg.cu deleted file mode 100644 index 587feb53e1..0000000000 --- a/cpp/src/neighbors/ann_mg.cu +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define instantiate_raft_neighbors_ann_mg(T, IdxT) \ - template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_flat::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - ->raft::neighbors::mg::detail::ann_mg_index, T, IdxT>; \ - \ - template auto raft::neighbors::mg::build( \ - const std::vector device_ids, \ - raft::neighbors::mg::dist_mode mode, \ - const ivf_pq::index_params& index_params, \ - raft::host_matrix_view index_dataset) \ - ->raft::neighbors::mg::detail::ann_mg_index, T, uint32_t>; \ - \ - template void raft::neighbors::mg::extend( \ - raft::neighbors::mg::detail::ann_mg_index, T, IdxT> & index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices); \ - \ - template void raft::neighbors::mg::extend( \ - raft::neighbors::mg::detail::ann_mg_index, T, uint32_t> & index, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices); \ - \ - template void raft::neighbors::mg::search( \ - raft::neighbors::mg::detail::ann_mg_index, T, IdxT> & index, \ - const ivf_flat::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ - \ - template void raft::neighbors::mg::search( \ - raft::neighbors::mg::detail::ann_mg_index, T, uint32_t> & index, \ - const ivf_pq::search_params& search_params, \ - raft::host_matrix_view query_dataset, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); - -instantiate_raft_neighbors_ann_mg(float, uint32_t); - -#undef instantiate_raft_neighbors_ann_mg From 8141f72adb8b4ce42cf2ee296ab9b7ff43d162d9 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 10 Jan 2024 15:42:11 +0100 Subject: [PATCH 06/22] ANN benchmark integration + offset fix + translations fix --- cpp/bench/ann/CMakeLists.txt | 14 + cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h | 121 +++++++ cpp/bench/ann/src/raft/raft_benchmark.cu | 36 ++ cpp/include/raft/neighbors/ann_mg.cuh | 34 +- .../raft/neighbors/brute_force-inl.cuh | 2 +- cpp/include/raft/neighbors/detail/ann_mg.cuh | 311 +++++++++++------- cpp/test/neighbors/ann_mg.cuh | 4 + .../run/conf/sift-128-euclidean.json | 13 + 8 files changed, 411 insertions(+), 124 deletions(-) create mode 100644 cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index d6a5fddb98..c186d7a0ee 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -30,6 +30,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG "Include raft's MG ANN in benchmark" ON) option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) option(RAFT_ANN_BENCH_SINGLE_EXE @@ -54,6 +55,7 @@ if(BUILD_CPU_ONLY) set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF) set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF) + set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG OFF) set(RAFT_ANN_BENCH_USE_GGNN OFF) else() # Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled. @@ -88,6 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_CAGRA + OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -250,6 +253,17 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA) ) endif() +if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG) + ConfigureAnnBench( + NAME + RAFT_ANN_MG + PATH + bench/ann/src/raft/raft_benchmark.cu + LINKS + raft::compiled + ) +endif() + set(RAFT_FAISS_TARGETS faiss::faiss) if(TARGET faiss::faiss_avx2) set(RAFT_FAISS_TARGETS faiss::faiss_avx2) diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h new file mode 100644 index 0000000000..9a15e2d2bc --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../common/ann_types.hpp" +#include "raft_ann_bench_utils.h" +#include + +namespace raft::bench::ann { + +template +class RaftAnnMG : public ANN { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + raft::neighbors::ivf_flat::search_params ivf_flat_params; + }; + + using BuildParam = raft::neighbors::ivf_flat::index_params; + + RaftAnnMG(Metric metric, int dim, const BuildParam& param) + : ANN(metric, dim), index_params_(param), dimension_(dim) + { + index_params_.metric = parse_metric_type(metric); + index_params_.conservative_memory_allocation = true; + RAFT_CUDA_TRY(cudaGetDevice(&device_)); + } + + ~RaftAnnMG() noexcept {} + + void build(const T* dataset, size_t nrow, cudaStream_t stream) final; + + void set_search_param(const AnnSearchParam& param) override; + + // TODO: if the number of results is less than k, the remaining elements of 'neighbors' + // will be filled with (size_t)-1 + void search(const T* queries, + int batch_size, + int k, + size_t* neighbors, + float* distances, + cudaStream_t stream = 0) const override; + + // to enable dataset access from GPU memory + AlgoProperty get_preference() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::Host; + property.query_memory_type = MemoryType::Host; + return property; + } + void save(const std::string& file) const override; + void load(const std::string&) override; + + private: + raft::device_resources handle_; + BuildParam index_params_; + raft::neighbors::ivf_flat::search_params search_params_; + std::optional, T, IdxT>> index_; + int device_; + int dimension_; +}; + +template +void RaftAnnMG::build(const T* dataset, size_t nrow, cudaStream_t) +{ + std::vector device_ids{0, 1}; + raft::neighbors::mg::dist_mode d_mode = raft::neighbors::mg::dist_mode::INDEX_DUPLICATION; + auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dimension_)); + index_ = neighbors::mg::build(device_ids, d_mode, index_params_, dataset_matrix); + return; +} + +template +void RaftAnnMG::set_search_param(const AnnSearchParam& param) +{ + auto search_param = dynamic_cast(param); + search_params_ = search_param.ivf_flat_params; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void RaftAnnMG::save(const std::string& file) const +{ + raft::neighbors::mg::serialize(handle_, index_.value(), file); + return; +} + +template +void RaftAnnMG::load(const std::string& file) +{ + index_.emplace(raft::neighbors::mg::deserialize(handle_, file)); +} + +template +void RaftAnnMG::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const +{ + static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(dimension_)); + auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + raft::neighbors::mg::search(index_.value(), search_params_, query_matrix, neighbors_matrix, distances_matrix); + resource::sync_stream(handle_); + return; +} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 6888340b4d..b5ce857af3 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -47,6 +48,9 @@ extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG +#include "raft_ann_mg_wrapper.h" +#endif #define JSON_DIAGNOSTICS 1 #include @@ -182,6 +186,24 @@ void parse_search_param(const nlohmann::json& conf, } #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftAnnMG::BuildParam& param) +{ + param.n_lists = conf.at("nlist"); + if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } + if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftAnnMG::SearchParam& param) +{ + param.ivf_flat_params.n_probes = conf.at("nprobe"); +} +#endif + template std::unique_ptr> create_algo(const std::string& algo, const std::string& distance, @@ -223,6 +245,13 @@ std::unique_ptr> create_algo(const std::string& algo, parse_build_param(conf, param); ann = std::make_unique>(metric, dim, param); } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG + if (algo == "raft_ann_mg") { + typename raft::bench::ann::RaftAnnMG::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } #endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } @@ -260,6 +289,13 @@ std::unique_ptr::AnnSearchParam> create_search parse_search_param(conf, *param); return param; } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG + if (algo == "raft_ann_mg") { + auto param = std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } #endif // else throw std::runtime_error("invalid algo: '" + algo + "'"); diff --git a/cpp/include/raft/neighbors/ann_mg.cuh b/cpp/include/raft/neighbors/ann_mg.cuh index b04afd7432..537a44de83 100644 --- a/cpp/include/raft/neighbors/ann_mg.cuh +++ b/cpp/include/raft/neighbors/ann_mg.cuh @@ -57,7 +57,7 @@ void extend(detail::ann_mg_index, T, uint32_t>& index, } template -void search(detail::ann_mg_index, T, IdxT>& index, +void search(const detail::ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, @@ -67,7 +67,7 @@ void search(detail::ann_mg_index, T, IdxT>& index, } template -void search(detail::ann_mg_index, T, uint32_t>& index, +void search(const detail::ann_mg_index, T, uint32_t>& index, const ivf_pq::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, @@ -75,4 +75,34 @@ void search(detail::ann_mg_index, T, uint32_t>& index, { mg::detail::search(index, search_params, query_dataset, neighbors, distances); } + +template +void serialize(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, index, filename); +} + +template +void serialize(const raft::resources& handle, + const detail::ann_mg_index, T, uint32_t>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, index, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize(const raft::resources& handle, + const std::string& filename) +{ + return mg::detail::deserialize(handle, filename); +} + +template +detail::ann_mg_index, T, uint32_t> deserialize(const raft::resources& handle, + const std::string& filename) +{ + return mg::detail::deserialize(handle, filename); +} } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index 52a40da272..afff8095c0 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -91,7 +91,7 @@ inline void knn_merge_parts( RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), "in_keys and in_values must have the same shape."); RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples, + out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == idx_t(n_samples), "Number of rows in output keys and val matrices must equal number of rows in search matrix."); RAFT_EXPECTS( out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1), diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 5fdd26525a..d0a4ec22ac 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -45,14 +45,18 @@ class ann_interface { const ann::index_params* index_params, raft::host_matrix_view h_index_dataset) { - auto index_dataset_view = store_to_device(handle, index_dataset_, h_index_dataset); + IdxT n_rows = h_index_dataset.extent(0); + IdxT n_dims = h_index_dataset.extent(1); + auto d_index_dataset = raft::make_device_matrix(handle, n_rows, n_dims); + raft::copy(d_index_dataset.data_handle(), h_index_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); + raft::device_matrix_view d_index_dataset_view = raft::make_device_matrix_view(d_index_dataset.data_handle(), n_rows, n_dims); if constexpr (std::is_same>::value) { index_.emplace(std::move(ivf_flat::build( - handle, *static_cast(index_params), index_dataset_view))); + handle, *static_cast(index_params), d_index_dataset_view))); } else if constexpr (std::is_same>::value) { index_.emplace(std::move(ivf_pq::build( - handle, *static_cast(index_params), index_dataset_view))); + handle, *static_cast(index_params), d_index_dataset_view))); } } @@ -60,23 +64,26 @@ class ann_interface { raft::host_matrix_view h_new_vectors, std::optional> h_new_indices) { - index_dataset_.reset(); - new_vectors_.reset(); - new_indices_.reset(); - - auto new_vectors_view = store_to_device(handle, new_vectors_, h_new_vectors); + IdxT n_rows = h_new_vectors.extent(0); + IdxT n_dims = h_new_vectors.extent(1); + auto d_new_vectors = raft::make_device_matrix(handle, n_rows, n_dims); + raft::copy(d_new_vectors.data_handle(), h_new_vectors.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); + raft::device_matrix_view d_new_vectors_view = raft::make_device_matrix_view(d_new_vectors.data_handle(), n_rows, n_dims); std::optional> new_indices_opt = std::nullopt; if (h_new_indices) { - new_indices_opt = store_to_device(handle, new_indices_, h_new_indices.value()); + auto d_new_indices = raft::make_device_vector(handle, n_rows); + raft::copy(d_new_indices.data_handle(), h_new_indices.value().data_handle(), n_rows, resource::get_cuda_stream(handle)); + auto d_new_indices_view = raft::device_vector_view(d_new_indices.data_handle(), n_rows); + new_indices_opt = std::move(d_new_indices_view); } if constexpr (std::is_same>::value) { index_.emplace(std::move( - ivf_flat::extend(handle, new_vectors_view, new_indices_opt, index_.value()))); + ivf_flat::extend(handle, d_new_vectors_view, new_indices_opt, index_.value()))); } else if constexpr (std::is_same>::value) { index_.emplace(std::move( - ivf_pq::extend(handle, new_vectors_view, new_indices_opt, index_.value()))); + ivf_pq::extend(handle, d_new_vectors_view, new_indices_opt, index_.value()))); } } @@ -84,7 +91,7 @@ class ann_interface { const ann::search_params* search_params, raft::device_matrix_view query_dataset, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances) const { if constexpr (std::is_same>::value) { ivf_flat::search(handle, @@ -108,33 +115,27 @@ class ann_interface { const ann::search_params* search_params, raft::host_matrix_view h_query_dataset, raft::host_matrix_view h_neighbors, - raft::host_matrix_view h_distances) + raft::host_matrix_view h_distances) const { - index_dataset_.reset(); - new_vectors_.reset(); - new_indices_.reset(); - query_dataset_.reset(); - IdxT n_rows = h_query_dataset.extent(0); + IdxT n_dims = h_query_dataset.extent(1); IdxT n_neighbors = h_neighbors.extent(1); - auto query_dataset_view = store_to_device(handle, query_dataset_, h_query_dataset); - auto neighbors_view = neighbors_ - .emplace(std::move(raft::make_device_matrix( - handle, n_rows, n_neighbors))) - .view(); - auto distances_view = distances_ - .emplace(std::move(raft::make_device_matrix( - handle, n_rows, n_neighbors))) - .view(); - - search_impl(handle, search_params, query_dataset_view, neighbors_view, distances_view); + + auto d_query = raft::make_device_matrix(handle, n_rows, n_dims); + raft::copy(d_query.data_handle(), h_query_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); + raft::device_matrix_view d_query_view = raft::make_device_matrix_view(d_query.data_handle(), n_rows, n_dims); + + auto d_neighbors = raft::make_device_matrix(handle, n_rows, n_neighbors); + auto d_distances = raft::make_device_matrix(handle, n_rows, n_neighbors); + + search_impl(handle, search_params, d_query_view, d_neighbors.view(), d_distances.view()); raft::copy(h_neighbors.data_handle(), - neighbors_view.data_handle(), + d_neighbors.data_handle(), n_rows * n_neighbors, resource::get_cuda_stream(handle)); raft::copy(h_distances.data_handle(), - distances_view.data_handle(), + d_distances.data_handle(), n_rows * n_neighbors, resource::get_cuda_stream(handle)); } @@ -144,93 +145,67 @@ class ann_interface { const ann::search_params* search_params, raft::host_matrix_view h_query_dataset, IdxT n_neighbors, - int root_rank) + int root_rank) const { - index_dataset_.reset(); - new_vectors_.reset(); - new_indices_.reset(); - query_dataset_.reset(); - IdxT n_rows = h_query_dataset.extent(0); - auto query_dataset_view = store_to_device(handle, query_dataset_, h_query_dataset); - auto neighbors_view = neighbors_ - .emplace(std::move(raft::make_device_matrix( - handle, n_rows, n_neighbors))) - .view(); - auto distances_view = distances_ - .emplace(std::move(raft::make_device_matrix( - handle, n_rows, n_neighbors))) - .view(); - - search_impl(handle, search_params, query_dataset_view, neighbors_view, distances_view); + IdxT n_dims = h_query_dataset.extent(1); + + auto d_query = raft::make_device_matrix(handle, n_rows, n_dims); + raft::copy(d_query.data_handle(), h_query_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); + raft::device_matrix_view d_query_view = raft::make_device_matrix_view(d_query.data_handle(), n_rows, n_dims); + + auto d_neighbors = raft::make_device_matrix(handle, n_rows, n_neighbors); + auto d_distances = raft::make_device_matrix(handle, n_rows, n_neighbors); + + search_impl(handle, search_params, d_query_view, d_neighbors.view(), d_distances.view()); const auto& comms = resource::get_comms(handle); - comms.device_send(neighbors_view.data_handle(), + comms.device_send(d_neighbors.data_handle(), n_rows * n_neighbors, root_rank, resource::get_cuda_stream(handle)); - comms.device_send(distances_view.data_handle(), + comms.device_send(d_distances.data_handle(), n_rows * n_neighbors, root_rank, resource::get_cuda_stream(handle)); } - private: - template - raft::device_matrix_view store_to_device( - raft::resources const& handle, - std::optional>& dev_mat_opt, - raft::host_matrix_view host_mat_view) + void serialize(raft::resources const& handle, + std::ostream& os) const + { + if constexpr (std::is_same>::value) { + ivf_flat::serialize(handle, os, index_.value()); + } else if constexpr (std::is_same>::value) { + ivf_pq::serialize(handle, os, index_.value()); + } + } + + void deserialize(raft::resources const& handle, + std::istream& is) { - DataIdxT n_rows = host_mat_view.extent(0); - DataIdxT n_cols = host_mat_view.extent(1); - dev_mat_opt.emplace( - std::move(raft::make_device_matrix(handle, n_rows, n_cols))); - raft::copy(dev_mat_opt.value().data_handle(), // async copy - host_mat_view.data_handle(), - n_rows * n_cols, - resource::get_cuda_stream(handle)); - auto const_dev_mat_view = dev_mat_opt.value().view(); - raft::device_matrix_view dev_mat_view = - raft::make_device_matrix_view( - const_dev_mat_view.data_handle(), - const_dev_mat_view.extent(0), - const_dev_mat_view.extent(1)); - return dev_mat_view; + if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_flat::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_pq::deserialize(handle, is))); + } } - template - raft::device_vector_view store_to_device( - raft::resources const& handle, - std::optional>& dev_vec_opt, - raft::host_vector_view host_vec_view) + const IdxT size() const { - DataIdxT n_rows = host_vec_view.extent(0); - dev_vec_opt.emplace(std::move(raft::make_device_vector(handle, n_rows))); - raft::copy(dev_vec_opt.value().data_handle(), // async copy - host_vec_view.data_handle(), - n_rows, - resource::get_cuda_stream(handle)); - auto const_dev_vec_view = dev_vec_opt.value().view(); - raft::device_vector_view dev_vec_view = - raft::make_device_vector_view(const_dev_vec_view.data_handle(), - const_dev_vec_view.extent(0)); - return dev_vec_view; + if constexpr (std::is_same>::value) { + return index_.value().size(); + } else if constexpr (std::is_same>::value) { + return index_.value().size(); + } } - std::optional> index_dataset_; - std::optional> new_vectors_; - std::optional> new_indices_; - std::optional> query_dataset_; - std::optional> neighbors_; - std::optional> distances_; + private: std::optional index_; }; template class ann_mg_index { public: - ann_mg_index() = delete; ann_mg_index(const std::vector& dev_list, dist_mode mode = SHARDING) : mode_(mode), root_rank_(0), @@ -238,28 +213,62 @@ class ann_mg_index { dev_ids_(dev_list), nccl_comms_(dev_list.size()) { - RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data())); + init_device_resources(); + init_nccl_clique(); + } + + // deserialization + ann_mg_index(const raft::resources& handle, + const std::string& filename) { + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + mode_ = (raft::neighbors::mg::dist_mode)deserialize_scalar(handle, is); + root_rank_ = 0; + num_ranks_ = deserialize_scalar(handle, is); + dev_ids_.resize(num_ranks_); + std::iota(std::begin(dev_ids_), std::end(dev_ids_), 0); + nccl_comms_.resize(num_ranks_); + + init_device_resources(); + init_nccl_clique(); + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.deserialize(dev_resources_[rank], is); + } + + is.close(); + } + + ann_mg_index(const ann_mg_index&) = delete; + ann_mg_index(ann_mg_index&&) = default; + auto operator=(const ann_mg_index&) -> ann_mg_index& = delete; + auto operator=(ann_mg_index&&) -> ann_mg_index& = default; + + void init_device_resources() { for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + dev_resources_.emplace_back(); + } + } - raft::device_resources& handle = dev_resources_.emplace_back(); - raft::comms::build_comms_nccl_only(&handle, nccl_comms_[rank], num_ranks_, rank); + void init_nccl_clique() { + RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data())); + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + raft::comms::build_comms_nccl_only(&dev_resources_[rank], nccl_comms_[rank], num_ranks_, rank); } } - ~ann_mg_index() - { + void destroy_nccl_clique() { for (int rank = 0; rank < num_ranks_; rank++) { cudaSetDevice(dev_ids_[rank]); ncclCommDestroy(nccl_comms_[rank]); } } - ann_mg_index(const ann_mg_index&) = delete; - ann_mg_index(ann_mg_index&&) = default; - auto operator=(const ann_mg_index&) -> ann_mg_index& = delete; - auto operator=(ann_mg_index&&) -> ann_mg_index& = default; - void build(const ann::index_params* index_params, raft::host_matrix_view index_dataset) { @@ -277,12 +286,12 @@ class ann_mg_index { for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + offset; + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_per_shard, n_cols); auto& ann_if = ann_interfaces_.emplace_back(); ann_if.build(dev_resources_[rank], index_params, partition); - offset += n_rows_per_shard * n_cols; + offset += n_rows_per_shard; } } set_current_device_to_root_rank(); @@ -305,7 +314,7 @@ class ann_mg_index { for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + offset; + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_per_shard, n_cols); @@ -317,7 +326,7 @@ class ann_mg_index { } auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_resources_[rank], new_vectors_part, new_indices_part); - offset += n_rows_per_shard * n_cols; + offset += n_rows_per_shard; } } set_current_device_to_root_rank(); @@ -326,32 +335,36 @@ class ann_mg_index { void search(const ann::search_params* search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view distances) const { if (mode_ == INDEX_DUPLICATION) { IdxT n_rows = query_dataset.extent(0); IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + + IdxT offset = 0; IdxT query_offset = 0; IdxT output_offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - query_offset); + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_per_shard, n_cols); auto neighbors_partition = raft::make_host_matrix_view( neighbors.data_handle() + output_offset, n_rows_per_shard, n_neighbors); auto distances_partition = raft::make_host_matrix_view( distances.data_handle() + output_offset, n_rows_per_shard, n_neighbors); + auto& ann_if = ann_interfaces_[rank]; ann_if.search(dev_resources_[rank], search_params, query_partition, neighbors_partition, distances_partition); - query_offset += n_rows_per_shard * n_cols; - output_offset += n_rows_per_shard * n_neighbors; + offset += n_rows_per_shard; + query_offset = offset * n_cols; + output_offset = offset * n_neighbors; } } else if (mode_ == SHARDING) { IdxT n_rows = query_dataset.extent(0); @@ -371,10 +384,11 @@ class ann_mg_index { auto out_distances = raft::make_device_matrix( root_handle, n_rows_per_batches, n_neighbors); + IdxT offset = 0; IdxT query_offset = 0; IdxT output_offset = 0; for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - n_rows_per_batches = std::min(n_rows_per_batches, n_rows - query_offset); + n_rows_per_batches = std::min(n_rows_per_batches, n_rows - offset); auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_per_batches, n_cols); @@ -409,9 +423,15 @@ class ann_mg_index { out_distances.data_handle(), n_rows_per_batches, n_neighbors); const auto& root_handle_ = set_current_device_to_root_rank(); - auto trans = raft::make_device_vector(root_handle_, num_ranks_); - RAFT_CUDA_TRY(cudaMemsetAsync(trans.data_handle(), 0, num_ranks_ * sizeof(IdxT), resource::get_cuda_stream(root_handle_))); - auto translations = std::make_optional>(trans.view()); + auto h_trans = std::vector(num_ranks_); + IdxT translation_offset = 0; + for (int rank = 0; rank < num_ranks_; rank++) { + h_trans[rank] = translation_offset; + translation_offset += ann_interfaces_[rank].size(); + } + auto d_trans = raft::make_device_vector(root_handle_, num_ranks_); + raft::copy(d_trans.data_handle(), h_trans.data(), num_ranks_, resource::get_cuda_stream(root_handle_)); + auto translations = std::make_optional>(d_trans.view()); raft::neighbors::brute_force::knn_merge_parts(root_handle_, in_distances_view, in_neighbors_view, @@ -429,15 +449,34 @@ class ann_mg_index { n_rows_per_batches * n_neighbors, resource::get_cuda_stream(root_handle_)); - query_offset += n_rows_per_batches * n_cols; - output_offset += n_rows_per_batches * n_neighbors; + offset += n_rows_per_batches; + query_offset = offset * n_cols; + output_offset = offset * n_neighbors; } } set_current_device_to_root_rank(); } - inline const raft::device_resources& set_current_device_to_root_rank() + void serialize(raft::resources const& handle, + const std::string& filename) const + { + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + serialize_scalar(handle, of, (int)mode_); + serialize_scalar(handle, of, num_ranks_); + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + auto& ann_if = ann_interfaces_[rank]; + ann_if.serialize(dev_resources_[rank], of); + } + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } + } + + inline const raft::device_resources& set_current_device_to_root_rank() const { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[root_rank_])); return dev_resources_[root_rank_]; @@ -494,7 +533,7 @@ void extend(ann_mg_index, T, uint32_t>& index, } template -void search(ann_mg_index, T, IdxT>& index, +void search(const ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, @@ -505,7 +544,7 @@ void search(ann_mg_index, T, IdxT>& index, } template -void search(ann_mg_index, T, uint32_t>& index, +void search(const ann_mg_index, T, uint32_t>& index, const ivf_pq::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, @@ -515,4 +554,34 @@ void search(ann_mg_index, T, uint32_t>& index, static_cast(&search_params), query_dataset, neighbors, distances); } +template +void serialize(const raft::resources& handle, + const ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + index.serialize(handle, filename); +} + +template +void serialize(const raft::resources& handle, + const ann_mg_index, T, uint32_t>& index, + const std::string& filename) +{ + index.serialize(handle, filename); +} + +template +ann_mg_index, T, IdxT> deserialize(const raft::resources& handle, + const std::string& filename) +{ + return ann_mg_index, T, IdxT>(handle, filename); +} + +template +ann_mg_index, T, uint32_t> deserialize(const raft::resources& handle, + const std::string& filename) +{ + return ann_mg_index, T, uint32_t>(handle, filename); +} + } // namespace raft::neighbors::mg::detail \ No newline at end of file diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index a866142761..ed868403b5 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -113,6 +113,8 @@ class AnnMGTest : public ::testing::TestWithParam> { ps.k, 0.001, min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); } // IVF-PQ @@ -151,6 +153,8 @@ class AnnMGTest : public ::testing::TestWithParam> { ps.k, 0.001, min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); } } diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json index 791261251a..d4bd4dda15 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json @@ -472,6 +472,19 @@ {"nprobe": 2000} ] }, + { + "name": "raft_ann_mg.nlist16384", + "algo": "raft_ann_mg", + "build_param": {"nlist": 16384, "ratio": 2, "niter": 20}, + "file": "sift-128-euclidean/raft_ann_mg/nlist16384", + "dataset_memory_type": "host", + "query_memory_type": "host", + "search_params": [ + {"nprobe": 100}, + {"nprobe": 200}, + {"nprobe": 500} + ] + }, { "name": "raft_cagra.dim32", "algo": "raft_cagra", From 6d29eb10b8dce7b312ea5e8908ac959c3040571a Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 26 Jan 2024 19:04:33 +0100 Subject: [PATCH 07/22] Adding CAGRA capability --- cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h | 2 +- cpp/include/raft/neighbors/ann_mg.cuh | 45 +++++++++++-- cpp/include/raft/neighbors/detail/ann_mg.cuh | 68 ++++++++++++++++++-- cpp/test/neighbors/ann_mg.cuh | 41 ++++++++++++ 4 files changed, 145 insertions(+), 11 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h index 9a15e2d2bc..02d77b3400 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h @@ -103,7 +103,7 @@ void RaftAnnMG::save(const std::string& file) const template void RaftAnnMG::load(const std::string& file) { - index_.emplace(raft::neighbors::mg::deserialize(handle_, file)); + index_.emplace(raft::neighbors::mg::deserialize_flat(handle_, file)); } template diff --git a/cpp/include/raft/neighbors/ann_mg.cuh b/cpp/include/raft/neighbors/ann_mg.cuh index 537a44de83..32df67cf62 100644 --- a/cpp/include/raft/neighbors/ann_mg.cuh +++ b/cpp/include/raft/neighbors/ann_mg.cuh @@ -40,6 +40,16 @@ auto build(const std::vector device_ids, return mg::detail::build(device_ids, mode, index_params, index_dataset); } +template +auto build(const std::vector device_ids, + raft::neighbors::mg::dist_mode mode, + const cagra::index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(device_ids, mode, index_params, index_dataset); +} + template void extend(detail::ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, @@ -76,6 +86,16 @@ void search(const detail::ann_mg_index, T, uint32_t>& in mg::detail::search(index, search_params, query_dataset, neighbors, distances); } +template +void search(const detail::ann_mg_index, T, IdxT>& index, + const cagra::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(index, search_params, query_dataset, neighbors, distances); +} + template void serialize(const raft::resources& handle, const detail::ann_mg_index, T, IdxT>& index, @@ -93,16 +113,31 @@ void serialize(const raft::resources& handle, } template -detail::ann_mg_index, T, IdxT> deserialize(const raft::resources& handle, - const std::string& filename) +void serialize(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) { - return mg::detail::deserialize(handle, filename); + mg::detail::serialize(handle, index, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, + const std::string& filename) +{ + return mg::detail::deserialize_flat(handle, filename); } template -detail::ann_mg_index, T, uint32_t> deserialize(const raft::resources& handle, +detail::ann_mg_index, T, uint32_t> deserialize_pq(const raft::resources& handle, + const std::string& filename) +{ + return mg::detail::deserialize_pq(handle, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, const std::string& filename) { - return mg::detail::deserialize(handle, filename); + return mg::detail::deserialize_cagra(handle, filename); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index d0a4ec22ac..126d8132ec 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -26,6 +26,8 @@ #include #include #include +#include +#include #define RAFT_EXPLICIT_INSTANTIATE_ONLY // Number of rows per batch (search on shards) @@ -57,6 +59,9 @@ class ann_interface { } else if constexpr (std::is_same>::value) { index_.emplace(std::move(ivf_pq::build( handle, *static_cast(index_params), d_index_dataset_view))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(cagra::build( + handle, *static_cast(index_params), d_index_dataset_view))); } } @@ -84,6 +89,8 @@ class ann_interface { } else if constexpr (std::is_same>::value) { index_.emplace(std::move( ivf_pq::extend(handle, d_new_vectors_view, new_indices_opt, index_.value()))); + } else if constexpr (std::is_same>::value) { + RAFT_FAIL("CAGRA does not implement the extend method"); } } @@ -107,7 +114,14 @@ class ann_interface { query_dataset, neighbors, distances); - } + } else if constexpr (std::is_same>::value) { + cagra::search(handle, + *reinterpret_cast(search_params), + index_.value(), + query_dataset, + neighbors, + distances); + } } // Index duplication, results stored on host memory without merge @@ -177,6 +191,8 @@ class ann_interface { ivf_flat::serialize(handle, os, index_.value()); } else if constexpr (std::is_same>::value) { ivf_pq::serialize(handle, os, index_.value()); + } else if constexpr (std::is_same>::value) { + cagra::serialize(handle, os, index_.value()); } } @@ -187,6 +203,8 @@ class ann_interface { index_.emplace(std::move(ivf_flat::deserialize(handle, is))); } else if constexpr (std::is_same>::value) { index_.emplace(std::move(ivf_pq::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(cagra::deserialize(handle, is))); } } @@ -196,6 +214,8 @@ class ann_interface { return index_.value().size(); } else if constexpr (std::is_same>::value) { return index_.value().size(); + } else if constexpr (std::is_same>::value) { + return index_.value().size(); } } @@ -516,6 +536,18 @@ ann_mg_index, T, uint32_t> build( return index; } +template +ann_mg_index, T, IdxT> build( + const std::vector device_ids, + dist_mode mode, + const cagra::index_params& index_params, + raft::host_matrix_view index_dataset) +{ + ann_mg_index, T, IdxT> index(device_ids, mode); + index.build(static_cast(&index_params), index_dataset); + return index; +} + template void extend(ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, @@ -554,6 +586,17 @@ void search(const ann_mg_index, T, uint32_t>& index, static_cast(&search_params), query_dataset, neighbors, distances); } +template +void search(const ann_mg_index, T, IdxT>& index, + const cagra::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + index.search( + static_cast(&search_params), query_dataset, neighbors, distances); +} + template void serialize(const raft::resources& handle, const ann_mg_index, T, IdxT>& index, @@ -571,17 +614,32 @@ void serialize(const raft::resources& handle, } template -ann_mg_index, T, IdxT> deserialize(const raft::resources& handle, - const std::string& filename) +void serialize(const raft::resources& handle, + const ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + index.serialize(handle, filename); +} + +template +ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, + const std::string& filename) { return ann_mg_index, T, IdxT>(handle, filename); } template -ann_mg_index, T, uint32_t> deserialize(const raft::resources& handle, - const std::string& filename) +ann_mg_index, T, uint32_t> deserialize_pq(const raft::resources& handle, + const std::string& filename) { return ann_mg_index, T, uint32_t>(handle, filename); } +template +ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, + const std::string& filename) +{ + return ann_mg_index, T, IdxT>(handle, filename); +} + } // namespace raft::neighbors::mg::detail \ No newline at end of file diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index ed868403b5..66edc81b9a 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -54,6 +54,7 @@ class AnnMGTest : public ::testing::TestWithParam> { std::vector distances_naive(queries_size); std::vector indices_ann(queries_size); std::vector distances_ann(queries_size); + std::vector indices_ann_int64(queries_size); { rmm::device_uvector distances_naive_dev(queries_size, stream_); @@ -156,6 +157,46 @@ class AnnMGTest : public ::testing::TestWithParam> { std::fill(indices_ann.begin(), indices_ann.end(), 0); std::fill(distances_ann.begin(), distances_ann.end(), 0); } + + // CAGRA + for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION, dist_mode::SHARDING}) { + cagra::index_params index_params; + index_params.intermediate_graph_degree = 128; + index_params.graph_degree = 64; + index_params.build_algo = cagra::graph_build_algo::IVF_PQ; + index_params.nn_descent_niter = 20; + + cagra::search_params search_params; + + auto index_dataset = raft::make_host_matrix_view( + h_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto query_dataset = raft::make_host_matrix_view( + h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view( + indices_ann_int64.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view( + distances_ann.data(), ps.num_queries, ps.k); + + auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); + resource::sync_stream(handle_); + + std::transform(indices_ann_int64.begin(), indices_ann_int64.end(), + indices_ann.begin(), [](int x) { return (int64_t)x;}); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(indices_ann_int64.begin(), indices_ann_int64.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } } void SetUp() override From 34bb6c4920dac880cf9f6894da22e2929892390f Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 1 Feb 2024 18:57:13 +0100 Subject: [PATCH 08/22] Testing serialization + use of pre-computed methods --- cpp/include/raft/neighbors/ann_mg.cuh | 36 +++++----- cpp/include/raft/neighbors/detail/ann_mg.cuh | 43 ++++++------ cpp/test/neighbors/ann_mg.cuh | 70 ++++++++++++-------- cpp/test/neighbors/ann_mg/test_ann_mg.cu | 2 +- 4 files changed, 82 insertions(+), 69 deletions(-) diff --git a/cpp/include/raft/neighbors/ann_mg.cuh b/cpp/include/raft/neighbors/ann_mg.cuh index 32df67cf62..9f9b37e22c 100644 --- a/cpp/include/raft/neighbors/ann_mg.cuh +++ b/cpp/include/raft/neighbors/ann_mg.cuh @@ -30,12 +30,12 @@ auto build(const std::vector device_ids, return mg::detail::build(device_ids, mode, index_params, index_dataset); } -template +template auto build(const std::vector device_ids, raft::neighbors::mg::dist_mode mode, const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, uint32_t> + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> { return mg::detail::build(device_ids, mode, index_params, index_dataset); } @@ -58,10 +58,10 @@ void extend(detail::ann_mg_index, T, IdxT>& index, mg::detail::extend(index, new_vectors, new_indices); } -template -void extend(detail::ann_mg_index, T, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) +template +void extend(detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) { mg::detail::extend(index, new_vectors, new_indices); } @@ -76,12 +76,12 @@ void search(const detail::ann_mg_index, T, IdxT>& index mg::detail::search(index, search_params, query_dataset, neighbors, distances); } -template -void search(const detail::ann_mg_index, T, uint32_t>& index, +template +void search(const detail::ann_mg_index, T, IdxT>& index, const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) { mg::detail::search(index, search_params, query_dataset, neighbors, distances); } @@ -104,9 +104,9 @@ void serialize(const raft::resources& handle, mg::detail::serialize(handle, index, filename); } -template +template void serialize(const raft::resources& handle, - const detail::ann_mg_index, T, uint32_t>& index, + const detail::ann_mg_index, T, IdxT>& index, const std::string& filename) { mg::detail::serialize(handle, index, filename); @@ -127,11 +127,11 @@ detail::ann_mg_index, T, IdxT> deserialize_flat(const r return mg::detail::deserialize_flat(handle, filename); } -template -detail::ann_mg_index, T, uint32_t> deserialize_pq(const raft::resources& handle, - const std::string& filename) +template +detail::ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, + const std::string& filename) { - return mg::detail::deserialize_pq(handle, filename); + return mg::detail::deserialize_pq(handle, filename); } template diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 126d8132ec..ff6f35affc 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #define RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -190,7 +191,7 @@ class ann_interface { if constexpr (std::is_same>::value) { ivf_flat::serialize(handle, os, index_.value()); } else if constexpr (std::is_same>::value) { - ivf_pq::serialize(handle, os, index_.value()); + ivf_pq::serialize(handle, os, index_.value()); } else if constexpr (std::is_same>::value) { cagra::serialize(handle, os, index_.value()); } @@ -202,7 +203,7 @@ class ann_interface { if constexpr (std::is_same>::value) { index_.emplace(std::move(ivf_flat::deserialize(handle, is))); } else if constexpr (std::is_same>::value) { - index_.emplace(std::move(ivf_pq::deserialize(handle, is))); + index_.emplace(std::move(ivf_pq::deserialize(handle, is))); } else if constexpr (std::is_same>::value) { index_.emplace(std::move(cagra::deserialize(handle, is))); } @@ -524,14 +525,14 @@ ann_mg_index, T, IdxT> build( return index; } -template -ann_mg_index, T, uint32_t> build( +template +ann_mg_index, T, IdxT> build( const std::vector device_ids, dist_mode mode, const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) + raft::host_matrix_view index_dataset) { - ann_mg_index, T, uint32_t> index(device_ids, mode); + ann_mg_index, T, IdxT> index(device_ids, mode); index.build(static_cast(&index_params), index_dataset); return index; } @@ -556,10 +557,10 @@ void extend(ann_mg_index, T, IdxT>& index, index.extend(new_vectors, new_indices); } -template -void extend(ann_mg_index, T, uint32_t>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) +template +void extend(ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) { index.extend(new_vectors, new_indices); } @@ -575,12 +576,12 @@ void search(const ann_mg_index, T, IdxT>& index, static_cast(&search_params), query_dataset, neighbors, distances); } -template -void search(const ann_mg_index, T, uint32_t>& index, +template +void search(const ann_mg_index, T, IdxT>& index, const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) { index.search( static_cast(&search_params), query_dataset, neighbors, distances); @@ -605,9 +606,9 @@ void serialize(const raft::resources& handle, index.serialize(handle, filename); } -template +template void serialize(const raft::resources& handle, - const ann_mg_index, T, uint32_t>& index, + const ann_mg_index, T, IdxT>& index, const std::string& filename) { index.serialize(handle, filename); @@ -628,11 +629,11 @@ ann_mg_index, T, IdxT> deserialize_flat(const raft::res return ann_mg_index, T, IdxT>(handle, filename); } -template -ann_mg_index, T, uint32_t> deserialize_pq(const raft::resources& handle, - const std::string& filename) +template +ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, + const std::string& filename) { - return ann_mg_index, T, uint32_t>(handle, filename); + return ann_mg_index, T, IdxT>(handle, filename); } template diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index 66edc81b9a..d1d4b93a10 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -51,10 +51,9 @@ class AnnMGTest : public ::testing::TestWithParam> { { size_t queries_size = ps.num_queries * ps.k; std::vector indices_naive(queries_size); - std::vector distances_naive(queries_size); + std::vector distances_naive(queries_size); std::vector indices_ann(queries_size); - std::vector distances_ann(queries_size); - std::vector indices_ann_int64(queries_size); + std::vector distances_ann(queries_size); { rmm::device_uvector distances_naive_dev(queries_size, stream_); @@ -98,11 +97,13 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); - auto index = - raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::extend(index, index_dataset, std::nullopt); - raft::neighbors::mg::search( - index, search_params, query_dataset, neighbors, distances); + { + auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::extend(index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, index, "ann_mg_ivf_flat_index"); + } + auto new_index = raft::neighbors::mg::deserialize_flat(handle_, "ann_mg_ivf_flat_index"); + raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -130,19 +131,22 @@ class AnnMGTest : public ::testing::TestWithParam> { ivf_pq::search_params search_params; search_params.n_probes = ps.nprobe; - auto index_dataset = raft::make_host_matrix_view( + auto index_dataset = raft::make_host_matrix_view( h_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto query_dataset = raft::make_host_matrix_view( + auto query_dataset = raft::make_host_matrix_view( h_query_dataset.data(), ps.num_queries, ps.dim); - auto neighbors = raft::make_host_matrix_view( + auto neighbors = raft::make_host_matrix_view( indices_ann.data(), ps.num_queries, ps.k); - auto distances = raft::make_host_matrix_view( + auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); - auto index = - raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::extend(index, index_dataset, std::nullopt); - raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); + { + auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::extend(index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, index, "ann_mg_ivf_pq_index"); + } + auto new_index = raft::neighbors::mg::deserialize_pq(handle_, "ann_mg_ivf_pq_index"); + raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -168,21 +172,30 @@ class AnnMGTest : public ::testing::TestWithParam> { cagra::search_params search_params; - auto index_dataset = raft::make_host_matrix_view( + auto index_dataset = raft::make_host_matrix_view( h_index_dataset.data(), ps.num_db_vecs, ps.dim); - auto query_dataset = raft::make_host_matrix_view( + auto query_dataset = raft::make_host_matrix_view( h_query_dataset.data(), ps.num_queries, ps.dim); - auto neighbors = raft::make_host_matrix_view( - indices_ann_int64.data(), ps.num_queries, ps.k); - auto distances = raft::make_host_matrix_view( + auto neighbors = raft::make_host_matrix_view( + indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); - auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); - resource::sync_stream(handle_); + /* + TODO : fix CAGRA serialization issue - std::transform(indices_ann_int64.begin(), indices_ann_int64.end(), - indices_ann.begin(), [](int x) { return (int64_t)x;}); + { + auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::serialize(handle_, index, "ann_mg_cagra_index"); + } + auto new_index = raft::neighbors::mg::deserialize_cagra(handle_, "ann_mg_cagra_index"); + raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); + */ + + auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); + raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); + + resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); ASSERT_TRUE(eval_neighbours(indices_naive, @@ -194,7 +207,6 @@ class AnnMGTest : public ::testing::TestWithParam> { 0.001, min_recall)); std::fill(indices_ann.begin(), indices_ann.end(), 0); - std::fill(indices_ann_int64.begin(), indices_ann_int64.end(), 0); std::fill(distances_ann.begin(), distances_ann.end(), 0); } } @@ -249,7 +261,7 @@ class AnnMGTest : public ::testing::TestWithParam> { rmm::device_uvector d_query_dataset; }; -const std::vector> inputs = { +const std::vector> inputs = { {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, }; -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/test/neighbors/ann_mg/test_ann_mg.cu b/cpp/test/neighbors/ann_mg/test_ann_mg.cu index c34cdabb36..b392f629bc 100644 --- a/cpp/test/neighbors/ann_mg/test_ann_mg.cu +++ b/cpp/test/neighbors/ann_mg/test_ann_mg.cu @@ -4,7 +4,7 @@ namespace raft::neighbors::mg { -typedef AnnMGTest AnnMGTestF_float; +typedef AnnMGTest AnnMGTestF_float; TEST_P(AnnMGTestF_float, AnnMG) { this->testAnnMG(); } INSTANTIATE_TEST_CASE_P(AnnMGTest, AnnMGTestF_float, ::testing::ValuesIn(inputs)); } From 2c3fbc2bacfbc91138ed09ee3e3583b782abdf80 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 2 Feb 2024 16:37:19 +0100 Subject: [PATCH 09/22] Add distribution feature --- cpp/include/raft/neighbors/ann_mg.cuh | 25 ++++ cpp/include/raft/neighbors/detail/ann_mg.cuh | 63 +++++++++- cpp/test/neighbors/ann_mg.cuh | 117 +++++++++++++++++++ 3 files changed, 204 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/neighbors/ann_mg.cuh b/cpp/include/raft/neighbors/ann_mg.cuh index 9f9b37e22c..1b764a2216 100644 --- a/cpp/include/raft/neighbors/ann_mg.cuh +++ b/cpp/include/raft/neighbors/ann_mg.cuh @@ -140,4 +140,29 @@ detail::ann_mg_index, T, IdxT> deserialize_cagra(const raf { return mg::detail::deserialize_cagra(handle, filename); } + +template +detail::ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return mg::detail::distribute_flat(handle, dev_list, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return mg::detail::distribute_pq(handle, dev_list, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return mg::detail::distribute_cagra(handle, dev_list, filename); +} + } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index ff6f35affc..12fb97376a 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -209,6 +209,23 @@ class ann_interface { } } + void deserialize(raft::resources const& handle, + const std::string& filename) + { + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_flat::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(ivf_pq::deserialize(handle, is))); + } else if constexpr (std::is_same>::value) { + index_.emplace(std::move(cagra::deserialize(handle, is))); + } + + is.close(); + } + const IdxT size() const { if constexpr (std::is_same>::value) { @@ -238,7 +255,27 @@ class ann_mg_index { init_nccl_clique(); } - // deserialization + // index deserialization and distribution + ann_mg_index(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) + : mode_(INDEX_DUPLICATION), + root_rank_(0), + num_ranks_(dev_list.size()), + dev_ids_(dev_list), + nccl_comms_(dev_list.size()) + { + init_device_resources(); + init_nccl_clique(); + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.deserialize(dev_resources_[rank], filename); + } + } + + // MG index deserialization ann_mg_index(const raft::resources& handle, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); @@ -643,4 +680,28 @@ ann_mg_index, T, IdxT> deserialize_cagra(const raft::resou return ann_mg_index, T, IdxT>(handle, filename); } +template +ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return ann_mg_index, T, IdxT>(handle, dev_list, filename); +} + +template +ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return ann_mg_index, T, IdxT>(handle, dev_list, filename); +} + +template +ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return ann_mg_index, T, IdxT>(handle, dev_list, filename); +} + } // namespace raft::neighbors::mg::detail \ No newline at end of file diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index d1d4b93a10..a9ac51030b 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -209,6 +209,123 @@ class AnnMGTest : public ::testing::TestWithParam> { std::fill(indices_ann.begin(), indices_ann.end(), 0); std::fill(distances_ann.begin(), distances_ann.end(), 0); } + + { + ivf_flat::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + + ivf_flat::search_params search_params; + search_params.n_probes = ps.nprobe; + + { + auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_flat::build(handle_, index_params, index_dataset); + ivf_flat::serialize(handle_, "local_ivf_flat_index", index); + } + + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, device_ids, "local_ivf_flat_index"); + raft::neighbors::mg::search(distributed_index, search_params, query_dataset, neighbors, distances); + + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + { + ivf_pq::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + + ivf_pq::search_params search_params; + search_params.n_probes = ps.nprobe; + + { + auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_pq::build(handle_, index_params, index_dataset); + ivf_pq::serialize(handle_, "local_ivf_pq_index", index); + } + + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, device_ids, "local_ivf_pq_index"); + raft::neighbors::mg::search(distributed_index, search_params, query_dataset, neighbors, distances); + + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + + { + cagra::index_params index_params; + index_params.intermediate_graph_degree = 128; + index_params.graph_degree = 64; + index_params.build_algo = cagra::graph_build_algo::IVF_PQ; + index_params.nn_descent_niter = 20; + + cagra::search_params search_params; + + { + auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index = cagra::build(handle_, index_params, index_dataset); + cagra::serialize(handle_, "local_cagra_index", index); + } + + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + + auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, device_ids, "local_cagra_index"); + raft::neighbors::mg::search(distributed_index, search_params, query_dataset, neighbors, distances); + + resource::sync_stream(handle_); + + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ann, + distances_naive, + distances_ann, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + std::fill(indices_ann.begin(), indices_ann.end(), 0); + std::fill(distances_ann.begin(), distances_ann.end(), 0); + } + } void SetUp() override From 4581fbd9563c9cba81dab7b6c1a5aad21a906018 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 23 Apr 2024 16:19:58 +0200 Subject: [PATCH 10/22] SNMG ANN bench update --- cpp/bench/ann/CMakeLists.txt | 14 ++--- .../src/raft/raft_ann_bench_param_parser.h | 24 +++++++++ cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h | 53 +++++++++++-------- cpp/bench/ann/src/raft/raft_benchmark.cu | 21 ++++++++ docs/source/ann_benchmarks_build.md | 1 + .../src/raft-ann-bench/run/algos.yaml | 3 ++ .../run/conf/algos/raft_ann_mg.yaml | 9 ++++ .../run/conf/mnist-784-euclidean.json | 18 +++++++ 8 files changed, 115 insertions(+), 28 deletions(-) create mode 100644 python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg.yaml diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index cb8e7cba27..dfa7f009ca 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -275,6 +275,13 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA) ) endif() +if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) + ConfigureAnnBench( + NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled + hnswlib::hnswlib + ) +endif() + if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG) ConfigureAnnBench( NAME @@ -286,13 +293,6 @@ if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG) ) endif() -if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) - ConfigureAnnBench( - NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled - hnswlib::hnswlib - ) -endif() - set(RAFT_FAISS_TARGETS faiss::faiss) if(TARGET faiss::faiss_avx2) set(RAFT_FAISS_TARGETS faiss::faiss_avx2) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 48bf1d70d8..4dfe5ed4eb 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -47,6 +47,12 @@ extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG +#include "raft_ann_mg_wrapper.h" +extern template class raft::bench::ann::RaftAnnMG; +extern template class raft::bench::ann::RaftAnnMG; +extern template class raft::bench::ann::RaftAnnMG; +#endif #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT template @@ -66,6 +72,24 @@ void parse_search_param(const nlohmann::json& conf, } #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftAnnMG::BuildParam& param) +{ + param.n_lists = conf.at("nlist"); + if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } + if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftAnnMG::SearchParam& param) +{ + param.ivf_flat_params.n_probes = conf.at("nprobe"); +} +#endif + #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) template diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h index 02d77b3400..47bf5ce2ab 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ namespace raft::bench::ann { template -class RaftAnnMG : public ANN { +class RaftAnnMG : public ANN, public AnnGPU { public: using typename ANN::AnnSearchParam; @@ -40,48 +40,52 @@ class RaftAnnMG : public ANN { RAFT_CUDA_TRY(cudaGetDevice(&device_)); } - ~RaftAnnMG() noexcept {} - - void build(const T* dataset, size_t nrow, cudaStream_t stream) final; + void build(const T* dataset, size_t nrow) final; void set_search_param(const AnnSearchParam& param) override; // TODO: if the number of results is less than k, the remaining elements of 'neighbors' // will be filled with (size_t)-1 - void search(const T* queries, - int batch_size, - int k, - size_t* neighbors, - float* distances, - cudaStream_t stream = 0) const override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + + [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override + { + return handle_.get_sync_stream(); + } // to enable dataset access from GPU memory AlgoProperty get_preference() const override { AlgoProperty property; - property.dataset_memory_type = MemoryType::Host; - property.query_memory_type = MemoryType::Host; + property.dataset_memory_type = MemoryType::HostMmap; + property.query_memory_type = MemoryType::Device; return property; } void save(const std::string& file) const override; void load(const std::string&) override; + std::unique_ptr> copy() override; private: - raft::device_resources handle_; + // handle_ must go first to make sure it dies last and all memory allocated in pool + configured_raft_resources handle_{}; BuildParam index_params_; raft::neighbors::ivf_flat::search_params search_params_; - std::optional, T, IdxT>> index_; + std::shared_ptr, T, IdxT>> index_; int device_; int dimension_; }; template -void RaftAnnMG::build(const T* dataset, size_t nrow, cudaStream_t) +void RaftAnnMG::build(const T* dataset, size_t nrow) { std::vector device_ids{0, 1}; raft::neighbors::mg::dist_mode d_mode = raft::neighbors::mg::dist_mode::INDEX_DUPLICATION; auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dimension_)); - index_ = neighbors::mg::build(device_ids, d_mode, index_params_, dataset_matrix); + auto idx = raft::neighbors::mg::build(device_ids, d_mode, index_params_, dataset_matrix); + index_ = std::make_shared, T, IdxT>>(std::move( + idx + )); return; } @@ -96,25 +100,32 @@ void RaftAnnMG::set_search_param(const AnnSearchParam& param) template void RaftAnnMG::save(const std::string& file) const { - raft::neighbors::mg::serialize(handle_, index_.value(), file); + raft::neighbors::mg::serialize(handle_, *index_, file); return; } template void RaftAnnMG::load(const std::string& file) { - index_.emplace(raft::neighbors::mg::deserialize_flat(handle_, file)); + index_ = std::make_shared, T, IdxT>>( + std::move(raft::neighbors::mg::deserialize_flat(handle_, file))); +} + +template +std::unique_ptr> RaftAnnMG::copy() +{ + return std::make_unique>(*this); // use copy constructor } template void RaftAnnMG::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(dimension_)); auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - raft::neighbors::mg::search(index_.value(), search_params_, query_matrix, neighbors_matrix, distances_matrix); + raft::neighbors::mg::search(*index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); resource::sync_stream(handle_); return; } diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 8bb4d9423c..ddf6b9f873 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -81,6 +81,16 @@ std::unique_ptr> create_algo(const std::string& algo, ann = std::make_unique>(metric, dim, param); } #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg") { + typename raft::bench::ann::RaftAnnMG::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } + } +#endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } @@ -122,6 +132,17 @@ std::unique_ptr::AnnSearchParam> create_search return param; } #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg") { + auto param = + std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + } +#endif // else throw std::runtime_error("invalid algo: '" + algo + "'"); diff --git a/docs/source/ann_benchmarks_build.md b/docs/source/ann_benchmarks_build.md index 80730c5d68..d06fa4fbe6 100644 --- a/docs/source/ann_benchmarks_build.md +++ b/docs/source/ann_benchmarks_build.md @@ -44,5 +44,6 @@ Available targets to use with `--limit-bench-ann` are: - RAFT_CAGRA_ANN_BENCH - RAFT_IVF_PQ_ANN_BENCH - RAFT_IVF_FLAT_ANN_BENCH +- RAFT_ANN_MG_ANN_BENCH By default, the `*_ANN_BENCH` executables program infer the dataset's datatype from the filename's extension. For example, an extension of `fbin` uses a `float` datatype, `f16bin` uses a `float16` datatype, extension of `i8bin` uses `int8_t` datatype, and `u8bin` uses `uint8_t` type. Currently, only `float`, `float16`, int8_t`, and `unit8_t` are supported. \ No newline at end of file diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml index e382bdcba6..c7bc62fa87 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml @@ -31,6 +31,9 @@ raft_cagra: raft_brute_force: executable: RAFT_BRUTE_FORCE_ANN_BENCH requires_gpu: true +raft_ann_mg: + executable: RAFT_ANN_MG_ANN_BENCH + requires_gpu: true ggnn: executable: GGNN_ANN_BENCH requires_gpu: true diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg.yaml new file mode 100644 index 0000000000..9f90879ec9 --- /dev/null +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg.yaml @@ -0,0 +1,9 @@ +name: raft_ann_mg +groups: + base: + build: + nlist: [1024, 2048, 4096, 8192, 16384, 32000, 64000] + ratio: [1, 2, 4] + niter: [20, 25] + search: + nprobe: [1, 5, 10, 50, 100, 200, 500, 1000, 2000] \ No newline at end of file diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json index 04e7ecb469..732ff109e7 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json @@ -1278,6 +1278,24 @@ ], "search_result_file": "result/mnist-784-euclidean/raft_ivf_flat/nlist1024" }, + { + "name": "raft_ann_mg.nlist1024", + "algo": "raft_ann_mg", + "build_param": { + "nlist": 1024, + "ratio": 1, + "niter": 25 + }, + "file": "index/mnist-784-euclidean/raft_ann_mg/nlist1024", + "dataset_memory_type": "host", + "query_memory_type": "host", + "search_params": [ + { + "nprobe": 5 + } + ], + "search_result_file": "result/mnist-784-euclidean/raft_ann_mg/nlist1024" + }, { "name": "raft_ivf_flat.nlist16384", "algo": "raft_ivf_flat", From e3b03b830fd72cf7f012e2496283a621818ff2ca Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 3 May 2024 14:31:13 +0200 Subject: [PATCH 11/22] OpenMP --- cpp/include/raft/neighbors/detail/ann_mg.cuh | 243 +++++++++---------- 1 file changed, 113 insertions(+), 130 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index cbb01439c9..3539afd8a0 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -69,6 +69,7 @@ class ann_interface { index_.emplace(std::move(raft::runtime::neighbors::cagra::build( handle, *static_cast(index_params), d_index_dataset_view))); } + resource::sync_stream(handle); } void extend(raft::resources const& handle, @@ -79,10 +80,11 @@ class ann_interface { IdxT n_dims = h_new_vectors.extent(1); auto d_new_vectors = raft::make_device_matrix(handle, n_rows, n_dims); raft::copy(d_new_vectors.data_handle(), h_new_vectors.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); - raft::device_matrix_view d_new_vectors_view = raft::make_device_matrix_view(d_new_vectors.data_handle(), n_rows, n_dims); + raft::device_matrix_view d_new_vectors_view = \ + raft::make_device_matrix_view(d_new_vectors.data_handle(), n_rows, n_dims); std::optional> new_indices_opt = std::nullopt; - if (h_new_indices) { + if (h_new_indices.has_value()) { auto d_new_indices = raft::make_device_vector(handle, n_rows); raft::copy(d_new_indices.data_handle(), h_new_indices.value().data_handle(), n_rows, resource::get_cuda_stream(handle)); auto d_new_indices_view = raft::device_vector_view(d_new_indices.data_handle(), n_rows); @@ -98,96 +100,43 @@ class ann_interface { } else if constexpr (std::is_same>::value) { RAFT_FAIL("CAGRA does not implement the extend method"); } + resource::sync_stream(handle); } - void search_impl(raft::resources const& handle, - const ann::search_params* search_params, - raft::device_matrix_view query_dataset, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) const + void search(raft::resources const& handle, + const ann::search_params* search_params, + raft::host_matrix_view h_query_dataset, + raft::device_matrix_view d_neighbors, + raft::device_matrix_view d_distances) const { + IdxT n_rows = h_query_dataset.extent(0); + IdxT n_dims = h_query_dataset.extent(1); + auto d_query_dataset = raft::make_device_matrix(handle, n_rows, n_dims); + raft::copy(d_query_dataset.data_handle(), h_query_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); + if constexpr (std::is_same>::value) { raft::runtime::neighbors::ivf_flat::search(handle, *reinterpret_cast(search_params), index_.value(), - query_dataset, - neighbors, - distances); + d_query_dataset.view(), + d_neighbors, + d_distances); } else if constexpr (std::is_same>::value) { raft::runtime::neighbors::ivf_pq::search(handle, *reinterpret_cast(search_params), index_.value(), - query_dataset, - neighbors, - distances); + d_query_dataset.view(), + d_neighbors, + d_distances); } else if constexpr (std::is_same>::value) { raft::runtime::neighbors::cagra::search(handle, *reinterpret_cast(search_params), index_.value(), - query_dataset, - neighbors, - distances); + d_query_dataset.view(), + d_neighbors, + d_distances); } - } - - // Index duplication, results stored on host memory without merge - void search(raft::resources const& handle, - const ann::search_params* search_params, - raft::host_matrix_view h_query_dataset, - raft::host_matrix_view h_neighbors, - raft::host_matrix_view h_distances) const - { - IdxT n_rows = h_query_dataset.extent(0); - IdxT n_dims = h_query_dataset.extent(1); - IdxT n_neighbors = h_neighbors.extent(1); - - auto d_query = raft::make_device_matrix(handle, n_rows, n_dims); - raft::copy(d_query.data_handle(), h_query_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); - raft::device_matrix_view d_query_view = raft::make_device_matrix_view(d_query.data_handle(), n_rows, n_dims); - - auto d_neighbors = raft::make_device_matrix(handle, n_rows, n_neighbors); - auto d_distances = raft::make_device_matrix(handle, n_rows, n_neighbors); - - search_impl(handle, search_params, d_query_view, d_neighbors.view(), d_distances.view()); - - raft::copy(h_neighbors.data_handle(), - d_neighbors.data_handle(), - n_rows * n_neighbors, - resource::get_cuda_stream(handle)); - raft::copy(h_distances.data_handle(), - d_distances.data_handle(), - n_rows * n_neighbors, - resource::get_cuda_stream(handle)); - } - - // Sharding, results sent to root rank, then merged by it - void search(raft::resources const& handle, - const ann::search_params* search_params, - raft::host_matrix_view h_query_dataset, - IdxT n_neighbors, - int root_rank) const - { - IdxT n_rows = h_query_dataset.extent(0); - IdxT n_dims = h_query_dataset.extent(1); - - auto d_query = raft::make_device_matrix(handle, n_rows, n_dims); - raft::copy(d_query.data_handle(), h_query_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); - raft::device_matrix_view d_query_view = raft::make_device_matrix_view(d_query.data_handle(), n_rows, n_dims); - - auto d_neighbors = raft::make_device_matrix(handle, n_rows, n_neighbors); - auto d_distances = raft::make_device_matrix(handle, n_rows, n_neighbors); - - search_impl(handle, search_params, d_query_view, d_neighbors.view(), d_distances.view()); - - const auto& comms = resource::get_comms(handle); - comms.device_send(d_neighbors.data_handle(), - n_rows * n_neighbors, - root_rank, - resource::get_cuda_stream(handle)); - comms.device_send(d_distances.data_handle(), - n_rows * n_neighbors, - root_rank, - resource::get_cuda_stream(handle)); + resource::sync_stream(handle); } void serialize(raft::resources const& handle, @@ -336,26 +285,36 @@ class ann_mg_index { raft::host_matrix_view index_dataset) { if (mode_ == INDEX_DUPLICATION) { + ann_interfaces_.resize(num_ranks_); + + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - auto& ann_if = ann_interfaces_.emplace_back(); + auto& ann_if = ann_interfaces_[rank]; ann_if.build(dev_resources_[rank], index_params, index_dataset); + resource::sync_stream(dev_resources_[rank]); } + #pragma omp barrier } else if (mode_ == SHARDING) { IdxT n_rows = index_dataset.extent(0); IdxT n_cols = index_dataset.extent(1); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT offset = 0; + + ann_interfaces_.resize(num_ranks_); + + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = rank * n_rows_per_shard; n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_per_shard, n_cols); - auto& ann_if = ann_interfaces_.emplace_back(); + auto& ann_if = ann_interfaces_[rank]; ann_if.build(dev_resources_[rank], index_params, partition); - offset += n_rows_per_shard; + resource::sync_stream(dev_resources_[rank]); } + #pragma omp barrier } set_current_device_to_root_rank(); } @@ -363,34 +322,38 @@ class ann_mg_index { void extend(raft::host_matrix_view new_vectors, std::optional> new_indices) { + IdxT n_rows = new_vectors.extent(0); if (mode_ == INDEX_DUPLICATION) { + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_resources_[rank], new_vectors, new_indices); + resource::sync_stream(dev_resources_[rank]); } + #pragma omp barrier } else if (mode_ == SHARDING) { - IdxT n_rows = new_vectors.extent(0); IdxT n_cols = new_vectors.extent(1); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT offset = 0; + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = rank * n_rows_per_shard; + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_per_shard, n_cols); std::optional> new_indices_part = std::nullopt; - if (new_indices) { + if (new_indices.has_value()) { const IdxT* new_indices_ptr = new_indices.value().data_handle() + offset; - new_indices_part = - raft::make_host_vector_view(new_indices_ptr, n_rows_per_shard); + new_indices_part = raft::make_host_vector_view(new_indices_ptr, n_rows_per_shard); } auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_resources_[rank], new_vectors_part, new_indices_part); - offset += n_rows_per_shard; + resource::sync_stream(dev_resources_[rank]); } + #pragma omp barrier } set_current_device_to_root_rank(); } @@ -404,77 +367,101 @@ class ann_mg_index { IdxT n_rows = query_dataset.extent(0); IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT offset = 0; - IdxT query_offset = 0; - IdxT output_offset = 0; + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT offset = rank * n_rows_per_shard; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_per_shard, n_cols); - auto neighbors_partition = raft::make_host_matrix_view( - neighbors.data_handle() + output_offset, n_rows_per_shard, n_neighbors); - auto distances_partition = raft::make_host_matrix_view( - distances.data_handle() + output_offset, n_rows_per_shard, n_neighbors); + + auto& handle = dev_resources_[rank]; + auto d_neighbors = raft::make_device_matrix(handle, n_rows_per_shard, n_neighbors); + auto d_distances = raft::make_device_matrix(handle, n_rows_per_shard, n_neighbors); auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_resources_[rank], + ann_if.search(handle, search_params, query_partition, - neighbors_partition, - distances_partition); - offset += n_rows_per_shard; - query_offset = offset * n_cols; - output_offset = offset * n_neighbors; + d_neighbors.view(), + d_distances.view()); + + raft::copy(neighbors.data_handle() + output_offset, + d_neighbors.data_handle(), + n_rows_per_shard * n_neighbors, + resource::get_cuda_stream(handle)); + raft::copy(distances.data_handle() + output_offset, + d_distances.data_handle(), + n_rows_per_shard * n_neighbors, + resource::get_cuda_stream(handle)); + resource::sync_stream(handle); } + #pragma omp barrier } else if (mode_ == SHARDING) { IdxT n_rows = query_dataset.extent(0); IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); - IdxT n_rows_per_batches = N_ROWS_PER_BATCH; - IdxT n_batches = (n_rows + n_rows_per_batches - 1) / n_rows_per_batches; + IdxT n_batches = (n_rows + N_ROWS_PER_BATCH - 1) / N_ROWS_PER_BATCH; const auto& root_handle = set_current_device_to_root_rank(); auto in_neighbors = raft::make_device_matrix( - root_handle, num_ranks_ * n_rows_per_batches, n_neighbors); + root_handle, num_ranks_ * N_ROWS_PER_BATCH, n_neighbors); auto in_distances = raft::make_device_matrix( - root_handle, num_ranks_ * n_rows_per_batches, n_neighbors); + root_handle, num_ranks_ * N_ROWS_PER_BATCH, n_neighbors); auto out_neighbors = raft::make_device_matrix( - root_handle, n_rows_per_batches, n_neighbors); + root_handle, N_ROWS_PER_BATCH, n_neighbors); auto out_distances = raft::make_device_matrix( - root_handle, n_rows_per_batches, n_neighbors); + root_handle, N_ROWS_PER_BATCH, n_neighbors); - IdxT offset = 0; - IdxT query_offset = 0; - IdxT output_offset = 0; for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - n_rows_per_batches = std::min(n_rows_per_batches, n_rows - offset); + IdxT offset = batch_idx * N_ROWS_PER_BATCH; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; + IdxT n_rows_per_batches = std::min((IdxT)N_ROWS_PER_BATCH, n_rows - offset); auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_per_batches, n_cols); - RAFT_NCCL_TRY(ncclGroupStart()); + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + auto& handle = dev_resources_[rank]; + auto d_neighbors = raft::make_device_matrix(handle, n_rows_per_batches, n_neighbors); + auto d_distances = raft::make_device_matrix(handle, n_rows_per_batches, n_neighbors); + auto& ann_if = ann_interfaces_[rank]; - ann_if.search( - dev_resources_[rank], search_params, query_partition, n_neighbors, root_rank_); + ann_if.search(handle, search_params, query_partition, d_neighbors.view(), d_distances.view()); - const auto& root_handle = set_current_device_to_root_rank(); - const auto& comms = resource::get_comms(root_handle); + RAFT_NCCL_TRY(ncclGroupStart()); uint64_t batch_offset = rank * n_rows_per_batches * n_neighbors; - comms.device_recv(in_neighbors.data_handle() + batch_offset, + + const auto& comms = resource::get_comms(handle); + comms.device_send(d_neighbors.data_handle(), n_rows_per_batches * n_neighbors, - rank, - resource::get_cuda_stream(root_handle)); - comms.device_recv(in_distances.data_handle() + batch_offset, + root_rank_, + resource::get_cuda_stream(handle)); + comms.device_send(d_distances.data_handle(), n_rows_per_batches * n_neighbors, - rank, - resource::get_cuda_stream(root_handle)); + root_rank_, + resource::get_cuda_stream(handle)); + + const auto& root_handle = set_current_device_to_root_rank(); + const auto& root_comms = resource::get_comms(root_handle); + root_comms.device_recv(in_neighbors.data_handle() + batch_offset, + n_rows_per_batches * n_neighbors, + rank, + resource::get_cuda_stream(root_handle)); + root_comms.device_recv(in_distances.data_handle() + batch_offset, + n_rows_per_batches * n_neighbors, + rank, + resource::get_cuda_stream(root_handle)); + RAFT_NCCL_TRY(ncclGroupEnd()); } - RAFT_NCCL_TRY(ncclGroupEnd()); + #pragma omp barrier auto in_neighbors_view = raft::make_device_matrix_view( in_neighbors.data_handle(), num_ranks_ * n_rows_per_batches, n_neighbors); @@ -511,10 +498,6 @@ class ann_mg_index { out_distances.data_handle(), n_rows_per_batches * n_neighbors, resource::get_cuda_stream(root_handle_)); - - offset += n_rows_per_batches; - query_offset = offset * n_cols; - output_offset = offset * n_neighbors; } } From a6707c35111da46584296eed20d4902e6da987e7 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 6 May 2024 16:04:44 +0200 Subject: [PATCH 12/22] Answering reviews --- .gitignore | 6 - cpp/CMakeLists.txt | 8 +- .../src/raft/raft_ann_bench_param_parser.h | 2 +- ...n_mg_wrapper.h => raft_ann_mg_wrapper.hpp} | 17 +- cpp/include/raft/neighbors/ann_mg.cuh | 168 ------------------ cpp/include/raft/neighbors/ann_mg_types.hpp | 46 +++++ cpp/include/raft/neighbors/cagra_mg.cuh | 44 +++++ .../raft/neighbors/cagra_mg_serialize.cuh | 47 +++++ cpp/include/raft/neighbors/detail/ann_mg.cuh | 68 +++---- cpp/include/raft/neighbors/ivf_flat_mg.cuh | 53 ++++++ .../raft/neighbors/ivf_flat_mg_serialize.cuh | 47 +++++ cpp/include/raft/neighbors/ivf_pq_mg.cuh | 61 +++++++ .../raft/neighbors/ivf_pq_mg_serialize.cuh | 39 ++++ cpp/test/neighbors/ann_mg.cuh | 70 +++++--- 14 files changed, 428 insertions(+), 248 deletions(-) rename cpp/bench/ann/src/raft/{raft_ann_mg_wrapper.h => raft_ann_mg_wrapper.hpp} (89%) delete mode 100644 cpp/include/raft/neighbors/ann_mg.cuh create mode 100644 cpp/include/raft/neighbors/ann_mg_types.hpp create mode 100644 cpp/include/raft/neighbors/cagra_mg.cuh create mode 100644 cpp/include/raft/neighbors/cagra_mg_serialize.cuh create mode 100644 cpp/include/raft/neighbors/ivf_flat_mg.cuh create mode 100644 cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh create mode 100644 cpp/include/raft/neighbors/ivf_pq_mg.cuh create mode 100644 cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh diff --git a/.gitignore b/.gitignore index f58e4950e2..ca1fc6b922 100644 --- a/.gitignore +++ b/.gitignore @@ -67,11 +67,5 @@ _text compile_commands.json .clangd/ -ann_mg_ivf_flat_index -ann_mg_ivf_pq_index datasets/ index/ -ivf_flat_index -local_cagra_index -local_ivf_flat_index -local_ivf_pq_index diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5a7c189ccb..60df058de4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -55,6 +55,7 @@ option(BUILD_TESTS "Build raft unit-tests" ON) option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF) option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF) option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON) +option(RAFT_BUILD_ANN_MG_API "Build SNMG ANN interface" ON) option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF) option(CUDA_ENABLE_LINEINFO "Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF @@ -592,6 +593,10 @@ if(RAFT_COMPILE_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) + if(RAFT_BUILD_ANN_MG_API) + set(RAFT_BUILD_ANN_MG_DEP nccl ucp) + endif() + foreach(target raft_lib raft_lib_static raft_objs) target_link_libraries( ${target} @@ -599,8 +604,7 @@ if(RAFT_COMPILE_LIBRARY) ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this # will just be cublas $ - nccl - ucp + ${RAFT_BUILD_ANN_MG_DEP} ) # So consumers know when using libraft.so/libraft.a diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 4dfe5ed4eb..101d54107c 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -48,7 +48,7 @@ extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG -#include "raft_ann_mg_wrapper.h" +#include "raft_ann_mg_wrapper.hpp" extern template class raft::bench::ann::RaftAnnMG; extern template class raft::bench::ann::RaftAnnMG; extern template class raft::bench::ann::RaftAnnMG; diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp similarity index 89% rename from cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h rename to cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp index 47bf5ce2ab..a67dabbaa5 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp @@ -17,7 +17,8 @@ #include "../common/ann_types.hpp" #include "raft_ann_bench_utils.h" -#include +#include +#include namespace raft::bench::ann { @@ -30,13 +31,15 @@ class RaftAnnMG : public ANN, public AnnGPU { raft::neighbors::ivf_flat::search_params ivf_flat_params; }; - using BuildParam = raft::neighbors::ivf_flat::index_params; + using BuildParam = raft::neighbors::ivf_flat::dist_index_params; RaftAnnMG(Metric metric, int dim, const BuildParam& param) : ANN(metric, dim), index_params_(param), dimension_(dim) { index_params_.metric = parse_metric_type(metric); index_params_.conservative_memory_allocation = true; + index_params_.device_ids = {0, 1}; + index_params_.mode = raft::neighbors::mg::parallel_mode::REPLICATION; RAFT_CUDA_TRY(cudaGetDevice(&device_)); } @@ -79,13 +82,9 @@ class RaftAnnMG : public ANN, public AnnGPU { template void RaftAnnMG::build(const T* dataset, size_t nrow) { - std::vector device_ids{0, 1}; - raft::neighbors::mg::dist_mode d_mode = raft::neighbors::mg::dist_mode::INDEX_DUPLICATION; auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dimension_)); - auto idx = raft::neighbors::mg::build(device_ids, d_mode, index_params_, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move( - idx - )); + auto idx = raft::neighbors::mg::build(handle_, index_params_, dataset_matrix); + index_ = std::make_shared, T, IdxT>>(std::move(idx)); return; } @@ -125,7 +124,7 @@ void RaftAnnMG::search( auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(dimension_)); auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - raft::neighbors::mg::search(*index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + raft::neighbors::mg::search(handle_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); resource::sync_stream(handle_); return; } diff --git a/cpp/include/raft/neighbors/ann_mg.cuh b/cpp/include/raft/neighbors/ann_mg.cuh deleted file mode 100644 index 1b764a2216..0000000000 --- a/cpp/include/raft/neighbors/ann_mg.cuh +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::neighbors::mg { - -template -auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_flat::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, IdxT> -{ - return mg::detail::build(device_ids, mode, index_params, index_dataset); -} - -template -auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const ivf_pq::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, IdxT> -{ - return mg::detail::build(device_ids, mode, index_params, index_dataset); -} - -template -auto build(const std::vector device_ids, - raft::neighbors::mg::dist_mode mode, - const cagra::index_params& index_params, - raft::host_matrix_view index_dataset) - -> detail::ann_mg_index, T, IdxT> -{ - return mg::detail::build(device_ids, mode, index_params, index_dataset); -} - -template -void extend(detail::ann_mg_index, T, IdxT>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) -{ - mg::detail::extend(index, new_vectors, new_indices); -} - -template -void extend(detail::ann_mg_index, T, IdxT>& index, - raft::host_matrix_view new_vectors, - std::optional> new_indices) -{ - mg::detail::extend(index, new_vectors, new_indices); -} - -template -void search(const detail::ann_mg_index, T, IdxT>& index, - const ivf_flat::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) -{ - mg::detail::search(index, search_params, query_dataset, neighbors, distances); -} - -template -void search(const detail::ann_mg_index, T, IdxT>& index, - const ivf_pq::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) -{ - mg::detail::search(index, search_params, query_dataset, neighbors, distances); -} - -template -void search(const detail::ann_mg_index, T, IdxT>& index, - const cagra::search_params& search_params, - raft::host_matrix_view query_dataset, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) -{ - mg::detail::search(index, search_params, query_dataset, neighbors, distances); -} - -template -void serialize(const raft::resources& handle, - const detail::ann_mg_index, T, IdxT>& index, - const std::string& filename) -{ - mg::detail::serialize(handle, index, filename); -} - -template -void serialize(const raft::resources& handle, - const detail::ann_mg_index, T, IdxT>& index, - const std::string& filename) -{ - mg::detail::serialize(handle, index, filename); -} - -template -void serialize(const raft::resources& handle, - const detail::ann_mg_index, T, IdxT>& index, - const std::string& filename) -{ - mg::detail::serialize(handle, index, filename); -} - -template -detail::ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, - const std::string& filename) -{ - return mg::detail::deserialize_flat(handle, filename); -} - -template -detail::ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, - const std::string& filename) -{ - return mg::detail::deserialize_pq(handle, filename); -} - -template -detail::ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, - const std::string& filename) -{ - return mg::detail::deserialize_cagra(handle, filename); -} - -template -detail::ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, - const std::vector& dev_list, - const std::string& filename) -{ - return mg::detail::distribute_flat(handle, dev_list, filename); -} - -template -detail::ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, - const std::vector& dev_list, - const std::string& filename) -{ - return mg::detail::distribute_pq(handle, dev_list, filename); -} - -template -detail::ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, - const std::vector& dev_list, - const std::string& filename) -{ - return mg::detail::distribute_cagra(handle, dev_list, filename); -} - -} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ann_mg_types.hpp b/cpp/include/raft/neighbors/ann_mg_types.hpp new file mode 100644 index 0000000000..e3eeca470f --- /dev/null +++ b/cpp/include/raft/neighbors/ann_mg_types.hpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::neighbors::mg { + enum parallel_mode { REPLICATION, SHARDING }; +} + +namespace raft::neighbors::ivf_flat { + struct dist_index_params : raft::neighbors::ivf_flat::index_params { + std::vector device_ids; + raft::neighbors::mg::parallel_mode mode; + }; +} + +namespace raft::neighbors::ivf_pq { + struct dist_index_params : raft::neighbors::ivf_pq::index_params { + std::vector device_ids; + raft::neighbors::mg::parallel_mode mode; + }; +} + +namespace raft::neighbors::cagra { + struct dist_index_params : raft::neighbors::cagra::index_params { + std::vector device_ids; + raft::neighbors::mg::parallel_mode mode; + }; +} diff --git a/cpp/include/raft/neighbors/cagra_mg.cuh b/cpp/include/raft/neighbors/cagra_mg.cuh new file mode 100644 index 0000000000..1eb602429a --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_mg.cuh @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +auto build(const raft::resources& handle, + const cagra::dist_index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(handle, index_params, index_dataset); +} + +template +void search(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const cagra::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(handle, index, search_params, query_dataset, neighbors, distances); +} + +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh new file mode 100644 index 0000000000..063493eda0 --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +void serialize(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, index, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, + const std::string& filename) +{ + return mg::detail::deserialize_cagra(handle, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return mg::detail::distribute_cagra(handle, dev_list, filename); +} + +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 3539afd8a0..7da22fbb8b 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -16,11 +16,13 @@ #pragma once +#include #include #include #include #include #include +#include #undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include @@ -39,10 +41,6 @@ // Number of rows per batch (search on shards) #define N_ROWS_PER_BATCH 3000 -namespace raft::neighbors::mg { -enum dist_mode { SHARDING, INDEX_DUPLICATION }; -} - namespace raft::neighbors::mg::detail { using namespace raft::neighbors; @@ -198,7 +196,7 @@ class ann_interface { template class ann_mg_index { public: - ann_mg_index(const std::vector& dev_list, dist_mode mode = SHARDING) + ann_mg_index(const std::vector& dev_list, parallel_mode mode = SHARDING) : mode_(mode), root_rank_(0), num_ranks_(dev_list.size()), @@ -213,7 +211,7 @@ class ann_mg_index { ann_mg_index(const raft::resources& handle, const std::vector& dev_list, const std::string& filename) - : mode_(INDEX_DUPLICATION), + : mode_(REPLICATION), root_rank_(0), num_ranks_(dev_list.size()), dev_ids_(dev_list), @@ -235,7 +233,7 @@ class ann_mg_index { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - mode_ = (raft::neighbors::mg::dist_mode)deserialize_scalar(handle, is); + mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); root_rank_ = 0; num_ranks_ = deserialize_scalar(handle, is); dev_ids_.resize(num_ranks_); @@ -284,7 +282,7 @@ class ann_mg_index { void build(const ann::index_params* index_params, raft::host_matrix_view index_dataset) { - if (mode_ == INDEX_DUPLICATION) { + if (mode_ == REPLICATION) { ann_interfaces_.resize(num_ranks_); #pragma omp parallel for @@ -304,7 +302,7 @@ class ann_mg_index { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); IdxT offset = rank * n_rows_per_shard; n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); @@ -323,7 +321,7 @@ class ann_mg_index { std::optional> new_indices) { IdxT n_rows = new_vectors.extent(0); - if (mode_ == INDEX_DUPLICATION) { + if (mode_ == REPLICATION) { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); @@ -337,9 +335,9 @@ class ann_mg_index { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; - IdxT offset = rank * n_rows_per_shard; - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + IdxT offset = rank * n_rows_per_shard; + n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); auto new_vectors_part = raft::make_host_matrix_view( new_vectors_ptr, n_rows_per_shard, n_cols); @@ -363,7 +361,7 @@ class ann_mg_index { raft::host_matrix_view neighbors, raft::host_matrix_view distances) const { - if (mode_ == INDEX_DUPLICATION) { + if (mode_ == REPLICATION) { IdxT n_rows = query_dataset.extent(0); IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); @@ -371,7 +369,7 @@ class ann_mg_index { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - IdxT n_rows_per_shard = (n_rows + num_ranks_ - 1) / num_ranks_; + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); IdxT offset = rank * n_rows_per_shard; IdxT query_offset = offset * n_cols; IdxT output_offset = offset * n_neighbors; @@ -406,7 +404,7 @@ class ann_mg_index { IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); - IdxT n_batches = (n_rows + N_ROWS_PER_BATCH - 1) / N_ROWS_PER_BATCH; + IdxT n_batches = raft::ceildiv(n_rows, (IdxT)N_ROWS_PER_BATCH); const auto& root_handle = set_current_device_to_root_rank(); auto in_neighbors = raft::make_device_matrix( @@ -529,7 +527,7 @@ class ann_mg_index { } private: - dist_mode mode_; + parallel_mode mode_; int root_rank_; int num_ranks_; std::vector dev_ids_; @@ -540,42 +538,40 @@ class ann_mg_index { template ann_mg_index, T, IdxT> build( - const std::vector device_ids, - dist_mode mode, - const ivf_flat::index_params& index_params, + const raft::resources& handle, + const ivf_flat::dist_index_params& index_params, raft::host_matrix_view index_dataset) { - ann_mg_index, T, IdxT> index(device_ids, mode); + ann_mg_index, T, IdxT> index(index_params.device_ids, index_params.mode); index.build(static_cast(&index_params), index_dataset); return index; } template ann_mg_index, T, IdxT> build( - const std::vector device_ids, - dist_mode mode, - const ivf_pq::index_params& index_params, + const raft::resources& handle, + const ivf_pq::dist_index_params& index_params, raft::host_matrix_view index_dataset) { - ann_mg_index, T, IdxT> index(device_ids, mode); + ann_mg_index, T, IdxT> index(index_params.device_ids, index_params.mode); index.build(static_cast(&index_params), index_dataset); return index; } template ann_mg_index, T, IdxT> build( - const std::vector device_ids, - dist_mode mode, - const cagra::index_params& index_params, + const raft::resources& handle, + const cagra::dist_index_params& index_params, raft::host_matrix_view index_dataset) { - ann_mg_index, T, IdxT> index(device_ids, mode); + ann_mg_index, T, IdxT> index(index_params.device_ids, index_params.mode); index.build(static_cast(&index_params), index_dataset); return index; } template -void extend(ann_mg_index, T, IdxT>& index, +void extend(const raft::resources& handle, + ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { @@ -583,7 +579,8 @@ void extend(ann_mg_index, T, IdxT>& index, } template -void extend(ann_mg_index, T, IdxT>& index, +void extend(const raft::resources& handle, + ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { @@ -591,7 +588,8 @@ void extend(ann_mg_index, T, IdxT>& index, } template -void search(const ann_mg_index, T, IdxT>& index, +void search(const raft::resources& handle, + const ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, @@ -602,7 +600,8 @@ void search(const ann_mg_index, T, IdxT>& index, } template -void search(const ann_mg_index, T, IdxT>& index, +void search(const raft::resources& handle, + const ann_mg_index, T, IdxT>& index, const ivf_pq::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, @@ -613,7 +612,8 @@ void search(const ann_mg_index, T, IdxT>& index, } template -void search(const ann_mg_index, T, IdxT>& index, +void search(const raft::resources& handle, + const ann_mg_index, T, IdxT>& index, const cagra::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, diff --git a/cpp/include/raft/neighbors/ivf_flat_mg.cuh b/cpp/include/raft/neighbors/ivf_flat_mg.cuh new file mode 100644 index 0000000000..f900a930ee --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat_mg.cuh @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +auto build(const raft::resources& handle, + const ivf_flat::dist_index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(handle, index_params, index_dataset); +} + +template +void extend(const raft::resources& handle, + detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(handle, index, new_vectors, new_indices); +} + +template +void search(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const ivf_flat::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(handle, index, search_params, query_dataset, neighbors, distances); +} + +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh new file mode 100644 index 0000000000..d2bc212081 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +void serialize(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, index, filename); +} + +template +detail::ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, + const std::string& filename) +{ + return mg::detail::deserialize_flat(handle, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return mg::detail::distribute_flat(handle, dev_list, filename); +} + +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_pq_mg.cuh b/cpp/include/raft/neighbors/ivf_pq_mg.cuh new file mode 100644 index 0000000000..a3e7358c2a --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq_mg.cuh @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +auto build(const raft::resources& handle, + const ivf_pq::dist_index_params& index_params, + raft::host_matrix_view index_dataset) + -> detail::ann_mg_index, T, IdxT> +{ + return mg::detail::build(handle, index_params, index_dataset); +} + +template +void extend(const raft::resources& handle, + detail::ann_mg_index, T, IdxT>& index, + raft::host_matrix_view new_vectors, + std::optional> new_indices) +{ + mg::detail::extend(handle, index, new_vectors, new_indices); +} + +template +void search(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const ivf_pq::search_params& search_params, + raft::host_matrix_view query_dataset, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + mg::detail::search(handle, index, search_params, query_dataset, neighbors, distances); +} + +template +void serialize(const raft::resources& handle, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, index, filename); +} + +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh new file mode 100644 index 0000000000..94bd6584fe --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::mg { + +template +detail::ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, + const std::string& filename) +{ + return mg::detail::deserialize_pq(handle, filename); +} + +template +detail::ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, + const std::vector& dev_list, + const std::string& filename) +{ + return mg::detail::distribute_pq(handle, dev_list, filename); +} + +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index ddfa9ebfab..3474727c86 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -17,7 +17,15 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" -#include + +#include +#include +#include + +#include +#include +#include + #include namespace raft::neighbors::mg { @@ -78,14 +86,16 @@ class AnnMGTest : public ::testing::TestWithParam> { std::vector device_ids{0, 1}; // IVF-Flat - for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION, dist_mode::SHARDING}) { - ivf_flat::index_params index_params; + for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { + ivf_flat::dist_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.adaptive_centers = ps.adaptive_centers; index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; + index_params.device_ids = device_ids; + index_params.mode = d_mode; ivf_flat::search_params search_params; search_params.n_probes = ps.nprobe; @@ -100,12 +110,12 @@ class AnnMGTest : public ::testing::TestWithParam> { distances_ann.data(), ps.num_queries, ps.k); { - auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::extend(index, index_dataset, std::nullopt); - raft::neighbors::mg::serialize(handle_, index, "ann_mg_ivf_flat_index"); + auto index = raft::neighbors::mg::build(handle_, index_params, index_dataset); + raft::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, index, "./cpp/build/ann_mg_ivf_flat_index"); } - auto new_index = raft::neighbors::mg::deserialize_flat(handle_, "ann_mg_ivf_flat_index"); - raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); + auto new_index = raft::neighbors::mg::deserialize_flat(handle_, "./cpp/build/ann_mg_ivf_flat_index"); + raft::neighbors::mg::search(handle_, new_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -122,13 +132,15 @@ class AnnMGTest : public ::testing::TestWithParam> { } // IVF-PQ - for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION, dist_mode::SHARDING}) { - ivf_pq::index_params index_params; + for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { + ivf_pq::dist_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; + index_params.device_ids = device_ids; + index_params.mode = d_mode; ivf_pq::search_params search_params; search_params.n_probes = ps.nprobe; @@ -143,12 +155,12 @@ class AnnMGTest : public ::testing::TestWithParam> { distances_ann.data(), ps.num_queries, ps.k); { - auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::extend(index, index_dataset, std::nullopt); - raft::neighbors::mg::serialize(handle_, index, "ann_mg_ivf_pq_index"); + auto index = raft::neighbors::mg::build(handle_, index_params, index_dataset); + raft::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, index, "./cpp/build/ann_mg_ivf_pq_index"); } - auto new_index = raft::neighbors::mg::deserialize_pq(handle_, "ann_mg_ivf_pq_index"); - raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); + auto new_index = raft::neighbors::mg::deserialize_pq(handle_, "./cpp/build/ann_mg_ivf_pq_index"); + raft::neighbors::mg::search(handle_, new_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -165,12 +177,14 @@ class AnnMGTest : public ::testing::TestWithParam> { } // CAGRA - for (dist_mode d_mode : {dist_mode::INDEX_DUPLICATION, dist_mode::SHARDING}) { - cagra::index_params index_params; + for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { + cagra::dist_index_params index_params; index_params.intermediate_graph_degree = 128; index_params.graph_degree = 64; index_params.build_algo = cagra::graph_build_algo::IVF_PQ; index_params.nn_descent_niter = 20; + index_params.device_ids = device_ids; + index_params.mode = d_mode; cagra::search_params search_params; @@ -194,8 +208,8 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); */ - auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::search(index, search_params, query_dataset, neighbors, distances); + auto index = raft::neighbors::mg::build(handle_, index_params, index_dataset); + raft::neighbors::mg::search(handle_, index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); @@ -227,15 +241,15 @@ class AnnMGTest : public ::testing::TestWithParam> { { auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); auto index = raft::runtime::neighbors::ivf_flat::build(handle_, index_params, index_dataset); - ivf_flat::serialize(handle_, "local_ivf_flat_index", index); + ivf_flat::serialize(handle_, "./cpp/build/local_ivf_flat_index", index); } auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, device_ids, "local_ivf_flat_index"); - raft::neighbors::mg::search(distributed_index, search_params, query_dataset, neighbors, distances); + auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, device_ids, "./cpp/build/local_ivf_flat_index"); + raft::neighbors::mg::search(handle_, distributed_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); @@ -266,15 +280,15 @@ class AnnMGTest : public ::testing::TestWithParam> { { auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); auto index = raft::runtime::neighbors::ivf_pq::build(handle_, index_params, index_dataset); - ivf_pq::serialize(handle_, "local_ivf_pq_index", index); + ivf_pq::serialize(handle_, "./cpp/build/local_ivf_pq_index", index); } auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, device_ids, "local_ivf_pq_index"); - raft::neighbors::mg::search(distributed_index, search_params, query_dataset, neighbors, distances); + auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, device_ids, "./cpp/build/local_ivf_pq_index"); + raft::neighbors::mg::search(handle_, distributed_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); @@ -303,15 +317,15 @@ class AnnMGTest : public ::testing::TestWithParam> { { auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); auto index = raft::runtime::neighbors::cagra::build(handle_, index_params, index_dataset); - raft::neighbors::cagra::serialize(handle_, "local_cagra_index", index); + raft::neighbors::cagra::serialize(handle_, "./cpp/build/local_cagra_index", index); } auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); auto neighbors = raft::make_host_matrix_view(indices_ann_32bits.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, device_ids, "local_cagra_index"); - raft::neighbors::mg::search(distributed_index, search_params, query_dataset, neighbors, distances); + auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, device_ids, "./cpp/build/local_cagra_index"); + raft::neighbors::mg::search(handle_, distributed_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); From 4a91a7fe5c643062946086c43d43171eb9de4678 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 8 May 2024 18:03:29 +0200 Subject: [PATCH 13/22] NCCL clique helper --- .../ann/src/raft/raft_ann_mg_wrapper.hpp | 22 +- cpp/include/raft/neighbors/ann_mg_helpers.cuh | 65 +++++ cpp/include/raft/neighbors/ann_mg_types.hpp | 3 - cpp/include/raft/neighbors/cagra_mg.cuh | 6 +- .../raft/neighbors/cagra_mg_serialize.cuh | 10 +- cpp/include/raft/neighbors/detail/ann_mg.cuh | 274 +++++++++--------- cpp/include/raft/neighbors/ivf_flat_mg.cuh | 9 +- .../raft/neighbors/ivf_flat_mg_serialize.cuh | 10 +- cpp/include/raft/neighbors/ivf_pq_mg.cuh | 17 +- .../raft/neighbors/ivf_pq_mg_serialize.cuh | 18 +- cpp/test/neighbors/ann_mg.cuh | 62 ++-- 11 files changed, 294 insertions(+), 202 deletions(-) create mode 100644 cpp/include/raft/neighbors/ann_mg_helpers.cuh diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp index a67dabbaa5..eb962d6af0 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp @@ -38,7 +38,6 @@ class RaftAnnMG : public ANN, public AnnGPU { { index_params_.metric = parse_metric_type(metric); index_params_.conservative_memory_allocation = true; - index_params_.device_ids = {0, 1}; index_params_.mode = raft::neighbors::mg::parallel_mode::REPLICATION; RAFT_CUDA_TRY(cudaGetDevice(&device_)); } @@ -82,8 +81,11 @@ class RaftAnnMG : public ANN, public AnnGPU { template void RaftAnnMG::build(const T* dataset, size_t nrow) { + std::vector device_ids{0, 1}; + raft::neighbors::mg::nccl_clique clique(device_ids); + auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dimension_)); - auto idx = raft::neighbors::mg::build(handle_, index_params_, dataset_matrix); + auto idx = raft::neighbors::mg::build(handle_, clique, index_params_, dataset_matrix); index_ = std::make_shared, T, IdxT>>(std::move(idx)); return; } @@ -99,15 +101,21 @@ void RaftAnnMG::set_search_param(const AnnSearchParam& param) template void RaftAnnMG::save(const std::string& file) const { - raft::neighbors::mg::serialize(handle_, *index_, file); + std::vector device_ids{0, 1}; + raft::neighbors::mg::nccl_clique clique(device_ids); + + raft::neighbors::mg::serialize(handle_, clique, *index_, file); return; } template void RaftAnnMG::load(const std::string& file) { + std::vector device_ids{0, 1}; + raft::neighbors::mg::nccl_clique clique(device_ids); + index_ = std::make_shared, T, IdxT>>( - std::move(raft::neighbors::mg::deserialize_flat(handle_, file))); + std::move(raft::neighbors::mg::deserialize_flat(handle_, clique, file))); } template @@ -121,10 +129,14 @@ void RaftAnnMG::search( const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + + std::vector device_ids{0, 1}; + raft::neighbors::mg::nccl_clique clique(device_ids); + auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(dimension_)); auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - raft::neighbors::mg::search(handle_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + raft::neighbors::mg::search(handle_, clique, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); resource::sync_stream(handle_); return; } diff --git a/cpp/include/raft/neighbors/ann_mg_helpers.cuh b/cpp/include/raft/neighbors/ann_mg_helpers.cuh new file mode 100644 index 0000000000..ebfb82a668 --- /dev/null +++ b/cpp/include/raft/neighbors/ann_mg_helpers.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + + +namespace raft::neighbors::mg { + +struct nccl_clique { + nccl_clique(const std::vector& device_ids) + : root_rank_(0), + num_ranks_(device_ids.size()), + device_ids_(device_ids), + nccl_comms_(device_ids.size()) + { + RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data())); + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(device_ids[rank])); + device_resources_.emplace_back(); + raft::comms::build_comms_nccl_only(&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); + } + } + + const raft::device_resources& set_current_device_to_root_rank() const + { + int root_device_id = device_ids_[root_rank_]; + RAFT_CUDA_TRY(cudaSetDevice(root_device_id)); + return device_resources_[root_rank_]; + } + + ~nccl_clique() + { + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(device_ids_[rank]); + ncclCommDestroy(nccl_comms_[rank]); + } + } + + int root_rank_; + int num_ranks_; + std::vector device_ids_; + std::vector nccl_comms_; + std::vector device_resources_; +}; + +} // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ann_mg_types.hpp b/cpp/include/raft/neighbors/ann_mg_types.hpp index e3eeca470f..f1509ae328 100644 --- a/cpp/include/raft/neighbors/ann_mg_types.hpp +++ b/cpp/include/raft/neighbors/ann_mg_types.hpp @@ -26,21 +26,18 @@ namespace raft::neighbors::mg { namespace raft::neighbors::ivf_flat { struct dist_index_params : raft::neighbors::ivf_flat::index_params { - std::vector device_ids; raft::neighbors::mg::parallel_mode mode; }; } namespace raft::neighbors::ivf_pq { struct dist_index_params : raft::neighbors::ivf_pq::index_params { - std::vector device_ids; raft::neighbors::mg::parallel_mode mode; }; } namespace raft::neighbors::cagra { struct dist_index_params : raft::neighbors::cagra::index_params { - std::vector device_ids; raft::neighbors::mg::parallel_mode mode; }; } diff --git a/cpp/include/raft/neighbors/cagra_mg.cuh b/cpp/include/raft/neighbors/cagra_mg.cuh index 1eb602429a..6287bbddea 100644 --- a/cpp/include/raft/neighbors/cagra_mg.cuh +++ b/cpp/include/raft/neighbors/cagra_mg.cuh @@ -23,22 +23,24 @@ namespace raft::neighbors::mg { template auto build(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const cagra::dist_index_params& index_params, raft::host_matrix_view index_dataset) -> detail::ann_mg_index, T, IdxT> { - return mg::detail::build(handle, index_params, index_dataset); + return mg::detail::build(handle, clique, index_params, index_dataset); } template void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const detail::ann_mg_index, T, IdxT>& index, const cagra::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - mg::detail::search(handle, index, search_params, query_dataset, neighbors, distances); + mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh index 063493eda0..fd8c4b7667 100644 --- a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh @@ -23,25 +23,27 @@ namespace raft::neighbors::mg { template void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const detail::ann_mg_index, T, IdxT>& index, const std::string& filename) { - mg::detail::serialize(handle, index, filename); + mg::detail::serialize(handle, clique, index, filename); } template detail::ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return mg::detail::deserialize_cagra(handle, filename); + return mg::detail::deserialize_cagra(handle, clique, filename); } template detail::ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, - const std::vector& dev_list, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return mg::detail::distribute_cagra(handle, dev_list, filename); + return mg::detail::distribute_cagra(handle, clique, filename); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 7da22fbb8b..2528128858 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include @@ -196,60 +197,15 @@ class ann_interface { template class ann_mg_index { public: - ann_mg_index(const std::vector& dev_list, parallel_mode mode = SHARDING) + ann_mg_index(parallel_mode mode, int num_ranks_) : mode_(mode), - root_rank_(0), - num_ranks_(dev_list.size()), - dev_ids_(dev_list), - nccl_comms_(dev_list.size()) - { - init_device_resources(); - init_nccl_clique(); - } - - // index deserialization and distribution - ann_mg_index(const raft::resources& handle, - const std::vector& dev_list, - const std::string& filename) - : mode_(REPLICATION), - root_rank_(0), - num_ranks_(dev_list.size()), - dev_ids_(dev_list), - nccl_comms_(dev_list.size()) - { - init_device_resources(); - init_nccl_clique(); - - for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - auto& ann_if = ann_interfaces_.emplace_back(); - ann_if.deserialize(dev_resources_[rank], filename); - } - } + num_ranks_(num_ranks_) + {} - // MG index deserialization ann_mg_index(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - std::ifstream is(filename, std::ios::in | std::ios::binary); - if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); - root_rank_ = 0; - num_ranks_ = deserialize_scalar(handle, is); - dev_ids_.resize(num_ranks_); - std::iota(std::begin(dev_ids_), std::end(dev_ids_), 0); - nccl_comms_.resize(num_ranks_); - - init_device_resources(); - init_nccl_clique(); - - for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - auto& ann_if = ann_interfaces_.emplace_back(); - ann_if.deserialize(dev_resources_[rank], is); - } - - is.close(); + deserialize_mg_index(handle, clique, filename); } ann_mg_index(const ann_mg_index&) = delete; @@ -257,29 +213,44 @@ class ann_mg_index { auto operator=(const ann_mg_index&) -> ann_mg_index& = delete; auto operator=(ann_mg_index&&) -> ann_mg_index& = default; - void init_device_resources() { + // local index deserialization and distribution + void deserialize_and_distribute(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) + { for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - dev_resources_.emplace_back(); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.deserialize(dev_res, filename); } } - void init_nccl_clique() { - RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, dev_ids_.data())); - for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - raft::comms::build_comms_nccl_only(&dev_resources_[rank], nccl_comms_[rank], num_ranks_, rank); - } - } + // MG index deserialization + void deserialize_mg_index(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) + { + std::ifstream is(filename, std::ios::in | std::ios::binary); + if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); + num_ranks_ = deserialize_scalar(handle, is); - void destroy_nccl_clique() { for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(dev_ids_[rank]); - ncclCommDestroy(nccl_comms_[rank]); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_.emplace_back(); + ann_if.deserialize(dev_res, is); } + + is.close(); } - void build(const ann::index_params* index_params, + void build(const raft::neighbors::mg::nccl_clique& clique, + const ann::index_params* index_params, raft::host_matrix_view index_dataset) { if (mode_ == REPLICATION) { @@ -287,10 +258,12 @@ class ann_mg_index { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_[rank]; - ann_if.build(dev_resources_[rank], index_params, index_dataset); - resource::sync_stream(dev_resources_[rank]); + ann_if.build(dev_res, index_params, index_dataset); + resource::sync_stream(dev_res); } #pragma omp barrier } else if (mode_ == SHARDING) { @@ -301,7 +274,9 @@ class ann_mg_index { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); IdxT offset = rank * n_rows_per_shard; n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); @@ -309,32 +284,36 @@ class ann_mg_index { auto partition = raft::make_host_matrix_view( partition_ptr, n_rows_per_shard, n_cols); auto& ann_if = ann_interfaces_[rank]; - ann_if.build(dev_resources_[rank], index_params, partition); - resource::sync_stream(dev_resources_[rank]); + ann_if.build(dev_res, index_params, partition); + resource::sync_stream(dev_res); } #pragma omp barrier } - set_current_device_to_root_rank(); } - void extend(raft::host_matrix_view new_vectors, + void extend(const raft::neighbors::mg::nccl_clique& clique, + raft::host_matrix_view new_vectors, std::optional> new_indices) { IdxT n_rows = new_vectors.extent(0); if (mode_ == REPLICATION) { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_[rank]; - ann_if.extend(dev_resources_[rank], new_vectors, new_indices); - resource::sync_stream(dev_resources_[rank]); + ann_if.extend(dev_res, new_vectors, new_indices); + resource::sync_stream(dev_res); } #pragma omp barrier } else if (mode_ == SHARDING) { IdxT n_cols = new_vectors.extent(1); #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); IdxT offset = rank * n_rows_per_shard; n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); @@ -348,15 +327,15 @@ class ann_mg_index { new_indices_part = raft::make_host_vector_view(new_indices_ptr, n_rows_per_shard); } auto& ann_if = ann_interfaces_[rank]; - ann_if.extend(dev_resources_[rank], new_vectors_part, new_indices_part); - resource::sync_stream(dev_resources_[rank]); + ann_if.extend(dev_res, new_vectors_part, new_indices_part); + resource::sync_stream(dev_res); } #pragma omp barrier } - set_current_device_to_root_rank(); } - void search(const ann::search_params* search_params, + void search(const raft::neighbors::mg::nccl_clique& clique, + const ann::search_params* search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) const @@ -368,7 +347,9 @@ class ann_mg_index { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); IdxT offset = rank * n_rows_per_shard; IdxT query_offset = offset * n_cols; @@ -377,12 +358,11 @@ class ann_mg_index { auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_per_shard, n_cols); - auto& handle = dev_resources_[rank]; - auto d_neighbors = raft::make_device_matrix(handle, n_rows_per_shard, n_neighbors); - auto d_distances = raft::make_device_matrix(handle, n_rows_per_shard, n_neighbors); + auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_per_shard, n_neighbors); + auto d_distances = raft::make_device_matrix(dev_res, n_rows_per_shard, n_neighbors); auto& ann_if = ann_interfaces_[rank]; - ann_if.search(handle, + ann_if.search(dev_res, search_params, query_partition, d_neighbors.view(), @@ -391,12 +371,12 @@ class ann_mg_index { raft::copy(neighbors.data_handle() + output_offset, d_neighbors.data_handle(), n_rows_per_shard * n_neighbors, - resource::get_cuda_stream(handle)); + resource::get_cuda_stream(dev_res)); raft::copy(distances.data_handle() + output_offset, d_distances.data_handle(), n_rows_per_shard * n_neighbors, - resource::get_cuda_stream(handle)); - resource::sync_stream(handle); + resource::get_cuda_stream(dev_res)); + resource::sync_stream(dev_res); } #pragma omp barrier } else if (mode_ == SHARDING) { @@ -406,7 +386,7 @@ class ann_mg_index { IdxT n_batches = raft::ceildiv(n_rows, (IdxT)N_ROWS_PER_BATCH); - const auto& root_handle = set_current_device_to_root_rank(); + const auto& root_handle = clique.set_current_device_to_root_rank(); auto in_neighbors = raft::make_device_matrix( root_handle, num_ranks_ * N_ROWS_PER_BATCH, n_neighbors); auto in_distances = raft::make_device_matrix( @@ -426,28 +406,29 @@ class ann_mg_index { #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); - auto& handle = dev_resources_[rank]; - auto d_neighbors = raft::make_device_matrix(handle, n_rows_per_batches, n_neighbors); - auto d_distances = raft::make_device_matrix(handle, n_rows_per_batches, n_neighbors); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_per_batches, n_neighbors); + auto d_distances = raft::make_device_matrix(dev_res, n_rows_per_batches, n_neighbors); auto& ann_if = ann_interfaces_[rank]; - ann_if.search(handle, search_params, query_partition, d_neighbors.view(), d_distances.view()); + ann_if.search(dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); RAFT_NCCL_TRY(ncclGroupStart()); uint64_t batch_offset = rank * n_rows_per_batches * n_neighbors; - const auto& comms = resource::get_comms(handle); + const auto& comms = resource::get_comms(dev_res); comms.device_send(d_neighbors.data_handle(), n_rows_per_batches * n_neighbors, - root_rank_, - resource::get_cuda_stream(handle)); + clique.root_rank_, + resource::get_cuda_stream(dev_res)); comms.device_send(d_distances.data_handle(), n_rows_per_batches * n_neighbors, - root_rank_, - resource::get_cuda_stream(handle)); + clique.root_rank_, + resource::get_cuda_stream(dev_res)); - const auto& root_handle = set_current_device_to_root_rank(); + const auto& root_handle = clique.set_current_device_to_root_rank(); const auto& root_comms = resource::get_comms(root_handle); root_comms.device_recv(in_neighbors.data_handle() + batch_offset, n_rows_per_batches * n_neighbors, @@ -470,7 +451,7 @@ class ann_mg_index { auto out_distances_view = raft::make_device_matrix_view( out_distances.data_handle(), n_rows_per_batches, n_neighbors); - const auto& root_handle_ = set_current_device_to_root_rank(); + const auto& root_handle_ = clique.set_current_device_to_root_rank(); auto h_trans = std::vector(num_ranks_); IdxT translation_offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { @@ -498,11 +479,10 @@ class ann_mg_index { resource::get_cuda_stream(root_handle_)); } } - - set_current_device_to_root_rank(); } void serialize(raft::resources const& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) const { std::ofstream of(filename, std::ios::out | std::ios::binary); @@ -511,185 +491,197 @@ class ann_mg_index { serialize_scalar(handle, of, (int)mode_); serialize_scalar(handle, of, num_ranks_); for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[rank])); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_[rank]; - ann_if.serialize(dev_resources_[rank], of); + ann_if.serialize(dev_res, of); } of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } - inline const raft::device_resources& set_current_device_to_root_rank() const - { - RAFT_CUDA_TRY(cudaSetDevice(dev_ids_[root_rank_])); - return dev_resources_[root_rank_]; - } - private: parallel_mode mode_; - int root_rank_; int num_ranks_; - std::vector dev_ids_; - std::vector dev_resources_; std::vector> ann_interfaces_; - std::vector nccl_comms_; }; template ann_mg_index, T, IdxT> build( const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ivf_flat::dist_index_params& index_params, raft::host_matrix_view index_dataset) { - ann_mg_index, T, IdxT> index(index_params.device_ids, index_params.mode); - index.build(static_cast(&index_params), index_dataset); + ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); + index.build(clique, static_cast(&index_params), index_dataset); return index; } template ann_mg_index, T, IdxT> build( const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ivf_pq::dist_index_params& index_params, raft::host_matrix_view index_dataset) { - ann_mg_index, T, IdxT> index(index_params.device_ids, index_params.mode); - index.build(static_cast(&index_params), index_dataset); + ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); + index.build(clique, static_cast(&index_params), index_dataset); return index; } template ann_mg_index, T, IdxT> build( const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const cagra::dist_index_params& index_params, raft::host_matrix_view index_dataset) { - ann_mg_index, T, IdxT> index(index_params.device_ids, index_params.mode); - index.build(static_cast(&index_params), index_dataset); + ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); + index.build(clique, static_cast(&index_params), index_dataset); return index; } template void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - index.extend(new_vectors, new_indices); + index.extend(clique, new_vectors, new_indices); } template void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - index.extend(new_vectors, new_indices); + index.extend(clique, new_vectors, new_indices); } template void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - index.search( - static_cast(&search_params), query_dataset, neighbors, distances); + index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances); } template void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ann_mg_index, T, IdxT>& index, const ivf_pq::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - index.search( - static_cast(&search_params), query_dataset, neighbors, distances); + index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances); } template void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ann_mg_index, T, IdxT>& index, const cagra::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - index.search( - static_cast(&search_params), query_dataset, neighbors, distances); + index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances); } template void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ann_mg_index, T, IdxT>& index, const std::string& filename) { - index.serialize(handle, filename); + index.serialize(handle, clique, filename); } template void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ann_mg_index, T, IdxT>& index, const std::string& filename) { - index.serialize(handle, filename); + index.serialize(handle, clique, filename); } template void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ann_mg_index, T, IdxT>& index, const std::string& filename) { - index.serialize(handle, filename); + index.serialize(handle, clique, filename); } template ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return ann_mg_index, T, IdxT>(handle, filename); + auto index = ann_mg_index, T, IdxT>(handle, clique, filename); + return index; } template ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return ann_mg_index, T, IdxT>(handle, filename); + auto index = ann_mg_index, T, IdxT>(handle, clique, filename); + return index; } template ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return ann_mg_index, T, IdxT>(handle, filename); + auto index = ann_mg_index, T, IdxT>(handle, clique, filename); + return index; } template ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, - const std::vector& dev_list, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return ann_mg_index, T, IdxT>(handle, dev_list, filename); + auto index = ann_mg_index, T, IdxT>(REPLICATION, clique.num_ranks_); + index.deserialize_and_distribute(handle, clique, filename); + return index; } template ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, - const std::vector& dev_list, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return ann_mg_index, T, IdxT>(handle, dev_list, filename); + auto index = ann_mg_index, T, IdxT>(REPLICATION, clique.num_ranks_); + index.deserialize_and_distribute(handle, clique, filename); + return index; } template ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, - const std::vector& dev_list, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return ann_mg_index, T, IdxT>(handle, dev_list, filename); + auto index = ann_mg_index, T, IdxT>(REPLICATION, clique.num_ranks_); + index.deserialize_and_distribute(handle, clique, filename); + return index; } } // namespace raft::neighbors::mg::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_flat_mg.cuh b/cpp/include/raft/neighbors/ivf_flat_mg.cuh index f900a930ee..5ca0084f74 100644 --- a/cpp/include/raft/neighbors/ivf_flat_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_mg.cuh @@ -23,31 +23,34 @@ namespace raft::neighbors::mg { template auto build(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ivf_flat::dist_index_params& index_params, raft::host_matrix_view index_dataset) -> detail::ann_mg_index, T, IdxT> { - return mg::detail::build(handle, index_params, index_dataset); + return mg::detail::build(handle, clique, index_params, index_dataset); } template void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, detail::ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - mg::detail::extend(handle, index, new_vectors, new_indices); + mg::detail::extend(handle, clique, index, new_vectors, new_indices); } template void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const detail::ann_mg_index, T, IdxT>& index, const ivf_flat::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - mg::detail::search(handle, index, search_params, query_dataset, neighbors, distances); + mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh index d2bc212081..6f4c9d820f 100644 --- a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh @@ -23,25 +23,27 @@ namespace raft::neighbors::mg { template void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const detail::ann_mg_index, T, IdxT>& index, const std::string& filename) { - mg::detail::serialize(handle, index, filename); + mg::detail::serialize(handle, clique, index, filename); } template detail::ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return mg::detail::deserialize_flat(handle, filename); + return mg::detail::deserialize_flat(handle, clique, filename); } template detail::ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, - const std::vector& dev_list, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return mg::detail::distribute_flat(handle, dev_list, filename); + return mg::detail::distribute_flat(handle, clique, filename); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_pq_mg.cuh b/cpp/include/raft/neighbors/ivf_pq_mg.cuh index a3e7358c2a..991a013f0e 100644 --- a/cpp/include/raft/neighbors/ivf_pq_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_mg.cuh @@ -23,39 +23,34 @@ namespace raft::neighbors::mg { template auto build(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const ivf_pq::dist_index_params& index_params, raft::host_matrix_view index_dataset) -> detail::ann_mg_index, T, IdxT> { - return mg::detail::build(handle, index_params, index_dataset); + return mg::detail::build(handle, clique, index_params, index_dataset); } template void extend(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, detail::ann_mg_index, T, IdxT>& index, raft::host_matrix_view new_vectors, std::optional> new_indices) { - mg::detail::extend(handle, index, new_vectors, new_indices); + mg::detail::extend(handle, clique, index, new_vectors, new_indices); } template void search(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const detail::ann_mg_index, T, IdxT>& index, const ivf_pq::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - mg::detail::search(handle, index, search_params, query_dataset, neighbors, distances); -} - -template -void serialize(const raft::resources& handle, - const detail::ann_mg_index, T, IdxT>& index, - const std::string& filename) -{ - mg::detail::serialize(handle, index, filename); + mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh index 94bd6584fe..ca6f91e763 100644 --- a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh @@ -21,19 +21,29 @@ namespace raft::neighbors::mg { +template +void serialize(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const detail::ann_mg_index, T, IdxT>& index, + const std::string& filename) +{ + mg::detail::serialize(handle, clique, index, filename); +} + template detail::ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - return mg::detail::deserialize_pq(handle, filename); + return mg::detail::deserialize_pq(handle, clique, filename); } template detail::ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, - const std::vector& dev_list, - const std::string& filename) + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { - return mg::detail::distribute_pq(handle, dev_list, filename); + return mg::detail::distribute_pq(handle, clique, filename); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index 3474727c86..28b41c4922 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -18,6 +18,8 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include + #include #include #include @@ -94,7 +96,6 @@ class AnnMGTest : public ::testing::TestWithParam> { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - index_params.device_ids = device_ids; index_params.mode = d_mode; ivf_flat::search_params search_params; @@ -109,13 +110,14 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); + raft::neighbors::mg::nccl_clique clique(device_ids); { - auto index = raft::neighbors::mg::build(handle_, index_params, index_dataset); - raft::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - raft::neighbors::mg::serialize(handle_, index, "./cpp/build/ann_mg_ivf_flat_index"); + auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::extend(handle_, clique, index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_ivf_flat_index"); } - auto new_index = raft::neighbors::mg::deserialize_flat(handle_, "./cpp/build/ann_mg_ivf_flat_index"); - raft::neighbors::mg::search(handle_, new_index, search_params, query_dataset, neighbors, distances); + auto new_index = raft::neighbors::mg::deserialize_flat(handle_, clique, "./cpp/build/ann_mg_ivf_flat_index"); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -139,7 +141,6 @@ class AnnMGTest : public ::testing::TestWithParam> { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 1.0; index_params.metric_arg = 0; - index_params.device_ids = device_ids; index_params.mode = d_mode; ivf_pq::search_params search_params; @@ -154,13 +155,14 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); + raft::neighbors::mg::nccl_clique clique(device_ids); { - auto index = raft::neighbors::mg::build(handle_, index_params, index_dataset); - raft::neighbors::mg::extend(handle_, index, index_dataset, std::nullopt); - raft::neighbors::mg::serialize(handle_, index, "./cpp/build/ann_mg_ivf_pq_index"); + auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::extend(handle_, clique, index, index_dataset, std::nullopt); + raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_ivf_pq_index"); } - auto new_index = raft::neighbors::mg::deserialize_pq(handle_, "./cpp/build/ann_mg_ivf_pq_index"); - raft::neighbors::mg::search(handle_, new_index, search_params, query_dataset, neighbors, distances); + auto new_index = raft::neighbors::mg::deserialize_pq(handle_, clique, "./cpp/build/ann_mg_ivf_pq_index"); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -183,7 +185,6 @@ class AnnMGTest : public ::testing::TestWithParam> { index_params.graph_degree = 64; index_params.build_algo = cagra::graph_build_algo::IVF_PQ; index_params.nn_descent_niter = 20; - index_params.device_ids = device_ids; index_params.mode = d_mode; cagra::search_params search_params; @@ -208,8 +209,10 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); */ - auto index = raft::neighbors::mg::build(handle_, index_params, index_dataset); - raft::neighbors::mg::search(handle_, index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::nccl_clique clique(device_ids); + + auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::search(handle_, clique, index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); @@ -238,6 +241,8 @@ class AnnMGTest : public ::testing::TestWithParam> { ivf_flat::search_params search_params; search_params.n_probes = ps.nprobe; + RAFT_CUDA_TRY(cudaSetDevice(0)); + { auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); auto index = raft::runtime::neighbors::ivf_flat::build(handle_, index_params, index_dataset); @@ -248,8 +253,9 @@ class AnnMGTest : public ::testing::TestWithParam> { auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, device_ids, "./cpp/build/local_ivf_flat_index"); - raft::neighbors::mg::search(handle_, distributed_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::nccl_clique clique(device_ids); + auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, clique, "./cpp/build/local_ivf_flat_index"); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); @@ -277,6 +283,8 @@ class AnnMGTest : public ::testing::TestWithParam> { ivf_pq::search_params search_params; search_params.n_probes = ps.nprobe; + RAFT_CUDA_TRY(cudaSetDevice(0)); + { auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); auto index = raft::runtime::neighbors::ivf_pq::build(handle_, index_params, index_dataset); @@ -287,8 +295,9 @@ class AnnMGTest : public ::testing::TestWithParam> { auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, device_ids, "./cpp/build/local_ivf_pq_index"); - raft::neighbors::mg::search(handle_, distributed_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::nccl_clique clique(device_ids); + auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, clique, "./cpp/build/local_ivf_pq_index"); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); @@ -314,18 +323,21 @@ class AnnMGTest : public ::testing::TestWithParam> { cagra::search_params search_params; + RAFT_CUDA_TRY(cudaSetDevice(0)); + { - auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); + auto index_dataset = raft::make_device_matrix_view(d_index_dataset.data(), ps.num_db_vecs, ps.dim); auto index = raft::runtime::neighbors::cagra::build(handle_, index_params, index_dataset); raft::neighbors::cagra::serialize(handle_, "./cpp/build/local_cagra_index", index); } - auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); - auto neighbors = raft::make_host_matrix_view(indices_ann_32bits.data(), ps.num_queries, ps.k); - auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); + auto query_dataset = raft::make_host_matrix_view(h_query_dataset.data(), ps.num_queries, ps.dim); + auto neighbors = raft::make_host_matrix_view(indices_ann_32bits.data(), ps.num_queries, ps.k); + auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, device_ids, "./cpp/build/local_cagra_index"); - raft::neighbors::mg::search(handle_, distributed_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::nccl_clique clique(device_ids); + auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, clique, "./cpp/build/local_cagra_index"); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); From 0a37d636b8b04c59f5bc2496aa6cba54909d0413 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 24 May 2024 19:04:26 +0200 Subject: [PATCH 14/22] SNMG ANN IVF-Flat & IVF-PQ bench + fixes --- cpp/bench/ann/CMakeLists.txt | 26 ++- .../src/raft/raft_ann_bench_param_parser.h | 88 +++++++++- .../ann/src/raft/raft_ann_mg_ivf_flat.cu | 28 ++++ .../src/raft/raft_ann_mg_ivf_flat_wrapper.hpp | 113 +++++++++++++ cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu | 28 ++++ .../src/raft/raft_ann_mg_ivf_pq_wrapper.hpp | 117 ++++++++++++++ .../ann/src/raft/raft_ann_mg_wrapper.hpp | 153 +++++------------- cpp/bench/ann/src/raft/raft_benchmark.cu | 34 +++- cpp/include/raft/neighbors/ann_mg_helpers.cuh | 9 +- cpp/include/raft/neighbors/detail/ann_mg.cuh | 1 - cpp/test/neighbors/ann_mg.cuh | 22 ++- .../src/raft-ann-bench/run/algos.yaml | 7 +- ..._ann_mg.yaml => raft_ann_mg_ivf_flat.yaml} | 4 +- .../run/conf/algos/raft_ann_mg_ivf_pq.yaml | 17 ++ .../run/conf/mnist-784-euclidean.json | 20 +-- .../run/conf/sift-128-euclidean.json | 15 +- 16 files changed, 493 insertions(+), 189 deletions(-) create mode 100644 cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu create mode 100644 cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp create mode 100644 cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu create mode 100644 cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp rename python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/{raft_ann_mg.yaml => raft_ann_mg_ivf_flat.yaml} (63%) create mode 100644 python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_ivf_pq.yaml diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index dfa7f009ca..330a1e021f 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -29,7 +29,8 @@ option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchm option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON) -option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG "Include raft's MG ANN in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT "Include raft's MG ANN IVF-FLAT in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ "Include raft's MG ANN IVF-PQ in benchmark" ON) option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) option(RAFT_ANN_BENCH_SINGLE_EXE @@ -56,7 +57,8 @@ if(BUILD_CPU_ONLY) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF) set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF) - set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG OFF) + set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT OFF) + set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ OFF) set(RAFT_ANN_BENCH_USE_GGNN OFF) else() # Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled. @@ -92,7 +94,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_CAGRA OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB - OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG + OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT + OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -282,12 +285,25 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ) endif() -if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG) +if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT) ConfigureAnnBench( NAME - RAFT_ANN_MG + RAFT_ANN_MG_IVF_FLAT PATH bench/ann/src/raft/raft_benchmark.cu + bench/ann/src/raft/raft_ann_mg_ivf_flat.cu + LINKS + raft::compiled + ) +endif() + +if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) + ConfigureAnnBench( + NAME + RAFT_ANN_MG_IVF_PQ + PATH + bench/ann/src/raft/raft_benchmark.cu + bench/ann/src/raft/raft_ann_mg_ivf_pq.cu LINKS raft::compiled ) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 101d54107c..906ce95b5d 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -47,11 +47,17 @@ extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; #endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG -#include "raft_ann_mg_wrapper.hpp" -extern template class raft::bench::ann::RaftAnnMG; -extern template class raft::bench::ann::RaftAnnMG; -extern template class raft::bench::ann::RaftAnnMG; +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT +#include "raft_ann_mg_ivf_flat_wrapper.hpp" +extern template class raft::bench::ann::RaftAnnMG_IvfFlat; +extern template class raft::bench::ann::RaftAnnMG_IvfFlat; +extern template class raft::bench::ann::RaftAnnMG_IvfFlat; +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +#include "raft_ann_mg_ivf_pq_wrapper.hpp" +extern template class raft::bench::ann::RaftAnnMG_IvfPq; +extern template class raft::bench::ann::RaftAnnMG_IvfPq; +extern template class raft::bench::ann::RaftAnnMG_IvfPq; #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT @@ -72,10 +78,10 @@ void parse_search_param(const nlohmann::json& conf, } #endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT template void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftAnnMG::BuildParam& param) + typename raft::bench::ann::RaftAnnMG_IvfFlat::BuildParam& param) { param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -84,7 +90,7 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftAnnMG::SearchParam& param) + typename raft::bench::ann::RaftAnnMG_IvfFlat::SearchParam& param) { param.ivf_flat_params.n_probes = conf.at("nprobe"); } @@ -157,6 +163,72 @@ void parse_search_param(const nlohmann::json& conf, } #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam& param) +{ + if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } + if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } + if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } + if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } + if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } + if (conf.contains("codebook_kind")) { + std::string kind = conf.at("codebook_kind"); + if (kind == "cluster") { + param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + } else if (kind == "subspace") { + param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + } else { + throw std::runtime_error("codebook_kind: '" + kind + + "', should be either 'cluster' or 'subspace'"); + } + } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::RaftAnnMG_IvfPq::SearchParam& param) +{ + if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } + if (conf.contains("internalDistanceDtype")) { + std::string type = conf.at("internalDistanceDtype"); + if (type == "float") { + param.pq_param.internal_distance_dtype = CUDA_R_32F; + } else if (type == "half") { + param.pq_param.internal_distance_dtype = CUDA_R_16F; + } else { + throw std::runtime_error("internalDistanceDtype: '" + type + + "', should be either 'float' or 'half'"); + } + } else { + // set half as default type + param.pq_param.internal_distance_dtype = CUDA_R_16F; + } + + if (conf.contains("smemLutDtype")) { + std::string type = conf.at("smemLutDtype"); + if (type == "float") { + param.pq_param.lut_dtype = CUDA_R_32F; + } else if (type == "half") { + param.pq_param.lut_dtype = CUDA_R_16F; + } else if (type == "fp8") { + param.pq_param.lut_dtype = CUDA_R_8U; + } else { + throw std::runtime_error("smemLutDtype: '" + type + + "', should be either 'float', 'half' or 'fp8'"); + } + } else { + // set half as default + param.pq_param.lut_dtype = CUDA_R_16F; + } + if (conf.contains("refine_ratio")) { + param.refine_ratio = conf.at("refine_ratio"); + if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); } + } +} +#endif + #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) template void parse_build_param(const nlohmann::json& conf, diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu new file mode 100644 index 0000000000..4f9d04072c --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "raft_ann_mg_ivf_flat_wrapper.hpp" +#include +#include + +namespace raft::bench::ann { + + template class RaftAnnMG_IvfFlat; + template class RaftAnnMG_IvfFlat; + template class RaftAnnMG_IvfFlat; + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp new file mode 100644 index 0000000000..b739fab81b --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "raft_ann_mg_wrapper.hpp" +#include +#include + +namespace raft::bench::ann { + +template +class RaftAnnMG_IvfFlat : public RaftAnnMG { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + raft::neighbors::ivf_flat::search_params ivf_flat_params; + }; + + using BuildParam = raft::neighbors::ivf_flat::dist_index_params; + + RaftAnnMG_IvfFlat(Metric metric, int dim, const BuildParam& param) + : RaftAnnMG(metric, dim), index_params_(param) + { + this->init_nccl_clique(); + + index_params_.metric = parse_metric_type(metric); + index_params_.conservative_memory_allocation = true; + index_params_.mode = raft::neighbors::mg::parallel_mode::REPLICATION; + } + + void build(const T* dataset, size_t nrow) final; + void set_search_param(const AnnSearchParam& param) override; + void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void save(const std::string& file) const override; + void load(const std::string&) override; + std::unique_ptr> copy() override; + + private: + BuildParam index_params_; + raft::neighbors::ivf_flat::search_params search_params_; + std::shared_ptr, T, IdxT>> index_; +}; + +template +void RaftAnnMG_IvfFlat::build(const T* dataset, size_t nrow) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared, T, IdxT>>(std::move(idx)); + return; +} + +template +void RaftAnnMG_IvfFlat::set_search_param(const AnnSearchParam& param) +{ + auto search_param = dynamic_cast(param); + search_params_ = search_param.ivf_flat_params; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void RaftAnnMG_IvfFlat::save(const std::string& file) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + raft::neighbors::mg::serialize(handle, *this->clique_, *index_, file); + return; +} + +template +void RaftAnnMG_IvfFlat::load(const std::string& file) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + index_ = std::make_shared, T, IdxT>>(std::move(raft::neighbors::mg::deserialize_flat(handle, *this->clique_, file))); +} + +template +std::unique_ptr> RaftAnnMG_IvfFlat::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + +template +void RaftAnnMG_IvfFlat::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + + const auto& handle = this->clique_->set_current_device_to_root_rank(); + + auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + resource::sync_stream(handle); + return; +} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu new file mode 100644 index 0000000000..84fdc0da91 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "raft_ann_mg_ivf_pq_wrapper.hpp" +#include +#include + +namespace raft::bench::ann { + + template class RaftAnnMG_IvfPq; + template class RaftAnnMG_IvfPq; + template class RaftAnnMG_IvfPq; + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp new file mode 100644 index 0000000000..1cdb361087 --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "raft_ann_mg_wrapper.hpp" +#include +#include + +namespace raft::bench::ann { + +template +class RaftAnnMG_IvfPq : public RaftAnnMG { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + raft::neighbors::ivf_pq::search_params pq_param; + float refine_ratio = 1.0f; + auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; } + }; + + using BuildParam = raft::neighbors::ivf_pq::dist_index_params; + + RaftAnnMG_IvfPq(Metric metric, int dim, const BuildParam& param) + : RaftAnnMG(metric, dim), index_params_(param) + { + this->init_nccl_clique(); + + index_params_.metric = parse_metric_type(metric); + index_params_.conservative_memory_allocation = true; + index_params_.mode = raft::neighbors::mg::parallel_mode::REPLICATION; + } + + void build(const T* dataset, size_t nrow) final; + void set_search_param(const AnnSearchParam& param) override; + void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void save(const std::string& file) const override; + void load(const std::string&) override; + std::unique_ptr> copy() override; + + private: + BuildParam index_params_; + raft::neighbors::ivf_pq::search_params search_params_; + std::shared_ptr, T, IdxT>> index_; + float refine_ratio_ = 1.0; +}; + +template +void RaftAnnMG_IvfPq::build(const T* dataset, size_t nrow) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared, T, IdxT>>(std::move(idx)); + return; +} + +template +void RaftAnnMG_IvfPq::set_search_param(const AnnSearchParam& param) +{ + auto search_param = dynamic_cast(param); + search_params_ = search_param.pq_param; + refine_ratio_ = search_param.refine_ratio; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void RaftAnnMG_IvfPq::save(const std::string& file) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + raft::neighbors::mg::serialize(handle, *this->clique_, *index_, file); + return; +} + +template +void RaftAnnMG_IvfPq::load(const std::string& file) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + index_ = std::make_shared, T, IdxT>>(std::move(raft::neighbors::mg::deserialize_pq(handle, *this->clique_, file))); +} + +template +std::unique_ptr> RaftAnnMG_IvfPq::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + +template +void RaftAnnMG_IvfPq::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); + + const auto& handle = this->clique_->set_current_device_to_root_rank(); + + auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + resource::sync_stream(handle); + return; +} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp index eb962d6af0..39cd2e9e9b 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,131 +13,50 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #pragma once #include "../common/ann_types.hpp" #include "raft_ann_bench_utils.h" -#include -#include +#include namespace raft::bench::ann { -template +template class RaftAnnMG : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - - struct SearchParam : public AnnSearchParam { - raft::neighbors::ivf_flat::search_params ivf_flat_params; - }; - - using BuildParam = raft::neighbors::ivf_flat::dist_index_params; - - RaftAnnMG(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim), index_params_(param), dimension_(dim) - { - index_params_.metric = parse_metric_type(metric); - index_params_.conservative_memory_allocation = true; - index_params_.mode = raft::neighbors::mg::parallel_mode::REPLICATION; - RAFT_CUDA_TRY(cudaGetDevice(&device_)); - } - - void build(const T* dataset, size_t nrow) final; - - void set_search_param(const AnnSearchParam& param) override; - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; - - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return handle_.get_sync_stream(); - } - - // to enable dataset access from GPU memory - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::HostMmap; - property.query_memory_type = MemoryType::Device; - return property; - } - void save(const std::string& file) const override; - void load(const std::string&) override; - std::unique_ptr> copy() override; - - private: - // handle_ must go first to make sure it dies last and all memory allocated in pool - configured_raft_resources handle_{}; - BuildParam index_params_; - raft::neighbors::ivf_flat::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; - int device_; - int dimension_; + public: + RaftAnnMG(Metric metric, int dim) + : ANN(metric, dim), dimension_(dim) + {} + + AlgoProperty get_preference() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::HostMmap; + property.query_memory_type = MemoryType::HostMmap; + return property; + } + + protected: + void init_nccl_clique() { + int n_devices; + cudaGetDeviceCount(&n_devices); + std::cout << n_devices << " GPUs detected" << std::endl; + + std::vector device_ids(n_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); + clique_ = std::make_shared(device_ids); + } + + [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override + { + const auto& handle = clique_->set_current_device_to_root_rank(); + return resource::get_cuda_stream(handle); + } + + std::shared_ptr clique_; + int dimension_; }; -template -void RaftAnnMG::build(const T* dataset, size_t nrow) -{ - std::vector device_ids{0, 1}; - raft::neighbors::mg::nccl_clique clique(device_ids); - - auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(dimension_)); - auto idx = raft::neighbors::mg::build(handle_, clique, index_params_, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); - return; -} - -template -void RaftAnnMG::set_search_param(const AnnSearchParam& param) -{ - auto search_param = dynamic_cast(param); - search_params_ = search_param.ivf_flat_params; - assert(search_params_.n_probes <= index_params_.n_lists); -} - -template -void RaftAnnMG::save(const std::string& file) const -{ - std::vector device_ids{0, 1}; - raft::neighbors::mg::nccl_clique clique(device_ids); - - raft::neighbors::mg::serialize(handle_, clique, *index_, file); - return; -} - -template -void RaftAnnMG::load(const std::string& file) -{ - std::vector device_ids{0, 1}; - raft::neighbors::mg::nccl_clique clique(device_ids); - - index_ = std::make_shared, T, IdxT>>( - std::move(raft::neighbors::mg::deserialize_flat(handle_, clique, file))); -} - -template -std::unique_ptr> RaftAnnMG::copy() -{ - return std::make_unique>(*this); // use copy constructor -} - -template -void RaftAnnMG::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const -{ - static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); - - std::vector device_ids{0, 1}; - raft::neighbors::mg::nccl_clique clique(device_ids); - - auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(dimension_)); - auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); - auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - raft::neighbors::mg::search(handle_, clique, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); - resource::sync_stream(handle_); - return; } -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index ddf6b9f873..e136f47d3c 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -81,15 +81,25 @@ std::unique_ptr> create_algo(const std::string& algo, ann = std::make_unique>(metric, dim, param); } #endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - if (algo == "raft_ann_mg") { - typename raft::bench::ann::RaftAnnMG::BuildParam param; + if (algo == "raft_ann_mg_ivf_flat") { + typename raft::bench::ann::RaftAnnMG_IvfFlat::BuildParam param; parse_build_param(conf, param); - ann = std::make_unique>(metric, dim, param); + ann = std::make_unique>(metric, dim, param); } } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_pq") { + typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } +} #endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } @@ -132,12 +142,22 @@ std::unique_ptr::AnnSearchParam> create_search return param; } #endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - if (algo == "raft_ann_mg") { + if (algo == "raft_ann_mg_ivf_flat") { auto param = - std::make_unique::SearchParam>(); + std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + } +#endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_pq") { + auto param = std::make_unique::SearchParam>(); parse_search_param(conf, *param); return param; } diff --git a/cpp/include/raft/neighbors/ann_mg_helpers.cuh b/cpp/include/raft/neighbors/ann_mg_helpers.cuh index ebfb82a668..38a15363d1 100644 --- a/cpp/include/raft/neighbors/ann_mg_helpers.cuh +++ b/cpp/include/raft/neighbors/ann_mg_helpers.cuh @@ -19,17 +19,22 @@ #include #include #include -#include +#include +namespace raft::comms { + void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank); +} namespace raft::neighbors::mg { struct nccl_clique { + nccl_clique(const std::vector& device_ids) : root_rank_(0), num_ranks_(device_ids.size()), device_ids_(device_ids), - nccl_comms_(device_ids.size()) + nccl_comms_(device_ids.size()), + device_resources_(0) { RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data())); diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 2528128858..bbaee2f169 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index 28b41c4922..9efd526359 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -30,6 +30,9 @@ #include +#include + + namespace raft::neighbors::mg { template @@ -181,6 +184,7 @@ class AnnMGTest : public ::testing::TestWithParam> { // CAGRA for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { cagra::dist_index_params index_params; + index_params.add_data_on_build = true; index_params.intermediate_graph_degree = 128; index_params.graph_degree = 64; index_params.build_algo = cagra::graph_build_algo::IVF_PQ; @@ -198,22 +202,16 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); + raft::neighbors::mg::nccl_clique clique(device_ids); /* - TODO : fix CAGRA serialization issue - { - auto index = raft::neighbors::mg::build(device_ids, d_mode, index_params, index_dataset); - raft::neighbors::mg::serialize(handle_, index, "ann_mg_cagra_index"); + auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_cagra_index"); } - auto new_index = raft::neighbors::mg::deserialize_cagra(handle_, "ann_mg_cagra_index"); - raft::neighbors::mg::search(new_index, search_params, query_dataset, neighbors, distances); + auto new_index = raft::neighbors::mg::deserialize_cagra(handle_, clique, "./cpp/build/ann_mg_cagra_index"); */ - - raft::neighbors::mg::nccl_clique clique(device_ids); - - auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); - raft::neighbors::mg::search(handle_, clique, index, search_params, query_dataset, neighbors, distances); - + auto new_index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml index c7bc62fa87..dea6215ede 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml @@ -31,8 +31,11 @@ raft_cagra: raft_brute_force: executable: RAFT_BRUTE_FORCE_ANN_BENCH requires_gpu: true -raft_ann_mg: - executable: RAFT_ANN_MG_ANN_BENCH +raft_ann_mg_ivf_flat: + executable: RAFT_ANN_MG_IVF_FLAT_ANN_BENCH + requires_gpu: true +raft_ann_mg_ivf_pq: + executable: RAFT_ANN_MG_IVF_PQ_ANN_BENCH requires_gpu: true ggnn: executable: GGNN_ANN_BENCH diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_ivf_flat.yaml similarity index 63% rename from python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg.yaml rename to python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_ivf_flat.yaml index 9f90879ec9..760bb70ed8 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_ivf_flat.yaml @@ -1,4 +1,4 @@ -name: raft_ann_mg +name: raft_ann_mg_ivf_flat groups: base: build: @@ -6,4 +6,4 @@ groups: ratio: [1, 2, 4] niter: [20, 25] search: - nprobe: [1, 5, 10, 50, 100, 200, 500, 1000, 2000] \ No newline at end of file + nprobe: [1, 5, 10, 50, 100, 200, 500, 1000, 2000] diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_ivf_pq.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_ivf_pq.yaml new file mode 100644 index 0000000000..9f5979912c --- /dev/null +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_ivf_pq.yaml @@ -0,0 +1,17 @@ +name: raft_ann_mg_ivf_pq +constraints: + build: raft-ann-bench.constraints.raft_ivf_pq_build_constraints + search: raft-ann-bench.constraints.raft_ivf_pq_search_constraints +groups: + base: + build: + nlist: [1024, 2048, 4096, 8192] + pq_dim: [64, 32] + pq_bits: [8, 6, 5, 4] + ratio: [10, 25] + niter: [25] + search: + nprobe: [1, 5, 10, 50, 100, 200] + internalDistanceDtype: ["float"] + smemLutDtype: ["float", "fp8", "half"] + refine_ratio: [1, 2, 4] diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json index 732ff109e7..e7aecfd239 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/mnist-784-euclidean.json @@ -1278,24 +1278,6 @@ ], "search_result_file": "result/mnist-784-euclidean/raft_ivf_flat/nlist1024" }, - { - "name": "raft_ann_mg.nlist1024", - "algo": "raft_ann_mg", - "build_param": { - "nlist": 1024, - "ratio": 1, - "niter": 25 - }, - "file": "index/mnist-784-euclidean/raft_ann_mg/nlist1024", - "dataset_memory_type": "host", - "query_memory_type": "host", - "search_params": [ - { - "nprobe": 5 - } - ], - "search_result_file": "result/mnist-784-euclidean/raft_ann_mg/nlist1024" - }, { "name": "raft_ivf_flat.nlist16384", "algo": "raft_ivf_flat", @@ -1367,4 +1349,4 @@ "search_result_file" : "result/mnist-784-euclidean/raft_cagra/dim64" } ] -} +} \ No newline at end of file diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json b/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json index d4bd4dda15..3ca47a2566 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/sift-128-euclidean.json @@ -472,19 +472,6 @@ {"nprobe": 2000} ] }, - { - "name": "raft_ann_mg.nlist16384", - "algo": "raft_ann_mg", - "build_param": {"nlist": 16384, "ratio": 2, "niter": 20}, - "file": "sift-128-euclidean/raft_ann_mg/nlist16384", - "dataset_memory_type": "host", - "query_memory_type": "host", - "search_params": [ - {"nprobe": 100}, - {"nprobe": 200}, - {"nprobe": 500} - ] - }, { "name": "raft_cagra.dim32", "algo": "raft_cagra", @@ -508,4 +495,4 @@ ] } ] -} +} \ No newline at end of file From 84176848d2e622c5e0479fb215d78ac87a959b5d Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 28 May 2024 15:35:30 +0200 Subject: [PATCH 15/22] Fixes & improvements --- .../src/raft/raft_ann_mg_ivf_flat_wrapper.hpp | 6 +- .../src/raft/raft_ann_mg_ivf_pq_wrapper.hpp | 6 +- .../ann/src/raft/raft_ann_mg_wrapper.hpp | 7 +- cpp/include/raft/neighbors/ann_mg_types.hpp | 6 +- cpp/include/raft/neighbors/cagra_mg.cuh | 9 +- cpp/include/raft/neighbors/detail/ann_mg.cuh | 211 ++++++++++-------- cpp/include/raft/neighbors/ivf_flat_mg.cuh | 7 +- cpp/include/raft/neighbors/ivf_pq_mg.cuh | 7 +- cpp/test/neighbors/ann_mg.cuh | 29 ++- 9 files changed, 163 insertions(+), 125 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp index b739fab81b..6af51cb2f7 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp @@ -31,16 +31,14 @@ class RaftAnnMG_IvfFlat : public RaftAnnMG { raft::neighbors::ivf_flat::search_params ivf_flat_params; }; - using BuildParam = raft::neighbors::ivf_flat::dist_index_params; + using BuildParam = raft::neighbors::ivf_flat::mg_index_params; RaftAnnMG_IvfFlat(Metric metric, int dim, const BuildParam& param) : RaftAnnMG(metric, dim), index_params_(param) { - this->init_nccl_clique(); - index_params_.metric = parse_metric_type(metric); index_params_.conservative_memory_allocation = true; - index_params_.mode = raft::neighbors::mg::parallel_mode::REPLICATION; + index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDING; } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp index 1cdb361087..fc99ca3c94 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp @@ -33,16 +33,14 @@ class RaftAnnMG_IvfPq : public RaftAnnMG { auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; } }; - using BuildParam = raft::neighbors::ivf_pq::dist_index_params; + using BuildParam = raft::neighbors::ivf_pq::mg_index_params; RaftAnnMG_IvfPq(Metric metric, int dim, const BuildParam& param) : RaftAnnMG(metric, dim), index_params_(param) { - this->init_nccl_clique(); - index_params_.metric = parse_metric_type(metric); index_params_.conservative_memory_allocation = true; - index_params_.mode = raft::neighbors::mg::parallel_mode::REPLICATION; + index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDING; } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp index 39cd2e9e9b..0cade61f03 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp @@ -28,7 +28,9 @@ class RaftAnnMG : public ANN, public AnnGPU { public: RaftAnnMG(Metric metric, int dim) : ANN(metric, dim), dimension_(dim) - {} + { + this->init_nccl_clique(); + } AlgoProperty get_preference() const override { @@ -38,7 +40,7 @@ class RaftAnnMG : public ANN, public AnnGPU { return property; } - protected: + private: void init_nccl_clique() { int n_devices; cudaGetDeviceCount(&n_devices); @@ -55,6 +57,7 @@ class RaftAnnMG : public ANN, public AnnGPU { return resource::get_cuda_stream(handle); } + protected: std::shared_ptr clique_; int dimension_; }; diff --git a/cpp/include/raft/neighbors/ann_mg_types.hpp b/cpp/include/raft/neighbors/ann_mg_types.hpp index f1509ae328..31e8e75a9b 100644 --- a/cpp/include/raft/neighbors/ann_mg_types.hpp +++ b/cpp/include/raft/neighbors/ann_mg_types.hpp @@ -25,19 +25,19 @@ namespace raft::neighbors::mg { } namespace raft::neighbors::ivf_flat { - struct dist_index_params : raft::neighbors::ivf_flat::index_params { + struct mg_index_params : raft::neighbors::ivf_flat::index_params { raft::neighbors::mg::parallel_mode mode; }; } namespace raft::neighbors::ivf_pq { - struct dist_index_params : raft::neighbors::ivf_pq::index_params { + struct mg_index_params : raft::neighbors::ivf_pq::index_params { raft::neighbors::mg::parallel_mode mode; }; } namespace raft::neighbors::cagra { - struct dist_index_params : raft::neighbors::cagra::index_params { + struct mg_index_params : raft::neighbors::cagra::index_params { raft::neighbors::mg::parallel_mode mode; }; } diff --git a/cpp/include/raft/neighbors/cagra_mg.cuh b/cpp/include/raft/neighbors/cagra_mg.cuh index 6287bbddea..62be51ca7b 100644 --- a/cpp/include/raft/neighbors/cagra_mg.cuh +++ b/cpp/include/raft/neighbors/cagra_mg.cuh @@ -24,7 +24,7 @@ namespace raft::neighbors::mg { template auto build(const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const cagra::dist_index_params& index_params, + const cagra::mg_index_params& index_params, raft::host_matrix_view index_dataset) -> detail::ann_mg_index, T, IdxT> { @@ -38,9 +38,10 @@ void search(const raft::resources& handle, const cagra::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view distances, + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances); + mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } - + // 2^20 } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index bbaee2f169..6c6aa690c7 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -253,9 +253,11 @@ class ann_mg_index { raft::host_matrix_view index_dataset) { if (mode_ == REPLICATION) { - ann_interfaces_.resize(num_ranks_); + IdxT n_rows = index_dataset.extent(0); + std::cout << "REPLICATION BUILD: " << num_ranks_ << "x" << n_rows << "rows" << std::endl; - #pragma omp parallel for + ann_interfaces_.resize(num_ranks_); + #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -266,22 +268,22 @@ class ann_mg_index { } #pragma omp barrier } else if (mode_ == SHARDING) { - IdxT n_rows = index_dataset.extent(0); - IdxT n_cols = index_dataset.extent(1); + IdxT n_rows = index_dataset.extent(0); + IdxT n_cols = index_dataset.extent(1); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); - ann_interfaces_.resize(num_ranks_); + std::cout << "SHARDED BUILD: " << num_ranks_ << "x" << n_rows_per_shard << "rows" << std::endl; - #pragma omp parallel for + ann_interfaces_.resize(num_ranks_); + #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); - IdxT offset = rank * n_rows_per_shard; - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view( - partition_ptr, n_rows_per_shard, n_cols); + auto partition = raft::make_host_matrix_view(partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = ann_interfaces_[rank]; ann_if.build(dev_res, index_params, partition); resource::sync_stream(dev_res); @@ -296,7 +298,9 @@ class ann_mg_index { { IdxT n_rows = new_vectors.extent(0); if (mode_ == REPLICATION) { - #pragma omp parallel for + std::cout << "REPLICATION EXTEND: " << num_ranks_ << "x" << n_rows << "rows" << std::endl; + + #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -308,22 +312,24 @@ class ann_mg_index { #pragma omp barrier } else if (mode_ == SHARDING) { IdxT n_cols = new_vectors.extent(1); - #pragma omp parallel for + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + + std::cout << "SHARDED EXTEND: " << num_ranks_ << "x" << n_rows_per_shard << "rows" << std::endl; + + #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); - IdxT offset = rank * n_rows_per_shard; - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view( - new_vectors_ptr, n_rows_per_shard, n_cols); + auto new_vectors_part = raft::make_host_matrix_view(new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; if (new_indices.has_value()) { const IdxT* new_indices_ptr = new_indices.value().data_handle() + offset; - new_indices_part = raft::make_host_vector_view(new_indices_ptr, n_rows_per_shard); + new_indices_part = raft::make_host_vector_view(new_indices_ptr, n_rows_of_current_shard); } auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_res, new_vectors_part, new_indices_part); @@ -337,28 +343,35 @@ class ann_mg_index { const ann::search_params* search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) const + raft::host_matrix_view distances, + IdxT n_rows_per_batch) const { + IdxT n_rows = query_dataset.extent(0); + IdxT n_cols = query_dataset.extent(1); + IdxT n_neighbors = neighbors.extent(1); + + IdxT n_batches = raft::ceildiv(n_rows, (IdxT)n_rows_per_batch); + if (n_batches == 1) + n_rows_per_batch = n_rows; + if (mode_ == REPLICATION) { - IdxT n_rows = query_dataset.extent(0); - IdxT n_cols = query_dataset.extent(1); - IdxT n_neighbors = neighbors.extent(1); + std::cout << "REPLICATION SEARCH: " << n_batches << "x" << n_rows_per_batch << "rows" << std::endl; - #pragma omp parallel for - for (int rank = 0; rank < num_ranks_; rank++) { + #pragma omp parallel for num_threads(num_ranks_) // avoid oversubscribing any given GPU + for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { + int rank = batch_idx % num_ranks_; // alternate GPUs int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); - IdxT offset = rank * n_rows_per_shard; + + IdxT offset = batch_idx * n_rows_per_batch; IdxT query_offset = offset * n_cols; IdxT output_offset = offset * n_neighbors; - n_rows_per_shard = std::min(n_rows_per_shard, n_rows - offset); - auto query_partition = raft::make_host_matrix_view( - query_dataset.data_handle() + query_offset, n_rows_per_shard, n_cols); + IdxT n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset); - auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_per_shard, n_neighbors); - auto d_distances = raft::make_device_matrix(dev_res, n_rows_per_shard, n_neighbors); + auto query_partition = raft::make_host_matrix_view(query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); + auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); auto& ann_if = ann_interfaces_[rank]; ann_if.search(dev_res, @@ -369,86 +382,97 @@ class ann_mg_index { raft::copy(neighbors.data_handle() + output_offset, d_neighbors.data_handle(), - n_rows_per_shard * n_neighbors, + n_rows_of_current_batch * n_neighbors, resource::get_cuda_stream(dev_res)); raft::copy(distances.data_handle() + output_offset, d_distances.data_handle(), - n_rows_per_shard * n_neighbors, + n_rows_of_current_batch * n_neighbors, resource::get_cuda_stream(dev_res)); + resource::sync_stream(dev_res); } #pragma omp barrier } else if (mode_ == SHARDING) { - IdxT n_rows = query_dataset.extent(0); - IdxT n_cols = query_dataset.extent(1); - IdxT n_neighbors = neighbors.extent(1); - - IdxT n_batches = raft::ceildiv(n_rows, (IdxT)N_ROWS_PER_BATCH); + std::cout << "SHARDED SEARCH: " << n_batches << "x" << n_rows_per_batch << "rows" << std::endl; const auto& root_handle = clique.set_current_device_to_root_rank(); auto in_neighbors = raft::make_device_matrix( - root_handle, num_ranks_ * N_ROWS_PER_BATCH, n_neighbors); + root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); auto in_distances = raft::make_device_matrix( - root_handle, num_ranks_ * N_ROWS_PER_BATCH, n_neighbors); + root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); auto out_neighbors = raft::make_device_matrix( - root_handle, N_ROWS_PER_BATCH, n_neighbors); + root_handle, n_rows_per_batch, n_neighbors); auto out_distances = raft::make_device_matrix( - root_handle, N_ROWS_PER_BATCH, n_neighbors); + root_handle, n_rows_per_batch, n_neighbors); for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { IdxT offset = batch_idx * N_ROWS_PER_BATCH; IdxT query_offset = offset * n_cols; IdxT output_offset = offset * n_neighbors; - IdxT n_rows_per_batches = std::min((IdxT)N_ROWS_PER_BATCH, n_rows - offset); + IdxT n_rows_of_current_batch = std::min((IdxT)n_rows_per_batch, n_rows - offset); auto query_partition = raft::make_host_matrix_view( - query_dataset.data_handle() + query_offset, n_rows_per_batches, n_cols); + query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); - #pragma omp parallel for + #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_per_batches, n_neighbors); - auto d_distances = raft::make_device_matrix(dev_res, n_rows_per_batches, n_neighbors); - auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); - - RAFT_NCCL_TRY(ncclGroupStart()); - uint64_t batch_offset = rank * n_rows_per_batches * n_neighbors; - const auto& comms = resource::get_comms(dev_res); - comms.device_send(d_neighbors.data_handle(), - n_rows_per_batches * n_neighbors, - clique.root_rank_, - resource::get_cuda_stream(dev_res)); - comms.device_send(d_distances.data_handle(), - n_rows_per_batches * n_neighbors, - clique.root_rank_, - resource::get_cuda_stream(dev_res)); - - const auto& root_handle = clique.set_current_device_to_root_rank(); - const auto& root_comms = resource::get_comms(root_handle); - root_comms.device_recv(in_neighbors.data_handle() + batch_offset, - n_rows_per_batches * n_neighbors, - rank, - resource::get_cuda_stream(root_handle)); - root_comms.device_recv(in_distances.data_handle() + batch_offset, - n_rows_per_batches * n_neighbors, - rank, - resource::get_cuda_stream(root_handle)); - RAFT_NCCL_TRY(ncclGroupEnd()); + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + + if (rank == clique.root_rank_) { // root rank + uint64_t batch_offset = clique.root_rank_ * n_rows_of_current_batch * n_neighbors; + auto d_neighbors = raft::make_device_matrix_view(in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix_view(in_distances.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + ann_if.search(dev_res, search_params, query_partition, d_neighbors, d_distances); // write search results inplace + + // wait for results of other ranks + RAFT_NCCL_TRY(ncclGroupStart()); + for (int from_rank = 0; from_rank < num_ranks_; from_rank++) { + if (from_rank == clique.root_rank_) + continue; + + batch_offset = from_rank * n_rows_of_current_batch * n_neighbors; + comms.device_recv(in_neighbors.data_handle() + batch_offset, + n_rows_of_current_batch * n_neighbors, + from_rank, + resource::get_cuda_stream(dev_res)); + comms.device_recv(in_distances.data_handle() + batch_offset, + n_rows_of_current_batch * n_neighbors, + from_rank, + resource::get_cuda_stream(dev_res)); + } + RAFT_NCCL_TRY(ncclGroupEnd()); + resource::sync_stream(dev_res); + } else { // non-root ranks + auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); + ann_if.search(dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); + + RAFT_NCCL_TRY(ncclGroupStart()); + comms.device_send(d_neighbors.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + comms.device_send(d_distances.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + RAFT_NCCL_TRY(ncclGroupEnd()); + resource::sync_stream(dev_res); + } } #pragma omp barrier auto in_neighbors_view = raft::make_device_matrix_view( - in_neighbors.data_handle(), num_ranks_ * n_rows_per_batches, n_neighbors); + in_neighbors.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); auto in_distances_view = raft::make_device_matrix_view( - in_distances.data_handle(), num_ranks_ * n_rows_per_batches, n_neighbors); + in_distances.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); auto out_neighbors_view = raft::make_device_matrix_view( - out_neighbors.data_handle(), n_rows_per_batches, n_neighbors); + out_neighbors.data_handle(), n_rows_of_current_batch, n_neighbors); auto out_distances_view = raft::make_device_matrix_view( - out_distances.data_handle(), n_rows_per_batches, n_neighbors); + out_distances.data_handle(), n_rows_of_current_batch, n_neighbors); const auto& root_handle_ = clique.set_current_device_to_root_rank(); auto h_trans = std::vector(num_ranks_); @@ -465,17 +489,19 @@ class ann_mg_index { in_neighbors_view, out_distances_view, out_neighbors_view, - n_rows_per_batches, + n_rows_of_current_batch, translations); raft::copy(neighbors.data_handle() + output_offset, out_neighbors.data_handle(), - n_rows_per_batches * n_neighbors, + n_rows_of_current_batch * n_neighbors, resource::get_cuda_stream(root_handle_)); raft::copy(distances.data_handle() + output_offset, out_distances.data_handle(), - n_rows_per_batches * n_neighbors, + n_rows_of_current_batch * n_neighbors, resource::get_cuda_stream(root_handle_)); + + resource::sync_stream(root_handle_); } } } @@ -511,7 +537,7 @@ template ann_mg_index, T, IdxT> build( const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const ivf_flat::dist_index_params& index_params, + const ivf_flat::mg_index_params& index_params, raft::host_matrix_view index_dataset) { ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); @@ -523,7 +549,7 @@ template ann_mg_index, T, IdxT> build( const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const ivf_pq::dist_index_params& index_params, + const ivf_pq::mg_index_params& index_params, raft::host_matrix_view index_dataset) { ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); @@ -535,7 +561,7 @@ template ann_mg_index, T, IdxT> build( const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const cagra::dist_index_params& index_params, + const cagra::mg_index_params& index_params, raft::host_matrix_view index_dataset) { ann_mg_index, T, IdxT> index(index_params.mode, clique.num_ranks_); @@ -570,9 +596,10 @@ void search(const raft::resources& handle, const ivf_flat::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view distances, + uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances); + index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); } template @@ -582,9 +609,10 @@ void search(const raft::resources& handle, const ivf_pq::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view distances, + uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances); + index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); } template @@ -594,9 +622,10 @@ void search(const raft::resources& handle, const cagra::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view distances, + uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances); + index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); } template diff --git a/cpp/include/raft/neighbors/ivf_flat_mg.cuh b/cpp/include/raft/neighbors/ivf_flat_mg.cuh index 5ca0084f74..12f76babe7 100644 --- a/cpp/include/raft/neighbors/ivf_flat_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_mg.cuh @@ -24,7 +24,7 @@ namespace raft::neighbors::mg { template auto build(const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const ivf_flat::dist_index_params& index_params, + const ivf_flat::mg_index_params& index_params, raft::host_matrix_view index_dataset) -> detail::ann_mg_index, T, IdxT> { @@ -48,9 +48,10 @@ void search(const raft::resources& handle, const ivf_flat::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view distances, + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances); + mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_pq_mg.cuh b/cpp/include/raft/neighbors/ivf_pq_mg.cuh index 991a013f0e..7cad3449d0 100644 --- a/cpp/include/raft/neighbors/ivf_pq_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_mg.cuh @@ -24,7 +24,7 @@ namespace raft::neighbors::mg { template auto build(const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const ivf_pq::dist_index_params& index_params, + const ivf_pq::mg_index_params& index_params, raft::host_matrix_view index_dataset) -> detail::ann_mg_index, T, IdxT> { @@ -48,9 +48,10 @@ void search(const raft::resources& handle, const ivf_pq::search_params& search_params, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, - raft::host_matrix_view distances) + raft::host_matrix_view distances, + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances); + mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } } // namespace raft::neighbors::mg \ No newline at end of file diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index 9efd526359..3c3e135651 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -88,11 +88,18 @@ class AnnMGTest : public ::testing::TestWithParam> { resource::sync_stream(handle_); } - std::vector device_ids{0, 1}; + uint64_t n_rows_per_batch = 3000; // [3000, 3000, 1000] == 7000 rows + + int n_devices; + cudaGetDeviceCount(&n_devices); + std::cout << n_devices << " GPUs detected" << std::endl; + + std::vector device_ids(n_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); // IVF-Flat for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { - ivf_flat::dist_index_params index_params; + ivf_flat::mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.adaptive_centers = ps.adaptive_centers; @@ -120,7 +127,7 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_ivf_flat_index"); } auto new_index = raft::neighbors::mg::deserialize_flat(handle_, clique, "./cpp/build/ann_mg_ivf_flat_index"); - raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -138,7 +145,7 @@ class AnnMGTest : public ::testing::TestWithParam> { // IVF-PQ for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { - ivf_pq::dist_index_params index_params; + ivf_pq::mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; index_params.add_data_on_build = false; @@ -165,7 +172,7 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_ivf_pq_index"); } auto new_index = raft::neighbors::mg::deserialize_pq(handle_, clique, "./cpp/build/ann_mg_ivf_pq_index"); - raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -183,7 +190,7 @@ class AnnMGTest : public ::testing::TestWithParam> { // CAGRA for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { - cagra::dist_index_params index_params; + cagra::mg_index_params index_params; index_params.add_data_on_build = true; index_params.intermediate_graph_degree = 128; index_params.graph_degree = 64; @@ -211,7 +218,7 @@ class AnnMGTest : public ::testing::TestWithParam> { auto new_index = raft::neighbors::mg::deserialize_cagra(handle_, clique, "./cpp/build/ann_mg_cagra_index"); */ auto new_index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); - raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); resource::sync_stream(handle_); double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); @@ -253,7 +260,7 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::neighbors::mg::nccl_clique clique(device_ids); auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, clique, "./cpp/build/local_ivf_flat_index"); - raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); resource::sync_stream(handle_); @@ -295,7 +302,7 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::neighbors::mg::nccl_clique clique(device_ids); auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, clique, "./cpp/build/local_ivf_pq_index"); - raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); resource::sync_stream(handle_); @@ -335,7 +342,7 @@ class AnnMGTest : public ::testing::TestWithParam> { raft::neighbors::mg::nccl_clique clique(device_ids); auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, clique, "./cpp/build/local_cagra_index"); - raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances); + raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); resource::sync_stream(handle_); @@ -405,6 +412,6 @@ class AnnMGTest : public ::testing::TestWithParam> { }; const std::vector> inputs = { - {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, + {7000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, }; } // namespace raft::neighbors::mg From 34d4fd3b921b29dd296ef1a825c3cefa80ac1153 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 28 May 2024 18:01:45 +0200 Subject: [PATCH 16/22] Setting NCCL init apart for bench --- .../src/raft/raft_ann_mg_ivf_flat_wrapper.hpp | 2 +- .../src/raft/raft_ann_mg_ivf_pq_wrapper.hpp | 2 +- cpp/include/raft/neighbors/ann_mg_helpers.cuh | 9 ++++++ cpp/include/raft/neighbors/ann_mg_types.hpp | 2 +- cpp/include/raft/neighbors/detail/ann_mg.cuh | 30 +++++++++---------- cpp/test/neighbors/ann_mg.cuh | 18 ++++------- 6 files changed, 33 insertions(+), 30 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp index 6af51cb2f7..5308f07161 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp @@ -38,7 +38,7 @@ class RaftAnnMG_IvfFlat : public RaftAnnMG { { index_params_.metric = parse_metric_type(metric); index_params_.conservative_memory_allocation = true; - index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDING; + index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDED; } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp index fc99ca3c94..8e7049d971 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp @@ -40,7 +40,7 @@ class RaftAnnMG_IvfPq : public RaftAnnMG { { index_params_.metric = parse_metric_type(metric); index_params_.conservative_memory_allocation = true; - index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDING; + index_params_.mode = raft::neighbors::mg::parallel_mode::SHARDED; } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/include/raft/neighbors/ann_mg_helpers.cuh b/cpp/include/raft/neighbors/ann_mg_helpers.cuh index 38a15363d1..24bb29767f 100644 --- a/cpp/include/raft/neighbors/ann_mg_helpers.cuh +++ b/cpp/include/raft/neighbors/ann_mg_helpers.cuh @@ -36,6 +36,8 @@ struct nccl_clique { nccl_comms_(device_ids.size()), device_resources_(0) { + RAFT_LOG_INFO("Starting NCCL initialization..."); + RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data())); for (int rank = 0; rank < num_ranks_; rank++) { @@ -43,6 +45,13 @@ struct nccl_clique { device_resources_.emplace_back(); raft::comms::build_comms_nccl_only(&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); } + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(device_ids[rank])); + resource::sync_stream(device_resources_[rank]); + } + + RAFT_LOG_INFO("NCCL initialization completed"); } const raft::device_resources& set_current_device_to_root_rank() const diff --git a/cpp/include/raft/neighbors/ann_mg_types.hpp b/cpp/include/raft/neighbors/ann_mg_types.hpp index 31e8e75a9b..4242f3fb9f 100644 --- a/cpp/include/raft/neighbors/ann_mg_types.hpp +++ b/cpp/include/raft/neighbors/ann_mg_types.hpp @@ -21,7 +21,7 @@ #include namespace raft::neighbors::mg { - enum parallel_mode { REPLICATION, SHARDING }; + enum parallel_mode { REPLICATED, SHARDED }; } namespace raft::neighbors::ivf_flat { diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 6c6aa690c7..563c922d45 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -252,9 +252,9 @@ class ann_mg_index { const ann::index_params* index_params, raft::host_matrix_view index_dataset) { - if (mode_ == REPLICATION) { + if (mode_ == REPLICATED) { IdxT n_rows = index_dataset.extent(0); - std::cout << "REPLICATION BUILD: " << num_ranks_ << "x" << n_rows << "rows" << std::endl; + RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", num_ranks_, n_rows); ann_interfaces_.resize(num_ranks_); #pragma omp parallel for num_threads(num_ranks_) @@ -267,12 +267,12 @@ class ann_mg_index { resource::sync_stream(dev_res); } #pragma omp barrier - } else if (mode_ == SHARDING) { + } else if (mode_ == SHARDED) { IdxT n_rows = index_dataset.extent(0); IdxT n_cols = index_dataset.extent(1); IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); - std::cout << "SHARDED BUILD: " << num_ranks_ << "x" << n_rows_per_shard << "rows" << std::endl; + RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", num_ranks_, n_rows_per_shard); ann_interfaces_.resize(num_ranks_); #pragma omp parallel for num_threads(num_ranks_) @@ -297,8 +297,8 @@ class ann_mg_index { std::optional> new_indices) { IdxT n_rows = new_vectors.extent(0); - if (mode_ == REPLICATION) { - std::cout << "REPLICATION EXTEND: " << num_ranks_ << "x" << n_rows << "rows" << std::endl; + if (mode_ == REPLICATED) { + RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", num_ranks_, n_rows); #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { @@ -310,11 +310,11 @@ class ann_mg_index { resource::sync_stream(dev_res); } #pragma omp barrier - } else if (mode_ == SHARDING) { + } else if (mode_ == SHARDED) { IdxT n_cols = new_vectors.extent(1); IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); - std::cout << "SHARDED EXTEND: " << num_ranks_ << "x" << n_rows_per_shard << "rows" << std::endl; + RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", num_ranks_, n_rows_per_shard); #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { @@ -354,8 +354,8 @@ class ann_mg_index { if (n_batches == 1) n_rows_per_batch = n_rows; - if (mode_ == REPLICATION) { - std::cout << "REPLICATION SEARCH: " << n_batches << "x" << n_rows_per_batch << "rows" << std::endl; + if (mode_ == REPLICATED) { + RAFT_LOG_INFO("REPLICATED SEARCH: %d*%drows", n_batches, n_rows_per_batch); #pragma omp parallel for num_threads(num_ranks_) // avoid oversubscribing any given GPU for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { @@ -392,8 +392,8 @@ class ann_mg_index { resource::sync_stream(dev_res); } #pragma omp barrier - } else if (mode_ == SHARDING) { - std::cout << "SHARDED SEARCH: " << n_batches << "x" << n_rows_per_batch << "rows" << std::endl; + } else if (mode_ == SHARDED) { + RAFT_LOG_INFO("SHARDED SEARCH: %d*%drows", n_batches, n_rows_per_batch); const auto& root_handle = clique.set_current_device_to_root_rank(); auto in_neighbors = raft::make_device_matrix( @@ -687,7 +687,7 @@ ann_mg_index, T, IdxT> distribute_flat(const raft::reso const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - auto index = ann_mg_index, T, IdxT>(REPLICATION, clique.num_ranks_); + auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); return index; } @@ -697,7 +697,7 @@ ann_mg_index, T, IdxT> distribute_pq(const raft::resources& const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - auto index = ann_mg_index, T, IdxT>(REPLICATION, clique.num_ranks_); + auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); return index; } @@ -707,7 +707,7 @@ ann_mg_index, T, IdxT> distribute_cagra(const raft::resour const raft::neighbors::mg::nccl_clique& clique, const std::string& filename) { - auto index = ann_mg_index, T, IdxT>(REPLICATION, clique.num_ranks_); + auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); return index; } diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index 3c3e135651..f7935102c5 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -88,17 +88,17 @@ class AnnMGTest : public ::testing::TestWithParam> { resource::sync_stream(handle_); } - uint64_t n_rows_per_batch = 3000; // [3000, 3000, 1000] == 7000 rows - int n_devices; cudaGetDeviceCount(&n_devices); std::cout << n_devices << " GPUs detected" << std::endl; - std::vector device_ids(n_devices); std::iota(device_ids.begin(), device_ids.end(), 0); + uint64_t n_rows_per_batch = 3000; // [3000, 3000, 1000] == 7000 rows + raft::neighbors::mg::nccl_clique clique(device_ids); + // IVF-Flat - for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { + for (parallel_mode d_mode : {parallel_mode::REPLICATED, parallel_mode::SHARDED}) { ivf_flat::mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; @@ -120,7 +120,6 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); - raft::neighbors::mg::nccl_clique clique(device_ids); { auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); raft::neighbors::mg::extend(handle_, clique, index, index_dataset, std::nullopt); @@ -144,7 +143,7 @@ class AnnMGTest : public ::testing::TestWithParam> { } // IVF-PQ - for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { + for (parallel_mode d_mode : {parallel_mode::REPLICATED, parallel_mode::SHARDED}) { ivf_pq::mg_index_params index_params; index_params.n_lists = ps.nlist; index_params.metric = ps.metric; @@ -165,7 +164,6 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); - raft::neighbors::mg::nccl_clique clique(device_ids); { auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); raft::neighbors::mg::extend(handle_, clique, index, index_dataset, std::nullopt); @@ -189,7 +187,7 @@ class AnnMGTest : public ::testing::TestWithParam> { } // CAGRA - for (parallel_mode d_mode : {parallel_mode::REPLICATION, parallel_mode::SHARDING}) { + for (parallel_mode d_mode : {parallel_mode::REPLICATED, parallel_mode::SHARDED}) { cagra::mg_index_params index_params; index_params.add_data_on_build = true; index_params.intermediate_graph_degree = 128; @@ -209,7 +207,6 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); - raft::neighbors::mg::nccl_clique clique(device_ids); /* { auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); @@ -258,7 +255,6 @@ class AnnMGTest : public ::testing::TestWithParam> { auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - raft::neighbors::mg::nccl_clique clique(device_ids); auto distributed_index = raft::neighbors::mg::distribute_flat(handle_, clique, "./cpp/build/local_ivf_flat_index"); raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); @@ -300,7 +296,6 @@ class AnnMGTest : public ::testing::TestWithParam> { auto neighbors = raft::make_host_matrix_view(indices_ann.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - raft::neighbors::mg::nccl_clique clique(device_ids); auto distributed_index = raft::neighbors::mg::distribute_pq(handle_, clique, "./cpp/build/local_ivf_pq_index"); raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); @@ -340,7 +335,6 @@ class AnnMGTest : public ::testing::TestWithParam> { auto neighbors = raft::make_host_matrix_view(indices_ann_32bits.data(), ps.num_queries, ps.k); auto distances = raft::make_host_matrix_view(distances_ann.data(), ps.num_queries, ps.k); - raft::neighbors::mg::nccl_clique clique(device_ids); auto distributed_index = raft::neighbors::mg::distribute_cagra(handle_, clique, "./cpp/build/local_cagra_index"); raft::neighbors::mg::search(handle_, clique, distributed_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); From 3f15c43dbe10e337edb0806a17604aac0e4ccdd2 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 12 Jun 2024 17:52:32 +0200 Subject: [PATCH 17/22] Mempool + NCCL fix --- cpp/include/raft/neighbors/ann_mg_helpers.cuh | 18 +++++++++++++++++- cpp/include/raft/neighbors/detail/ann_mg.cuh | 12 +++++++----- cpp/test/neighbors/ann_mg.cuh | 3 --- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/ann_mg_helpers.cuh b/cpp/include/raft/neighbors/ann_mg_helpers.cuh index 24bb29767f..ebf006050b 100644 --- a/cpp/include/raft/neighbors/ann_mg_helpers.cuh +++ b/cpp/include/raft/neighbors/ann_mg_helpers.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -27,6 +28,8 @@ namespace raft::comms { namespace raft::neighbors::mg { +using pool_mr = rmm::mr::pool_memory_resource; + struct nccl_clique { nccl_clique(const std::vector& device_ids) @@ -34,15 +37,25 @@ struct nccl_clique { num_ranks_(device_ids.size()), device_ids_(device_ids), nccl_comms_(device_ids.size()), + per_device_pools_(0), device_resources_(0) { RAFT_LOG_INFO("Starting NCCL initialization..."); - RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data())); for (int rank = 0; rank < num_ranks_; rank++) { RAFT_CUDA_TRY(cudaSetDevice(device_ids[rank])); + + // create a pool memory resource for each device + auto old_mr = rmm::mr::get_current_device_resource(); + per_device_pools_.push_back(std::make_unique(old_mr, rmm::percent_of_free_device_memory(80))); + rmm::cuda_device_id id(device_ids[rank]); + rmm::mr::set_per_device_resource(id, per_device_pools_.back().get()); + + // create a device resource handle for each device device_resources_.emplace_back(); + + // add NCCL communications to the device resource handle raft::comms::build_comms_nccl_only(&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); } @@ -66,6 +79,8 @@ struct nccl_clique { for (int rank = 0; rank < num_ranks_; rank++) { cudaSetDevice(device_ids_[rank]); ncclCommDestroy(nccl_comms_[rank]); + rmm::cuda_device_id id(device_ids_[rank]); + rmm::mr::set_per_device_resource(id, nullptr); } } @@ -73,6 +88,7 @@ struct nccl_clique { int num_ranks_; std::vector device_ids_; std::vector nccl_comms_; + std::vector> per_device_pools_; std::vector device_resources_; }; diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index 563c922d45..cd29e3eb87 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -257,7 +257,7 @@ class ann_mg_index { RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", num_ranks_, n_rows); ann_interfaces_.resize(num_ranks_); - #pragma omp parallel for num_threads(num_ranks_) + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -275,7 +275,7 @@ class ann_mg_index { RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", num_ranks_, n_rows_per_shard); ann_interfaces_.resize(num_ranks_); - #pragma omp parallel for num_threads(num_ranks_) + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -300,7 +300,7 @@ class ann_mg_index { if (mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", num_ranks_, n_rows); - #pragma omp parallel for num_threads(num_ranks_) + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -316,7 +316,7 @@ class ann_mg_index { RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", num_ranks_, n_rows_per_shard); - #pragma omp parallel for num_threads(num_ranks_) + #pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; @@ -357,7 +357,7 @@ class ann_mg_index { if (mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED SEARCH: %d*%drows", n_batches, n_rows_per_batch); - #pragma omp parallel for num_threads(num_ranks_) // avoid oversubscribing any given GPU + #pragma omp parallel for for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { int rank = batch_idx % num_ranks_; // alternate GPUs int dev_id = clique.device_ids_[rank]; @@ -413,6 +413,7 @@ class ann_mg_index { auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); + // should use at least num_ranks_ threads to avoid NCCL hang #pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { int dev_id = clique.device_ids_[rank]; @@ -450,6 +451,7 @@ class ann_mg_index { auto d_distances = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); ann_if.search(dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); + // send results to root rank RAFT_NCCL_TRY(ncclGroupStart()); comms.device_send(d_neighbors.data_handle(), n_rows_of_current_batch * n_neighbors, diff --git a/cpp/test/neighbors/ann_mg.cuh b/cpp/test/neighbors/ann_mg.cuh index f7935102c5..284871ef70 100644 --- a/cpp/test/neighbors/ann_mg.cuh +++ b/cpp/test/neighbors/ann_mg.cuh @@ -207,14 +207,11 @@ class AnnMGTest : public ::testing::TestWithParam> { auto distances = raft::make_host_matrix_view( distances_ann.data(), ps.num_queries, ps.k); - /* { auto index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); raft::neighbors::mg::serialize(handle_, clique, index, "./cpp/build/ann_mg_cagra_index"); } auto new_index = raft::neighbors::mg::deserialize_cagra(handle_, clique, "./cpp/build/ann_mg_cagra_index"); - */ - auto new_index = raft::neighbors::mg::build(handle_, clique, index_params, index_dataset); raft::neighbors::mg::search(handle_, clique, new_index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); resource::sync_stream(handle_); From b77f938f34ea3692545630bab256cf4f7eb99b4f Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 12 Jun 2024 17:52:49 +0200 Subject: [PATCH 18/22] SNMG cagra bench --- cpp/bench/ann/CMakeLists.txt | 15 ++ .../src/raft/raft_ann_bench_param_parser.h | 130 ++++++------------ cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu | 28 ++++ .../src/raft/raft_ann_mg_cagra_wrapper.hpp | 125 +++++++++++++++++ cpp/bench/ann/src/raft/raft_benchmark.cu | 21 ++- .../src/raft-ann-bench/run/algos.yaml | 3 + .../run/conf/algos/raft_ann_mg_cagra.yaml | 13 ++ 7 files changed, 245 insertions(+), 90 deletions(-) create mode 100644 cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu create mode 100644 cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp create mode 100644 python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_cagra.yaml diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 330a1e021f..ccf9bf01b4 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -31,6 +31,7 @@ option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in be option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT "Include raft's MG ANN IVF-FLAT in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ "Include raft's MG ANN IVF-PQ in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA "Include raft's MG ANN CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) option(RAFT_ANN_BENCH_SINGLE_EXE @@ -59,6 +60,7 @@ if(BUILD_CPU_ONLY) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF) set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT OFF) set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ OFF) + set(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA OFF) set(RAFT_ANN_BENCH_USE_GGNN OFF) else() # Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled. @@ -96,6 +98,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + OR RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -309,6 +312,18 @@ if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) ) endif() +if(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + ConfigureAnnBench( + NAME + RAFT_ANN_MG_CAGRA + PATH + bench/ann/src/raft/raft_benchmark.cu + bench/ann/src/raft/raft_ann_mg_cagra.cu + LINKS + raft::compiled + ) +endif() + set(RAFT_FAISS_TARGETS faiss::faiss) if(TARGET faiss::faiss_avx2) set(RAFT_FAISS_TARGETS faiss::faiss_avx2) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 906ce95b5d..1c30da41e2 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -30,7 +30,7 @@ extern template class raft::bench::ann::RaftIvfFlatGpu; extern template class raft::bench::ann::RaftIvfFlatGpu; #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) #include "raft_ivf_pq_wrapper.h" #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ @@ -38,7 +38,8 @@ extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; #endif -#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) #include "raft_cagra_wrapper.h" #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA @@ -59,29 +60,21 @@ extern template class raft::bench::ann::RaftAnnMG_IvfPq; extern template class raft::bench::ann::RaftAnnMG_IvfPq; extern template class raft::bench::ann::RaftAnnMG_IvfPq; #endif - -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) -{ - param.n_lists = conf.at("nlist"); - if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } - if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) -{ - param.ivf_flat_params.n_probes = conf.at("nprobe"); -} +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA +#include "raft_ann_mg_cagra_wrapper.hpp" +extern template class raft::bench::ann::RaftAnnMG_Cagra; +extern template class raft::bench::ann::RaftAnnMG_Cagra; +extern template class raft::bench::ann::RaftAnnMG_Cagra; #endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT +#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT) template void parse_build_param(const nlohmann::json& conf, + #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT + typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) + #else typename raft::bench::ann::RaftAnnMG_IvfFlat::BuildParam& param) + #endif { param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -90,83 +83,26 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, + #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT + typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) + #else typename raft::bench::ann::RaftAnnMG_IvfFlat::SearchParam& param) + #endif { param.ivf_flat_params.n_probes = conf.at("nprobe"); } #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfPQ::BuildParam& param) -{ - if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } - if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } - if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } - if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } - if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } - if (conf.contains("codebook_kind")) { - std::string kind = conf.at("codebook_kind"); - if (kind == "cluster") { - param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; - } else if (kind == "subspace") { - param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; - } else { - throw std::runtime_error("codebook_kind: '" + kind + - "', should be either 'cluster' or 'subspace'"); - } - } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfPQ::SearchParam& param) -{ - if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } - if (conf.contains("internalDistanceDtype")) { - std::string type = conf.at("internalDistanceDtype"); - if (type == "float") { - param.pq_param.internal_distance_dtype = CUDA_R_32F; - } else if (type == "half") { - param.pq_param.internal_distance_dtype = CUDA_R_16F; - } else { - throw std::runtime_error("internalDistanceDtype: '" + type + - "', should be either 'float' or 'half'"); - } - } else { - // set half as default type - param.pq_param.internal_distance_dtype = CUDA_R_16F; - } - - if (conf.contains("smemLutDtype")) { - std::string type = conf.at("smemLutDtype"); - if (type == "float") { - param.pq_param.lut_dtype = CUDA_R_32F; - } else if (type == "half") { - param.pq_param.lut_dtype = CUDA_R_16F; - } else if (type == "fp8") { - param.pq_param.lut_dtype = CUDA_R_8U; - } else { - throw std::runtime_error("smemLutDtype: '" + type + - "', should be either 'float', 'half' or 'fp8'"); - } - } else { - // set half as default - param.pq_param.lut_dtype = CUDA_R_16F; - } - if (conf.contains("refine_ratio")) { - param.refine_ratio = conf.at("refine_ratio"); - if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); } - } -} -#endif - -#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, + #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam& param) + #else + typename raft::bench::ann::RaftIvfPQ::BuildParam& param) + #endif { if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -188,7 +124,11 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, + #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ typename raft::bench::ann::RaftAnnMG_IvfPq::SearchParam& param) + #else + typename raft::bench::ann::RaftIvfPQ::SearchParam& param) + #endif { if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } if (conf.contains("internalDistanceDtype")) { @@ -229,7 +169,8 @@ void parse_search_param(const nlohmann::json& conf, } #endif -#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, raft::neighbors::experimental::nn_descent::index_params& param) @@ -276,7 +217,11 @@ nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, template void parse_build_param(const nlohmann::json& conf, + #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam& param) + #else typename raft::bench::ann::RaftCagra::BuildParam& param) + #endif { if (conf.contains("graph_degree")) { param.cagra_params.graph_degree = conf.at("graph_degree"); @@ -340,7 +285,11 @@ raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) template void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftCagra::SearchParam& param) + #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + typename raft::bench::ann::RaftAnnMG_Cagra::SearchParam& param) + #else + typename raft::bench::ann::RaftCagra::SearchParam& param) + #endif { if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } @@ -359,12 +308,15 @@ void parse_search_param(const nlohmann::json& conf, THROW("Invalid value for algo: %s", tmp.c_str()); } } + + #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) if (conf.contains("graph_memory_type")) { param.graph_mem = parse_allocator(conf.at("graph_memory_type")); } if (conf.contains("internal_dataset_memory_type")) { param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); } + #endif // Same ratio as in IVF-PQ param.refine_ratio = conf.value("refine_ratio", 1.0f); } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu new file mode 100644 index 0000000000..84699e052b --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "raft_ann_mg_cagra_wrapper.hpp" +#include +#include + +namespace raft::bench::ann { + + template class RaftAnnMG_Cagra; + template class RaftAnnMG_Cagra; + template class RaftAnnMG_Cagra; + +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp new file mode 100644 index 0000000000..80f05e814c --- /dev/null +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "raft_ann_mg_wrapper.hpp" +#include +#include +#include + +namespace raft::bench::ann { + +enum class AllocatorType_t { HostPinned, HostHugePage, Device }; +template +class RaftAnnMG_Cagra : public RaftAnnMG { + public: + using typename ANN::AnnSearchParam; + + struct SearchParam : public AnnSearchParam { + raft::neighbors::cagra::search_params p; + float refine_ratio; + AllocatorType_t graph_mem = AllocatorType_t::Device; + AllocatorType_t dataset_mem = AllocatorType_t::Device; + auto needs_dataset() const -> bool override { return true; } + }; + + struct BuildParam { + raft::neighbors::cagra::mg_index_params cagra_params; + std::optional nn_descent_params = std::nullopt; + std::optional ivf_pq_refine_rate = std::nullopt; + std::optional ivf_pq_build_params = std::nullopt; + std::optional ivf_pq_search_params = std::nullopt; + }; + + RaftAnnMG_Cagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) + : RaftAnnMG(metric, dim), + index_params_(param), + dimension_(dim) + { + index_params_.cagra_params.metric = parse_metric_type(metric); + index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + } + + void build(const T* dataset, size_t nrow) final; + void set_search_param(const AnnSearchParam& param) override; + void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void save(const std::string& file) const override; + void load(const std::string&) override; + std::unique_ptr> copy() override; + + private: + BuildParam index_params_; + raft::neighbors::cagra::search_params search_params_; + std::shared_ptr, T, IdxT>> index_; + float refine_ratio_ = 1.0; + int dimension_; +}; + +template +void RaftAnnMG_Cagra::build(const T* dataset, size_t nrow) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_.cagra_params, dataset_matrix); + index_ = std::make_shared, T, IdxT>>(std::move(idx)); + return; +} + +template +void RaftAnnMG_Cagra::set_search_param(const AnnSearchParam& param) +{ + auto search_param = dynamic_cast(param); + search_params_ = search_param.p; + refine_ratio_ = search_param.refine_ratio; + assert(search_params_.n_probes <= index_params_.n_lists); +} + +template +void RaftAnnMG_Cagra::save(const std::string& file) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + raft::neighbors::mg::serialize(handle, *this->clique_, *index_, file); + return; +} + +template +void RaftAnnMG_Cagra::load(const std::string& file) +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto idx = raft::neighbors::mg::deserialize_cagra(handle, *this->clique_, file); + index_ = std::make_shared, T, IdxT>>(std::move(idx)); +} + +template +std::unique_ptr> RaftAnnMG_Cagra::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + +template +void RaftAnnMG_Cagra::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + resource::sync_stream(handle); + return; +} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index e136f47d3c..54cb33fd59 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -101,7 +101,16 @@ if constexpr (std::is_same_v || std::is_same_v || } } #endif - +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA +if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_cagra") { + typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam param; + parse_build_param(conf, param); + ann = std::make_unique>(metric, dim, param); + } +} +#endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } return ann; @@ -163,6 +172,16 @@ std::unique_ptr::AnnSearchParam> create_search } } #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_cagra") { + auto param = std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + } +#endif // else throw std::runtime_error("invalid algo: '" + algo + "'"); diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml index dea6215ede..e63b9226b5 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/algos.yaml @@ -37,6 +37,9 @@ raft_ann_mg_ivf_flat: raft_ann_mg_ivf_pq: executable: RAFT_ANN_MG_IVF_PQ_ANN_BENCH requires_gpu: true +raft_ann_mg_cagra: + executable: RAFT_ANN_MG_CAGRA_ANN_BENCH + requires_gpu: true ggnn: executable: GGNN_ANN_BENCH requires_gpu: true diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_cagra.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_cagra.yaml new file mode 100644 index 0000000000..232010ec0e --- /dev/null +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ann_mg_cagra.yaml @@ -0,0 +1,13 @@ +name: raft_ann_mg_cagra +constraints: + build: raft-ann-bench.constraints.raft_cagra_build_constraints + search: raft-ann-bench.constraints.raft_cagra_search_constraints +groups: + base: + build: + graph_degree: [32, 64, 128, 256] + intermediate_graph_degree: [32, 64, 96, 128] + graph_build_algo: ["NN_DESCENT"] + search: + itopk: [32, 64, 128, 256, 512] + search_width: [1, 2, 4, 8, 16, 32, 64] From fc748aeb25e896f630a668ca43816ee61bd08fca Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 13 Jun 2024 16:02:50 +0200 Subject: [PATCH 19/22] SNMG CAGRA bench --- cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp | 6 ++++-- cpp/include/raft/neighbors/detail/ann_mg.cuh | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp index 80f05e814c..cab2605b4d 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp @@ -50,8 +50,10 @@ class RaftAnnMG_Cagra : public RaftAnnMG { index_params_(param), dimension_(dim) { - index_params_.cagra_params.metric = parse_metric_type(metric); - index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + index_params_.cagra_params.add_data_on_build = true; + index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED; + index_params_.cagra_params.metric = parse_metric_type(metric); + index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); } void build(const T* dataset, size_t nrow) final; diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index cd29e3eb87..df04c395c6 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -145,7 +145,7 @@ class ann_interface { } else if constexpr (std::is_same>::value) { ivf_pq::serialize(handle, os, index_.value()); } else if constexpr (std::is_same>::value) { - cagra::serialize(handle, os, index_.value()); + cagra::serialize(handle, os, index_.value(), true); } } From 666d47fc2457d33a53edc0db4844bfe07b2aecbe Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 17 Jun 2024 18:49:19 +0200 Subject: [PATCH 20/22] Increase search batch size + fix build --- cpp/include/raft/neighbors/detail/ann_mg.cuh | 2 +- cpp/test/CMakeLists.txt | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index df04c395c6..c64338fe30 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -39,7 +39,7 @@ // Number of rows per batch (search on shards) -#define N_ROWS_PER_BATCH 3000 +#define N_ROWS_PER_BATCH 2^24 namespace raft::neighbors::mg::detail { using namespace raft::neighbors; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 2c74593d43..c3ba64d1f4 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -451,8 +451,9 @@ if(BUILD_TESTS) NEIGHBORS_ANN_MG_TEST PATH neighbors/ann_mg/test_ann_mg.cu - LIB + ADDITIONAL_LIBS ucp ucs ucxx nccl + LIB EXPLICIT_INSTANTIATE_ONLY GPUS 1 From 1a559a6eb6ac321e4eb73ee0cac7f3b16fb3ba43 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 8 Jul 2024 12:00:30 +0200 Subject: [PATCH 21/22] mdspan feature for build and extend --- cpp/include/raft/neighbors/detail/ann_mg.cuh | 81 +++++++++++--------- 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index c64338fe30..fe85a70901 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -26,6 +26,9 @@ #undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include +#include +#include +#include #define RAFT_EXPLICIT_INSTANTIATE_ONLY #include @@ -38,63 +41,65 @@ #include -// Number of rows per batch (search on shards) -#define N_ROWS_PER_BATCH 2^24 - namespace raft::neighbors::mg::detail { using namespace raft::neighbors; template class ann_interface { public: + + template void build(raft::resources const& handle, const ann::index_params* index_params, - raft::host_matrix_view h_index_dataset) + raft::mdspan, row_major, Accessor> index_dataset) { - IdxT n_rows = h_index_dataset.extent(0); - IdxT n_dims = h_index_dataset.extent(1); - auto d_index_dataset = raft::make_device_matrix(handle, n_rows, n_dims); - raft::copy(d_index_dataset.data_handle(), h_index_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); - raft::device_matrix_view d_index_dataset_view = raft::make_device_matrix_view(d_index_dataset.data_handle(), n_rows, n_dims); - if constexpr (std::is_same>::value) { - index_.emplace(std::move(raft::runtime::neighbors::ivf_flat::build( - handle, *static_cast(index_params), d_index_dataset_view))); + auto idx = raft::neighbors::ivf_flat::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); + index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - index_.emplace(std::move(raft::runtime::neighbors::ivf_pq::build( - handle, *static_cast(index_params), d_index_dataset_view))); + auto idx = raft::neighbors::ivf_pq::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); + index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - index_.emplace(std::move(raft::runtime::neighbors::cagra::build( - handle, *static_cast(index_params), d_index_dataset_view))); + auto extents = raft::make_extents(index_dataset.extent(0), index_dataset.extent(1)); + const bool host_acc = decltype(index_dataset)::accessor_type::is_host_type::value; + const bool device_acc = decltype(index_dataset)::accessor_type::is_device_type::value; + auto dataset = raft::make_mdspan(index_dataset.data_handle(), extents); + cagra::index idx(handle); + idx = raft::neighbors::cagra::build(handle, + *static_cast(index_params), + dataset); + index_.emplace(std::move(idx)); } resource::sync_stream(handle); } + template void extend(raft::resources const& handle, - raft::host_matrix_view h_new_vectors, - std::optional> h_new_indices) + raft::mdspan, row_major, Accessor1> new_vectors, + std::optional, layout_c_contiguous, Accessor2>> new_indices) { - IdxT n_rows = h_new_vectors.extent(0); - IdxT n_dims = h_new_vectors.extent(1); - auto d_new_vectors = raft::make_device_matrix(handle, n_rows, n_dims); - raft::copy(d_new_vectors.data_handle(), h_new_vectors.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); - raft::device_matrix_view d_new_vectors_view = \ - raft::make_device_matrix_view(d_new_vectors.data_handle(), n_rows, n_dims); - - std::optional> new_indices_opt = std::nullopt; - if (h_new_indices.has_value()) { - auto d_new_indices = raft::make_device_vector(handle, n_rows); - raft::copy(d_new_indices.data_handle(), h_new_indices.value().data_handle(), n_rows, resource::get_cuda_stream(handle)); - auto d_new_indices_view = raft::device_vector_view(d_new_indices.data_handle(), n_rows); - new_indices_opt = std::move(d_new_indices_view); - } - if constexpr (std::is_same>::value) { - index_.emplace(std::move(raft::runtime::neighbors::ivf_flat::extend( - handle, d_new_vectors_view, new_indices_opt, index_.value()))); + auto idx = raft::neighbors::ivf_flat::extend(handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); + index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - index_.emplace(std::move(raft::runtime::neighbors::ivf_pq::extend( - handle, d_new_vectors_view, new_indices_opt, index_.value()))); + auto idx = raft::neighbors::ivf_pq::extend(handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); + index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { RAFT_FAIL("CAGRA does not implement the extend method"); } @@ -406,7 +411,7 @@ class ann_mg_index { root_handle, n_rows_per_batch, n_neighbors); for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - IdxT offset = batch_idx * N_ROWS_PER_BATCH; + IdxT offset = batch_idx * n_rows_per_batch; IdxT query_offset = offset * n_cols; IdxT output_offset = offset * n_neighbors; IdxT n_rows_of_current_batch = std::min((IdxT)n_rows_per_batch, n_rows - offset); From 11d30da9c97307b1df81985b09e8c02da2adaf47 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 8 Jul 2024 12:35:01 +0200 Subject: [PATCH 22/22] style fix --- cpp/CMakeLists.txt | 4 +- .../src/raft/raft_ann_bench_param_parser.h | 54 +-- cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu | 10 +- .../src/raft/raft_ann_mg_cagra_wrapper.hpp | 61 ++- .../ann/src/raft/raft_ann_mg_ivf_flat.cu | 10 +- .../src/raft/raft_ann_mg_ivf_flat_wrapper.hpp | 46 +- cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu | 10 +- .../src/raft/raft_ann_mg_ivf_pq_wrapper.hpp | 46 +- .../ann/src/raft/raft_ann_mg_wrapper.hpp | 64 +-- cpp/bench/ann/src/raft/raft_benchmark.cu | 20 +- cpp/include/raft/neighbors/ann_mg_helpers.cuh | 22 +- cpp/include/raft/neighbors/ann_mg_types.hpp | 30 +- .../raft/neighbors/brute_force-inl.cuh | 2 +- cpp/include/raft/neighbors/cagra_mg.cuh | 10 +- .../raft/neighbors/cagra_mg_serialize.cuh | 18 +- cpp/include/raft/neighbors/detail/ann_mg.cuh | 413 ++++++++++-------- cpp/include/raft/neighbors/ivf_flat_mg.cuh | 7 +- .../raft/neighbors/ivf_flat_mg_serialize.cuh | 16 +- cpp/include/raft/neighbors/ivf_pq_mg.cuh | 7 +- .../raft/neighbors/ivf_pq_mg_serialize.cuh | 16 +- cpp/test/neighbors/ann_mg/test_ann_mg.cu | 24 +- .../run/conf/mnist-784-euclidean.json | 3 +- .../run/conf/sift-128-euclidean.json | 3 +- 23 files changed, 513 insertions(+), 383 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 0736abe637..fe9132b223 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -302,6 +302,8 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu @@ -592,8 +594,6 @@ if(RAFT_COMPILE_LIBRARY) INTERFACE_POSITION_INDEPENDENT_CODE ON ) - - foreach(target raft_lib raft_lib_static raft_objs) target_link_libraries( ${target} diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 1c30da41e2..da8eab76b5 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ extern template class raft::bench::ann::RaftIvfPQ; extern template class raft::bench::ann::RaftIvfPQ; #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) #include "raft_cagra_wrapper.h" #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA @@ -70,11 +70,11 @@ extern template class raft::bench::ann::RaftAnnMG_Cagra; #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_FLAT) template void parse_build_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) - #else +#else typename raft::bench::ann::RaftAnnMG_IvfFlat::BuildParam& param) - #endif +#endif { param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -83,26 +83,26 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT +#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) - #else +#else typename raft::bench::ann::RaftAnnMG_IvfFlat::SearchParam& param) - #endif +#endif { param.ivf_flat_params.n_probes = conf.at("nprobe"); } #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam& param) - #else +#else typename raft::bench::ann::RaftIvfPQ::BuildParam& param) - #endif +#endif { if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } @@ -124,11 +124,11 @@ void parse_build_param(const nlohmann::json& conf, template void parse_search_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ typename raft::bench::ann::RaftAnnMG_IvfPq::SearchParam& param) - #else +#else typename raft::bench::ann::RaftIvfPQ::SearchParam& param) - #endif +#endif { if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } if (conf.contains("internalDistanceDtype")) { @@ -170,7 +170,7 @@ void parse_search_param(const nlohmann::json& conf, #endif #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) + defined(RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA) template void parse_build_param(const nlohmann::json& conf, raft::neighbors::experimental::nn_descent::index_params& param) @@ -217,11 +217,11 @@ nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, template void parse_build_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam& param) - #else +#else typename raft::bench::ann::RaftCagra::BuildParam& param) - #endif +#endif { if (conf.contains("graph_degree")) { param.cagra_params.graph_degree = conf.at("graph_degree"); @@ -285,11 +285,11 @@ raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) template void parse_search_param(const nlohmann::json& conf, - #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA - typename raft::bench::ann::RaftAnnMG_Cagra::SearchParam& param) - #else - typename raft::bench::ann::RaftCagra::SearchParam& param) - #endif +#ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA + typename raft::bench::ann::RaftAnnMG_Cagra::SearchParam& param) +#else + typename raft::bench::ann::RaftCagra::SearchParam& param) +#endif { if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } @@ -309,14 +309,14 @@ void parse_search_param(const nlohmann::json& conf, } } - #if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) +#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) if (conf.contains("graph_memory_type")) { param.graph_mem = parse_allocator(conf.at("graph_memory_type")); } if (conf.contains("internal_dataset_memory_type")) { param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); } - #endif +#endif // Same ratio as in IVF-PQ param.refine_ratio = conf.value("refine_ratio", 1.0f); } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu index 84699e052b..0243529a67 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra.cu @@ -14,15 +14,15 @@ * limitations under the License. */ - #include "raft_ann_mg_cagra_wrapper.hpp" -#include + #include +#include namespace raft::bench::ann { - template class RaftAnnMG_Cagra; - template class RaftAnnMG_Cagra; - template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; +template class RaftAnnMG_Cagra; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp index cab2605b4d..ef3ece2839 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_cagra_wrapper.hpp @@ -17,6 +17,7 @@ #pragma once #include "raft_ann_mg_wrapper.hpp" + #include #include #include @@ -39,26 +40,26 @@ class RaftAnnMG_Cagra : public RaftAnnMG { struct BuildParam { raft::neighbors::cagra::mg_index_params cagra_params; - std::optional nn_descent_params = std::nullopt; + std::optional nn_descent_params = + std::nullopt; std::optional ivf_pq_refine_rate = std::nullopt; std::optional ivf_pq_build_params = std::nullopt; std::optional ivf_pq_search_params = std::nullopt; }; RaftAnnMG_Cagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) - : RaftAnnMG(metric, dim), - index_params_(param), - dimension_(dim) + : RaftAnnMG(metric, dim), index_params_(param), dimension_(dim) { - index_params_.cagra_params.add_data_on_build = true; - index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED; - index_params_.cagra_params.metric = parse_metric_type(metric); - index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); + index_params_.cagra_params.add_data_on_build = true; + index_params_.cagra_params.mode = raft::neighbors::mg::parallel_mode::SHARDED; + index_params_.cagra_params.metric = parse_metric_type(metric); + index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); } void build(const T* dataset, size_t nrow) final; void set_search_param(const AnnSearchParam& param) override; - void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; void save(const std::string& file) const override; void load(const std::string&) override; std::unique_ptr> copy() override; @@ -66,7 +67,9 @@ class RaftAnnMG_Cagra : public RaftAnnMG { private: BuildParam index_params_; raft::neighbors::cagra::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; float refine_ratio_ = 1.0; int dimension_; }; @@ -74,10 +77,14 @@ class RaftAnnMG_Cagra : public RaftAnnMG { template void RaftAnnMG_Cagra::build(const T* dataset, size_t nrow) { - const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); - auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_.cagra_params, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = raft::neighbors::mg::build( + handle, *this->clique_, index_params_.cagra_params, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); return; } @@ -103,7 +110,9 @@ void RaftAnnMG_Cagra::load(const std::string& file) { const auto& handle = this->clique_->set_current_device_to_root_rank(); auto idx = raft::neighbors::mg::deserialize_cagra(handle, *this->clique_, file); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); } template @@ -113,14 +122,24 @@ std::unique_ptr> RaftAnnMG_Cagra::copy() } template -void RaftAnnMG_Cagra::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +void RaftAnnMG_Cagra::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); - auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); - auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - - raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); resource::sync_stream(handle); return; } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu index 4f9d04072c..a11460dea0 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat.cu @@ -14,15 +14,15 @@ * limitations under the License. */ - #include "raft_ann_mg_ivf_flat_wrapper.hpp" -#include + #include +#include namespace raft::bench::ann { - template class RaftAnnMG_IvfFlat; - template class RaftAnnMG_IvfFlat; - template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; +template class RaftAnnMG_IvfFlat; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp index 5308f07161..02e55423be 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_flat_wrapper.hpp @@ -17,6 +17,7 @@ #pragma once #include "raft_ann_mg_wrapper.hpp" + #include #include @@ -43,7 +44,8 @@ class RaftAnnMG_IvfFlat : public RaftAnnMG { void build(const T* dataset, size_t nrow) final; void set_search_param(const AnnSearchParam& param) override; - void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; void save(const std::string& file) const override; void load(const std::string&) override; std::unique_ptr> copy() override; @@ -51,16 +53,22 @@ class RaftAnnMG_IvfFlat : public RaftAnnMG { private: BuildParam index_params_; raft::neighbors::ivf_flat::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; }; template void RaftAnnMG_IvfFlat::build(const T* dataset, size_t nrow) { - const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); - auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = + raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); return; } @@ -84,7 +92,9 @@ template void RaftAnnMG_IvfFlat::load(const std::string& file) { const auto& handle = this->clique_->set_current_device_to_root_rank(); - index_ = std::make_shared, T, IdxT>>(std::move(raft::neighbors::mg::deserialize_flat(handle, *this->clique_, file))); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(raft::neighbors::mg::deserialize_flat(handle, *this->clique_, file))); } template @@ -94,17 +104,27 @@ std::unique_ptr> RaftAnnMG_IvfFlat::copy() } template -void RaftAnnMG_IvfFlat::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +void RaftAnnMG_IvfFlat::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); - auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); - auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - - raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); resource::sync_stream(handle); return; } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu index 84fdc0da91..b0bf8a9d84 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq.cu @@ -14,15 +14,15 @@ * limitations under the License. */ - #include "raft_ann_mg_ivf_pq_wrapper.hpp" -#include + #include +#include namespace raft::bench::ann { - template class RaftAnnMG_IvfPq; - template class RaftAnnMG_IvfPq; - template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; +template class RaftAnnMG_IvfPq; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp index 8e7049d971..8c0867462d 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_ivf_pq_wrapper.hpp @@ -17,6 +17,7 @@ #pragma once #include "raft_ann_mg_wrapper.hpp" + #include #include @@ -45,7 +46,8 @@ class RaftAnnMG_IvfPq : public RaftAnnMG { void build(const T* dataset, size_t nrow) final; void set_search_param(const AnnSearchParam& param) override; - void search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; void save(const std::string& file) const override; void load(const std::string&) override; std::unique_ptr> copy() override; @@ -53,17 +55,23 @@ class RaftAnnMG_IvfPq : public RaftAnnMG { private: BuildParam index_params_; raft::neighbors::ivf_pq::search_params search_params_; - std::shared_ptr, T, IdxT>> index_; + std::shared_ptr< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>> + index_; float refine_ratio_ = 1.0; }; template void RaftAnnMG_IvfPq::build(const T* dataset, size_t nrow) { - const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto dataset_matrix = raft::make_host_matrix_view(dataset, IdxT(nrow), IdxT(this->dimension_)); - auto idx = raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); - index_ = std::make_shared, T, IdxT>>(std::move(idx)); + const auto& handle = this->clique_->set_current_device_to_root_rank(); + auto dataset_matrix = raft::make_host_matrix_view( + dataset, IdxT(nrow), IdxT(this->dimension_)); + auto idx = + raft::neighbors::mg::build(handle, *this->clique_, index_params_, dataset_matrix); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(idx)); return; } @@ -88,7 +96,9 @@ template void RaftAnnMG_IvfPq::load(const std::string& file) { const auto& handle = this->clique_->set_current_device_to_root_rank(); - index_ = std::make_shared, T, IdxT>>(std::move(raft::neighbors::mg::deserialize_pq(handle, *this->clique_, file))); + index_ = std::make_shared< + raft::neighbors::mg::detail::ann_mg_index, T, IdxT>>( + std::move(raft::neighbors::mg::deserialize_pq(handle, *this->clique_, file))); } template @@ -98,17 +108,27 @@ std::unique_ptr> RaftAnnMG_IvfPq::copy() } template -void RaftAnnMG_IvfPq::search(const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +void RaftAnnMG_IvfPq::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); const auto& handle = this->clique_->set_current_device_to_root_rank(); - auto query_matrix = raft::make_host_matrix_view(queries, IdxT(batch_size), IdxT(this->dimension_)); - auto neighbors_matrix = raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); - auto distances_matrix = raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); - - raft::neighbors::mg::search(handle, *this->clique_, *index_, search_params_, query_matrix, neighbors_matrix, distances_matrix); + auto query_matrix = raft::make_host_matrix_view( + queries, IdxT(batch_size), IdxT(this->dimension_)); + auto neighbors_matrix = + raft::make_host_matrix_view((IdxT*)neighbors, IdxT(batch_size), IdxT(k)); + auto distances_matrix = + raft::make_host_matrix_view(distances, IdxT(batch_size), IdxT(k)); + + raft::neighbors::mg::search(handle, + *this->clique_, + *index_, + search_params_, + query_matrix, + neighbors_matrix, + distances_matrix); resource::sync_stream(handle); return; } diff --git a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp index 0cade61f03..3e642fe0c8 100644 --- a/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp +++ b/cpp/bench/ann/src/raft/raft_ann_mg_wrapper.hpp @@ -18,48 +18,48 @@ #include "../common/ann_types.hpp" #include "raft_ann_bench_utils.h" + #include namespace raft::bench::ann { template class RaftAnnMG : public ANN, public AnnGPU { + public: + RaftAnnMG(Metric metric, int dim) : ANN(metric, dim), dimension_(dim) + { + this->init_nccl_clique(); + } - public: - RaftAnnMG(Metric metric, int dim) - : ANN(metric, dim), dimension_(dim) - { - this->init_nccl_clique(); - } - - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::HostMmap; - property.query_memory_type = MemoryType::HostMmap; - return property; - } + AlgoProperty get_preference() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::HostMmap; + property.query_memory_type = MemoryType::HostMmap; + return property; + } - private: - void init_nccl_clique() { - int n_devices; - cudaGetDeviceCount(&n_devices); - std::cout << n_devices << " GPUs detected" << std::endl; + private: + void init_nccl_clique() + { + int n_devices; + cudaGetDeviceCount(&n_devices); + std::cout << n_devices << " GPUs detected" << std::endl; - std::vector device_ids(n_devices); - std::iota(device_ids.begin(), device_ids.end(), 0); - clique_ = std::make_shared(device_ids); - } + std::vector device_ids(n_devices); + std::iota(device_ids.begin(), device_ids.end(), 0); + clique_ = std::make_shared(device_ids); + } - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - const auto& handle = clique_->set_current_device_to_root_rank(); - return resource::get_cuda_stream(handle); - } + [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override + { + const auto& handle = clique_->set_current_device_to_root_rank(); + return resource::get_cuda_stream(handle); + } - protected: - std::shared_ptr clique_; - int dimension_; + protected: + std::shared_ptr clique_; + int dimension_; }; -} +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 54cb33fd59..495cf639f8 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -92,24 +92,24 @@ std::unique_ptr> create_algo(const std::string& algo, } #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_IVF_PQ -if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { - if (algo == "raft_ann_mg_ivf_pq") { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if (algo == "raft_ann_mg_ivf_pq") { typename raft::bench::ann::RaftAnnMG_IvfPq::BuildParam param; parse_build_param(conf, param); ann = std::make_unique>(metric, dim, param); } -} + } #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_ANN_MG_CAGRA -if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { if (algo == "raft_ann_mg_cagra") { typename raft::bench::ann::RaftAnnMG_Cagra::BuildParam param; parse_build_param(conf, param); ann = std::make_unique>(metric, dim, param); } -} + } #endif if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } @@ -166,7 +166,8 @@ std::unique_ptr::AnnSearchParam> create_search if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { if (algo == "raft_ann_mg_ivf_pq") { - auto param = std::make_unique::SearchParam>(); + auto param = + std::make_unique::SearchParam>(); parse_search_param(conf, *param); return param; } @@ -176,7 +177,8 @@ std::unique_ptr::AnnSearchParam> create_search if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { if (algo == "raft_ann_mg_cagra") { - auto param = std::make_unique::SearchParam>(); + auto param = + std::make_unique::SearchParam>(); parse_search_param(conf, *param); return param; } diff --git a/cpp/include/raft/neighbors/ann_mg_helpers.cuh b/cpp/include/raft/neighbors/ann_mg_helpers.cuh index ebf006050b..4b1a5405a6 100644 --- a/cpp/include/raft/neighbors/ann_mg_helpers.cuh +++ b/cpp/include/raft/neighbors/ann_mg_helpers.cuh @@ -16,14 +16,17 @@ #pragma once -#include -#include -#include -#include #include +#include + +#include + +#include + +#include namespace raft::comms { - void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank); +void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank); } namespace raft::neighbors::mg { @@ -31,7 +34,6 @@ namespace raft::neighbors::mg { using pool_mr = rmm::mr::pool_memory_resource; struct nccl_clique { - nccl_clique(const std::vector& device_ids) : root_rank_(0), num_ranks_(device_ids.size()), @@ -48,7 +50,8 @@ struct nccl_clique { // create a pool memory resource for each device auto old_mr = rmm::mr::get_current_device_resource(); - per_device_pools_.push_back(std::make_unique(old_mr, rmm::percent_of_free_device_memory(80))); + per_device_pools_.push_back( + std::make_unique(old_mr, rmm::percent_of_free_device_memory(80))); rmm::cuda_device_id id(device_ids[rank]); rmm::mr::set_per_device_resource(id, per_device_pools_.back().get()); @@ -56,7 +59,8 @@ struct nccl_clique { device_resources_.emplace_back(); // add NCCL communications to the device resource handle - raft::comms::build_comms_nccl_only(&device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); + raft::comms::build_comms_nccl_only( + &device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); } for (int rank = 0; rank < num_ranks_; rank++) { @@ -92,4 +96,4 @@ struct nccl_clique { std::vector device_resources_; }; -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ann_mg_types.hpp b/cpp/include/raft/neighbors/ann_mg_types.hpp index 4242f3fb9f..bdc653bbdb 100644 --- a/cpp/include/raft/neighbors/ann_mg_types.hpp +++ b/cpp/include/raft/neighbors/ann_mg_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,28 +16,28 @@ #pragma once +#include #include #include -#include namespace raft::neighbors::mg { - enum parallel_mode { REPLICATED, SHARDED }; +enum parallel_mode { REPLICATED, SHARDED }; } namespace raft::neighbors::ivf_flat { - struct mg_index_params : raft::neighbors::ivf_flat::index_params { - raft::neighbors::mg::parallel_mode mode; - }; -} +struct mg_index_params : raft::neighbors::ivf_flat::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::ivf_flat namespace raft::neighbors::ivf_pq { - struct mg_index_params : raft::neighbors::ivf_pq::index_params { - raft::neighbors::mg::parallel_mode mode; - }; -} +struct mg_index_params : raft::neighbors::ivf_pq::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::ivf_pq namespace raft::neighbors::cagra { - struct mg_index_params : raft::neighbors::cagra::index_params { - raft::neighbors::mg::parallel_mode mode; - }; -} +struct mg_index_params : raft::neighbors::cagra::index_params { + raft::neighbors::mg::parallel_mode mode; +}; +} // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index 60bb2b133f..36941aa6cd 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/cagra_mg.cuh b/cpp/include/raft/neighbors/cagra_mg.cuh index 62be51ca7b..6cf51f580b 100644 --- a/cpp/include/raft/neighbors/cagra_mg.cuh +++ b/cpp/include/raft/neighbors/cagra_mg.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,9 +39,9 @@ void search(const raft::resources& handle, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances, - uint64_t n_rows_per_batch = 1 << 20) // 2^20 + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } - // 2^20 -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh index fd8c4b7667..afb640f8f5 100644 --- a/cpp/include/raft/neighbors/cagra_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_mg_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,19 +31,21 @@ void serialize(const raft::resources& handle, } template -detail::ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> deserialize_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::deserialize_cagra(handle, clique, filename); } template -detail::ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> distribute_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::distribute_cagra(handle, clique, filename); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/detail/ann_mg.cuh b/cpp/include/raft/neighbors/detail/ann_mg.cuh index fe85a70901..070821e2a6 100644 --- a/cpp/include/raft/neighbors/detail/ann_mg.cuh +++ b/cpp/include/raft/neighbors/detail/ann_mg.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,30 +16,28 @@ #pragma once -#include #include #include #include -#include -#include #include +#include +#include +#include #undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include +#include #include #include -#include #define RAFT_EXPLICIT_INSTANTIATE_ONLY -#include +#include #include - -#include #include #include -#include - +#include +#include namespace raft::neighbors::mg::detail { using namespace raft::neighbors; @@ -47,58 +45,63 @@ using namespace raft::neighbors; template class ann_interface { public: - template void build(raft::resources const& handle, const ann::index_params* index_params, raft::mdspan, row_major, Accessor> index_dataset) { if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_flat::build(handle, - *static_cast(index_params), - index_dataset.data_handle(), - index_dataset.extent(0), - index_dataset.extent(1)); + auto idx = + raft::neighbors::ivf_flat::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_pq::build(handle, - *static_cast(index_params), - index_dataset.data_handle(), - index_dataset.extent(0), - index_dataset.extent(1)); + auto idx = + raft::neighbors::ivf_pq::build(handle, + *static_cast(index_params), + index_dataset.data_handle(), + index_dataset.extent(0), + index_dataset.extent(1)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - auto extents = raft::make_extents(index_dataset.extent(0), index_dataset.extent(1)); + auto extents = raft::make_extents(index_dataset.extent(0), index_dataset.extent(1)); const bool host_acc = decltype(index_dataset)::accessor_type::is_host_type::value; const bool device_acc = decltype(index_dataset)::accessor_type::is_device_type::value; - auto dataset = raft::make_mdspan(index_dataset.data_handle(), extents); + auto dataset = raft::make_mdspan( + index_dataset.data_handle(), extents); cagra::index idx(handle); - idx = raft::neighbors::cagra::build(handle, - *static_cast(index_params), - dataset); + idx = raft::neighbors::cagra::build( + handle, *static_cast(index_params), dataset); index_.emplace(std::move(idx)); } resource::sync_stream(handle); } template - void extend(raft::resources const& handle, - raft::mdspan, row_major, Accessor1> new_vectors, - std::optional, layout_c_contiguous, Accessor2>> new_indices) + void extend( + raft::resources const& handle, + raft::mdspan, row_major, Accessor1> new_vectors, + std::optional, layout_c_contiguous, Accessor2>> + new_indices) { if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_flat::extend(handle, - index_.value(), - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); + auto idx = raft::neighbors::ivf_flat::extend( + handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { - auto idx = raft::neighbors::ivf_pq::extend(handle, - index_.value(), - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); + auto idx = raft::neighbors::ivf_pq::extend( + handle, + index_.value(), + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); index_.emplace(std::move(idx)); } else if constexpr (std::is_same>::value) { RAFT_FAIL("CAGRA does not implement the extend method"); @@ -112,39 +115,44 @@ class ann_interface { raft::device_matrix_view d_neighbors, raft::device_matrix_view d_distances) const { - IdxT n_rows = h_query_dataset.extent(0); - IdxT n_dims = h_query_dataset.extent(1); + IdxT n_rows = h_query_dataset.extent(0); + IdxT n_dims = h_query_dataset.extent(1); auto d_query_dataset = raft::make_device_matrix(handle, n_rows, n_dims); - raft::copy(d_query_dataset.data_handle(), h_query_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle)); + raft::copy(d_query_dataset.data_handle(), + h_query_dataset.data_handle(), + n_rows * n_dims, + resource::get_cuda_stream(handle)); if constexpr (std::is_same>::value) { - raft::runtime::neighbors::ivf_flat::search(handle, - *reinterpret_cast(search_params), - index_.value(), - d_query_dataset.view(), - d_neighbors, - d_distances); + raft::runtime::neighbors::ivf_flat::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); } else if constexpr (std::is_same>::value) { - raft::runtime::neighbors::ivf_pq::search(handle, - *reinterpret_cast(search_params), - index_.value(), - d_query_dataset.view(), - d_neighbors, - d_distances); + raft::runtime::neighbors::ivf_pq::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); } else if constexpr (std::is_same>::value) { - raft::runtime::neighbors::cagra::search(handle, - *reinterpret_cast(search_params), - index_.value(), - d_query_dataset.view(), - d_neighbors, - d_distances); + raft::runtime::neighbors::cagra::search( + handle, + *reinterpret_cast(search_params), + index_.value(), + d_query_dataset.view(), + d_neighbors, + d_distances); } resource::sync_stream(handle); } - void serialize(raft::resources const& handle, - std::ostream& os) const - { + void serialize(raft::resources const& handle, std::ostream& os) const + { if constexpr (std::is_same>::value) { ivf_flat::serialize(handle, os, index_.value()); } else if constexpr (std::is_same>::value) { @@ -154,8 +162,7 @@ class ann_interface { } } - void deserialize(raft::resources const& handle, - std::istream& is) + void deserialize(raft::resources const& handle, std::istream& is) { if constexpr (std::is_same>::value) { index_.emplace(std::move(ivf_flat::deserialize(handle, is))); @@ -166,8 +173,7 @@ class ann_interface { } } - void deserialize(raft::resources const& handle, - const std::string& filename) + void deserialize(raft::resources const& handle, const std::string& filename) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } @@ -201,14 +207,12 @@ class ann_interface { template class ann_mg_index { public: - ann_mg_index(parallel_mode mode, int num_ranks_) - : mode_(mode), - num_ranks_(num_ranks_) - {} + ann_mg_index(parallel_mode mode, int num_ranks_) : mode_(mode), num_ranks_(num_ranks_) {} ann_mg_index(const raft::resources& handle, const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) { + const std::string& filename) + { deserialize_mg_index(handle, clique, filename); } @@ -223,7 +227,7 @@ class ann_mg_index { const std::string& filename) { for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_.emplace_back(); @@ -239,11 +243,11 @@ class ann_mg_index { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); + mode_ = (raft::neighbors::mg::parallel_mode)deserialize_scalar(handle, is); num_ranks_ = deserialize_scalar(handle, is); for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_.emplace_back(); @@ -262,38 +266,39 @@ class ann_mg_index { RAFT_LOG_INFO("REPLICATED BUILD: %d*%drows", num_ranks_, n_rows); ann_interfaces_.resize(num_ranks_); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_[rank]; ann_if.build(dev_res, index_params, index_dataset); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } else if (mode_ == SHARDED) { - IdxT n_rows = index_dataset.extent(0); - IdxT n_cols = index_dataset.extent(1); - IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + IdxT n_rows = index_dataset.extent(0); + IdxT n_cols = index_dataset.extent(1); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); RAFT_LOG_INFO("SHARDED BUILD: %d*%drows", num_ranks_, n_rows_per_shard); ann_interfaces_.resize(num_ranks_); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT offset = rank * n_rows_per_shard; - IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); - auto partition = raft::make_host_matrix_view(partition_ptr, n_rows_of_current_shard, n_cols); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* partition_ptr = index_dataset.data_handle() + (offset * n_cols); + auto partition = raft::make_host_matrix_view( + partition_ptr, n_rows_of_current_shard, n_cols); auto& ann_if = ann_interfaces_[rank]; ann_if.build(dev_res, index_params, partition); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } } @@ -305,42 +310,44 @@ class ann_mg_index { if (mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED EXTEND: %d*%drows", num_ranks_, n_rows); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_res, new_vectors, new_indices); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } else if (mode_ == SHARDED) { IdxT n_cols = new_vectors.extent(1); - IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); + IdxT n_rows_per_shard = raft::ceildiv(n_rows, (IdxT)num_ranks_); RAFT_LOG_INFO("SHARDED EXTEND: %d*%drows", num_ranks_, n_rows_per_shard); - #pragma omp parallel for +#pragma omp parallel for for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT offset = rank * n_rows_per_shard; - IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); - const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); - auto new_vectors_part = raft::make_host_matrix_view(new_vectors_ptr, n_rows_of_current_shard, n_cols); + IdxT offset = rank * n_rows_per_shard; + IdxT n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset); + const T* new_vectors_ptr = new_vectors.data_handle() + (offset * n_cols); + auto new_vectors_part = raft::make_host_matrix_view( + new_vectors_ptr, n_rows_of_current_shard, n_cols); std::optional> new_indices_part = std::nullopt; if (new_indices.has_value()) { const IdxT* new_indices_ptr = new_indices.value().data_handle() + offset; - new_indices_part = raft::make_host_vector_view(new_indices_ptr, n_rows_of_current_shard); + new_indices_part = + raft::make_host_vector_view(new_indices_ptr, n_rows_of_current_shard); } auto& ann_if = ann_interfaces_[rank]; ann_if.extend(dev_res, new_vectors_part, new_indices_part); resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } } @@ -355,35 +362,34 @@ class ann_mg_index { IdxT n_cols = query_dataset.extent(1); IdxT n_neighbors = neighbors.extent(1); - IdxT n_batches = raft::ceildiv(n_rows, (IdxT)n_rows_per_batch); - if (n_batches == 1) - n_rows_per_batch = n_rows; + IdxT n_batches = raft::ceildiv(n_rows, (IdxT)n_rows_per_batch); + if (n_batches == 1) n_rows_per_batch = n_rows; if (mode_ == REPLICATED) { RAFT_LOG_INFO("REPLICATED SEARCH: %d*%drows", n_batches, n_rows_per_batch); - #pragma omp parallel for +#pragma omp parallel for for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - int rank = batch_idx % num_ranks_; // alternate GPUs - int dev_id = clique.device_ids_[rank]; + int rank = batch_idx % num_ranks_; // alternate GPUs + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - IdxT offset = batch_idx * n_rows_per_batch; - IdxT query_offset = offset * n_cols; - IdxT output_offset = offset * n_neighbors; - IdxT n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset); + IdxT offset = batch_idx * n_rows_per_batch; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; + IdxT n_rows_of_current_batch = std::min(n_rows_per_batch, n_rows - offset); - auto query_partition = raft::make_host_matrix_view(query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); - auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); - auto d_distances = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); + auto query_partition = raft::make_host_matrix_view( + query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); + auto d_neighbors = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); auto& ann_if = ann_interfaces_[rank]; - ann_if.search(dev_res, - search_params, - query_partition, - d_neighbors.view(), - d_distances.view()); + ann_if.search( + dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); raft::copy(neighbors.data_handle() + output_offset, d_neighbors.data_handle(), @@ -396,7 +402,7 @@ class ann_mg_index { resource::sync_stream(dev_res); } - #pragma omp barrier +#pragma omp barrier } else if (mode_ == SHARDED) { RAFT_LOG_INFO("SHARDED SEARCH: %d*%drows", n_batches, n_rows_per_batch); @@ -405,39 +411,44 @@ class ann_mg_index { root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); auto in_distances = raft::make_device_matrix( root_handle, num_ranks_ * n_rows_per_batch, n_neighbors); - auto out_neighbors = raft::make_device_matrix( - root_handle, n_rows_per_batch, n_neighbors); + auto out_neighbors = + raft::make_device_matrix(root_handle, n_rows_per_batch, n_neighbors); auto out_distances = raft::make_device_matrix( root_handle, n_rows_per_batch, n_neighbors); for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) { - IdxT offset = batch_idx * n_rows_per_batch; - IdxT query_offset = offset * n_cols; - IdxT output_offset = offset * n_neighbors; + IdxT offset = batch_idx * n_rows_per_batch; + IdxT query_offset = offset * n_cols; + IdxT output_offset = offset * n_neighbors; IdxT n_rows_of_current_batch = std::min((IdxT)n_rows_per_batch, n_rows - offset); - auto query_partition = raft::make_host_matrix_view( + auto query_partition = raft::make_host_matrix_view( query_dataset.data_handle() + query_offset, n_rows_of_current_batch, n_cols); - // should use at least num_ranks_ threads to avoid NCCL hang - #pragma omp parallel for num_threads(num_ranks_) +// should use at least num_ranks_ threads to avoid NCCL hang +#pragma omp parallel for num_threads(num_ranks_) for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; + int dev_id = clique.device_ids_[rank]; const raft::device_resources& dev_res = clique.device_resources_[rank]; - auto& ann_if = ann_interfaces_[rank]; - const auto& comms = resource::get_comms(dev_res); + auto& ann_if = ann_interfaces_[rank]; + const auto& comms = resource::get_comms(dev_res); RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - if (rank == clique.root_rank_) { // root rank + if (rank == clique.root_rank_) { // root rank uint64_t batch_offset = clique.root_rank_ * n_rows_of_current_batch * n_neighbors; - auto d_neighbors = raft::make_device_matrix_view(in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); - auto d_distances = raft::make_device_matrix_view(in_distances.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); - ann_if.search(dev_res, search_params, query_partition, d_neighbors, d_distances); // write search results inplace + auto d_neighbors = raft::make_device_matrix_view( + in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix_view( + in_distances.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors); + ann_if.search(dev_res, + search_params, + query_partition, + d_neighbors, + d_distances); // write search results inplace // wait for results of other ranks RAFT_NCCL_TRY(ncclGroupStart()); for (int from_rank = 0; from_rank < num_ranks_; from_rank++) { - if (from_rank == clique.root_rank_) - continue; + if (from_rank == clique.root_rank_) continue; batch_offset = from_rank * n_rows_of_current_batch * n_neighbors; comms.device_recv(in_neighbors.data_handle() + batch_offset, @@ -451,28 +462,31 @@ class ann_mg_index { } RAFT_NCCL_TRY(ncclGroupEnd()); resource::sync_stream(dev_res); - } else { // non-root ranks - auto d_neighbors = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); - auto d_distances = raft::make_device_matrix(dev_res, n_rows_of_current_batch, n_neighbors); - ann_if.search(dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); - - // send results to root rank - RAFT_NCCL_TRY(ncclGroupStart()); - comms.device_send(d_neighbors.data_handle(), - n_rows_of_current_batch * n_neighbors, - clique.root_rank_, - resource::get_cuda_stream(dev_res)); - comms.device_send(d_distances.data_handle(), - n_rows_of_current_batch * n_neighbors, - clique.root_rank_, - resource::get_cuda_stream(dev_res)); - RAFT_NCCL_TRY(ncclGroupEnd()); - resource::sync_stream(dev_res); - } + } else { // non-root ranks + auto d_neighbors = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + auto d_distances = raft::make_device_matrix( + dev_res, n_rows_of_current_batch, n_neighbors); + ann_if.search( + dev_res, search_params, query_partition, d_neighbors.view(), d_distances.view()); + + // send results to root rank + RAFT_NCCL_TRY(ncclGroupStart()); + comms.device_send(d_neighbors.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + comms.device_send(d_distances.data_handle(), + n_rows_of_current_batch * n_neighbors, + clique.root_rank_, + resource::get_cuda_stream(dev_res)); + RAFT_NCCL_TRY(ncclGroupEnd()); + resource::sync_stream(dev_res); + } } - #pragma omp barrier +#pragma omp barrier - auto in_neighbors_view = raft::make_device_matrix_view( + auto in_neighbors_view = raft::make_device_matrix_view( in_neighbors.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); auto in_distances_view = raft::make_device_matrix_view( in_distances.data_handle(), num_ranks_ * n_rows_of_current_batch, n_neighbors); @@ -482,15 +496,19 @@ class ann_mg_index { out_distances.data_handle(), n_rows_of_current_batch, n_neighbors); const auto& root_handle_ = clique.set_current_device_to_root_rank(); - auto h_trans = std::vector(num_ranks_); - IdxT translation_offset = 0; + auto h_trans = std::vector(num_ranks_); + IdxT translation_offset = 0; for (int rank = 0; rank < num_ranks_; rank++) { h_trans[rank] = translation_offset; translation_offset += ann_interfaces_[rank].size(); } auto d_trans = raft::make_device_vector(root_handle_, num_ranks_); - raft::copy(d_trans.data_handle(), h_trans.data(), num_ranks_, resource::get_cuda_stream(root_handle_)); - auto translations = std::make_optional>(d_trans.view()); + raft::copy(d_trans.data_handle(), + h_trans.data(), + num_ranks_, + resource::get_cuda_stream(root_handle_)); + auto translations = + std::make_optional>(d_trans.view()); raft::neighbors::brute_force::knn_merge_parts(root_handle_, in_distances_view, in_neighbors_view, @@ -523,11 +541,11 @@ class ann_mg_index { serialize_scalar(handle, of, (int)mode_); serialize_scalar(handle, of, num_ranks_); for (int rank = 0; rank < num_ranks_; rank++) { - int dev_id = clique.device_ids_[rank]; - const raft::device_resources& dev_res = clique.device_resources_[rank]; - RAFT_CUDA_TRY(cudaSetDevice(dev_id)); - auto& ann_if = ann_interfaces_[rank]; - ann_if.serialize(dev_res, of); + int dev_id = clique.device_ids_[rank]; + const raft::device_resources& dev_res = clique.device_resources_[rank]; + RAFT_CUDA_TRY(cudaSetDevice(dev_id)); + auto& ann_if = ann_interfaces_[rank]; + ann_if.serialize(dev_res, of); } of.close(); @@ -606,7 +624,12 @@ void search(const raft::resources& handle, raft::host_matrix_view distances, uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); } template @@ -619,7 +642,12 @@ void search(const raft::resources& handle, raft::host_matrix_view distances, uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); } template @@ -632,7 +660,12 @@ void search(const raft::resources& handle, raft::host_matrix_view distances, uint64_t n_rows_per_batch) { - index.search(clique, static_cast(&search_params), query_dataset, neighbors, distances, n_rows_per_batch); + index.search(clique, + static_cast(&search_params), + query_dataset, + neighbors, + distances, + n_rows_per_batch); } template @@ -663,36 +696,40 @@ void serialize(const raft::resources& handle, } template -ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> deserialize_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(handle, clique, filename); return index; } template -ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> deserialize_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(handle, clique, filename); return index; } template -ann_mg_index, T, IdxT> deserialize_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> deserialize_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(handle, clique, filename); return index; } template -ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> distribute_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); @@ -700,9 +737,10 @@ ann_mg_index, T, IdxT> distribute_flat(const raft::reso } template -ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> distribute_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); @@ -710,13 +748,14 @@ ann_mg_index, T, IdxT> distribute_pq(const raft::resources& } template -ann_mg_index, T, IdxT> distribute_cagra(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +ann_mg_index, T, IdxT> distribute_cagra( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { auto index = ann_mg_index, T, IdxT>(REPLICATED, clique.num_ranks_); index.deserialize_and_distribute(handle, clique, filename); return index; } -} // namespace raft::neighbors::mg::detail \ No newline at end of file +} // namespace raft::neighbors::mg::detail diff --git a/cpp/include/raft/neighbors/ivf_flat_mg.cuh b/cpp/include/raft/neighbors/ivf_flat_mg.cuh index 12f76babe7..f5170d063a 100644 --- a/cpp/include/raft/neighbors/ivf_flat_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_mg.cuh @@ -49,9 +49,10 @@ void search(const raft::resources& handle, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances, - uint64_t n_rows_per_batch = 1 << 20) // 2^20 + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh index 6f4c9d820f..9459243a5f 100644 --- a/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_mg_serialize.cuh @@ -31,19 +31,21 @@ void serialize(const raft::resources& handle, } template -detail::ann_mg_index, T, IdxT> deserialize_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> deserialize_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::deserialize_flat(handle, clique, filename); } template -detail::ann_mg_index, T, IdxT> distribute_flat(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> distribute_flat( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::distribute_flat(handle, clique, filename); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_pq_mg.cuh b/cpp/include/raft/neighbors/ivf_pq_mg.cuh index 7cad3449d0..104745485a 100644 --- a/cpp/include/raft/neighbors/ivf_pq_mg.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_mg.cuh @@ -49,9 +49,10 @@ void search(const raft::resources& handle, raft::host_matrix_view query_dataset, raft::host_matrix_view neighbors, raft::host_matrix_view distances, - uint64_t n_rows_per_batch = 1 << 20) // 2^20 + uint64_t n_rows_per_batch = 1 << 20) // 2^20 { - mg::detail::search(handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); + mg::detail::search( + handle, clique, index, search_params, query_dataset, neighbors, distances, n_rows_per_batch); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh index ca6f91e763..56e00576e9 100644 --- a/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_mg_serialize.cuh @@ -31,19 +31,21 @@ void serialize(const raft::resources& handle, } template -detail::ann_mg_index, T, IdxT> deserialize_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> deserialize_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::deserialize_pq(handle, clique, filename); } template -detail::ann_mg_index, T, IdxT> distribute_pq(const raft::resources& handle, - const raft::neighbors::mg::nccl_clique& clique, - const std::string& filename) +detail::ann_mg_index, T, IdxT> distribute_pq( + const raft::resources& handle, + const raft::neighbors::mg::nccl_clique& clique, + const std::string& filename) { return mg::detail::distribute_pq(handle, clique, filename); } -} // namespace raft::neighbors::mg \ No newline at end of file +} // namespace raft::neighbors::mg diff --git a/cpp/test/neighbors/ann_mg/test_ann_mg.cu b/cpp/test/neighbors/ann_mg/test_ann_mg.cu index b392f629bc..61ea54b808 100644 --- a/cpp/test/neighbors/ann_mg/test_ann_mg.cu +++ b/cpp/test/neighbors/ann_mg/test_ann_mg.cu @@ -1,10 +1,26 @@ -#include +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "../ann_mg.cuh" +#include + namespace raft::neighbors::mg { typedef AnnMGTest AnnMGTestF_float; - TEST_P(AnnMGTestF_float, AnnMG) { this->testAnnMG(); } - INSTANTIATE_TEST_CASE_P(AnnMGTest, AnnMGTestF_float, ::testing::ValuesIn(inputs)); -} +TEST_P(AnnMGTestF_float, AnnMG) { this->testAnnMG(); } +INSTANTIATE_TEST_CASE_P(AnnMGTest, AnnMGTestF_float, ::testing::ValuesIn(inputs)); +} // namespace raft::neighbors::mg diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json b/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json index e7aecfd239..59f6b7c5d8 100644 --- a/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/mnist-784-euclidean.json @@ -1349,4 +1349,5 @@ "search_result_file" : "result/mnist-784-euclidean/raft_cagra/dim64" } ] -} \ No newline at end of file +} + diff --git a/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json b/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json index 3ca47a2566..5803e9bf7e 100644 --- a/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json +++ b/python/raft-ann-bench/src/raft_ann_bench/run/conf/sift-128-euclidean.json @@ -495,4 +495,5 @@ ] } ] -} \ No newline at end of file +} +