From d9a7290b60d1037a7fbc00b4b6e5c371b8b86ca8 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Wed, 13 Dec 2023 17:30:36 +0100 Subject: [PATCH] Fix ann-bench multithreading (#2021) In the current state, ann-benchmarks running in the `--throughput` mode (multi-threaded) share ANN wrappers among CPU threads. This is not thread-safe and may result in incorrectly measured time (e.g. sharing cuda events among CPU threads) or various exceptions and segfaults (e.g. doing state-changing cublas calls from multiple CPU threads). This PR makes the search benchmarks copy ANN wrappers in each thread. The copies of the wrappers then selectively: - share thread-safe resources (e.g. rmm memory pool) and large objects that are not expected to change during search (e.g. index data); - duplicate the resources that are not thread-safe or carry the thread-specific state (e.g. cublas handles, CUDA events and streams). Alongside, the PR adds a few small changes, including: - enables ann-bench NVTX annotations for the non-common-executable mode (shows benchmark labels and iterations in nsys timeline); - fixes compile errors for the common-executable mode. Authors: - Artem M. Chirkin (https://github.com/achirkin) - William Hicks (https://github.com/wphicks) Approvers: - William Hicks (https://github.com/wphicks) - Mark Harris (https://github.com/harrism) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2021 --- cpp/bench/ann/CMakeLists.txt | 45 ++++++--- cpp/bench/ann/src/common/ann_types.hpp | 15 ++- cpp/bench/ann/src/common/benchmark.hpp | 16 ++- cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h | 50 +++++++--- cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h | 50 +++++++--- cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh | 22 +++-- cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h | 27 ++--- cpp/bench/ann/src/raft/raft_ann_bench_utils.h | 99 +++++++++++++++++++ cpp/bench/ann/src/raft/raft_benchmark.cu | 12 +-- .../ann/src/raft/raft_cagra_hnswlib_wrapper.h | 55 +++++------ cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 98 ++++++++++-------- .../ann/src/raft/raft_ivf_flat_wrapper.h | 36 ++++--- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 45 ++++----- 13 files changed, 364 insertions(+), 206 deletions(-) diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 5919de07e7..c144d1399e 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -116,6 +116,21 @@ if(RAFT_ANN_BENCH_USE_FAISS) include(cmake/thirdparty/get_faiss.cmake) endif() +# ################################################################################################## +# * Enable NVTX if available + +# Note: ANN_BENCH wrappers have extra NVTX code not related to raft::nvtx.They track gbench +# benchmark cases and iterations. This is to make limited NVTX available to all algos, not just +# raft. +if(TARGET CUDA::nvtx3) + set(_CMAKE_REQUIRED_INCLUDES_ORIG ${CMAKE_REQUIRED_INCLUDES}) + get_target_property(CMAKE_REQUIRED_INCLUDES CUDA::nvtx3 INTERFACE_INCLUDE_DIRECTORIES) + unset(NVTX3_HEADERS_FOUND CACHE) + # Check the headers explicitly to make sure the cpu-only build succeeds + CHECK_INCLUDE_FILE_CXX(nvtx3/nvToolsExt.h NVTX3_HEADERS_FOUND) + set(CMAKE_REQUIRED_INCLUDES ${_CMAKE_REQUIRED_INCLUDES_ORIG}) +endif() + # ################################################################################################## # * Configure tests function------------------------------------------------------------- @@ -141,8 +156,13 @@ function(ConfigureAnnBench) add_dependencies(${BENCH_NAME} ANN_BENCH) else() add_executable(${BENCH_NAME} ${ConfigureAnnBench_PATH}) - target_compile_definitions(${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN) - target_link_libraries(${BENCH_NAME} PRIVATE benchmark::benchmark) + target_compile_definitions( + ${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN + $<$:ANN_BENCH_NVTX3_HEADERS_FOUND> + ) + target_link_libraries( + ${BENCH_NAME} PRIVATE benchmark::benchmark $<$:CUDA::nvtx3> + ) endif() target_link_libraries( @@ -340,8 +360,16 @@ if(RAFT_ANN_BENCH_SINGLE_EXE) target_include_directories(ANN_BENCH PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_link_libraries( - ANN_BENCH PRIVATE nlohmann_json::nlohmann_json benchmark_static dl -static-libgcc - -static-libstdc++ CUDA::nvtx3 + ANN_BENCH + PRIVATE raft::raft + nlohmann_json::nlohmann_json + benchmark_static + dl + -static-libgcc + fmt::fmt-header-only + spdlog::spdlog_header_only + -static-libstdc++ + $<$:CUDA::nvtx3> ) set_target_properties( ANN_BENCH @@ -355,17 +383,10 @@ if(RAFT_ANN_BENCH_SINGLE_EXE) BUILD_RPATH "\$ORIGIN" INSTALL_RPATH "\$ORIGIN" ) - - # Disable NVTX when the nvtx3 headers are missing - set(_CMAKE_REQUIRED_INCLUDES_ORIG ${CMAKE_REQUIRED_INCLUDES}) - get_target_property(CMAKE_REQUIRED_INCLUDES ANN_BENCH INCLUDE_DIRECTORIES) - CHECK_INCLUDE_FILE_CXX(nvtx3/nvToolsExt.h NVTX3_HEADERS_FOUND) - set(CMAKE_REQUIRED_INCLUDES ${_CMAKE_REQUIRED_INCLUDES_ORIG}) target_compile_definitions( ANN_BENCH PRIVATE - $<$:ANN_BENCH_LINK_CUDART="libcudart.so.${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}.${CUDAToolkit_VERSION_PATCH} - "> + $<$:ANN_BENCH_LINK_CUDART="libcudart.so.${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}.${CUDAToolkit_VERSION_PATCH}"> $<$:ANN_BENCH_NVTX3_HEADERS_FOUND> ) diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index e964a81efa..9b77c9df91 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -18,6 +18,7 @@ #include "cuda_stub.hpp" // cudaStream_t +#include #include #include #include @@ -64,17 +65,10 @@ inline auto parse_memory_type(const std::string& memory_type) -> MemoryType } } -class AlgoProperty { - public: - inline AlgoProperty() {} - inline AlgoProperty(MemoryType dataset_memory_type_, MemoryType query_memory_type_) - : dataset_memory_type(dataset_memory_type_), query_memory_type(query_memory_type_) - { - } +struct AlgoProperty { MemoryType dataset_memory_type; // neighbors/distances should have same memory type as queries MemoryType query_memory_type; - virtual ~AlgoProperty() = default; }; class AnnBase { @@ -125,6 +119,11 @@ class ANN : public AnnBase { // The client code should call set_search_dataset() before searching, // and should not release dataset before searching is finished. virtual void set_search_dataset(const T* /*dataset*/, size_t /*nrow*/){}; + + /** + * Make a shallow copy of the ANN wrapper that shares the resources and ensures thread-safe access + * to them. */ + virtual auto copy() -> std::unique_ptr> = 0; }; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index a2e77323c1..e61de6745e 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -45,7 +45,7 @@ std::condition_variable cond_var; std::atomic_int processed_threads{0}; static inline std::unique_ptr current_algo{nullptr}; -static inline std::shared_ptr current_algo_props{nullptr}; +static inline std::unique_ptr current_algo_props{nullptr}; using kv_series = std::vector>>; @@ -241,9 +241,8 @@ void bench_search(::benchmark::State& state, return; } - auto algo_property = parse_algo_property(algo->get_preference(), sp_json); - current_algo_props = std::make_shared(algo_property.dataset_memory_type, - algo_property.query_memory_type); + current_algo_props = std::make_unique( + std::move(parse_algo_property(algo->get_preference(), sp_json))); if (search_param->needs_dataset()) { try { @@ -277,23 +276,22 @@ void bench_search(::benchmark::State& state, // We are accessing shared variables (like current_algo, current_algo_probs) before the // benchmark loop, therefore the synchronization here is necessary. } - const auto algo_property = *current_algo_props; - query_set = dataset->query_set(algo_property.query_memory_type); + query_set = dataset->query_set(current_algo_props->query_memory_type); /** * Each thread will manage its own outputs */ std::shared_ptr> distances = - std::make_shared>(algo_property.query_memory_type, k * query_set_size); + std::make_shared>(current_algo_props->query_memory_type, k * query_set_size); std::shared_ptr> neighbors = - std::make_shared>(algo_property.query_memory_type, k * query_set_size); + std::make_shared>(current_algo_props->query_memory_type, k * query_set_size); cuda_timer gpu_timer; auto start = std::chrono::high_resolution_clock::now(); { nvtx_case nvtx{state.name()}; - ANN* algo = dynamic_cast*>(current_algo.get()); + auto algo = dynamic_cast*>(current_algo.get())->copy(); for (auto _ : state) { [[maybe_unused]] auto ntx_lap = nvtx.lap(); [[maybe_unused]] auto gpu_lap = gpu_timer.lap(); diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h index 755fe9f197..3cc4e10b49 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h @@ -73,8 +73,6 @@ class FaissCpu : public ANN { static_assert(std::is_same_v, "faiss support only float type"); } - virtual ~FaissCpu() noexcept {} - void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) final; void set_search_param(const AnnSearchParam& param) override; @@ -82,9 +80,9 @@ class FaissCpu : public ANN { void init_quantizer(int dim) { if (this->metric_type_ == faiss::MetricType::METRIC_L2) { - this->quantizer_ = std::make_unique(dim); + this->quantizer_ = std::make_shared(dim); } else if (this->metric_type_ == faiss::MetricType::METRIC_INNER_PRODUCT) { - this->quantizer_ = std::make_unique(dim); + this->quantizer_ = std::make_shared(dim); } } @@ -113,15 +111,15 @@ class FaissCpu : public ANN { template void load_(const std::string& file); - std::unique_ptr index_; - std::unique_ptr quantizer_; - std::unique_ptr index_refine_; + std::shared_ptr index_; + std::shared_ptr quantizer_; + std::shared_ptr index_refine_; faiss::MetricType metric_type_; int nlist_; double training_sample_fraction_; int num_threads_; - std::unique_ptr thread_pool_; + std::shared_ptr thread_pool_; }; template @@ -152,7 +150,7 @@ void FaissCpu::build(const T* dataset, size_t nrow, cudaStream_t stream) index_->train(nrow, dataset); // faiss::IndexFlat::train() will do nothing assert(index_->is_trained); index_->add(nrow, dataset); - index_refine_ = std::make_unique(this->index_.get(), dataset); + index_refine_ = std::make_shared(this->index_.get(), dataset); } template @@ -169,7 +167,7 @@ void FaissCpu::set_search_param(const AnnSearchParam& param) if (!thread_pool_ || num_threads_ != search_param.num_threads) { num_threads_ = search_param.num_threads; - thread_pool_ = std::make_unique(num_threads_); + thread_pool_ = std::make_shared(num_threads_); } } @@ -203,7 +201,7 @@ template template void FaissCpu::load_(const std::string& file) { - index_ = std::unique_ptr(dynamic_cast(faiss::read_index(file.c_str()))); + index_ = std::shared_ptr(dynamic_cast(faiss::read_index(file.c_str()))); } template @@ -214,7 +212,7 @@ class FaissCpuIVFFlat : public FaissCpu { FaissCpuIVFFlat(Metric metric, int dim, const BuildParam& param) : FaissCpu(metric, dim, param) { this->init_quantizer(dim); - this->index_ = std::make_unique( + this->index_ = std::make_shared( this->quantizer_.get(), dim, param.nlist, this->metric_type_); } @@ -223,6 +221,11 @@ class FaissCpuIVFFlat : public FaissCpu { this->template save_(file); } void load(const std::string& file) override { this->template load_(file); } + + std::unique_ptr> copy() + { + return std::make_unique>(*this); // use copy constructor + } }; template @@ -237,7 +240,7 @@ class FaissCpuIVFPQ : public FaissCpu { FaissCpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissCpu(metric, dim, param) { this->init_quantizer(dim); - this->index_ = std::make_unique( + this->index_ = std::make_shared( this->quantizer_.get(), dim, param.nlist, param.M, param.bitsPerCode, this->metric_type_); } @@ -246,6 +249,11 @@ class FaissCpuIVFPQ : public FaissCpu { this->template save_(file); } void load(const std::string& file) override { this->template load_(file); } + + std::unique_ptr> copy() + { + return std::make_unique>(*this); // use copy constructor + } }; // TODO: Enable this in cmake @@ -270,7 +278,7 @@ class FaissCpuIVFSQ : public FaissCpu { } this->init_quantizer(dim); - this->index_ = std::make_unique( + this->index_ = std::make_shared( this->quantizer_.get(), dim, param.nlist, qtype, this->metric_type_, true); } @@ -282,6 +290,11 @@ class FaissCpuIVFSQ : public FaissCpu { { this->template load_(file); } + + std::unique_ptr> copy() + { + return std::make_unique>(*this); // use copy constructor + } }; template @@ -290,7 +303,7 @@ class FaissCpuFlat : public FaissCpu { FaissCpuFlat(Metric metric, int dim) : FaissCpu(metric, dim, typename FaissCpu::BuildParam{}) { - this->index_ = std::make_unique(dim, this->metric_type_); + this->index_ = std::make_shared(dim, this->metric_type_); } // class FaissCpu is more like a IVF class, so need special treating here @@ -299,7 +312,7 @@ class FaissCpuFlat : public FaissCpu { auto search_param = dynamic_cast::SearchParam&>(param); if (!this->thread_pool_ || this->num_threads_ != search_param.num_threads) { this->num_threads_ = search_param.num_threads; - this->thread_pool_ = std::make_unique(this->num_threads_); + this->thread_pool_ = std::make_shared(this->num_threads_); } }; @@ -308,6 +321,11 @@ class FaissCpuFlat : public FaissCpu { this->template save_(file); } void load(const std::string& file) override { this->template load_(file); } + + std::unique_ptr> copy() + { + return std::make_unique>(*this); // use copy constructor + } }; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h index 4f13ff8a49..ad51dd4e68 100644 --- a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h @@ -105,7 +105,6 @@ class FaissGpu : public ANN { RAFT_CUDA_TRY(cudaGetDevice(&device_)); RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); faiss_default_stream_ = gpu_resource_.getDefaultStream(device_); - raft::resource::set_cuda_stream(handle_, faiss_default_stream_); } virtual ~FaissGpu() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); } @@ -147,18 +146,33 @@ class FaissGpu : public ANN { RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_)); } + /** [NOTE Multithreading] + * + * `gpu_resource_` is a shared resource: + * 1. It uses a shared_ptr under the hood, so the copies of it refer to the same + * resource implementation instance + * 2. GpuIndex is probably keeping a reference to it, as it's passed to the constructor + * + * To avoid copying the index (database) in each thread, we make both the index and + * the gpu_resource shared. + * This means faiss GPU streams are possibly shared among the CPU threads; + * the throughput search mode may be inaccurate. + * + * WARNING: we haven't investigated whether faiss::gpu::GpuIndex or + * faiss::gpu::StandardGpuResources are thread-safe. + * + */ mutable faiss::gpu::StandardGpuResources gpu_resource_; - std::unique_ptr index_; - std::unique_ptr index_refine_{nullptr}; + std::shared_ptr index_; + std::shared_ptr index_refine_{nullptr}; faiss::MetricType metric_type_; int nlist_; int device_; cudaEvent_t sync_{nullptr}; cudaStream_t faiss_default_stream_{nullptr}; double training_sample_fraction_; - std::unique_ptr search_params_; + std::shared_ptr search_params_; const T* dataset_; - raft::device_resources handle_; float refine_ratio_ = 1.0; }; @@ -263,7 +277,7 @@ class FaissGpuIVFFlat : public FaissGpu { { faiss::gpu::GpuIndexIVFFlatConfig config; config.device = this->device_; - this->index_ = std::make_unique( + this->index_ = std::make_shared( &(this->gpu_resource_), dim, param.nlist, this->metric_type_, config); } @@ -275,7 +289,7 @@ class FaissGpuIVFFlat : public FaissGpu { faiss::IVFSearchParameters faiss_search_params; faiss_search_params.nprobe = nprobe; - this->search_params_ = std::make_unique(faiss_search_params); + this->search_params_ = std::make_shared(faiss_search_params); this->refine_ratio_ = search_param.refine_ratio; } @@ -287,6 +301,7 @@ class FaissGpuIVFFlat : public FaissGpu { { this->template load_(file); } + std::unique_ptr> copy() override { return std::make_unique>(*this); }; }; template @@ -306,7 +321,7 @@ class FaissGpuIVFPQ : public FaissGpu { config.device = this->device_; this->index_ = - std::make_unique(&(this->gpu_resource_), + std::make_shared(&(this->gpu_resource_), dim, param.nlist, param.M, @@ -324,11 +339,11 @@ class FaissGpuIVFPQ : public FaissGpu { faiss::IVFPQSearchParameters faiss_search_params; faiss_search_params.nprobe = nprobe; - this->search_params_ = std::make_unique(faiss_search_params); + this->search_params_ = std::make_shared(faiss_search_params); if (search_param.refine_ratio > 1.0) { this->index_refine_ = - std::make_unique(this->index_.get(), this->dataset_); + std::make_shared(this->index_.get(), this->dataset_); this->index_refine_.get()->k_factor = search_param.refine_ratio; } } @@ -341,6 +356,7 @@ class FaissGpuIVFPQ : public FaissGpu { { this->template load_(file); } + std::unique_ptr> copy() override { return std::make_unique>(*this); }; }; // TODO: Enable this in cmake @@ -366,7 +382,7 @@ class FaissGpuIVFSQ : public FaissGpu { faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; config.device = this->device_; - this->index_ = std::make_unique( + this->index_ = std::make_shared( &(this->gpu_resource_), dim, param.nlist, qtype, this->metric_type_, true, config); } @@ -379,11 +395,11 @@ class FaissGpuIVFSQ : public FaissGpu { faiss::IVFSearchParameters faiss_search_params; faiss_search_params.nprobe = nprobe; - this->search_params_ = std::make_unique(faiss_search_params); + this->search_params_ = std::make_shared(faiss_search_params); this->refine_ratio_ = search_param.refine_ratio; if (search_param.refine_ratio > 1.0) { this->index_refine_ = - std::make_unique(this->index_.get(), this->dataset_); + std::make_shared(this->index_.get(), this->dataset_); this->index_refine_.get()->k_factor = search_param.refine_ratio; } } @@ -398,6 +414,7 @@ class FaissGpuIVFSQ : public FaissGpu { this->template load_( file); } + std::unique_ptr> copy() override { return std::make_unique>(*this); }; }; template @@ -408,7 +425,7 @@ class FaissGpuFlat : public FaissGpu { { faiss::gpu::GpuIndexFlatConfig config; config.device = this->device_; - this->index_ = std::make_unique( + this->index_ = std::make_shared( &(this->gpu_resource_), dim, this->metric_type_, config); } void set_search_param(const typename FaissGpu::AnnSearchParam& param) override @@ -417,7 +434,7 @@ class FaissGpuFlat : public FaissGpu { int nprobe = search_param.nprobe; assert(nprobe <= nlist_); - this->search_params_ = std::make_unique(); + this->search_params_ = std::make_shared(); } void save(const std::string& file) const override @@ -428,8 +445,9 @@ class FaissGpuFlat : public FaissGpu { { this->template load_(file); } + std::unique_ptr> copy() override { return std::make_unique>(*this); }; }; } // namespace raft::bench::ann -#endif \ No newline at end of file +#endif diff --git a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh index 664ec511dd..20c50a5119 100644 --- a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh +++ b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh @@ -52,7 +52,6 @@ class Ggnn : public ANN { }; Ggnn(Metric metric, int dim, const BuildParam& param); - ~Ggnn() { delete impl_; } void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) override { @@ -72,6 +71,7 @@ class Ggnn : public ANN { void save(const std::string& file) const override { impl_->save(file); } void load(const std::string& file) override { impl_->load(file); } + std::unique_ptr> copy() override { return std::make_unique>(*this); }; AlgoProperty get_preference() const override { return impl_->get_preference(); } @@ -81,7 +81,7 @@ class Ggnn : public ANN { }; private: - ANN* impl_; + std::shared_ptr> impl_; }; template @@ -90,23 +90,23 @@ Ggnn::Ggnn(Metric metric, int dim, const BuildParam& param) : ANN(metric, // ggnn/src/sift1m.cu if (metric == Metric::kEuclidean && dim == 128 && param.k_build == 24 && param.k == 10 && param.segment_size == 32) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } // ggnn/src/deep1b_multi_gpu.cu, and adapt it deep1B else if (metric == Metric::kEuclidean && dim == 96 && param.k_build == 24 && param.k == 10 && param.segment_size == 32) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } else if (metric == Metric::kInnerProduct && dim == 96 && param.k_build == 24 && param.k == 10 && param.segment_size == 32) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } else if (metric == Metric::kInnerProduct && dim == 96 && param.k_build == 96 && param.k == 10 && param.segment_size == 64) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } // ggnn/src/glove200.cu, adapt it to glove100 else if (metric == Metric::kInnerProduct && dim == 100 && param.k_build == 96 && param.k == 10 && param.segment_size == 64) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } else { throw std::runtime_error( "ggnn: not supported combination of metric, dim and build param; " @@ -133,6 +133,10 @@ class GgnnImpl : public ANN { void save(const std::string& file) const override; void load(const std::string& file) override; + std::unique_ptr> copy() override + { + return std::make_unique>(*this); + }; AlgoProperty get_preference() const override { @@ -159,7 +163,7 @@ class GgnnImpl : public ANN { KBuild / 2 /* KF */, KQuery, S>; - std::unique_ptr ggnn_; + std::shared_ptr ggnn_; typename Ggnn::BuildParam build_param_; typename Ggnn::SearchParam search_param_; }; @@ -189,7 +193,7 @@ void GgnnImpl::build(const T* dataset, { int device; RAFT_CUDA_TRY(cudaGetDevice(&device)); - ggnn_ = std::make_unique( + ggnn_ = std::make_shared( device, nrow, build_param_.num_layers, true, build_param_.tau); ggnn_->set_base_data(dataset); diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 921d72decc..2a5177d295 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -82,6 +82,7 @@ class HnswLib : public ANN { void save(const std::string& path_to_index) const override; void load(const std::string& path_to_index) override; + std::unique_ptr> copy() override { return std::make_unique>(*this); }; AlgoProperty get_preference() const override { @@ -96,15 +97,15 @@ class HnswLib : public ANN { private: void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const; - std::unique_ptr::type>> appr_alg_; - std::unique_ptr::type>> space_; + std::shared_ptr::type>> appr_alg_; + std::shared_ptr::type>> space_; using ANN::metric_; using ANN::dim_; int ef_construction_; int m_; int num_threads_; - std::unique_ptr thread_pool_; + std::shared_ptr thread_pool_; Objective metric_objective_; }; @@ -129,18 +130,18 @@ void HnswLib::build(const T* dataset, size_t nrow, cudaStream_t) { if constexpr (std::is_same_v) { if (metric_ == Metric::kInnerProduct) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } else { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } - appr_alg_ = std::make_unique::type>>( + appr_alg_ = std::make_shared::type>>( space_.get(), nrow, m_, ef_construction_); - thread_pool_ = std::make_unique(num_threads_); + thread_pool_ = std::make_shared(num_threads_); const size_t items_per_thread = nrow / (num_threads_ + 1); thread_pool_->submit( @@ -168,7 +169,7 @@ void HnswLib::set_search_param(const AnnSearchParam& param_) // Create a pool if multiple query threads have been set and the pool hasn't been created already bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_); - if (create_pool) { thread_pool_ = std::make_unique(num_threads_); } + if (create_pool) { thread_pool_ = std::make_shared(num_threads_); } } template @@ -199,15 +200,15 @@ void HnswLib::load(const std::string& path_to_index) { if constexpr (std::is_same_v) { if (metric_ == Metric::kInnerProduct) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } else { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } - appr_alg_ = std::make_unique::type>>( + appr_alg_ = std::make_shared::type>>( space_.get(), path_to_index); } diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index cb30c2693f..2b91c2588c 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -41,4 +41,103 @@ inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric m throw std::runtime_error("raft supports only metric type of inner product and L2"); } } + +/** + * This struct is used by multiple raft benchmark wrappers. It serves as a thread-safe keeper of + * shared and private GPU resources (see below). + * + * - Accessing the same `configured_raft_resources` from concurrent threads is not safe. + * - Accessing the copies of `configured_raft_resources` from concurrent threads is safe. + * - There must be at most one "original" `configured_raft_resources` at any time, but as many + * copies of it as needed (modifies the program static state). + */ +class configured_raft_resources { + public: + using device_mr_t = rmm::mr::pool_memory_resource; + /** + * This constructor has the shared state passed unmodified but creates the local state anew. + * It's used by the copy constructor. + */ + explicit configured_raft_resources(const std::shared_ptr& mr) + : mr_{mr}, + sync_{[]() { + auto* ev = new cudaEvent_t; + RAFT_CUDA_TRY(cudaEventCreate(ev, cudaEventDisableTiming)); + return ev; + }(), + [](cudaEvent_t* ev) { + RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(*ev)); + delete ev; + }}, + res_{cudaStreamPerThread} + { + } + + /** Default constructor creates all resources anew. */ + configured_raft_resources() + : configured_raft_resources{ + {[]() { + auto* mr = + new device_mr_t{rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull}; + rmm::mr::set_current_device_resource(mr); + return mr; + }(), + [](device_mr_t* mr) { + if (mr == nullptr) { return; } + auto* cur_mr = dynamic_cast(rmm::mr::get_current_device_resource()); + if (cur_mr != nullptr && (*cur_mr) == (*mr)) { + // Normally, we'd always want to set the rmm resource back to the upstream of the pool + // here. However, we expect some implementations may be buggy and mess up the rmm + // resource, especially during development. This extra check here adds a little bit of + // resilience: let the program crash/fail somewhere else rather than in the destructor + // of the shared pointer. + rmm::mr::set_current_device_resource(mr->get_upstream()); + } + delete mr; + }}} + { + } + + configured_raft_resources(configured_raft_resources&&) = default; + configured_raft_resources& operator=(configured_raft_resources&&) = default; + ~configured_raft_resources() = default; + configured_raft_resources(const configured_raft_resources& res) + : configured_raft_resources{res.mr_} + { + } + configured_raft_resources& operator=(const configured_raft_resources& other) + { + this->mr_ = other.mr_; + return *this; + } + + operator raft::resources&() noexcept { return res_; } + operator const raft::resources&() const noexcept { return res_; } + + /** Make the given stream wait on all work submitted to the resource. */ + void stream_wait(cudaStream_t stream) const + { + RAFT_CUDA_TRY(cudaEventRecord(*sync_, resource::get_cuda_stream(res_))); + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, *sync_)); + } + + /** Get the internal sync event (which otherwise used only in `stream_wait`). */ + cudaEvent_t get_sync_event() const { return *sync_; } + + private: + /** + * This pool is set as the RMM current device, hence its shared among all users of RMM resources. + * Its lifetime must be longer than that of any other cuda resources. It's not exposed and not + * used by anyone directly. + */ + std::shared_ptr mr_; + /** Each benchmark wrapper must have its own copy of the synchronization event. */ + std::unique_ptr> sync_; + /** + * Until we make the use of copies of raft::resources thread-safe, each benchmark wrapper must + * have its own copy of it. + */ + raft::device_resources res_; +}; + } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index f8c65a2d6e..b776a9fafb 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -126,15 +126,5 @@ REGISTER_ALGO_INSTANCE(std::uint8_t); #ifdef ANN_BENCH_BUILD_MAIN #include "../common/benchmark.hpp" -int main(int argc, char** argv) -{ - rmm::mr::cuda_memory_resource cuda_mr; - // Construct a resource that uses a coalescing best-fit pool allocator - rmm::mr::pool_memory_resource pool_mr{&cuda_mr}; - rmm::mr::set_current_device_resource( - &pool_mr); // Updates the current device resource pointer to `pool_mr` - rmm::mr::device_memory_resource* mr = - rmm::mr::get_current_device_resource(); // Points to `pool_mr` - return raft::bench::ann::run_main(argc, argv); -} +int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index 432caecfcc..3fd0a374b7 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -30,15 +30,12 @@ class RaftCagraHnswlib : public ANN { RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) : ANN(metric, dim), - metric_(metric), - index_params_(param), - dimension_(dim), - handle_(cudaStreamPerThread) + cagra_build_{metric, dim, param, concurrent_searches}, + // HnswLib param values don't matter since we don't build with HnswLib + hnswlib_search_{metric, dim, typename HnswLib::BuildParam{50, 100}} { } - ~RaftCagraHnswlib() noexcept {} - void build(const T* dataset, size_t nrow, cudaStream_t stream) final; void set_search_param(const AnnSearchParam& param) override; @@ -60,61 +57,53 @@ class RaftCagraHnswlib : public ANN { property.query_memory_type = MemoryType::Host; return property; } + void save(const std::string& file) const override; void load(const std::string&) override; + std::unique_ptr> copy() override + { + return std::make_unique>(*this); + } private: - raft::device_resources handle_; - Metric metric_; - BuildParam index_params_; - int dimension_; - - std::unique_ptr> cagra_build_; - std::unique_ptr> hnswlib_search_; - - Objective metric_objective_; + RaftCagra cagra_build_; + HnswLib hnswlib_search_; }; template void RaftCagraHnswlib::build(const T* dataset, size_t nrow, cudaStream_t stream) { - if (not cagra_build_) { - cagra_build_ = std::make_unique>(metric_, dimension_, index_params_); - } - cagra_build_->build(dataset, nrow, stream); + cagra_build_.build(dataset, nrow, stream); } template void RaftCagraHnswlib::set_search_param(const AnnSearchParam& param_) { - hnswlib_search_->set_search_param(param_); + hnswlib_search_.set_search_param(param_); } template void RaftCagraHnswlib::save(const std::string& file) const { - cagra_build_->save_to_hnswlib(file); + cagra_build_.save_to_hnswlib(file); } template void RaftCagraHnswlib::load(const std::string& file) { - typename HnswLib::BuildParam param; - // these values don't matter since we don't build with HnswLib - param.M = 50; - param.ef_construction = 100; - if (not hnswlib_search_) { - hnswlib_search_ = std::make_unique>(metric_, dimension_, param); - } - hnswlib_search_->load(file); - hnswlib_search_->set_base_layer_only(); + hnswlib_search_.load(file); + hnswlib_search_.set_base_layer_only(); } template -void RaftCagraHnswlib::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const +void RaftCagraHnswlib::search(const T* queries, + int batch_size, + int k, + size_t* neighbors, + float* distances, + cudaStream_t stream) const { - hnswlib_search_->search(queries, batch_size, k, neighbors, distances); + hnswlib_search_.search(queries, batch_size, k, neighbors, distances, stream); } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index a3e481ec5a..ec71de9cff 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -76,21 +76,20 @@ class RaftCagra : public ANN { : ANN(metric, dim), index_params_(param), dimension_(dim), - handle_(cudaStreamPerThread), need_dataset_update_(true), - dataset_(make_device_matrix(handle_, 0, 0)), - graph_(make_device_matrix(handle_, 0, 0)), - input_dataset_v_(nullptr, 0, 0), + dataset_(std::make_shared>( + std::move(make_device_matrix(handle_, 0, 0)))), + graph_(std::make_shared>( + std::move(make_device_matrix(handle_, 0, 0)))), + input_dataset_v_( + std::make_shared>(nullptr, 0, 0)), graph_mem_(AllocatorType::Device), dataset_mem_(AllocatorType::Device) { index_params_.cagra_params.metric = parse_metric_type(metric); index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); - RAFT_CUDA_TRY(cudaGetDevice(&device_)); } - ~RaftCagra() noexcept {} - void build(const T* dataset, size_t nrow, cudaStream_t stream) final; void set_search_param(const AnnSearchParam& param) override; @@ -117,8 +116,24 @@ class RaftCagra : public ANN { void save(const std::string& file) const override; void load(const std::string&) override; void save_to_hnswlib(const std::string& file) const; + std::unique_ptr> copy() override; private: + // handle_ must go first to make sure it dies last and all memory allocated in pool + configured_raft_resources handle_{}; + raft::mr::cuda_pinned_resource mr_pinned_; + raft::mr::cuda_huge_page_resource mr_huge_page_; + AllocatorType graph_mem_; + AllocatorType dataset_mem_; + BuildParam index_params_; + bool need_dataset_update_; + raft::neighbors::cagra::search_params search_params_; + std::shared_ptr> index_; + int dimension_; + std::shared_ptr> graph_; + std::shared_ptr> dataset_; + std::shared_ptr> input_dataset_v_; + inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type) { switch (mem_type) { @@ -127,38 +142,26 @@ class RaftCagra : public ANN { default: return rmm::mr::get_current_device_resource(); } } - raft ::mr::cuda_pinned_resource mr_pinned_; - raft ::mr::cuda_huge_page_resource mr_huge_page_; - raft::device_resources handle_; - AllocatorType graph_mem_; - AllocatorType dataset_mem_; - BuildParam index_params_; - bool need_dataset_update_; - raft::neighbors::cagra::search_params search_params_; - std::optional> index_; - int device_; - int dimension_; - raft::device_matrix graph_; - raft::device_matrix dataset_; - raft::device_matrix_view input_dataset_v_; }; template -void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t) +void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t stream) { auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), dimension_); auto& params = index_params_.cagra_params; - index_.emplace(raft::neighbors::cagra::detail::build(handle_, - params, - dataset_view, - index_params_.nn_descent_params, - index_params_.ivf_pq_refine_rate, - index_params_.ivf_pq_build_params, - index_params_.ivf_pq_search_params)); - return; + index_ = std::make_shared>( + std::move(raft::neighbors::cagra::detail::build(handle_, + params, + dataset_view, + index_params_.nn_descent_params, + index_params_.ivf_pq_refine_rate, + index_params_.ivf_pq_build_params, + index_params_.ivf_pq_search_params))); + + handle_.stream_wait(stream); // RAFT stream -> bench stream } inline std::string allocator_to_string(AllocatorType mem_type) @@ -194,24 +197,24 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) index_->update_graph(handle_, make_const_mdspan(new_graph.view())); // update_graph() only stores a view in the index. We need to keep the graph object alive. - graph_ = std::move(new_graph); + *graph_ = std::move(new_graph); } if (search_param.dataset_mem != dataset_mem_ || need_dataset_update_) { dataset_mem_ = search_param.dataset_mem; // First free up existing memory - dataset_ = make_device_matrix(handle_, 0, 0); - index_->update_dataset(handle_, make_const_mdspan(dataset_.view())); + *dataset_ = make_device_matrix(handle_, 0, 0); + index_->update_dataset(handle_, make_const_mdspan(dataset_->view())); // Allocate space using the correct memory resource. RAFT_LOG_INFO("moving dataset to new memory space: %s", allocator_to_string(dataset_mem_).c_str()); auto mr = get_mr(dataset_mem_); - raft::neighbors::cagra::detail::copy_with_padding(handle_, dataset_, input_dataset_v_, mr); + raft::neighbors::cagra::detail::copy_with_padding(handle_, *dataset_, *input_dataset_v_, mr); - index_->update_dataset(handle_, make_const_mdspan(dataset_.view())); + index_->update_dataset(handle_, make_const_mdspan(dataset_->view())); // Ideally, instead of dataset_.view(), we should pass a strided matrix view to update. // See Issue https://github.com/rapidsai/raft/issues/1972 for details. @@ -227,9 +230,9 @@ void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) { // It can happen that we are re-using a previous algo object which already has // the dataset set. Check if we need update. - if (static_cast(input_dataset_v_.extent(0)) != nrow || - input_dataset_v_.data_handle() != dataset) { - input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); + if (static_cast(input_dataset_v_->extent(0)) != nrow || + input_dataset_v_->data_handle() != dataset) { + *input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); need_dataset_update_ = true; } } @@ -249,12 +252,23 @@ void RaftCagra::save_to_hnswlib(const std::string& file) const template void RaftCagra::load(const std::string& file) { - index_ = raft::neighbors::cagra::deserialize(handle_, file); + index_ = std::make_shared>( + std::move(raft::neighbors::cagra::deserialize(handle_, file))); +} + +template +std::unique_ptr> RaftCagra::copy() +{ + return std::make_unique>(*this); // use copy constructor } template -void RaftCagra::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const +void RaftCagra::search(const T* queries, + int batch_size, + int k, + size_t* neighbors, + float* distances, + cudaStream_t stream) const { IdxT* neighbors_IdxT; rmm::device_uvector neighbors_storage(0, resource::get_cuda_stream(handle_)); @@ -281,6 +295,6 @@ void RaftCagra::search( raft::resource::get_cuda_stream(handle_)); } - handle_.sync_stream(); + handle_.stream_wait(stream); // RAFT stream -> bench stream } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 13ea20d483..51b8b67f37 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -59,8 +59,6 @@ class RaftIvfFlatGpu : public ANN { RAFT_CUDA_TRY(cudaGetDevice(&device_)); } - ~RaftIvfFlatGpu() noexcept {} - void build(const T* dataset, size_t nrow, cudaStream_t stream) final; void set_search_param(const AnnSearchParam& param) override; @@ -84,22 +82,24 @@ class RaftIvfFlatGpu : public ANN { } void save(const std::string& file) const override; void load(const std::string&) override; + std::unique_ptr> copy() override; private: - raft::device_resources handle_; + // handle_ must go first to make sure it dies last and all memory allocated in pool + configured_raft_resources handle_{}; BuildParam index_params_; raft::neighbors::ivf_flat::search_params search_params_; - std::optional> index_; + std::shared_ptr> index_; int device_; int dimension_; }; template -void RaftIvfFlatGpu::build(const T* dataset, size_t nrow, cudaStream_t) +void RaftIvfFlatGpu::build(const T* dataset, size_t nrow, cudaStream_t stream) { - index_.emplace( - raft::neighbors::ivf_flat::build(handle_, index_params_, dataset, IdxT(nrow), dimension_)); - return; + index_ = std::make_shared>(std::move( + raft::neighbors::ivf_flat::build(handle_, index_params_, dataset, IdxT(nrow), dimension_))); + handle_.stream_wait(stream); // RAFT stream -> bench stream } template @@ -120,18 +120,28 @@ void RaftIvfFlatGpu::save(const std::string& file) const template void RaftIvfFlatGpu::load(const std::string& file) { - index_ = raft::neighbors::ivf_flat::deserialize(handle_, file); + index_ = std::make_shared>( + std::move(raft::neighbors::ivf_flat::deserialize(handle_, file))); return; } template -void RaftIvfFlatGpu::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const +std::unique_ptr> RaftIvfFlatGpu::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + +template +void RaftIvfFlatGpu::search(const T* queries, + int batch_size, + int k, + size_t* neighbors, + float* distances, + cudaStream_t stream) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); raft::neighbors::ivf_flat::search( handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances); - resource::sync_stream(handle_); - return; + handle_.stream_wait(stream); // RAFT stream -> bench stream } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index e4004b0007..9a373787ac 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -57,12 +57,8 @@ class RaftIvfPQ : public ANN { : ANN(metric, dim), index_params_(param), dimension_(dim) { index_params_.metric = parse_metric_type(metric); - RAFT_CUDA_TRY(cudaGetDevice(&device_)); - RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); } - ~RaftIvfPQ() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); } - void build(const T* dataset, size_t nrow, cudaStream_t stream) final; void set_search_param(const AnnSearchParam& param) override; @@ -87,23 +83,17 @@ class RaftIvfPQ : public ANN { } void save(const std::string& file) const override; void load(const std::string&) override; + std::unique_ptr> copy() override; private: - raft::device_resources handle_; - cudaEvent_t sync_{nullptr}; + // handle_ must go first to make sure it dies last and all memory allocated in pool + configured_raft_resources handle_{}; BuildParam index_params_; raft::neighbors::ivf_pq::search_params search_params_; - std::optional> index_; - int device_; + std::shared_ptr> index_; int dimension_; float refine_ratio_ = 1.0; raft::device_matrix_view dataset_; - - void stream_wait(cudaStream_t stream) const - { - RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_))); - RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_)); - } }; template @@ -115,9 +105,9 @@ void RaftIvfPQ::save(const std::string& file) const template void RaftIvfPQ::load(const std::string& file) { - auto index_tmp = raft::neighbors::ivf_pq::index(handle_, index_params_, dimension_); - raft::runtime::neighbors::ivf_pq::deserialize(handle_, file, &index_tmp); - index_.emplace(std::move(index_tmp)); + std::make_shared>(handle_, index_params_, dimension_) + .swap(index_); + raft::runtime::neighbors::ivf_pq::deserialize(handle_, file, index_.get()); return; } @@ -125,9 +115,16 @@ template void RaftIvfPQ::build(const T* dataset, size_t nrow, cudaStream_t stream) { auto dataset_v = raft::make_device_matrix_view(dataset, IdxT(nrow), dim_); + std::make_shared>( + std::move(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v))) + .swap(index_); + handle_.stream_wait(stream); // RAFT stream -> bench stream +} - index_.emplace(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v)); - stream_wait(stream); +template +std::unique_ptr> RaftIvfPQ::copy() +{ + return std::make_unique>(*this); // use copy constructor } template @@ -176,7 +173,7 @@ void RaftIvfPQ::search(const T* queries, neighbors_v, distances_v, index_->metric()); - stream_wait(stream); // RAFT stream -> bench stream + handle_.stream_wait(stream); // RAFT stream -> bench stream } else { auto queries_host = raft::make_host_matrix(batch_size, index_->dim()); auto candidates_host = raft::make_host_matrix(batch_size, k0); @@ -193,9 +190,9 @@ void RaftIvfPQ::search(const T* queries, dataset_.data_handle(), dataset_.extent(0), dataset_.extent(1)); // wait for the queries to copy to host in 'stream` and for IVF-PQ::search to finish - RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_))); - RAFT_CUDA_TRY(cudaEventRecord(sync_, stream)); - RAFT_CUDA_TRY(cudaEventSynchronize(sync_)); + RAFT_CUDA_TRY(cudaEventRecord(handle_.get_sync_event(), resource::get_cuda_stream(handle_))); + RAFT_CUDA_TRY(cudaEventRecord(handle_.get_sync_event(), stream)); + RAFT_CUDA_TRY(cudaEventSynchronize(handle_.get_sync_event())); raft::runtime::neighbors::refine(handle_, dataset_v, queries_host.view(), @@ -215,7 +212,7 @@ void RaftIvfPQ::search(const T* queries, raft::runtime::neighbors::ivf_pq::search( handle_, search_params_, *index_, queries_v, neighbors_v, distances_v); - stream_wait(stream); // RAFT stream -> bench stream + handle_.stream_wait(stream); // RAFT stream -> bench stream } } } // namespace raft::bench::ann