From 9ad76faeec798a45c359fb358c414f1d3a19eb3f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 28 Oct 2023 02:47:25 +0200 Subject: [PATCH] Adding `throughput` and `latency` modes to `raft-ann-bench` (#1920) Separating the way the benhcmarks are measured into `throughput` and `latency` modes. - `latency` mode accumulates the times for each batch to be processed and then estimates QPS and provides the average time spent doing processing on the GPU. For batch size of 1, this becomes a fairly estimate of average latency per query. For larger batches, it becomes a fairly accurate estimate of time spent per batch. - `throughput` mode pipelines the individual batches using a thread pool (and stream pool for the GPU algos). For both smaller and larger batches, this gives a good estimate of the amount of data we can push through the hardware in a period of time. A good comprehensive comparison will include both of these numbers. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1920 --- cpp/bench/ann/CMakeLists.txt | 6 +- cpp/bench/ann/src/common/ann_types.hpp | 17 +- cpp/bench/ann/src/common/benchmark.hpp | 296 ++++++++++++------ cpp/bench/ann/src/common/thread_pool.hpp | 2 + cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h | 29 +- cpp/bench/ann/src/raft/raft_benchmark.cu | 10 +- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 12 +- .../ann/src/raft/raft_ivf_flat_wrapper.h | 13 +- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 14 +- docs/source/raft_ann_benchmarks.md | 37 +++ .../src/raft-ann-bench/run/__main__.py | 13 + 11 files changed, 299 insertions(+), 150 deletions(-) diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 502f371a25..d6a5fddb98 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -106,10 +106,8 @@ if(RAFT_ANN_BENCH_USE_GGNN) endif() if(RAFT_ANN_BENCH_USE_FAISS) - # We need to ensure that faiss has all the conda - # information. So we currently use the very ugly - # hammer of `link_libraries` to ensure that all - # targets in this directory and the faiss directory + # We need to ensure that faiss has all the conda information. So we currently use the very ugly + # hammer of `link_libraries` to ensure that all targets in this directory and the faiss directory # will have the conda includes/link dirs link_libraries($) include(cmake/thirdparty/get_faiss.cmake) diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index 33716bd45a..2c1105a272 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -24,6 +24,11 @@ namespace raft::bench::ann { +enum Objective { + THROUGHPUT, // See how many vectors we can push through + LATENCY // See how fast we can push a vector through +}; + enum class MemoryType { Host, HostMmap, @@ -59,10 +64,17 @@ inline auto parse_memory_type(const std::string& memory_type) -> MemoryType } } -struct AlgoProperty { +class AlgoProperty { + public: + inline AlgoProperty() {} + inline AlgoProperty(MemoryType dataset_memory_type_, MemoryType query_memory_type_) + : dataset_memory_type(dataset_memory_type_), query_memory_type(query_memory_type_) + { + } MemoryType dataset_memory_type; // neighbors/distances should have same memory type as queries MemoryType query_memory_type; + virtual ~AlgoProperty() = default; }; class AnnBase { @@ -79,7 +91,8 @@ template class ANN : public AnnBase { public: struct AnnSearchParam { - virtual ~AnnSearchParam() = default; + Objective metric_objective = Objective::LATENCY; + virtual ~AnnSearchParam() = default; [[nodiscard]] virtual auto needs_dataset() const -> bool { return false; }; }; diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 4ec977700d..10a256bd63 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -36,6 +37,7 @@ namespace raft::bench::ann { static inline std::unique_ptr current_algo{nullptr}; +static inline std::shared_ptr current_algo_props{nullptr}; using kv_series = std::vector>>; @@ -153,7 +155,7 @@ void bench_build(::benchmark::State& state, } } state.counters.insert( - {{"GPU Time", gpu_timer.total_time() / state.iterations()}, {"index_size", index_size}}); + {{"GPU", gpu_timer.total_time() / state.iterations()}, {"index_size", index_size}}); if (state.skipped()) { return; } make_sure_parent_dir_exists(index.file); @@ -162,12 +164,19 @@ void bench_build(::benchmark::State& state, template void bench_search(::benchmark::State& state, - std::shared_ptr> dataset, Configuration::Index index, - std::size_t search_param_ix) + std::size_t search_param_ix, + std::shared_ptr> dataset, + Objective metric_objective) { + std::ptrdiff_t batch_offset = 0; + std::size_t queries_processed = 0; + + double total_time = 0; + const auto& sp_json = index.search_params[search_param_ix]; - dump_parameters(state, sp_json); + + if (state.thread_index() == 0) { dump_parameters(state, sp_json); } // NB: `k` and `n_queries` are guaranteed to be populated in conf.cpp const std::uint32_t k = sp_json["k"]; @@ -180,129 +189,168 @@ void bench_search(::benchmark::State& state, state.SkipWithError("Index file is missing. Run the benchmark in the build mode first."); return; } - // algo is static to cache it between close search runs to save time on index loading - static std::string index_file = ""; - if (index.file != index_file) { - current_algo.reset(); - index_file = index.file; - } - ANN* algo; - std::unique_ptr::AnnSearchParam> search_param; - try { - if (!current_algo || (algo = dynamic_cast*>(current_algo.get())) == nullptr) { - auto ualgo = ann::create_algo( - index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); - algo = ualgo.get(); - algo->load(index_file); - current_algo = std::move(ualgo); - } - search_param = ann::create_search_param(index.algo, sp_json); - } catch (const std::exception& e) { - return state.SkipWithError("Failed to create an algo: " + std::string(e.what())); - } - algo->set_search_param(*search_param); - const auto algo_property = parse_algo_property(algo->get_preference(), sp_json); - const T* query_set = dataset->query_set(algo_property.query_memory_type); - buf distances{algo_property.query_memory_type, k * query_set_size}; - buf neighbors{algo_property.query_memory_type, k * query_set_size}; + /** + * Make sure the first thread loads the algo and dataset + */ + if (state.thread_index() == 0) { + // algo is static to cache it between close search runs to save time on index loading + static std::string index_file = ""; + if (index.file != index_file) { + current_algo.reset(); + index_file = index.file; + } - if (search_param->needs_dataset()) { + std::unique_ptr::AnnSearchParam> search_param; + ANN* algo; try { - algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type), - dataset->base_set_size()); - } catch (const std::exception& ex) { - state.SkipWithError("The algorithm '" + index.name + - "' requires the base set, but it's not available. " + - "Exception: " + std::string(ex.what())); - return; + if (!current_algo || (algo = dynamic_cast*>(current_algo.get())) == nullptr) { + auto ualgo = ann::create_algo( + index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); + algo = ualgo.get(); + algo->load(index_file); + current_algo = std::move(ualgo); + } + search_param = ann::create_search_param(index.algo, sp_json); + search_param->metric_objective = metric_objective; + } catch (const std::exception& e) { + state.SkipWithError("Failed to create an algo: " + std::string(e.what())); + } + algo->set_search_param(*search_param); + auto algo_property = parse_algo_property(algo->get_preference(), sp_json); + current_algo_props = std::make_shared(algo_property.dataset_memory_type, + algo_property.query_memory_type); + if (search_param->needs_dataset()) { + try { + algo->set_search_dataset(dataset->base_set(current_algo_props->dataset_memory_type), + dataset->base_set_size()); + } catch (const std::exception& ex) { + state.SkipWithError("The algorithm '" + index.name + + "' requires the base set, but it's not available. " + + "Exception: " + std::string(ex.what())); + return; + } } } - std::ptrdiff_t batch_offset = 0; - std::size_t queries_processed = 0; + const auto algo_property = *current_algo_props; + const T* query_set = dataset->query_set(algo_property.query_memory_type); + + /** + * Each thread will manage its own outputs + */ + std::shared_ptr> distances = + std::make_shared>(algo_property.query_memory_type, k * query_set_size); + std::shared_ptr> neighbors = + std::make_shared>(algo_property.query_memory_type, k * query_set_size); + + auto start = std::chrono::high_resolution_clock::now(); cuda_timer gpu_timer; { nvtx_case nvtx{state.name()}; + + // TODO: Have the odd threads load the queries backwards just to rule out caching. + ANN* algo = dynamic_cast*>(current_algo.get()); for (auto _ : state) { - // measure the GPU time using the RAII helper [[maybe_unused]] auto ntx_lap = nvtx.lap(); [[maybe_unused]] auto gpu_lap = gpu_timer.lap(); + + auto start = std::chrono::high_resolution_clock::now(); // run the search try { algo->search(query_set + batch_offset * dataset->dim(), n_queries, k, - neighbors.data + batch_offset * k, - distances.data + batch_offset * k, + neighbors->data + batch_offset * k, + distances->data + batch_offset * k, gpu_timer.stream()); } catch (const std::exception& e) { state.SkipWithError(std::string(e.what())); } + + auto end = std::chrono::high_resolution_clock::now(); + + auto elapsed_seconds = std::chrono::duration_cast>(end - start); // advance to the next batch batch_offset = (batch_offset + n_queries) % query_set_size; queries_processed += n_queries; + state.SetIterationTime(elapsed_seconds.count()); + total_time += elapsed_seconds.count(); } } + auto end = std::chrono::high_resolution_clock::now(); + if (state.thread_index() == 0) { + auto duration = std::chrono::duration_cast>(end - start).count(); + state.counters.insert({{"end_to_end", duration}}); + } state.SetItemsProcessed(queries_processed); - state.counters.insert({{"k", k}, {"n_queries", n_queries}}); if (cudart.found()) { - state.counters.insert({{"GPU Time", gpu_timer.total_time() / state.iterations()}, - {"GPU QPS", queries_processed / gpu_timer.total_time()}}); + state.counters.insert({{"GPU", gpu_timer.total_time() / double(state.iterations())}}); } + + // This will be the total number of queries across all threads + state.counters.insert({{"total_queries", queries_processed}}); + if (state.skipped()) { return; } - // evaluate recall - if (dataset->max_k() >= k) { - const std::int32_t* gt = dataset->gt_set(); - const std::uint32_t max_k = dataset->max_k(); - buf neighbors_host = neighbors.move(MemoryType::Host); - - std::size_t rows = std::min(queries_processed, query_set_size); - std::size_t match_count = 0; - std::size_t total_count = rows * static_cast(k); - for (std::size_t i = 0; i < rows; i++) { - for (std::uint32_t j = 0; j < k; j++) { - auto act_idx = std::int32_t(neighbors_host.data[i * k + j]); - for (std::uint32_t l = 0; l < k; l++) { - auto exp_idx = gt[i * max_k + l]; - if (act_idx == exp_idx) { - match_count++; - break; + // Use the last thread as a sanity check that all the threads are working. + if (state.thread_index() == state.threads() - 1) { + // evaluate recall + if (dataset->max_k() >= k) { + const std::int32_t* gt = dataset->gt_set(); + const std::uint32_t max_k = dataset->max_k(); + buf neighbors_host = neighbors->move(MemoryType::Host); + std::size_t rows = std::min(queries_processed, query_set_size); + std::size_t match_count = 0; + std::size_t total_count = rows * static_cast(k); + for (std::size_t i = 0; i < rows; i++) { + for (std::uint32_t j = 0; j < k; j++) { + auto act_idx = std::int32_t(neighbors_host.data[i * k + j]); + for (std::uint32_t l = 0; l < k; l++) { + auto exp_idx = gt[i * max_k + l]; + if (act_idx == exp_idx) { + match_count++; + break; + } } } } + double actual_recall = static_cast(match_count) / static_cast(total_count); + state.counters.insert({{"Recall", actual_recall}}); } - double actual_recall = static_cast(match_count) / static_cast(total_count); - state.counters.insert({{"Recall", actual_recall}}); } } inline void printf_usage() { ::benchmark::PrintDefaultHelp(); - fprintf( - stdout, - " [--build|--search] \n" - " [--overwrite]\n" - " [--data_prefix=]\n" - " [--index_prefix=]\n" - " [--override_kv=]\n" - " .json\n" - "\n" - "Note the non-standard benchmark parameters:\n" - " --build: build mode, will build index\n" - " --search: search mode, will search using the built index\n" - " one and only one of --build and --search should be specified\n" - " --overwrite: force overwriting existing index files\n" - " --data_prefix=:" - " prepend to dataset file paths specified in the .json (default = 'data/').\n" - " --index_prefix=:" - " prepend to index file paths specified in the .json (default = 'index/').\n" - " --override_kv=:" - " override a build/search key one or more times multiplying the number of configurations;" - " you can use this parameter multiple times to get the Cartesian product of benchmark" - " configs.\n"); + fprintf(stdout, + " [--build|--search] \n" + " [--overwrite]\n" + " [--data_prefix=]\n" + " [--index_prefix=]\n" + " [--override_kv=]\n" + " [--mode=\n" + " .json\n" + "\n" + "Note the non-standard benchmark parameters:\n" + " --build: build mode, will build index\n" + " --search: search mode, will search using the built index\n" + " one and only one of --build and --search should be specified\n" + " --overwrite: force overwriting existing index files\n" + " --data_prefix=:" + " prepend to dataset file paths specified in the .json (default = " + "'data/').\n" + " --index_prefix=:" + " prepend to index file paths specified in the .json (default = " + "'index/').\n" + " --override_kv=:" + " override a build/search key one or more times multiplying the number of configurations;" + " you can use this parameter multiple times to get the Cartesian product of benchmark" + " configs.\n" + " --mode=" + " run the benchmarks in latency (accumulate times spent in each batch) or " + " throughput (pipeline batches and measure end-to-end) mode\n"); } template @@ -319,22 +367,41 @@ void register_build(std::shared_ptr> dataset, auto* b = ::benchmark::RegisterBenchmark( index.name + suf, bench_build, dataset, index, force_overwrite); b->Unit(benchmark::kSecond); + b->MeasureProcessCPUTime(); b->UseRealTime(); } } template void register_search(std::shared_ptr> dataset, - std::vector indices) + std::vector indices, + Objective metric_objective) { for (auto index : indices) { for (std::size_t i = 0; i < index.search_params.size(); i++) { auto suf = static_cast(index.search_params[i]["override_suffix"]); index.search_params[i].erase("override_suffix"); - auto* b = - ::benchmark::RegisterBenchmark(index.name + suf, bench_search, dataset, index, i); - b->Unit(benchmark::kMillisecond); - b->UseRealTime(); + + int max_threads = + metric_objective == Objective::THROUGHPUT ? std::thread::hardware_concurrency() : 1; + + auto* b = ::benchmark::RegisterBenchmark( + index.name + suf, bench_search, index, i, dataset, metric_objective) + ->Unit(benchmark::kMillisecond) + ->ThreadRange(1, max_threads) + + /** + * The following are important for getting accuracy QPS measurements on both CPU + * and GPU These make sure that + * - `end_to_end` ~ (`Time` * `Iterations`) + * - `items_per_second` ~ (`total_queries` / `end_to_end`) + * - `Time` = `end_to_end` / `Iterations` + * + * - Latency = `Time` + * - Throughput = `items_per_second` + */ + ->MeasureProcessCPUTime() + ->UseRealTime(); } } } @@ -346,7 +413,8 @@ void dispatch_benchmark(const Configuration& conf, bool search_mode, std::string data_prefix, std::string index_prefix, - kv_series override_kv) + kv_series override_kv, + Objective metric_objective) { if (cudart.found()) { for (auto [key, value] : cuda_info()) { @@ -403,7 +471,8 @@ void dispatch_benchmark(const Configuration& conf, } else { log_warn( "Ground truth file is not provided; the recall won't be reported. NB: use " - "the 'groundtruth_neighbors_file' alongside the 'query_file' key to specify the path to " + "the 'groundtruth_neighbors_file' alongside the 'query_file' key to specify the " + "path to " "the ground truth in your conf.json."); } } else { @@ -414,7 +483,7 @@ void dispatch_benchmark(const Configuration& conf, index.search_params = apply_overrides(index.search_params, override_kv); index.file = combine_path(index_prefix, index.file); } - register_search(dataset, indices); + register_search(dataset, indices, metric_objective); } } @@ -445,6 +514,7 @@ inline auto run_main(int argc, char** argv) -> int std::string data_prefix = "data"; std::string index_prefix = "index"; std::string new_override_kv = ""; + std::string mode = "latency"; kv_series override_kv{}; char arg0_default[] = "benchmark"; // NOLINT @@ -467,6 +537,7 @@ inline auto run_main(int argc, char** argv) -> int parse_bool_flag(argv[i], "--search", search_mode) || parse_string_flag(argv[i], "--data_prefix", data_prefix) || parse_string_flag(argv[i], "--index_prefix", index_prefix) || + parse_string_flag(argv[i], "--mode", mode) || parse_string_flag(argv[i], "--override_kv", new_override_kv)) { if (!new_override_kv.empty()) { auto kvv = split(new_override_kv, ':'); @@ -486,6 +557,9 @@ inline auto run_main(int argc, char** argv) -> int } } + Objective metric_objective = Objective::LATENCY; + if (mode == "throughput") { metric_objective = Objective::THROUGHPUT; } + if (build_mode == search_mode) { log_error("One and only one of --build and --search should be specified"); printf_usage(); @@ -505,14 +579,32 @@ inline auto run_main(int argc, char** argv) -> int std::string dtype = conf.get_dataset_conf().dtype; if (dtype == "float") { - dispatch_benchmark( - conf, force_overwrite, build_mode, search_mode, data_prefix, index_prefix, override_kv); + dispatch_benchmark(conf, + force_overwrite, + build_mode, + search_mode, + data_prefix, + index_prefix, + override_kv, + metric_objective); } else if (dtype == "uint8") { - dispatch_benchmark( - conf, force_overwrite, build_mode, search_mode, data_prefix, index_prefix, override_kv); + dispatch_benchmark(conf, + force_overwrite, + build_mode, + search_mode, + data_prefix, + index_prefix, + override_kv, + metric_objective); } else if (dtype == "int8") { - dispatch_benchmark( - conf, force_overwrite, build_mode, search_mode, data_prefix, index_prefix, override_kv); + dispatch_benchmark(conf, + force_overwrite, + build_mode, + search_mode, + data_prefix, + index_prefix, + override_kv, + metric_objective); } else { log_error("datatype '%s' is not supported", dtype.c_str()); return -1; @@ -522,8 +614,8 @@ inline auto run_main(int argc, char** argv) -> int if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return -1; ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); - // Release a possibly cached ANN object, so that it cannot be alive longer than the handle to a - // shared library it depends on (dynamic benchmark executable). + // Release a possibly cached ANN object, so that it cannot be alive longer than the handle + // to a shared library it depends on (dynamic benchmark executable). current_algo.reset(); return 0; } diff --git a/cpp/bench/ann/src/common/thread_pool.hpp b/cpp/bench/ann/src/common/thread_pool.hpp index efea938d5b..c01fa2c32c 100644 --- a/cpp/bench/ann/src/common/thread_pool.hpp +++ b/cpp/bench/ann/src/common/thread_pool.hpp @@ -72,6 +72,7 @@ class FixedThreadPool { template void submit(Func f, IdxT len) { + // Run functions in main thread if thread pool has no threads if (threads_.empty()) { for (IdxT i = 0; i < len; ++i) { f(i); @@ -84,6 +85,7 @@ class FixedThreadPool { const IdxT items_per_thread = len / (num_threads + 1); std::atomic cnt(items_per_thread * num_threads); + // Wrap function auto wrapped_f = [&](IdxT start, IdxT end) { for (IdxT i = start; i < end; ++i) { f(i); diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index df44605493..23cae6352c 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -65,7 +65,7 @@ class HnswLib : public ANN { using typename ANN::AnnSearchParam; struct SearchParam : public AnnSearchParam { int ef; - int num_threads = omp_get_num_procs(); + int num_threads = 1; }; HnswLib(Metric metric, int dim, const BuildParam& param); @@ -103,6 +103,7 @@ class HnswLib : public ANN { int m_; int num_threads_; std::unique_ptr thread_pool_; + Objective metric_objective_; }; template @@ -159,10 +160,13 @@ void HnswLib::build(const T* dataset, size_t nrow, cudaStream_t) template void HnswLib::set_search_param(const AnnSearchParam& param_) { - auto param = dynamic_cast(param_); - appr_alg_->ef_ = param.ef; + auto param = dynamic_cast(param_); + appr_alg_->ef_ = param.ef; + metric_objective_ = param.metric_objective; - if (!thread_pool_ || num_threads_ != param.num_threads) { + bool use_pool = (metric_objective_ == Objective::LATENCY && param.num_threads > 1) && + (!thread_pool_ || num_threads_ != param.num_threads); + if (use_pool) { num_threads_ = param.num_threads; thread_pool_ = std::make_unique(num_threads_); } @@ -172,12 +176,17 @@ template void HnswLib::search( const T* query, int batch_size, int k, size_t* indices, float* distances, cudaStream_t) const { - thread_pool_->submit( - [&](int i) { - // hnsw can only handle a single vector at a time. - get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k); - }, - batch_size); + auto f = [&](int i) { + // hnsw can only handle a single vector at a time. + get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k); + }; + if (metric_objective_ == Objective::LATENCY) { + thread_pool_->submit(f, batch_size); + } else { + for (int i = 0; i < batch_size; i++) { + f(i); + } + } } template diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index fa20c5c223..3b9bcc7e15 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -272,5 +272,13 @@ REGISTER_ALGO_INSTANCE(std::uint8_t); #ifdef ANN_BENCH_BUILD_MAIN #include "../common/benchmark.hpp" -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } +int main(int argc, char** argv) +{ + rmm::mr::cuda_memory_resource cuda_mr; + // Construct a resource that uses a coalescing best-fit pool allocator + rmm::mr::pool_memory_resource pool_mr{&cuda_mr}; + rmm::mr::set_current_device_resource( + &pool_mr); // Updates the current device resource pointer to `pool_mr` + return raft::bench::ann::run_main(argc, argv); +} #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 19c5151186..f1c8154b7c 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -53,17 +53,13 @@ class RaftCagra : public ANN { using BuildParam = raft::neighbors::cagra::index_params; RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) - : ANN(metric, dim), - index_params_(param), - dimension_(dim), - mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) + : ANN(metric, dim), index_params_(param), dimension_(dim), handle_(cudaStreamPerThread) { - rmm::mr::set_current_device_resource(&mr_); index_params_.metric = parse_metric_type(metric); RAFT_CUDA_TRY(cudaGetDevice(&device_)); } - ~RaftCagra() noexcept { rmm::mr::set_current_device_resource(mr_.get_upstream()); } + ~RaftCagra() noexcept {} void build(const T* dataset, size_t nrow, cudaStream_t stream) final; @@ -92,8 +88,6 @@ class RaftCagra : public ANN { void load(const std::string&) override; private: - // `mr_` must go first to make sure it dies last - rmm::mr::pool_memory_resource mr_; raft::device_resources handle_; BuildParam index_params_; raft::neighbors::cagra::search_params search_params_; @@ -170,7 +164,7 @@ void RaftCagra::search( neighbors_IdxT, batch_size * k, raft::cast_op(), - resource::get_cuda_stream(handle_)); + raft::resource::get_cuda_stream(handle_)); } handle_.sync_stream(); diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index b6df7de068..24b3c69bb6 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -52,18 +52,14 @@ class RaftIvfFlatGpu : public ANN { using BuildParam = raft::neighbors::ivf_flat::index_params; RaftIvfFlatGpu(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim), - index_params_(param), - dimension_(dim), - mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) + : ANN(metric, dim), index_params_(param), dimension_(dim) { index_params_.metric = parse_metric_type(metric); index_params_.conservative_memory_allocation = true; - rmm::mr::set_current_device_resource(&mr_); RAFT_CUDA_TRY(cudaGetDevice(&device_)); } - ~RaftIvfFlatGpu() noexcept { rmm::mr::set_current_device_resource(mr_.get_upstream()); } + ~RaftIvfFlatGpu() noexcept {} void build(const T* dataset, size_t nrow, cudaStream_t stream) final; @@ -90,8 +86,6 @@ class RaftIvfFlatGpu : public ANN { void load(const std::string&) override; private: - // `mr_` must go first to make sure it dies last - rmm::mr::pool_memory_resource mr_; raft::device_resources handle_; BuildParam index_params_; raft::neighbors::ivf_flat::search_params search_params_; @@ -134,10 +128,9 @@ template void RaftIvfFlatGpu::search( const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const { - rmm::mr::device_memory_resource* mr_ptr = &const_cast(this)->mr_; static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); raft::neighbors::ivf_flat::search( - handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances, mr_ptr); + handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances); resource::sync_stream(handle_); return; } diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 1b74dcf975..e4004b0007 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -54,22 +54,14 @@ class RaftIvfPQ : public ANN { using BuildParam = raft::neighbors::ivf_pq::index_params; RaftIvfPQ(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim), - index_params_(param), - dimension_(dim), - mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) + : ANN(metric, dim), index_params_(param), dimension_(dim) { - rmm::mr::set_current_device_resource(&mr_); index_params_.metric = parse_metric_type(metric); RAFT_CUDA_TRY(cudaGetDevice(&device_)); RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); } - ~RaftIvfPQ() noexcept - { - RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); - rmm::mr::set_current_device_resource(mr_.get_upstream()); - } + ~RaftIvfPQ() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); } void build(const T* dataset, size_t nrow, cudaStream_t stream) final; @@ -97,8 +89,6 @@ class RaftIvfPQ : public ANN { void load(const std::string&) override; private: - // `mr_` must go first to make sure it dies last - rmm::mr::pool_memory_resource mr_; raft::device_resources handle_; cudaEvent_t sync_{nullptr}; BuildParam index_params_; diff --git a/docs/source/raft_ann_benchmarks.md b/docs/source/raft_ann_benchmarks.md index fadca595fb..6a436a7213 100644 --- a/docs/source/raft_ann_benchmarks.md +++ b/docs/source/raft_ann_benchmarks.md @@ -16,6 +16,7 @@ This project provides a benchmark program for various ANN search implementations - [End to end: small-scale (<1M to 10M)](#end-to-end-small-scale-benchmarks-1m-to-10m) - [End to end: large-scale (>10M)](#end-to-end-large-scale-benchmarks-10m-vectors) - [Running with Docker containers](#running-with-docker-containers) + - [Evaluating the results](#evaluating-the-results) - [Creating and customizing dataset configurations](#creating-and-customizing-dataset-configurations) - [Adding a new ANN algorithm](#adding-a-new-ann-algorithm) - [Parameter tuning guide](https://docs.rapids.ai/api/raft/nightly/ann_benchmarks_param_tuning/) @@ -141,6 +142,10 @@ options: run only comma separated list of named algorithms (default: None) --indices INDICES run only comma separated list of named indices. parameter `algorithms` is ignored (default: None) -f, --force re-run algorithms even if their results already exist (default: False) + -m MODE, --search-mode MODE + run search in 'latency' (measure individual batches) or + 'throughput' (pipeline batches and measure end-to-end) mode. + (default: 'latency') ``` `configuration` and `dataset` : `configuration` is a path to a configuration file for a given dataset. @@ -355,6 +360,38 @@ This will drop you into a command line in the container, with the `raft-ann-benc Additionally, the containers can be run in detached mode without any issue. + +### Evaluating the results + +The benchmarks capture several different measurements. The table below describes each of the measurements for index build benchmarks: + +| Name | Description | +|------------|--------------------------------------------------------| +| Benchmark | A name that uniquely identifies the benchmark instance | +| Time | Wall-time spent training the index | +| CPU | CPU time spent training the index | +| Iterations | Number of iterations (this is usually 1) | +| GPU | GPU time spent building | +| index_size | Number of vectors used to train index | + + +The table below describes each of the measurements for the index search benchmarks: + +| Name | Description | +|------|-------------------------------------------------------------------------------------------------------------------------------------------------------| +| Benchmark | A name that uniquely identifies the benchmark instance | +| Time | The average runtime for each batch. This is approximately `end_to_end` / `Iterations` | +| CPU | The average `wall-time`. In `throughput` mode, this is the average `wall-time` spent in each thread. | +| Iterations | Total number of batches. This is going to be `total_queres` / `n_queries` | +| Recall | Proportion of correct neighbors to ground truth neighbors. Note this column is only present if groundtruth file is specified in dataset configuration | +| items_per_second | Total throughput. This is approximately `total_queries` / `end_to_end`. | +| k | Number of neighbors being queried in each iteration | +| end_to_end | Total time taken to run all batches for all iterations | +| n_queries | Total number of query vectors in each batch | +| total_queries | Total number of vectors queries across all iterations | + +Note that the actual table displayed on the screen may differ slightly as the hyper-parameters will also be displayed for each different combination being benchmarked. + ## Creating and customizing dataset configurations A single configuration file will often define a set of algorithms, with associated index and search parameters, for a specific dataset. A configuration file uses json format with 4 major parts: diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py index a0d4fabb77..30d642f3ac 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py @@ -78,6 +78,7 @@ def run_build_and_search( search, k, batch_size, + mode="throughput", ): for executable, ann_executable_path, algo in executables_to_run.keys(): # Need to write temporary configuration @@ -104,6 +105,7 @@ def run_build_and_search( "--build", "--data_prefix=" + dataset_path, "--benchmark_out_format=json", + "--benchmark_counters_tabular=true", "--benchmark_out=" + f"{os.path.join(build_folder, f'{algo}.json')}", ] @@ -126,6 +128,7 @@ def run_build_and_search( "--benchmark_out_format=json", "--benchmark_out=" + f"{os.path.join(search_folder, f'{algo}.json')}", + "--mode=%s" % mode, ] if force: cmd = cmd + ["--overwrite"] @@ -211,6 +214,14 @@ def main(): action="store_true", ) + parser.add_argument( + "-m", + "--search-mode", + help="run search in 'latency' (measure individual batches) or " + "'throughput' (pipeline batches and measure end-to-end) mode", + default="throughput", + ) + args = parser.parse_args() # If both build and search are not provided, @@ -222,6 +233,7 @@ def main(): build = args.build search = args.search + mode = args.search_mode k = args.count batch_size = args.batch_size @@ -316,6 +328,7 @@ def main(): search, k, batch_size, + mode, )