From 9359e8997059ededd32e1c2833eb725bc2b4e1f5 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 23 Nov 2023 16:51:23 +0100 Subject: [PATCH] Copy benchmark wrappers to avoid concurrently accessing not thread-safe resources --- cpp/bench/ann/src/common/ann_types.hpp | 6 ++ cpp/bench/ann/src/common/benchmark.hpp | 2 +- cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h | 1 + cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h | 3 +- cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh | 22 ++++--- cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h | 27 +++++---- cpp/bench/ann/src/raft/raft_ann_bench_utils.h | 59 ++++++++++++++----- .../ann/src/raft/raft_cagra_hnswlib_wrapper.h | 1 + cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 58 ++++++++++-------- .../ann/src/raft/raft_ivf_flat_wrapper.h | 16 +++-- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 20 +++++-- 11 files changed, 143 insertions(+), 72 deletions(-) diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index 852c784552..9b77c9df91 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -18,6 +18,7 @@ #include "cuda_stub.hpp" // cudaStream_t +#include #include #include #include @@ -118,6 +119,11 @@ class ANN : public AnnBase { // The client code should call set_search_dataset() before searching, // and should not release dataset before searching is finished. virtual void set_search_dataset(const T* /*dataset*/, size_t /*nrow*/){}; + + /** + * Make a shallow copy of the ANN wrapper that shares the resources and ensures thread-safe access + * to them. */ + virtual auto copy() -> std::unique_ptr> = 0; }; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 60464104f5..e61de6745e 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -291,7 +291,7 @@ void bench_search(::benchmark::State& state, { nvtx_case nvtx{state.name()}; - ANN* algo = dynamic_cast*>(current_algo.get()); + auto algo = dynamic_cast*>(current_algo.get())->copy(); for (auto _ : state) { [[maybe_unused]] auto ntx_lap = nvtx.lap(); [[maybe_unused]] auto gpu_lap = gpu_timer.lap(); diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h index 755fe9f197..85e3ec61a5 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h @@ -105,6 +105,7 @@ class FaissCpu : public ANN { property.query_memory_type = MemoryType::Host; return property; } + std::unique_ptr> copy() override; protected: template diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h index 4f13ff8a49..56ed7f6aad 100644 --- a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h @@ -133,6 +133,7 @@ class FaissGpu : public ANN { property.query_memory_type = MemoryType::Host; return property; } + std::unique_ptr> copy() override; protected: template @@ -432,4 +433,4 @@ class FaissGpuFlat : public FaissGpu { } // namespace raft::bench::ann -#endif \ No newline at end of file +#endif diff --git a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh index 664ec511dd..20c50a5119 100644 --- a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh +++ b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh @@ -52,7 +52,6 @@ class Ggnn : public ANN { }; Ggnn(Metric metric, int dim, const BuildParam& param); - ~Ggnn() { delete impl_; } void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) override { @@ -72,6 +71,7 @@ class Ggnn : public ANN { void save(const std::string& file) const override { impl_->save(file); } void load(const std::string& file) override { impl_->load(file); } + std::unique_ptr> copy() override { return std::make_unique>(*this); }; AlgoProperty get_preference() const override { return impl_->get_preference(); } @@ -81,7 +81,7 @@ class Ggnn : public ANN { }; private: - ANN* impl_; + std::shared_ptr> impl_; }; template @@ -90,23 +90,23 @@ Ggnn::Ggnn(Metric metric, int dim, const BuildParam& param) : ANN(metric, // ggnn/src/sift1m.cu if (metric == Metric::kEuclidean && dim == 128 && param.k_build == 24 && param.k == 10 && param.segment_size == 32) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } // ggnn/src/deep1b_multi_gpu.cu, and adapt it deep1B else if (metric == Metric::kEuclidean && dim == 96 && param.k_build == 24 && param.k == 10 && param.segment_size == 32) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } else if (metric == Metric::kInnerProduct && dim == 96 && param.k_build == 24 && param.k == 10 && param.segment_size == 32) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } else if (metric == Metric::kInnerProduct && dim == 96 && param.k_build == 96 && param.k == 10 && param.segment_size == 64) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } // ggnn/src/glove200.cu, adapt it to glove100 else if (metric == Metric::kInnerProduct && dim == 100 && param.k_build == 96 && param.k == 10 && param.segment_size == 64) { - impl_ = new GgnnImpl(metric, dim, param); + impl_ = std::make_shared>(metric, dim, param); } else { throw std::runtime_error( "ggnn: not supported combination of metric, dim and build param; " @@ -133,6 +133,10 @@ class GgnnImpl : public ANN { void save(const std::string& file) const override; void load(const std::string& file) override; + std::unique_ptr> copy() override + { + return std::make_unique>(*this); + }; AlgoProperty get_preference() const override { @@ -159,7 +163,7 @@ class GgnnImpl : public ANN { KBuild / 2 /* KF */, KQuery, S>; - std::unique_ptr ggnn_; + std::shared_ptr ggnn_; typename Ggnn::BuildParam build_param_; typename Ggnn::SearchParam search_param_; }; @@ -189,7 +193,7 @@ void GgnnImpl::build(const T* dataset, { int device; RAFT_CUDA_TRY(cudaGetDevice(&device)); - ggnn_ = std::make_unique( + ggnn_ = std::make_shared( device, nrow, build_param_.num_layers, true, build_param_.tau); ggnn_->set_base_data(dataset); diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 921d72decc..2a5177d295 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -82,6 +82,7 @@ class HnswLib : public ANN { void save(const std::string& path_to_index) const override; void load(const std::string& path_to_index) override; + std::unique_ptr> copy() override { return std::make_unique>(*this); }; AlgoProperty get_preference() const override { @@ -96,15 +97,15 @@ class HnswLib : public ANN { private: void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const; - std::unique_ptr::type>> appr_alg_; - std::unique_ptr::type>> space_; + std::shared_ptr::type>> appr_alg_; + std::shared_ptr::type>> space_; using ANN::metric_; using ANN::dim_; int ef_construction_; int m_; int num_threads_; - std::unique_ptr thread_pool_; + std::shared_ptr thread_pool_; Objective metric_objective_; }; @@ -129,18 +130,18 @@ void HnswLib::build(const T* dataset, size_t nrow, cudaStream_t) { if constexpr (std::is_same_v) { if (metric_ == Metric::kInnerProduct) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } else { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } - appr_alg_ = std::make_unique::type>>( + appr_alg_ = std::make_shared::type>>( space_.get(), nrow, m_, ef_construction_); - thread_pool_ = std::make_unique(num_threads_); + thread_pool_ = std::make_shared(num_threads_); const size_t items_per_thread = nrow / (num_threads_ + 1); thread_pool_->submit( @@ -168,7 +169,7 @@ void HnswLib::set_search_param(const AnnSearchParam& param_) // Create a pool if multiple query threads have been set and the pool hasn't been created already bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_); - if (create_pool) { thread_pool_ = std::make_unique(num_threads_); } + if (create_pool) { thread_pool_ = std::make_shared(num_threads_); } } template @@ -199,15 +200,15 @@ void HnswLib::load(const std::string& path_to_index) { if constexpr (std::is_same_v) { if (metric_ == Metric::kInnerProduct) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } else { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_unique(dim_); + space_ = std::make_shared(dim_); } - appr_alg_ = std::make_unique::type>>( + appr_alg_ = std::make_shared::type>>( space_.get(), path_to_index); } diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index 5fd27f4500..51732d63e0 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -44,21 +44,50 @@ inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric m class configured_raft_resources { public: + explicit configured_raft_resources( + const std::shared_ptr> mr) + : mr_{mr}, + sync_{[]() { + auto* ev = new cudaEvent_t; + RAFT_CUDA_TRY(cudaEventCreate(ev, cudaEventDisableTiming)); + return ev; + }(), + [](cudaEvent_t* ev) { + RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(*ev)); + delete ev; + }}, + res_{cudaStreamPerThread} + { + } + configured_raft_resources() - : mr_{rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull}, - res_{cudaStreamPerThread}, - sync_{nullptr} + : configured_raft_resources{ + {[]() { + auto* mr = new rmm::mr::pool_memory_resource{ + rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull}; + rmm::mr::set_current_device_resource(mr); + return mr; + }(), + [](rmm::mr::pool_memory_resource* mr) { + if (rmm::mr::get_current_device_resource()->is_equal(*mr)) { + rmm::mr::set_current_device_resource(mr->get_upstream()); + } + delete mr; + }}} { - rmm::mr::set_current_device_resource(&mr_); - RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); } - ~configured_raft_resources() noexcept + configured_raft_resources(configured_raft_resources&&) = default; + configured_raft_resources& operator=(configured_raft_resources&&) = default; + ~configured_raft_resources() = default; + configured_raft_resources(const configured_raft_resources& res) + : configured_raft_resources{res.mr_} + { + } + configured_raft_resources& operator=(const configured_raft_resources& other) { - RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); - if (rmm::mr::get_current_device_resource()->is_equal(mr_)) { - rmm::mr::set_current_device_resource(mr_.get_upstream()); - } + this->mr_ = other.mr_; + return *this; } operator raft::resources&() noexcept { return res_; } @@ -67,17 +96,17 @@ class configured_raft_resources { /** Make the given stream wait on all work submitted to the resource. */ void stream_wait(cudaStream_t stream) const { - RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(res_))); - RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_)); + RAFT_CUDA_TRY(cudaEventRecord(*sync_, resource::get_cuda_stream(res_))); + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, *sync_)); } /** Get the internal sync event (which otherwise used only in `stream_wait`). */ - cudaEvent_t get_sync_event() const { return sync_; } + cudaEvent_t get_sync_event() const { return *sync_; } private: - rmm::mr::pool_memory_resource mr_; + std::shared_ptr> mr_; + std::unique_ptr> sync_; raft::device_resources res_; - cudaEvent_t sync_; }; } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index 432caecfcc..e42cb5e7f2 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -62,6 +62,7 @@ class RaftCagraHnswlib : public ANN { } void save(const std::string& file) const override; void load(const std::string&) override; + std::unique_ptr> copy() override; private: raft::device_resources handle_; diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index f04ab59e19..ec71de9cff 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -77,9 +77,12 @@ class RaftCagra : public ANN { index_params_(param), dimension_(dim), need_dataset_update_(true), - dataset_(make_device_matrix(handle_, 0, 0)), - graph_(make_device_matrix(handle_, 0, 0)), - input_dataset_v_(nullptr, 0, 0), + dataset_(std::make_shared>( + std::move(make_device_matrix(handle_, 0, 0)))), + graph_(std::make_shared>( + std::move(make_device_matrix(handle_, 0, 0)))), + input_dataset_v_( + std::make_shared>(nullptr, 0, 0)), graph_mem_(AllocatorType::Device), dataset_mem_(AllocatorType::Device) { @@ -113,6 +116,7 @@ class RaftCagra : public ANN { void save(const std::string& file) const override; void load(const std::string&) override; void save_to_hnswlib(const std::string& file) const; + std::unique_ptr> copy() override; private: // handle_ must go first to make sure it dies last and all memory allocated in pool @@ -124,11 +128,11 @@ class RaftCagra : public ANN { BuildParam index_params_; bool need_dataset_update_; raft::neighbors::cagra::search_params search_params_; - std::optional> index_; + std::shared_ptr> index_; int dimension_; - raft::device_matrix graph_; - raft::device_matrix dataset_; - raft::device_matrix_view input_dataset_v_; + std::shared_ptr> graph_; + std::shared_ptr> dataset_; + std::shared_ptr> input_dataset_v_; inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type) { @@ -148,13 +152,14 @@ void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t strea auto& params = index_params_.cagra_params; - index_.emplace(raft::neighbors::cagra::detail::build(handle_, - params, - dataset_view, - index_params_.nn_descent_params, - index_params_.ivf_pq_refine_rate, - index_params_.ivf_pq_build_params, - index_params_.ivf_pq_search_params)); + index_ = std::make_shared>( + std::move(raft::neighbors::cagra::detail::build(handle_, + params, + dataset_view, + index_params_.nn_descent_params, + index_params_.ivf_pq_refine_rate, + index_params_.ivf_pq_build_params, + index_params_.ivf_pq_search_params))); handle_.stream_wait(stream); // RAFT stream -> bench stream } @@ -192,24 +197,24 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) index_->update_graph(handle_, make_const_mdspan(new_graph.view())); // update_graph() only stores a view in the index. We need to keep the graph object alive. - graph_ = std::move(new_graph); + *graph_ = std::move(new_graph); } if (search_param.dataset_mem != dataset_mem_ || need_dataset_update_) { dataset_mem_ = search_param.dataset_mem; // First free up existing memory - dataset_ = make_device_matrix(handle_, 0, 0); - index_->update_dataset(handle_, make_const_mdspan(dataset_.view())); + *dataset_ = make_device_matrix(handle_, 0, 0); + index_->update_dataset(handle_, make_const_mdspan(dataset_->view())); // Allocate space using the correct memory resource. RAFT_LOG_INFO("moving dataset to new memory space: %s", allocator_to_string(dataset_mem_).c_str()); auto mr = get_mr(dataset_mem_); - raft::neighbors::cagra::detail::copy_with_padding(handle_, dataset_, input_dataset_v_, mr); + raft::neighbors::cagra::detail::copy_with_padding(handle_, *dataset_, *input_dataset_v_, mr); - index_->update_dataset(handle_, make_const_mdspan(dataset_.view())); + index_->update_dataset(handle_, make_const_mdspan(dataset_->view())); // Ideally, instead of dataset_.view(), we should pass a strided matrix view to update. // See Issue https://github.com/rapidsai/raft/issues/1972 for details. @@ -225,9 +230,9 @@ void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) { // It can happen that we are re-using a previous algo object which already has // the dataset set. Check if we need update. - if (static_cast(input_dataset_v_.extent(0)) != nrow || - input_dataset_v_.data_handle() != dataset) { - input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); + if (static_cast(input_dataset_v_->extent(0)) != nrow || + input_dataset_v_->data_handle() != dataset) { + *input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); need_dataset_update_ = true; } } @@ -247,7 +252,14 @@ void RaftCagra::save_to_hnswlib(const std::string& file) const template void RaftCagra::load(const std::string& file) { - index_ = raft::neighbors::cagra::deserialize(handle_, file); + index_ = std::make_shared>( + std::move(raft::neighbors::cagra::deserialize(handle_, file))); +} + +template +std::unique_ptr> RaftCagra::copy() +{ + return std::make_unique>(*this); // use copy constructor } template diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 78d44ab1ab..4bb9a89af3 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -82,13 +82,14 @@ class RaftIvfFlatGpu : public ANN { } void save(const std::string& file) const override; void load(const std::string&) override; + std::unique_ptr> copy() override; private: // handle_ must go first to make sure it dies last and all memory allocated in pool configured_raft_resources handle_{}; BuildParam index_params_; raft::neighbors::ivf_flat::search_params search_params_; - std::optional> index_; + std::shared_ptr> index_; int device_; int dimension_; }; @@ -96,8 +97,8 @@ class RaftIvfFlatGpu : public ANN { template void RaftIvfFlatGpu::build(const T* dataset, size_t nrow, cudaStream_t stream) { - index_.emplace( - raft::neighbors::ivf_flat::build(handle_, index_params_, dataset, IdxT(nrow), dimension_)); + index_ = std::make_shared>(std::move( + raft::neighbors::ivf_flat::build(handle_, index_params_, dataset, IdxT(nrow), dimension_))); handle_.stream_wait(stream); // RAFT stream -> bench stream } @@ -119,10 +120,17 @@ void RaftIvfFlatGpu::save(const std::string& file) const template void RaftIvfFlatGpu::load(const std::string& file) { - index_ = raft::neighbors::ivf_flat::deserialize(handle_, file); + index_ = std::make_shared>( + std::move(raft::neighbors::ivf_flat::deserialize(handle_, file))); return; } +template +std::unique_ptr> RaftIvfFlatGpu::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + template void RaftIvfFlatGpu::search(const T* queries, int batch_size, diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index dfff781ce2..9a373787ac 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -83,13 +83,14 @@ class RaftIvfPQ : public ANN { } void save(const std::string& file) const override; void load(const std::string&) override; + std::unique_ptr> copy() override; private: // handle_ must go first to make sure it dies last and all memory allocated in pool configured_raft_resources handle_{}; BuildParam index_params_; raft::neighbors::ivf_pq::search_params search_params_; - std::optional> index_; + std::shared_ptr> index_; int dimension_; float refine_ratio_ = 1.0; raft::device_matrix_view dataset_; @@ -104,9 +105,9 @@ void RaftIvfPQ::save(const std::string& file) const template void RaftIvfPQ::load(const std::string& file) { - auto index_tmp = raft::neighbors::ivf_pq::index(handle_, index_params_, dimension_); - raft::runtime::neighbors::ivf_pq::deserialize(handle_, file, &index_tmp); - index_.emplace(std::move(index_tmp)); + std::make_shared>(handle_, index_params_, dimension_) + .swap(index_); + raft::runtime::neighbors::ivf_pq::deserialize(handle_, file, index_.get()); return; } @@ -114,11 +115,18 @@ template void RaftIvfPQ::build(const T* dataset, size_t nrow, cudaStream_t stream) { auto dataset_v = raft::make_device_matrix_view(dataset, IdxT(nrow), dim_); - - index_.emplace(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v)); + std::make_shared>( + std::move(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v))) + .swap(index_); handle_.stream_wait(stream); // RAFT stream -> bench stream } +template +std::unique_ptr> RaftIvfPQ::copy() +{ + return std::make_unique>(*this); // use copy constructor +} + template void RaftIvfPQ::set_search_param(const AnnSearchParam& param) {