diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 9e625af49e..c561a144e5 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -342,6 +342,7 @@ if(RAFT_ANN_BENCH_USE_DISKANN) ConfigureAnnBench( NAME DISKANN PATH bench/ann/src/diskann/diskann_benchmark.cpp LINKS diskann::diskann ) +endif() # ################################################################################################## # * Dynamically-loading ANN_BENCH executable ------------------------------------------------------- diff --git a/cpp/bench/ann/src/diskann/diskann_benchmark.cpp b/cpp/bench/ann/src/diskann/diskann_benchmark.cpp new file mode 100644 index 0000000000..46652bb092 --- /dev/null +++ b/cpp/bench/ann/src/diskann/diskann_benchmark.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../common/ann_types.hpp" +#include "diskann_wrapper.h" +#include "hnswlib_wrapper.h" + +#define JSON_DIAGNOSTICS 1 +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace raft::bench::ann { + +template +void parse_build_param(const nlohmann::json& conf, + typename raft::bench::ann::DiskANNMemory::BuildParam& param) +{ + param.ef_construction = conf.at("efConstruction"); + param.M = conf.at("M"); + param.R = conf.at("R"); + param.Lb = conf.at("Lb"); + param.alpha = conf.at("alpha"); + if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } + param.use_raft_cagra = conf.at("use_raft_cagra"); + param.filtered_index = false; + if (param.use_raft_cagra) { + raft::neighbors::cagra::index_params cagra_index_params; + cagra_index_params.graph_degree = conf.at("cagra_params_graph_degree"); + cagra_index_params.intermediate_graph_degree = + conf.at("cagra_params_intermediate_graph_degree"); + } +} + +template +void parse_search_param(const nlohmann::json& conf, + typename raft::bench::ann::DiskANNMemory::SearchParam& param) +{ + param.Ls = conf.at("Ls"); + if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } +} + +template class Algo> +std::unique_ptr> make_algo(raft::bench::ann::Metric metric, + int dim, + const nlohmann::json& conf) +{ + typename Algo::BuildParam param; + parse_build_param(conf, param); + return std::make_unique>(metric, dim, param); +} + +template class Algo> +std::unique_ptr> make_algo(raft::bench::ann::Metric metric, + int dim, + const nlohmann::json& conf, + const std::vector& dev_list) +{ + typename Algo::BuildParam param; + parse_build_param(conf, param); + + (void)dev_list; + return std::make_unique>(metric, dim, param); +} + +template +std::unique_ptr> create_algo(const std::string& algo, + const std::string& distance, + int dim, + const nlohmann::json& conf, + const std::vector& dev_list) +{ + // stop compiler warning; not all algorithms support multi-GPU so it may not be used + (void)dev_list; + + raft::bench::ann::Metric metric = parse_metric(distance); + std::unique_ptr> ann; + + if constexpr (std::is_same_v) { + if (algo == "hnswlib") { ann = make_algo(metric, dim, conf); } + } + + if constexpr (std::is_same_v) { + if (algo == "hnswlib") { ann = make_algo(metric, dim, conf); } + } + + if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } + return ann; +} + +template +std::unique_ptr::AnnSearchParam> create_search_param( + const std::string& algo, const nlohmann::json& conf) +{ + if (algo == "hnswlib") { + auto param = std::make_unique::SearchParam>(); + parse_search_param(conf, *param); + return param; + } + // else + throw std::runtime_error("invalid algo: '" + algo + "'"); +} + +}; // namespace raft::bench::ann + +REGISTER_ALGO_INSTANCE(float); +REGISTER_ALGO_INSTANCE(std::int8_t); +REGISTER_ALGO_INSTANCE(std::uint8_t); + +#ifdef ANN_BENCH_BUILD_MAIN +#include "../common/benchmark.hpp" +int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } +#endif \ No newline at end of file diff --git a/cpp/bench/ann/src/diskann/diskann_wrapper.cuh b/cpp/bench/ann/src/diskann/diskann_wrapper.cuh deleted file mode 100644 index ee79541be1..0000000000 --- a/cpp/bench/ann/src/diskann/diskann_wrapper.cuh +++ /dev/null @@ -1,304 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/ann_types.hpp" -#include "../common/thread_pool.hpp" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "program_options_utils.hpp" -#include "raft/neighbors/cagra_types.hpp" - -#include -#include -#include -#include - -#include - -#ifndef _WINDOWS -#include -#include -#else -#include -#endif - -#include "ann_exception.h" -#include "memory_mapper.h" - -#include - -namespace raft::bench::ann { - -diskann::Metric parse_metric_type(raft::bench::ann::Metric metric) -{ - if (metric == raft::bench::ann::Metric::kInnerProduct) { - return diskann::Metric::INNER_PRODUCT; - } else if (metric == raft::bench::ann::Metric::kEuclidean) { - return diskann::Metric::L2; - } else { - throw std::runtime_error("currently only inner product and L2 supported for benchmarking"); - } -} - -template -class DiskANNMemory : public ANN { - public: - optional_configs.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_degree,R", - po::value(&R)->default_value(64), - program_options_utils::MAX_BUILD_DEGREE); - optional_configs.add_options()("cagra_"); - optional_configs.add_options()("Lbuild,L", - po::value(&L)->default_value(100), - program_options_utils::GRAPH_BUILD_COMPLEXITY); - optional_configs.add_options()("alpha", - po::value(&alpha)->default_value(1.2f), - program_options_utils::GRAPH_BUILD_ALPHA); - optional_configs.add_options()("build_PQ_bytes", - po::value(&build_PQ_bytes)->default_value(0), - program_options_utils::BUIlD_GRAPH_PQ_BYTES); - optional_configs.add_options()("use_opq", - po::bool_switch()->default_value(false), - program_options_utils::USE_OPQ); - optional_configs.add_options()("label_file", - po::value(&label_file)->default_value(""), - program_options_utils::LABEL_FILE); - optional_configs.add_options()("universal_label", - po::value(&universal_label)->default_value(""), - program_options_utils::UNIVERSAL_LABEL); - - optional_configs.add_options()("FilteredLbuild", - po::value(&Lf)->default_value(0), - program_options_utils::FILTERED_LBUILD); - optional_configs.add_options()("label_type", - po::value(&label_type)->default_value("uint"), - program_options_utils::LABEL_TYPE_DESCRIPTION); - - struct BuildParam { - uint32_t R; - uint32_t Lb; - float alpha; - int num_threads = omp_get_num_procs(); - bool use_raft_cagra; - bool filtered_index; - raft::neighbors::cagra::index_params cagra_params; - }; - - using typename ANN::AnnSearchParam; - struct SearchParam : public AnnSearchParam { - uint32_t Ls; - int num_threads = omp_get_num_procs(); - }; - - DiskANNMemory(Metric metric, int dim, const BuildParam& param); - - void build(const T* dataset, size_t nrow) override; - - void set_search_param(const AnnSearchParam& param) override; - void search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const override; - - void save(const std::string& path_to_index) const override; - void load(const std::string& path_to_index) override; - - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::Host; - property.query_memory_type = MemoryType::Host; - return property; - } - - private: - void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const; - - using ANN::metric_; - using ANN::dim_; - int num_threads_; -}; - -template -DiskANNMemory::DiskANNMemory(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim) -{ - assert(dim_ > 0); - - this->index_build_params = std::make_shared(diskann::IndexWriteParametersBuilder(param.L, param.R) - .with_filter_list_size(0) - .with_alpha(param.alpha) - .with_saturate_graph(false) - .with_num_threads(param.num_threads) - .build()); - - bool use_pq_build = param.build_PQ_bytes > 0; - this->index_build_config_ = diskann::IndexConfigBuilder() - .with_metric(parse_metric_type(metric_)) - .with_dimension(dim_) - .with_max_points(0) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .with_data_type(diskann::diskann_type_to_name()) - .with_label_type(diskann::diskann_type_to_name()) - .is_dynamic_index(false) - .with_index_write_params(this->index_build_params_) - .is_enable_tags(false) - .is_use_opq(false) - .is_pq_dist_build(use_pq_build) - .with_num_pq_chunks(this->build_PQ_bytes) - .build(); -} - -template -void DiskANNMemory::build(const T* dataset, size_t nrow) -{ - this->index_build_config_.with_max_points(nrow) - - Index::Index(metric_, - dim_, - nrow, - this->index_build_params, - nullptr, - 0, - false, - false, - false, - use_pq_build, - this->build_PQ_bytes, - false); -} - -template -void HnswLib::set_search_param(const AnnSearchParam& param_) -{ - auto param = dynamic_cast(param_); - appr_alg_->ef_ = param.ef; - metric_objective_ = param.metric_objective; - num_threads_ = param.num_threads; - - // Create a pool if multiple query threads have been set and the pool hasn't been created already - bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_); - if (create_pool) { thread_pool_ = std::make_shared(num_threads_); } -} - -template -void DiskANNMemory::search( - const T* query, int batch_size, int k, size_t* indices, float* distances) const -{ - auto f = [&](int i) { - // hnsw can only handle a single vector at a time. - get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k); - }; - if (metric_objective_ == Objective::LATENCY && num_threads_ > 1) { - thread_pool_->submit(f, batch_size); - } else { - for (int i = 0; i < batch_size; i++) { - f(i); - } - } - -#pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < (int64_t)query_num; i++) { - if (filtered_search && !tags) { - std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - auto retval = index->search_with_filters(query + i * query_aligned_dim, - raw_filter, - recall_at, - L, - query_result_ids[test_id].data() + i * recall_at, - query_result_dists[test_id].data() + i * recall_at); - cmp_stats[i] = retval.second; - } else if (metric == diskann::FAST_L2) { - index->search_with_optimized_layout(query + i * query_aligned_dim, - recall_at, - L, - query_result_ids[test_id].data() + i * recall_at); - } else if (tags) { - if (!filtered_search) { - index->search_with_tags(query + i * query_aligned_dim, - recall_at, - L, - query_result_tags.data() + i * recall_at, - nullptr, - res); - } else { - std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - index->search_with_tags(query + i * query_aligned_dim, - recall_at, - L, - query_result_tags.data() + i * recall_at, - nullptr, - res, - true, - raw_filter); - } - - for (int64_t r = 0; r < (int64_t)recall_at; r++) { - query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; - } - } else { - cmp_stats[i] = index - ->search(query + i * query_aligned_dim, - recall_at, - L, - query_result_ids[test_id].data() + i * recall_at) - .second; - } - auto qe = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = qe - qs; - latency_stats[i] = (float)(diff.count() * 1000000); - } -} - -template -void DiskANNMemory::save(const std::string& path_to_index) const -{ - index_->save(path_to_index.c_str()); -} - -template -void DiskANNMemory::load(const std::string& path_to_index) -{ - index_->load(path_to_index.c_str(), num_threads, search_l); -} - -}; // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/diskann/diskann_wrapper.h b/cpp/bench/ann/src/diskann/diskann_wrapper.h new file mode 100644 index 0000000000..606d0acc6e --- /dev/null +++ b/cpp/bench/ann/src/diskann/diskann_wrapper.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../common/ann_types.hpp" +#include "../common/thread_pool.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "program_options_utils.hpp" +#include "raft/neighbors/cagra_types.hpp" + +#include +#include +#include +#include + +#include + +#ifndef _WINDOWS +#include +#include +#else +#include +#endif + +#include "ann_exception.h" +#include "memory_mapper.h" + +#include + +namespace raft::bench::ann { + +diskann::Metric parse_metric_type(raft::bench::ann::Metric metric) +{ + if (metric == raft::bench::ann::Metric::kInnerProduct) { + return diskann::Metric::INNER_PRODUCT; + } else if (metric == raft::bench::ann::Metric::kEuclidean) { + return diskann::Metric::L2; + } else { + throw std::runtime_error("currently only inner product and L2 supported for benchmarking"); + } +} + +template +class DiskANNMemory : public ANN { + public: + struct BuildParam { + uint32_t R; + uint32_t L_build; + float alpha; + int num_threads = omp_get_num_procs(); + bool use_raft_cagra; + bool filtered_index; + raft::neighbors::cagra::index_params cagra_index_params; + }; + + using typename ANN::AnnSearchParam; + struct SearchParam : public AnnSearchParam { + uint32_t L_search; + uint32_t L_load; + int num_threads = omp_get_num_procs(); + }; + + DiskANNMemory(Metric metric, int dim, const BuildParam& param); + + void build(const T* dataset, size_t nrow) override; + + void set_search_param(const AnnSearchParam& param) override; + void search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + + void save(const std::string& path_to_index) const override; + void load(const std::string& path_to_index) override; + + AlgoProperty get_preference() const override + { + AlgoProperty property; + property.dataset_memory_type = MemoryType::Host; + property.query_memory_type = MemoryType::Host; + return property; + } + + private: + bool use_pq_build_; + uint32_t build_PQ_bytes_; + std::shared_ptr diskann_index_write_params_{nullptr}; + std::shared_ptr diskann_index_search_params_{nullptr}; + std::unique_ptr> diskann_index_{nullptr}; + int num_threads_search_; + uint32_t L_search_; +}; + +template +DiskANNMemory::DiskANNMemory(Metric metric, int dim, const BuildParam& param) + : ANN(metric, dim) +{ + assert(dim_ > 0); + + this->diskann_index_write_params_ = std::make_shared( + diskann::IndexWriteParametersBuilder(param.L_build, param.R) + .with_filter_list_size(0) + .with_alpha(param.alpha) + .with_saturate_graph(false) + .with_num_threads(param.num_threads) + .build()); +} + +template +void DiskANNMemory::build(const T* dataset, size_t nrow) +{ + this->diskann_index_ = + std::make_unique>(diskann::Index(parse_metric_type(this->metric_), + this->dim_, + nrow, + this->diskann_index_write_params_, + nullptr, + 0, + false, + false, + false, + this->use_pq_build_, + this->build_PQ_bytes_, + false)); +} + +template +void DiskANNMemory::set_search_param(const AnnSearchParam& param_) +{ + auto param = dynamic_cast(param_); + this->num_threads_search_ = param.num_threads; + L_search_ = param.L_search; +} + +template +void DiskANNMemory::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + omp_set_num_threads(num_threads_search_); +#pragma omp parallel for schedule(dynamic, 1) + for (int64_t i = 0; i < (int64_t)batch_size; i++) { + diskann_index_->search(queries + i * dim_, k, L_search_, neighbors, distances); + } +} + +template +void DiskANNMemory::save(const std::string& path_to_index) const +{ + this->diskann_index_->save(path_to_index.c_str()); +} + +template +void DiskANNMemory::load(const std::string& path_to_index) +{ + this->diskann_index_->load(path_to_index.c_str(), num_threads_search_, L_search_); +} + +}; // namespace raft::bench::ann diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 0c5521d447..3150659650 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -128,17 +128,12 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME RANDOM_BENCH PATH bench/prims/random/make_blobs.cu bench/prims/random/permute.cu - bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp + bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) ConfigureBench( - NAME - SPARSE_BENCH - PATH - bench/prims/sparse/bitmap_to_csr.cu - bench/prims/sparse/convert_csr.cu - bench/prims/sparse/select_k_csr.cu - bench/prims/main.cpp + NAME SPARSE_BENCH PATH bench/prims/sparse/bitmap_to_csr.cu bench/prims/sparse/convert_csr.cu + bench/prims/sparse/select_k_csr.cu bench/prims/main.cpp ) ConfigureBench( diff --git a/cpp/cmake/patches/diskann.diff b/cpp/cmake/patches/diskann.diff new file mode 100644 index 0000000000..bae50b8246 --- /dev/null +++ b/cpp/cmake/patches/diskann.diff @@ -0,0 +1,783 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 3d3d2b8..c775d07 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -23,6 +23,28 @@ set(CMAKE_STANDARD 17) + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + ++cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) ++ ++# ------------- configure rapids-cmake --------------# ++ ++include(cmake/thirdparty/fetch_rapids.cmake) ++include(rapids-cmake) ++include(rapids-cpm) ++include(rapids-cuda) ++include(rapids-export) ++include(rapids-find) ++ ++# ------------- configure project --------------# ++ ++rapids_cuda_init_architectures(${PROJECT_NAME}) ++ ++project(${PROJECT_NAME} LANGUAGES CXX CUDA) ++ ++# ------------- configure raft -----------------# ++ ++rapids_cpm_init() ++include(cmake/thirdparty/get_raft.cmake) ++ + if(NOT MSVC) + set(CMAKE_CXX_COMPILER g++) + endif() +@@ -331,3 +353,7 @@ include(clang-format.cmake) + if(PYBIND) + add_subdirectory(python) + endif() ++ ++if(NOT TARGET raft::raft) ++ find_package(raft COMPONENTS compiled distributed) ++endif() +diff --git a/apps/CMakeLists.txt b/apps/CMakeLists.txt +index e42c0b6..2401163 100644 +--- a/apps/CMakeLists.txt ++++ b/apps/CMakeLists.txt +@@ -2,7 +2,7 @@ + # Licensed under the MIT license. + + set(CMAKE_CXX_STANDARD 17) +-set(CMAKE_COMPILE_WARNING_AS_ERROR ON) ++set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) + + add_executable(build_memory_index build_memory_index.cpp) + target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) +diff --git a/apps/restapi/CMakeLists.txt b/apps/restapi/CMakeLists.txt +index c73b427..de0b794 100644 +--- a/apps/restapi/CMakeLists.txt ++++ b/apps/restapi/CMakeLists.txt +@@ -37,4 +37,4 @@ if(MSVC) + target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options) + else() + target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options) +-endif() +\ No newline at end of file ++endif() +diff --git a/build.sh b/build.sh +new file mode 100755 +index 0000000..fd20a3b +--- /dev/null ++++ b/build.sh +@@ -0,0 +1,36 @@ ++#!/bin/bash ++ ++# NOTE: This file is temporary for the proof-of-concept branch and will be removed before this PR is merged ++ ++BUILD_TYPE=Release ++BUILD_DIR=build/ ++ ++RAFT_REPO_REL="" ++EXTRA_CMAKE_ARGS="" ++set -e ++ ++if [[ ${RAFT_REPO_REL} != "" ]]; then ++ RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`" ++ EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}" ++fi ++ ++if [ "$1" == "clean" ]; then ++ rm -rf build ++ rm -rf .cache ++ exit 0 ++fi ++ ++mkdir -p $BUILD_DIR ++cd $BUILD_DIR ++ ++cmake \ ++ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ ++ -DCMAKE_CUDA_ARCHITECTURES="NATIVE" \ ++ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ ++ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ ++ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ ++ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ ++ ${EXTRA_CMAKE_ARGS} \ ++ ../ ++ ++make -j30 +diff --git a/cmake/thirdparty/fetch_rapids.cmake b/cmake/thirdparty/fetch_rapids.cmake +new file mode 100644 +index 0000000..11d2403 +--- /dev/null ++++ b/cmake/thirdparty/fetch_rapids.cmake +@@ -0,0 +1,21 @@ ++# ============================================================================= ++# Copyright (c) 2023, NVIDIA CORPORATION. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except ++# in compliance with the License. You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software distributed under the License ++# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express ++# or implied. See the License for the specific language governing permissions and limitations under ++# the License. ++ ++# Use this variable to update RAPIDS and RAFT versions ++set(RAPIDS_VERSION "24.06") ++ ++if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) ++ file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake ++ ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) ++endif() ++include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) +diff --git a/cmake/thirdparty/get_raft.cmake b/cmake/thirdparty/get_raft.cmake +new file mode 100644 +index 0000000..6128b5c +--- /dev/null ++++ b/cmake/thirdparty/get_raft.cmake +@@ -0,0 +1,63 @@ ++# ============================================================================= ++# Copyright (c) 2023, NVIDIA CORPORATION. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except ++# in compliance with the License. You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software distributed under the License ++# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express ++# or implied. See the License for the specific language governing permissions and limitations under ++# the License. ++ ++# Use RAPIDS_VERSION from cmake/thirdparty/fetch_rapids.cmake ++set(RAFT_VERSION "${RAPIDS_VERSION}") ++set(RAFT_FORK "rapidsai") ++set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") ++ ++function(find_and_configure_raft) ++ set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES) ++ cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" ++ "${multiValueArgs}" ${ARGN} ) ++ ++ set(RAFT_COMPONENTS "") ++ if(PKG_COMPILE_LIBRARY) ++ string(APPEND RAFT_COMPONENTS " compiled") ++ endif() ++ ++ if(PKG_ENABLE_MNMG_DEPENDENCIES) ++ string(APPEND RAFT_COMPONENTS " distributed") ++ endif() ++ ++ #----------------------------------------------------- ++ # Invoke CPM find_package() ++ #----------------------------------------------------- ++ rapids_cpm_find(raft ${PKG_VERSION} ++ GLOBAL_TARGETS raft::raft ++ BUILD_EXPORT_SET raft-template-exports ++ INSTALL_EXPORT_SET raft-template-exports ++ COMPONENTS ${RAFT_COMPONENTS} ++ CPM_ARGS ++ GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git ++ GIT_TAG ${PKG_PINNED_TAG} ++ SOURCE_SUBDIR cpp ++ OPTIONS ++ "BUILD_TESTS OFF" ++ "BUILD_PRIMS_BENCH OFF" ++ "BUILD_ANN_BENCH OFF" ++ "RAFT_NVTX ${ENABLE_NVTX}" ++ "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" ++ ) ++endfunction() ++ ++# Change pinned tag here to test a commit in CI ++# To use a different RAFT locally, set the CMake variable ++# CPM_raft_SOURCE=/path/to/local/raft ++find_and_configure_raft(VERSION ${RAFT_VERSION}.00 ++ FORK ${RAFT_FORK} ++ PINNED_TAG ${RAFT_PINNED_TAG} ++ COMPILE_LIBRARY ON ++ ENABLE_MNMG_DEPENDENCIES OFF ++ ENABLE_NVTX OFF ++) +diff --git a/include/distance.h b/include/distance.h +index f3b1de2..4e92738 100644 +--- a/include/distance.h ++++ b/include/distance.h +@@ -77,6 +77,7 @@ class DistanceCosineInt8 : public Distance + DistanceCosineInt8() : Distance(diskann::Metric::COSINE) + { + } ++ // using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const; + }; + +@@ -86,6 +87,7 @@ class DistanceL2Int8 : public Distance + DistanceL2Int8() : Distance(diskann::Metric::L2) + { + } ++ // using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size) const; + }; + +@@ -96,6 +98,7 @@ class AVXDistanceL2Int8 : public Distance + AVXDistanceL2Int8() : Distance(diskann::Metric::L2) + { + } ++ // using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const; + }; + +@@ -105,6 +108,7 @@ class DistanceCosineFloat : public Distance + DistanceCosineFloat() : Distance(diskann::Metric::COSINE) + { + } ++ // using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; + }; + +@@ -114,6 +118,7 @@ class DistanceL2Float : public Distance + DistanceL2Float() : Distance(diskann::Metric::L2) + { + } ++ // using Distance::compare; + + #ifdef _WINDOWS + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const; +@@ -128,6 +133,7 @@ class AVXDistanceL2Float : public Distance + AVXDistanceL2Float() : Distance(diskann::Metric::L2) + { + } ++ // using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; + }; + +@@ -146,6 +152,7 @@ class SlowDistanceCosineUInt8 : public Distance + SlowDistanceCosineUInt8() : Distance(diskann::Metric::COSINE) + { + } ++ using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length) const; + }; + +@@ -155,6 +162,7 @@ class DistanceL2UInt8 : public Distance + DistanceL2UInt8() : Distance(diskann::Metric::L2) + { + } ++ // using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size) const; + }; + +@@ -170,6 +178,8 @@ template class DistanceInnerProduct : public Distance + } + inline float inner_product(const T *a, const T *b, unsigned size) const; + ++ // using Distance::compare; ++ + inline float compare(const T *a, const T *b, unsigned size) const + { + float result = inner_product(a, b, size); +@@ -198,6 +208,7 @@ class AVXDistanceInnerProductFloat : public Distance + AVXDistanceInnerProductFloat() : Distance(diskann::Metric::INNER_PRODUCT) + { + } ++ using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; + }; + +@@ -213,6 +224,7 @@ class AVXNormalizedCosineDistanceFloat : public Distance + AVXNormalizedCosineDistanceFloat() : Distance(diskann::Metric::COSINE) + { + } ++ using Distance::compare; + DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const + { + // Inner product returns negative values to indicate distance. +diff --git a/include/index.h b/include/index.h +index b9bf4f3..bd2c07e 100644 +--- a/include/index.h ++++ b/include/index.h +@@ -29,6 +29,11 @@ + #define EXPAND_IF_FULL 0 + #define DEFAULT_MAXC 750 + ++// namespace raft::neighbors::cagra{ ++// template ++// class index; ++// } ++ + namespace diskann + { + +@@ -66,7 +71,7 @@ template clas + const size_t num_frozen_pts = 0, const bool dynamic_index = false, + const bool enable_tags = false, const bool concurrent_consolidate = false, + const bool pq_dist_build = false, const size_t num_pq_chunks = 0, +- const bool use_opq = false, const bool filtered_index = false); ++ const bool use_opq = false, const bool filtered_index = false, const bool raft_cagra_index = false, const std::shared_ptr raft_cagra_index_params = nullptr); + + DISKANN_DLLEXPORT ~Index(); + +@@ -236,6 +241,9 @@ template clas + Index(const Index &) = delete; + Index &operator=(const Index &) = delete; + ++ // Build the raft CAGRA index ++ void build_raft_cagra_index(const T* data); ++ + // Use after _data and _nd have been populated + // Acquire exclusive _update_lock before calling + void build_with_data_populated(const std::vector &tags); +@@ -286,6 +294,8 @@ template clas + // Acquire exclusive _update_lock before calling + void link(); + ++ void add_raft_cagra_nbrs(); ++ + // Acquire exclusive _tag_lock and _delete_lock before calling + int reserve_location(); + +@@ -444,5 +454,11 @@ template clas + std::vector _locks; + + static const float INDEX_GROWTH_FACTOR; ++ ++ // optional around the Raft Cagra index ++ // raft::neighbors::cagra::index* raft_knn_index; ++ bool _raft_cagra_index = false; ++ std::shared_ptr _raft_cagra_index_params = nullptr; ++ std::vector host_cagra_graph; + }; + } // namespace diskann +diff --git a/include/index_config.h b/include/index_config.h +index 452498b..c9110da 100644 +--- a/include/index_config.h ++++ b/include/index_config.h +@@ -1,5 +1,10 @@ + #include "common_includes.h" + #include "parameters.h" ++#include ++ ++namespace raft::neighbors::cagra{ ++struct index_params; ++} + + namespace diskann + { +@@ -41,18 +46,23 @@ struct IndexConfig + // Params for searching index + std::shared_ptr index_search_params; + ++ bool raft_cagra_index; ++ std::shared_ptr raft_cagra_index_params; ++ + private: + IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension, + size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags, +- bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index, ++ bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index, bool raft_cagra_index, + std::string &data_type, const std::string &tag_type, const std::string &label_type, + std::shared_ptr index_write_params, +- std::shared_ptr index_search_params) ++ std::shared_ptr index_search_params, ++ std::shared_ptr raft_cagra_index_params ++ ) + : data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension), + max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build), +- concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index), ++ concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index), raft_cagra_index(raft_cagra_index), + num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type), +- data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params) ++ data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params), raft_cagra_index_params{raft_cagra_index_params} + { + } + +@@ -194,6 +204,18 @@ class IndexConfigBuilder + return *this; + } + ++ IndexConfigBuilder &is_raft_cagra_index(bool is_raft_cagra_index) ++ { ++ this->_raft_cagra_index = is_raft_cagra_index; ++ return *this; ++ } ++ ++ IndexConfigBuilder &with_raft_cagra_index_params(std::shared_ptr raft_cagra_index_params_ptr) ++ { ++ this->_raft_cagra_index_params = raft_cagra_index_params_ptr; ++ return *this; ++ } ++ + IndexConfig build() + { + if (_data_type == "" || _data_type.empty()) +@@ -218,9 +240,9 @@ class IndexConfigBuilder + } + + return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks, +- _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate, ++ _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate, _raft_cagra_index, + _use_opq, _filtered_index, _data_type, _tag_type, _label_type, _index_write_params, +- _index_search_params); ++ _index_search_params, _raft_cagra_index_params); + } + + IndexConfigBuilder(const IndexConfigBuilder &) = delete; +@@ -240,6 +262,7 @@ class IndexConfigBuilder + bool _concurrent_consolidate = false; + bool _use_opq = false; + bool _filtered_index{defaults::HAS_LABELS}; ++ bool _raft_cagra_index = false; + + size_t _num_pq_chunks = 0; + size_t _num_frozen_pts{defaults::NUM_FROZEN_POINTS_STATIC}; +@@ -250,5 +273,6 @@ class IndexConfigBuilder + + std::shared_ptr _index_write_params; + std::shared_ptr _index_search_params; ++ std::shared_ptr _raft_cagra_index_params; + }; + } // namespace diskann +diff --git a/include/index_factory.h b/include/index_factory.h +index 80bc40d..138adcb 100644 +--- a/include/index_factory.h ++++ b/include/index_factory.h +@@ -46,4 +46,4 @@ class IndexFactory + std::unique_ptr _config; + }; + +-} // namespace diskann ++} // namespace diskann +\ No newline at end of file +diff --git a/include/utils.h b/include/utils.h +index d3af5c3..2cb2181 100644 +--- a/include/utils.h ++++ b/include/utils.h +@@ -1,4 +1,4 @@ +-// Copyright (c) Microsoft Corporation. All rights reserved. ++// Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT license. + + #pragma once +@@ -29,6 +29,7 @@ typedef int FileHandle; + #include "types.h" + #include "tag_uint128.h" + #include ++#include + + #ifdef EXEC_ENV_OLS + #include "content_buf.h" +diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt +index cbca264..5bec876 100644 +--- a/src/CMakeLists.txt ++++ b/src/CMakeLists.txt +@@ -2,14 +2,14 @@ + #Licensed under the MIT license. + + set(CMAKE_CXX_STANDARD 17) +-set(CMAKE_COMPILE_WARNING_AS_ERROR ON) ++set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) + + if(MSVC) + add_subdirectory(dll) + else() + #file(GLOB CPP_SOURCES *.cpp) + set(CPP_SOURCES abstract_data_store.cpp ann_exception.cpp disk_utils.cpp +- distance.cpp index.cpp in_mem_graph_store.cpp in_mem_data_store.cpp ++ distance.cpp index.cu in_mem_graph_store.cpp in_mem_data_store.cpp + linux_aligned_file_reader.cpp math_utils.cpp natural_number_map.cpp + in_mem_data_store.cpp in_mem_graph_store.cpp + natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp +@@ -19,6 +19,9 @@ else() + endif() + add_library(${PROJECT_NAME} ${CPP_SOURCES}) + add_library(${PROJECT_NAME}_s STATIC ${CPP_SOURCES}) ++ ++ target_link_libraries(${PROJECT_NAME} PRIVATE raft::raft raft::compiled) ++ target_link_libraries(${PROJECT_NAME}_s PRIVATE raft::raft raft::compiled) + endif() + + if (NOT MSVC) +diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt +index 096d1b7..e36fe7c 100644 +--- a/src/dll/CMakeLists.txt ++++ b/src/dll/CMakeLists.txt +@@ -2,7 +2,7 @@ + #Licensed under the MIT license. + + add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp +- ../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp ++ ../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cu + ../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + +@@ -32,4 +32,4 @@ foreach(RUNTIME_FILE ${RUNTIME_FILES_TO_COPY}) + add_custom_command(TARGET ${PROJECT_NAME} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${RUNTIME_FILE}" "${TARGET_DIR}") +-endforeach() +\ No newline at end of file ++endforeach() +diff --git a/src/index.cpp b/src/index.cu +similarity index 96% +rename from src/index.cpp +rename to src/index.cu +index bf93344..0f9571d 100644 +--- a/src/index.cpp ++++ b/src/index.cu +@@ -1,6 +1,7 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT license. + ++#include + #include + + #include +@@ -22,11 +23,27 @@ + #endif + + #include "index.h" ++#include ++#include + + #define MAX_POINTS_FOR_USING_BITSET 10000000 + + namespace diskann + { ++ ++raft::distance::DistanceType parse_metric_to_raft(diskann::Metric m) ++{ ++ switch (m) ++ { ++ case diskann::Metric::L2: ++ return raft::distance::DistanceType::L2Expanded; ++ case diskann::Metric::INNER_PRODUCT: ++ return raft::distance::DistanceType::InnerProduct; ++ default: ++ throw ANNException("ERROR: RAFT only supports L2 and INNER_PRODUCT.", -1, __FUNCSIG__, __FILE__, __LINE__); ++ } ++} ++ + // Initialize an index with metric m, load the data of type T with filename + // (bin), and initialize max_points + template +@@ -38,7 +55,8 @@ Index::Index(const IndexConfig &index_config, std::shared_ptr), _conc_consolidate(index_config.concurrent_consolidate) ++ _delete_set(new tsl::robin_set), _conc_consolidate(index_config.concurrent_consolidate), ++ _raft_cagra_index(index_config.raft_cagra_index) + { + if (_dynamic_index && !_enable_tags) + { +@@ -109,6 +127,21 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrget_dims()); + } + } ++ ++ if (_raft_cagra_index) ++ { ++ if (index_config.raft_cagra_index_params != nullptr) ++ { ++ assert(parse_metric_to_raft(_dist_metric) == raft_cagra_index_params->metric); ++ _raft_cagra_index_params = index_config.raft_cagra_index_params; ++ } ++ else ++ { ++ raft::neighbors::cagra::index_params raft_cagra_index_params; ++ raft_cagra_index_params.metric = parse_metric_to_raft(_dist_metric); ++ _raft_cagra_index_params = std::make_shared(raft_cagra_index_params); ++ } ++ } + } + + template +@@ -117,7 +150,8 @@ Index::Index(Metric m, const size_t dim, const size_t max_point + const std::shared_ptr index_search_params, const size_t num_frozen_pts, + const bool dynamic_index, const bool enable_tags, const bool concurrent_consolidate, + const bool pq_dist_build, const size_t num_pq_chunks, const bool use_opq, +- const bool filtered_index) ++ const bool filtered_index, const bool raft_cagra_index, ++ const std::shared_ptr raft_cagra_index_params) + : Index( + IndexConfigBuilder() + .with_metric(m) +@@ -134,6 +168,8 @@ Index::Index(Metric m, const size_t dim, const size_t max_point + .is_use_opq(use_opq) + .is_filtered(filtered_index) + .with_data_type(diskann_type_to_name()) ++ .is_raft_cagra_index(raft_cagra_index) ++ .with_raft_cagra_index_params(raft_cagra_index_params) + .build(), + IndexFactory::construct_datastore(DataStoreStrategy::MEMORY, + (max_points == 0 ? (size_t)1 : max_points) + +@@ -732,6 +768,7 @@ template int Index + + template uint32_t Index::calculate_entry_point() + { ++ std::cout << "inside calculate entry point" << std::endl; + // REFACTOR TODO: This function does not support multi-threaded calculation of medoid. + // Must revisit if perf is a concern. + return _data_store->calculate_medoid(); +@@ -739,6 +776,7 @@ template uint32_t Index std::vector Index::get_init_ids() + { ++ // std::cout << "num_frozen_pts" << _num_frozen_pts << std::endl; + std::vector init_ids; + init_ids.reserve(1 + _num_frozen_pts); + +@@ -839,6 +877,8 @@ std::pair Index::iterate_to_fixed_point( + _pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch); + }; + ++ // raft::print_host_vector("init_ids", init_ids.data(), init_ids.size(), std::cout); ++ + // Initialize the candidate pool with starting points + for (auto id : init_ids) + { +@@ -1371,6 +1411,56 @@ template void Index void Index::add_raft_cagra_nbrs() ++{ ++ uint32_t num_threads = _indexingThreads; ++ if (num_threads != 0) ++ omp_set_num_threads(num_threads); ++ ++ assert(_num_frozen_pts == 0); ++ ++ /* visit_order is a vector that is initialized to the entire graph */ ++ std::vector visit_order; ++ tsl::robin_set visited; ++ visit_order.reserve(_nd + _num_frozen_pts); ++ for (uint32_t i = 0; i < (uint32_t)_nd; i++) ++ { ++ visit_order.emplace_back(i); ++ } ++ ++ // if there are frozen points, the first such one is set to be the _start ++ if (_num_frozen_pts > 0) ++ _start = (uint32_t)_max_points; ++ else ++ _start = calculate_entry_point(); ++ ++#pragma omp parallel for schedule(dynamic, 2048) ++ for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); node_ctr++) ++ { ++ auto node = visit_order[node_ctr]; ++ ++ std::vector cagra_nbrs(_indexingRange); ++ uint32_t *nbr_start_ptr = host_cagra_graph.data() + node * _indexingRange; ++ uint32_t *nbr_end_ptr = nbr_start_ptr + _indexingRange; ++ std::copy(nbr_start_ptr, nbr_end_ptr, cagra_nbrs.data()); ++ ++ assert(cagra_nbrs.size() > 0); ++ ++ { ++ LockGuard guard(_locks[node]); ++ ++ _graph_store->set_neighbours(node, cagra_nbrs); ++ assert(_graph_store->get_neighbours((location_t)node).size() <= _indexingRange); ++ } ++ ++ if (node_ctr % 100000 == 0) ++ { ++ diskann::cout << "\r" << (100.0 * node_ctr) / (visit_order.size()) << "% of index build completed." ++ << std::flush; ++ } ++ } ++} ++ + template + void Index::prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion_size, + const float alpha) +@@ -1448,8 +1538,6 @@ void Index::set_start_points(const T *data, size_t data_count) + if (data_count != _num_frozen_pts * _dim) + throw ANNException("Invalid number of points", -1, __FUNCSIG__, __FILE__, __LINE__); + +- // memcpy(_data + _aligned_dim * _max_points, data, _aligned_dim * +- // sizeof(T) * _num_frozen_pts); + for (location_t i = 0; i < _num_frozen_pts; i++) + { + _data_store->set_vector((location_t)(i + _max_points), data + i * _dim); +@@ -1505,6 +1593,24 @@ void Index::set_start_points_at_random(T radius, uint32_t rando + set_start_points(points_data.data(), points_data.size()); + } + ++template void Index::build_raft_cagra_index(const T *data) ++{ ++ raft::device_resources handle; ++ auto dataset_view = raft::make_host_matrix_view(data, int64_t(_nd), _dim); ++ auto raft_knn_index = raft::neighbors::cagra::build(handle, *_raft_cagra_index_params, dataset_view); ++ ++ auto stream = handle.get_stream(); ++ auto device_graph = raft_knn_index.graph(); ++ host_cagra_graph.resize(device_graph.extent(0) * device_graph.extent(1)); ++ ++ std::cout << "host_cagra_graph_size" << host_cagra_graph.size() << std::endl; ++ ++ thrust::copy(thrust::device_ptr(device_graph.data_handle()), ++ thrust::device_ptr(device_graph.data_handle() + device_graph.size()), ++ host_cagra_graph.data()); ++ handle.sync_stream(); ++} ++ + template + void Index::build_with_data_populated(const std::vector &tags) + { +@@ -1542,7 +1648,14 @@ void Index::build_with_data_populated(const std::vector & + } + + generate_frozen_point(); +- link(); ++ if (_raft_cagra_index) ++ { ++ add_raft_cagra_nbrs(); ++ } ++ else ++ { ++ link(); ++ } + + size_t max = 0, min = SIZE_MAX, total = 0, cnt = 0; + for (size_t i = 0; i < _nd; i++) +@@ -1559,6 +1672,7 @@ void Index::build_with_data_populated(const std::vector & + + _has_built = true; + } ++ + template + void Index::_build(const DataType &data, const size_t num_points_to_load, TagVector &tags) + { +@@ -1597,6 +1711,8 @@ void Index::build(const T *data, const size_t num_points_to_loa + _data_store->populate_data(data, (location_t)num_points_to_load); + } + ++ build_raft_cagra_index(data); ++ + build_with_data_populated(tags); + } + +@@ -1683,6 +1799,9 @@ void Index::build(const char *filename, const size_t num_points + std::unique_lock tl(_tag_lock); + _nd = num_points_to_load; + } ++ ++ auto _in_mem_data_store = std::static_pointer_cast>(_data_store); ++ build_raft_cagra_index(_in_mem_data_store->_data); + build_with_data_populated(tags); + } + +diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt +index 6af8405..a44caab 100644 +--- a/tests/CMakeLists.txt ++++ b/tests/CMakeLists.txt +@@ -38,4 +38,3 @@ add_executable(${PROJECT_NAME}_unit_tests ${DISKANN_SOURCES} ${DISKANN_UNIT_TEST + target_link_libraries(${PROJECT_NAME}_unit_tests ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::unit_test_framework) + + add_test(NAME ${PROJECT_NAME}_unit_tests COMMAND ${PROJECT_NAME}_unit_tests) +- diff --git a/cpp/cmake/thirdparty/get_diskann.cmake b/cpp/cmake/thirdparty/get_diskann.cmake index 5cbae0ec18..0309d18c3b 100644 --- a/cpp/cmake/thirdparty/get_diskann.cmake +++ b/cpp/cmake/thirdparty/get_diskann.cmake @@ -18,21 +18,33 @@ function(find_and_configure_diskann) set(oneValueArgs VERSION REPOSITORY PINNED_TAG) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) + + set(patch_files_to_run "${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/diskann.diff") + set(patch_issues_to_ref "fix compile issues") + set(patch_script "${CMAKE_BINARY_DIR}/rapids-cmake/patches/diskann/patch.cmake") + set(log_file "${CMAKE_BINARY_DIR}/rapids-cmake/patches/diskann/log") + string(TIMESTAMP current_year "%Y" UTC) + configure_file(${rapids-cmake-dir}/cpm/patches/command_template.cmake.in "${patch_script}" + @ONLY) - rapids_cpm_find(diskann ${PKG_VERSION} + rapids_cpm_find(diskann GLOBAL_TARGETS diskann::diskann CPM_ARGS GIT_REPOSITORY ${PKG_REPOSITORY} GIT_TAG ${PKG_PINNED_TAG} ) + + if(TARGET diskann AND NOT TARGET diskann::diskann) + add_library(diskann::diskann ALIAS diskann) + endif() endfunction() if(NOT RAFT_DISKANN_GIT_TAG) - set(RAFT_DISKANN_GIT_TAG cagra_int) + set(RAFT_DISKANN_GIT_TAG main) endif() if(NOT RAFT_DISKANN_GIT_REPOSITORY) - # set(RAFT_FAISS_GIT_REPOSITORY https://github.com/tarang-jain/DiskANN.git) + set(RAFT_FAISS_GIT_REPOSITORY https://github.com/microsoft/DiskANN.git) endif() find_and_configure_diskann(VERSION 0.7.0 diff --git a/rapids_config.cmake b/rapids_config.cmake index c8077f7f4b..a40d7130c0 100644 --- a/rapids_config.cmake +++ b/rapids_config.cmake @@ -22,13 +22,15 @@ else() string(REPLACE "\n" "\n " _rapids_version_formatted " ${_rapids_version}") message( FATAL_ERROR - "Could not determine RAPIDS version. Contents of VERSION file:\n${_rapids_version_formatted}") + "Could not determine RAPIDS version. Contents of VERSION file:\n${_rapids_version_formatted}" + ) endif() if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake") file( DOWNLOAD "https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION_MAJOR_MINOR}/RAPIDS.cmake" - "${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake") + "${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake" + ) endif() include("${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake")