From 0f6ce2223b847b2992692763a07af9654e5feeb3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 16 Nov 2023 15:16:03 +0100 Subject: [PATCH] ANN bench options to specify CAGRA graph and dataset locations (#1896) CAGRA index can be constructed using existing device_mdarrays, in which case just reference to these arrays are stored. This way allocations are managed outside the index, and we can customize how these allocations are made. This PR - modifies the CAGRA ANN bench wrapper to manage the allocations locally, - add options for the json file to specify whether the graph / dataset is allocated in device / host_pinned / host_huge_page memory. Authors: - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) - Robert Maynard (https://github.com/robertmaynard) Approvers: - Micka (https://github.com/lowener) - Divye Gala (https://github.com/divyegala) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1896 --- cpp/bench/ann/src/common/benchmark.hpp | 20 ++- .../src/common/cuda_huge_page_resource.hpp | 132 ++++++++++++++++++ .../ann/src/common/cuda_pinned_resource.hpp | 130 +++++++++++++++++ .../src/raft/raft_ann_bench_param_parser.h | 21 +++ cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 99 ++++++++++++- cpp/include/raft/neighbors/cagra_types.hpp | 27 +--- .../raft/neighbors/detail/cagra/utils.hpp | 34 +++++ cpp/include/raft/util/integer_utils.hpp | 2 +- docs/source/ann_benchmarks_param_tuning.md | 9 +- .../src/raft-ann-bench/run/__main__.py | 87 +++++++++++- 10 files changed, 522 insertions(+), 39 deletions(-) create mode 100644 cpp/bench/ann/src/common/cuda_huge_page_resource.hpp create mode 100644 cpp/bench/ann/src/common/cuda_pinned_resource.hpp diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 7db5eab194..a2e77323c1 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -21,6 +21,7 @@ #include "util.hpp" #include +#include #include #include @@ -131,7 +132,7 @@ void bench_build(::benchmark::State& state, log_info("Overwriting file: %s", index.file.c_str()); } else { return state.SkipWithMessage( - "Index file already exists (use --overwrite to overwrite the index)."); + "Index file already exists (use --force to overwrite the index)."); } } @@ -380,7 +381,7 @@ inline void printf_usage() ::benchmark::PrintDefaultHelp(); fprintf(stdout, " [--build|--search] \n" - " [--overwrite]\n" + " [--force]\n" " [--data_prefix=]\n" " [--index_prefix=]\n" " [--override_kv=]\n" @@ -392,7 +393,7 @@ inline void printf_usage() " --build: build mode, will build index\n" " --search: search mode, will search using the built index\n" " one and only one of --build and --search should be specified\n" - " --overwrite: force overwriting existing index files\n" + " --force: force overwriting existing index files\n" " --data_prefix=:" " prepend to dataset file paths specified in the .json (default = " "'data/').\n" @@ -572,6 +573,8 @@ inline auto run_main(int argc, char** argv) -> int std::string mode = "latency"; std::string threads_arg_txt = ""; std::vector threads = {1, -1}; // min_thread, max_thread + std::string log_level_str = ""; + int raft_log_level = raft::logger::get(RAFT_NAME).get_level(); kv_series override_kv{}; char arg0_default[] = "benchmark"; // NOLINT @@ -589,14 +592,19 @@ inline auto run_main(int argc, char** argv) -> int std::ifstream conf_stream(conf_path); for (int i = 1; i < argc; i++) { - if (parse_bool_flag(argv[i], "--overwrite", force_overwrite) || + if (parse_bool_flag(argv[i], "--force", force_overwrite) || parse_bool_flag(argv[i], "--build", build_mode) || parse_bool_flag(argv[i], "--search", search_mode) || parse_string_flag(argv[i], "--data_prefix", data_prefix) || parse_string_flag(argv[i], "--index_prefix", index_prefix) || parse_string_flag(argv[i], "--mode", mode) || parse_string_flag(argv[i], "--override_kv", new_override_kv) || - parse_string_flag(argv[i], "--threads", threads_arg_txt)) { + parse_string_flag(argv[i], "--threads", threads_arg_txt) || + parse_string_flag(argv[i], "--raft_log_level", log_level_str)) { + if (!log_level_str.empty()) { + raft_log_level = std::stoi(log_level_str); + log_level_str = ""; + } if (!threads_arg_txt.empty()) { auto threads_arg = split(threads_arg_txt, ':'); threads[0] = std::stoi(threads_arg[0]); @@ -625,6 +633,8 @@ inline auto run_main(int argc, char** argv) -> int } } + raft::logger::get(RAFT_NAME).set_level(raft_log_level); + Objective metric_objective = Objective::LATENCY; if (mode == "throughput") { metric_objective = Objective::THROUGHPUT; } diff --git a/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp b/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp new file mode 100644 index 0000000000..9132db7c04 --- /dev/null +++ b/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include + +namespace raft::mr { +/** + * @brief `device_memory_resource` derived class that uses mmap to allocate memory. + * This class enables memory allocation using huge pages. + * It is assumed that the allocated memory is directly accessible on device. This currently only + * works on GH systems. + * + * TODO(tfeher): consider improving or removing this helper once we made progress with + * https://github.com/rapidsai/raft/issues/1819 + */ +class cuda_huge_page_resource final : public rmm::mr::device_memory_resource { + public: + cuda_huge_page_resource() = default; + ~cuda_huge_page_resource() override = default; + cuda_huge_page_resource(cuda_huge_page_resource const&) = default; + cuda_huge_page_resource(cuda_huge_page_resource&&) = default; + cuda_huge_page_resource& operator=(cuda_huge_page_resource const&) = default; + cuda_huge_page_resource& operator=(cuda_huge_page_resource&&) = default; + + /** + * @brief Query whether the resource supports use of non-null CUDA streams for + * allocation/deallocation. `cuda_huge_page_resource` does not support streams. + * + * @returns bool false + */ + [[nodiscard]] bool supports_streams() const noexcept override { return false; } + + /** + * @brief Query whether the resource supports the get_mem_info API. + * + * @return true + */ + [[nodiscard]] bool supports_get_mem_info() const noexcept override { return true; } + + private: + /** + * @brief Allocates memory of size at least `bytes` using cudaMalloc. + * + * The returned pointer has at least 256B alignment. + * + * @note Stream argument is ignored + * + * @throws `rmm::bad_alloc` if the requested allocation could not be fulfilled + * + * @param bytes The size, in bytes, of the allocation + * @return void* Pointer to the newly allocated memory + */ + void* do_allocate(std::size_t bytes, rmm::cuda_stream_view) override + { + void* _addr{nullptr}; + _addr = mmap(NULL, bytes, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (_addr == MAP_FAILED) { RAFT_FAIL("huge_page_resource::MAP FAILED"); } + if (madvise(_addr, bytes, MADV_HUGEPAGE) == -1) { + munmap(_addr, bytes); + RAFT_FAIL("huge_page_resource::madvise MADV_HUGEPAGE"); + } + memset(_addr, 0, bytes); + return _addr; + } + + /** + * @brief Deallocate memory pointed to by \p p. + * + * @note Stream argument is ignored. + * + * @throws Nothing. + * + * @param p Pointer to be deallocated + */ + void do_deallocate(void* ptr, std::size_t size, rmm::cuda_stream_view) override + { + if (munmap(ptr, size) == -1) { RAFT_FAIL("huge_page_resource::munmap"); } + } + + /** + * @brief Compare this resource to another. + * + * Two cuda_huge_page_resources always compare equal, because they can each + * deallocate memory allocated by the other. + * + * @throws Nothing. + * + * @param other The other resource to compare to + * @return true If the two resources are equivalent + * @return false If the two resources are not equal + */ + [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override + { + return dynamic_cast(&other) != nullptr; + } + + /** + * @brief Get free and available memory for memory resource + * + * @throws `rmm::cuda_error` if unable to retrieve memory info. + * + * @return std::pair contaiing free_size and total_size of memory + */ + [[nodiscard]] std::pair do_get_mem_info( + rmm::cuda_stream_view) const override + { + std::size_t free_size{}; + std::size_t total_size{}; + RMM_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size)); + return std::make_pair(free_size, total_size); + } +}; +} // namespace raft::mr \ No newline at end of file diff --git a/cpp/bench/ann/src/common/cuda_pinned_resource.hpp b/cpp/bench/ann/src/common/cuda_pinned_resource.hpp new file mode 100644 index 0000000000..28ca691f86 --- /dev/null +++ b/cpp/bench/ann/src/common/cuda_pinned_resource.hpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include +#include + +#include + +namespace raft::mr { +/** + * @brief `device_memory_resource` derived class that uses cudaMallocHost/Free for + * allocation/deallocation. + * + * This is almost the same as rmm::mr::host::pinned_memory_resource, but it has + * device_memory_resource as base class. Pinned memory can be accessed from device, + * and using this allocator we can create device_mdarray backed by pinned allocator. + * + * TODO(tfeher): it would be preferred to just rely on the existing allocator from rmm + * (pinned_memory_resource), but that is incompatible with the container_policy class + * for device matrix, because the latter expects a device_memory_resource. We shall + * revise this once we progress with Issue https://github.com/rapidsai/raft/issues/1819 + */ +class cuda_pinned_resource final : public rmm::mr::device_memory_resource { + public: + cuda_pinned_resource() = default; + ~cuda_pinned_resource() override = default; + cuda_pinned_resource(cuda_pinned_resource const&) = default; + cuda_pinned_resource(cuda_pinned_resource&&) = default; + cuda_pinned_resource& operator=(cuda_pinned_resource const&) = default; + cuda_pinned_resource& operator=(cuda_pinned_resource&&) = default; + + /** + * @brief Query whether the resource supports use of non-null CUDA streams for + * allocation/deallocation. `cuda_pinned_resource` does not support streams. + * + * @returns bool false + */ + [[nodiscard]] bool supports_streams() const noexcept override { return false; } + + /** + * @brief Query whether the resource supports the get_mem_info API. + * + * @return true + */ + [[nodiscard]] bool supports_get_mem_info() const noexcept override { return true; } + + private: + /** + * @brief Allocates memory of size at least `bytes` using cudaMalloc. + * + * The returned pointer has at least 256B alignment. + * + * @note Stream argument is ignored + * + * @throws `rmm::bad_alloc` if the requested allocation could not be fulfilled + * + * @param bytes The size, in bytes, of the allocation + * @return void* Pointer to the newly allocated memory + */ + void* do_allocate(std::size_t bytes, rmm::cuda_stream_view) override + { + void* ptr{nullptr}; + RMM_CUDA_TRY_ALLOC(cudaMallocHost(&ptr, bytes)); + return ptr; + } + + /** + * @brief Deallocate memory pointed to by \p p. + * + * @note Stream argument is ignored. + * + * @throws Nothing. + * + * @param p Pointer to be deallocated + */ + void do_deallocate(void* ptr, std::size_t, rmm::cuda_stream_view) override + { + RMM_ASSERT_CUDA_SUCCESS(cudaFreeHost(ptr)); + } + + /** + * @brief Compare this resource to another. + * + * Two cuda_pinned_resources always compare equal, because they can each + * deallocate memory allocated by the other. + * + * @throws Nothing. + * + * @param other The other resource to compare to + * @return true If the two resources are equivalent + * @return false If the two resources are not equal + */ + [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override + { + return dynamic_cast(&other) != nullptr; + } + + /** + * @brief Get free and available memory for memory resource + * + * @throws `rmm::cuda_error` if unable to retrieve memory info. + * + * @return std::pair contaiing free_size and total_size of memory + */ + [[nodiscard]] std::pair do_get_mem_info( + rmm::cuda_stream_view) const override + { + std::size_t free_size{}; + std::size_t total_size{}; + RMM_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size)); + return std::make_pair(free_size, total_size); + } +}; +} // namespace raft::mr \ No newline at end of file diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 479a90e3b5..1eb0e53cc5 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -206,6 +206,21 @@ void parse_build_param(const nlohmann::json& conf, } } +raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) +{ + if (mem_type == "device") { + return raft::bench::ann::AllocatorType::Device; + } else if (mem_type == "host_pinned") { + return raft::bench::ann::AllocatorType::HostPinned; + } else if (mem_type == "host_huge_page") { + return raft::bench::ann::AllocatorType::HostHugePage; + } + THROW( + "Invalid value for memory type %s, must be one of [\"device\", \"host_pinned\", " + "\"host_huge_page\"", + mem_type.c_str()); +} + template void parse_search_param(const nlohmann::json& conf, typename raft::bench::ann::RaftCagra::SearchParam& param) @@ -227,5 +242,11 @@ void parse_search_param(const nlohmann::json& conf, THROW("Invalid value for algo: %s", tmp.c_str()); } } + if (conf.contains("graph_memory_type")) { + param.graph_mem = parse_allocator(conf.at("graph_memory_type")); + } + if (conf.contains("internal_dataset_memory_type")) { + param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); + } } #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index bf526101be..a3e481ec5a 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -42,8 +42,15 @@ #include "raft_ann_bench_utils.h" #include +#include "../common/cuda_huge_page_resource.hpp" +#include "../common/cuda_pinned_resource.hpp" + +#include +#include + namespace raft::bench::ann { +enum class AllocatorType { HostPinned, HostHugePage, Device }; template class RaftCagra : public ANN { public: @@ -51,6 +58,8 @@ class RaftCagra : public ANN { struct SearchParam : public AnnSearchParam { raft::neighbors::experimental::cagra::search_params p; + AllocatorType graph_mem = AllocatorType::Device; + AllocatorType dataset_mem = AllocatorType::Device; auto needs_dataset() const -> bool override { return true; } }; @@ -64,7 +73,16 @@ class RaftCagra : public ANN { }; RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) - : ANN(metric, dim), index_params_(param), dimension_(dim), handle_(cudaStreamPerThread) + : ANN(metric, dim), + index_params_(param), + dimension_(dim), + handle_(cudaStreamPerThread), + need_dataset_update_(true), + dataset_(make_device_matrix(handle_, 0, 0)), + graph_(make_device_matrix(handle_, 0, 0)), + input_dataset_v_(nullptr, 0, 0), + graph_mem_(AllocatorType::Device), + dataset_mem_(AllocatorType::Device) { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); @@ -101,12 +119,28 @@ class RaftCagra : public ANN { void save_to_hnswlib(const std::string& file) const; private: + inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type) + { + switch (mem_type) { + case (AllocatorType::HostPinned): return &mr_pinned_; + case (AllocatorType::HostHugePage): return &mr_huge_page_; + default: return rmm::mr::get_current_device_resource(); + } + } + raft ::mr::cuda_pinned_resource mr_pinned_; + raft ::mr::cuda_huge_page_resource mr_huge_page_; raft::device_resources handle_; + AllocatorType graph_mem_; + AllocatorType dataset_mem_; BuildParam index_params_; + bool need_dataset_update_; raft::neighbors::cagra::search_params search_params_; std::optional> index_; int device_; int dimension_; + raft::device_matrix graph_; + raft::device_matrix dataset_; + raft::device_matrix_view input_dataset_v_; }; template @@ -127,18 +161,77 @@ void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t) return; } +inline std::string allocator_to_string(AllocatorType mem_type) +{ + if (mem_type == AllocatorType::Device) { + return "device"; + } else if (mem_type == AllocatorType::HostPinned) { + return "host_pinned"; + } else if (mem_type == AllocatorType::HostHugePage) { + return "host_huge_page"; + } + return ""; +} + template void RaftCagra::set_search_param(const AnnSearchParam& param) { auto search_param = dynamic_cast(param); search_params_ = search_param.p; + if (search_param.graph_mem != graph_mem_) { + // Move graph to correct memory space + graph_mem_ = search_param.graph_mem; + RAFT_LOG_INFO("moving graph to new memory space: %s", allocator_to_string(graph_mem_).c_str()); + // We create a new graph and copy to it from existing graph + auto mr = get_mr(graph_mem_); + auto new_graph = make_device_mdarray( + handle_, mr, make_extents(index_->graph().extent(0), index_->graph_degree())); + + raft::copy(new_graph.data_handle(), + index_->graph().data_handle(), + index_->graph().size(), + resource::get_cuda_stream(handle_)); + + index_->update_graph(handle_, make_const_mdspan(new_graph.view())); + // update_graph() only stores a view in the index. We need to keep the graph object alive. + graph_ = std::move(new_graph); + } + + if (search_param.dataset_mem != dataset_mem_ || need_dataset_update_) { + dataset_mem_ = search_param.dataset_mem; + + // First free up existing memory + dataset_ = make_device_matrix(handle_, 0, 0); + index_->update_dataset(handle_, make_const_mdspan(dataset_.view())); + + // Allocate space using the correct memory resource. + RAFT_LOG_INFO("moving dataset to new memory space: %s", + allocator_to_string(dataset_mem_).c_str()); + + auto mr = get_mr(dataset_mem_); + raft::neighbors::cagra::detail::copy_with_padding(handle_, dataset_, input_dataset_v_, mr); + + index_->update_dataset(handle_, make_const_mdspan(dataset_.view())); + + // Ideally, instead of dataset_.view(), we should pass a strided matrix view to update. + // See Issue https://github.com/rapidsai/raft/issues/1972 for details. + // auto dataset_view = make_device_strided_matrix_view( + // dataset_.data_handle(), dataset_.extent(0), this->dim_, dataset_.extent(1)); + // index_->update_dataset(handle_, dataset_view); + need_dataset_update_ = false; + } } template void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) { - index_->update_dataset(handle_, - raft::make_host_matrix_view(dataset, nrow, this->dim_)); + // It can happen that we are re-using a previous algo object which already has + // the dataset set. Check if we need update. + if (static_cast(input_dataset_v_.extent(0)) != nrow || + input_dataset_v_.data_handle() != dataset) { + input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); + need_dataset_update_ = true; + } } template diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 4db08110b9..e8a0b8a7bd 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -330,32 +331,8 @@ struct index : ann::index { void copy_padded(raft::resources const& res, mdspan, row_major, data_accessor> dataset) { - size_t padded_dim = round_up_safe(dataset.extent(1) * sizeof(T), 16) / sizeof(T); + detail::copy_with_padding(res, dataset_, dataset); - if ((dataset_.extent(0) != dataset.extent(0)) || - (static_cast(dataset_.extent(1)) != padded_dim)) { - // clear existing memory before allocating to prevent OOM errors on large datasets - if (dataset_.size()) { dataset_ = make_device_matrix(res, 0, 0); } - dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim); - } - if (dataset_.extent(1) == dataset.extent(1)) { - raft::copy(dataset_.data_handle(), - dataset.data_handle(), - dataset.size(), - resource::get_cuda_stream(res)); - } else { - // copy with padding - RAFT_CUDA_TRY(cudaMemsetAsync( - dataset_.data_handle(), 0, dataset_.size() * sizeof(T), resource::get_cuda_stream(res))); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(dataset_.data_handle(), - sizeof(T) * dataset_.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.extent(1), - sizeof(T) * dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - } dataset_view_ = make_device_strided_matrix_view( dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1)); RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu", diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 22cbe6bbac..5e57a9589f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include namespace raft::neighbors::cagra::detail { @@ -245,4 +247,36 @@ class host_matrix_view_from_device { device_matrix_view device_view_; T* host_ptr; }; + +// Copy matrix src to dst. pad rows with 0 if necessary to make them 16 byte aligned. +template +void copy_with_padding(raft::resources const& res, + raft::device_matrix& dst, + mdspan, row_major, data_accessor> src, + rmm::mr::device_memory_resource* mr = nullptr) +{ + if (!mr) { mr = rmm::mr::get_current_device_resource(); } + size_t padded_dim = round_up_safe(src.extent(1) * sizeof(T), 16) / sizeof(T); + + if ((dst.extent(0) != src.extent(0)) || (static_cast(dst.extent(1)) != padded_dim)) { + // clear existing memory before allocating to prevent OOM errors on large datasets + if (dst.size()) { dst = make_device_matrix(res, 0, 0); } + dst = make_device_mdarray(res, mr, make_extents(src.extent(0), padded_dim)); + } + if (dst.extent(1) == src.extent(1)) { + raft::copy(dst.data_handle(), src.data_handle(), src.size(), resource::get_cuda_stream(res)); + } else { + // copy with padding + RAFT_CUDA_TRY(cudaMemsetAsync( + dst.data_handle(), 0, dst.size() * sizeof(T), resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), + sizeof(T) * dst.extent(1), + src.data_handle(), + sizeof(T) * src.extent(1), + sizeof(T) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + } +} } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/util/integer_utils.hpp b/cpp/include/raft/util/integer_utils.hpp index c14d4327c3..e9a386c426 100644 --- a/cpp/include/raft/util/integer_utils.hpp +++ b/cpp/include/raft/util/integer_utils.hpp @@ -36,7 +36,7 @@ namespace raft { * `modulus` is positive. */ template -inline S round_up_safe(S number_to_round, S modulus) +constexpr inline S round_up_safe(S number_to_round, S modulus) { auto remainder = number_to_round % modulus; if (remainder == 0) { return number_to_round; } diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index 4c95b9e520..dd74f030ad 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -53,12 +53,16 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of | `graph_degree` | `build_param` | N | Positive Integer >0 | 64 | Degree of the final kNN graph index. | | `intermediate_graph_degree` | `build_param` | N | Positive Integer >0 | 128 | Degree of the intermediate kNN graph. | | `graph_build_algo` | `build_param` | N | ["IVF_PQ", "NN_DESCENT"] | "IVF_PQ" | Algorithm to use for search | -| `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside? | +| `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside while constructing the index? | | `query_memory_type` | `search_params` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? | | `itopk` | `search_wdith` | N | Positive Integer >0 | 64 | Number of intermediate search results retained during the search. Higher values improve search accuracy at the cost of speed. | | `search_width` | `search_param` | N | Positive Integer >0 | 1 | Number of graph nodes to select as the starting point for the search in each iteration. | | `max_iterations` | `search_param` | N | Integer >=0 | 0 | Upper limit of search iterations. Auto select when 0. | | `algo` | `search_param` | N | string | "auto" | Algorithm to use for search. Possible values: {"auto", "single_cta", "multi_cta", "multi_kernel"} | +| `graph_memory_type` | `search_param` | N | string | "device" | Memory type to store gaph. Must be one of {"device", "host_pinned", "host_huge_page"}. | +| `internal_dataset_memory_type` | `search_param` | N | string | "device" | Memory type to store dataset in the index. Must be one of {"device", "host_pinned", "host_huge_page"}. | + +The `graph_memory_type` or `internal_dataset_memory_type` options can be useful for large datasets that do not fit the device memory. Setting `internal_dataset_memory_type` other than `device` has negative impact on search speed. Using `host_huge_page` option is only supported on systems with Heterogeneous Memory Management or on platforms that natively support GPU access to system allocated memory, for example Grace Hopper. To fine tune CAGRA index building we can customize IVF-PQ index builder options using the following settings. These take effect only if `graph_build_algo == "IVF_PQ"`. It is recommended to experiment using a separate IVF-PQ index to find the config that gives the largest QPS for large batch. Recall does not need to be very high, since CAGRA further optimizes the kNN neighbor graph. Some of the default values are derived from the dataset size which is assumed to be [n_vecs, dim]. @@ -76,6 +80,7 @@ To fine tune CAGRA index building we can customize IVF-PQ index builder options | `ivf_pq_search_refine_ratio` | `build_params` | N| Positive Number >=0 | 2 | `refine_ratio * k` nearest neighbors are queried from the index initially and an additional refinement step improves recall by selecting only the best `k` neighbors. | Alternatively, if `graph_build_algo == "NN_DESCENT"`, then we can customize the following parameters + | Parameter | Type | Required | Data Type | Default | Description | |-----------------------------|----------------|----------|----------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | `nn_descent_niter` | `build_param` | N | Positive Integer>0 | 20 | Number of NN Descent iterations. | @@ -170,4 +175,4 @@ Use FAISS IVF-PQ index on CPU | `ef` | `search_param` | Y | Positive Integer >0 | | Size of the dynamic list for the nearest neighbors used for search. Higher value leads to more accurate but slower search. Cannot be lower than `k`. | | `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. | -Please refer to [HNSW algorithm parameters guide] from `hnswlib` to learn more about these arguments. \ No newline at end of file +Please refer to [HNSW algorithm parameters guide](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md) from `hnswlib` to learn more about these arguments. \ No newline at end of file diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py index a33467b554..4611f39264 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py @@ -25,6 +25,21 @@ import yaml +log_levels = { + "off": 0, + "error": 1, + "warn": 2, + "info": 3, + "debug": 4, + "trace": 5, +} + + +def parse_log_level(level_str): + if level_str not in log_levels: + raise ValueError("Invalid log level: %s" % level_str) + return log_levels[level_str.lower()] + def positive_int(input_str: str) -> int: try: @@ -39,6 +54,53 @@ def positive_int(input_str: str) -> int: return i +def merge_build_files(build_dir, build_file, temp_build_file): + + build_dict = {} + + # If build file exists, read it + build_json_path = os.path.join(build_dir, build_file) + tmp_build_json_path = os.path.join(build_dir, temp_build_file) + if os.path.isfile(build_json_path): + try: + with open(build_json_path, "r") as f: + build_dict = json.load(f) + except Exception as e: + print( + "Error loading existing build file: %s (%s)" + % (build_json_path, e) + ) + + temp_build_dict = {} + if os.path.isfile(tmp_build_json_path): + with open(tmp_build_json_path, "r") as f: + temp_build_dict = json.load(f) + else: + raise ValueError("Temp build file not found: %s" % tmp_build_json_path) + + tmp_benchmarks = ( + temp_build_dict["benchmarks"] + if "benchmarks" in temp_build_dict + else {} + ) + benchmarks = build_dict["benchmarks"] if "benchmarks" in build_dict else {} + + # If the build time is absolute 0 then an error occurred + final_bench_dict = {} + for b in benchmarks: + if b["real_time"] > 0: + final_bench_dict[b["name"]] = b + + for tmp_bench in tmp_benchmarks: + if tmp_bench["real_time"] > 0: + final_bench_dict[tmp_bench["name"]] = tmp_bench + + temp_build_dict["benchmarks"] = [v for k, v in final_bench_dict.items()] + with open(build_json_path, "w") as f: + json_str = json.dumps(temp_build_dict, indent=2) + f.write(json_str) + + def validate_algorithm(algos_conf, algo, gpu_present): algos_conf_keys = set(algos_conf.keys()) if gpu_present: @@ -88,6 +150,7 @@ def run_build_and_search( batch_size, search_threads, mode="throughput", + raft_log_level="info", ): for executable, ann_executable_path, algo in executables_to_run.keys(): # Need to write temporary configuration @@ -109,6 +172,8 @@ def run_build_and_search( if build: build_folder = os.path.join(legacy_result_folder, "build") os.makedirs(build_folder, exist_ok=True) + build_file = f"{algo}.json" + temp_build_file = f"{build_file}.lock" cmd = [ ann_executable_path, "--build", @@ -116,10 +181,11 @@ def run_build_and_search( "--benchmark_out_format=json", "--benchmark_counters_tabular=true", "--benchmark_out=" - + f"{os.path.join(build_folder, f'{algo}.json')}", + + f"{os.path.join(build_folder, temp_build_file)}", + "--raft_log_level=" + f"{parse_log_level(raft_log_level)}", ] if force: - cmd = cmd + ["--overwrite"] + cmd = cmd + ["--force"] cmd = cmd + [temp_conf_filename] if dry_run: @@ -129,9 +195,13 @@ def run_build_and_search( else: try: subprocess.run(cmd, check=True) + merge_build_files( + build_folder, build_file, temp_build_file + ) except Exception as e: print("Error occurred running benchmark: %s" % e) finally: + os.remove(os.path.join(build_folder, temp_build_file)) if not search: os.remove(temp_conf_filename) @@ -150,9 +220,10 @@ def run_build_and_search( "--mode=%s" % mode, "--benchmark_out=" + f"{os.path.join(search_folder, f'{algo}.json')}", + "--raft_log_level=" + f"{parse_log_level(raft_log_level)}", ] if force: - cmd = cmd + ["--overwrite"] + cmd = cmd + ["--force"] if search_threads: cmd = cmd + ["--threads=%s" % search_threads] @@ -294,6 +365,15 @@ def main(): "the command.", action="store_true", ) + parser.add_argument( + "--raft-log-level", + help="Log level, possible values are " + "[off, error, warn, info, debug, trace]. " + "Default: 'info'. Note that 'debug' or more detailed " + "logging level requires that the library is compiled with " + "-DRAFT_ACTIVE_LEVEL= where >= ", + default="info", + ) if len(sys.argv) == 1: parser.print_help() @@ -511,6 +591,7 @@ def add_algo_group(group_list): batch_size, args.search_threads, mode, + args.raft_log_level, )