From ab4787286bb50a3a233b841fee28bba6769fbc68 Mon Sep 17 00:00:00 2001 From: achirkin Date: Wed, 13 Dec 2023 18:51:43 +0100 Subject: [PATCH] Make faiss wrapper create a new sync event on index copy --- cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h | 33 +++++++++++++-------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h index ad51dd4e68..7879530753 100644 --- a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h @@ -80,6 +80,19 @@ class OmpSingleThreadScope { namespace raft::bench::ann { +struct copyable_event { + copyable_event() { RAFT_CUDA_TRY(cudaEventCreate(&value_, cudaEventDisableTiming)); } + ~copyable_event() { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(value_)); } + copyable_event(copyable_event&&) = default; + copyable_event& operator=(copyable_event&&) = default; + copyable_event(const copyable_event& res) : copyable_event{} {} + copyable_event& operator=(const copyable_event& other) = delete; + operator cudaEvent_t() const noexcept { return value_; } + + private: + cudaEvent_t value_{nullptr}; +}; + template class FaissGpu : public ANN { public: @@ -97,18 +110,15 @@ class FaissGpu : public ANN { FaissGpu(Metric metric, int dim, const BuildParam& param) : ANN(metric, dim), + gpu_resource_{std::make_shared()}, metric_type_(parse_metric_type(metric)), nlist_{param.nlist}, training_sample_fraction_{1.0 / double(param.ratio)} { static_assert(std::is_same_v, "faiss support only float type"); RAFT_CUDA_TRY(cudaGetDevice(&device_)); - RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); - faiss_default_stream_ = gpu_resource_.getDefaultStream(device_); } - virtual ~FaissGpu() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); } - void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) final; virtual void set_search_param(const FaissGpu::AnnSearchParam& param) {} @@ -142,7 +152,7 @@ class FaissGpu : public ANN { void stream_wait(cudaStream_t stream) const { - RAFT_CUDA_TRY(cudaEventRecord(sync_, faiss_default_stream_)); + RAFT_CUDA_TRY(cudaEventRecord(sync_, gpu_resource_->getDefaultStream(device_))); RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_)); } @@ -162,14 +172,13 @@ class FaissGpu : public ANN { * faiss::gpu::StandardGpuResources are thread-safe. * */ - mutable faiss::gpu::StandardGpuResources gpu_resource_; + mutable std::shared_ptr gpu_resource_; std::shared_ptr index_; std::shared_ptr index_refine_{nullptr}; faiss::MetricType metric_type_; int nlist_; int device_; - cudaEvent_t sync_{nullptr}; - cudaStream_t faiss_default_stream_{nullptr}; + copyable_event sync_{}; double training_sample_fraction_; std::shared_ptr search_params_; const T* dataset_; @@ -278,7 +287,7 @@ class FaissGpuIVFFlat : public FaissGpu { faiss::gpu::GpuIndexIVFFlatConfig config; config.device = this->device_; this->index_ = std::make_shared( - &(this->gpu_resource_), dim, param.nlist, this->metric_type_, config); + this->gpu_resource_.get(), dim, param.nlist, this->metric_type_, config); } void set_search_param(const typename FaissGpu::AnnSearchParam& param) override @@ -321,7 +330,7 @@ class FaissGpuIVFPQ : public FaissGpu { config.device = this->device_; this->index_ = - std::make_shared(&(this->gpu_resource_), + std::make_shared(this->gpu_resource_.get(), dim, param.nlist, param.M, @@ -383,7 +392,7 @@ class FaissGpuIVFSQ : public FaissGpu { faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; config.device = this->device_; this->index_ = std::make_shared( - &(this->gpu_resource_), dim, param.nlist, qtype, this->metric_type_, true, config); + this->gpu_resource_.get(), dim, param.nlist, qtype, this->metric_type_, true, config); } void set_search_param(const typename FaissGpu::AnnSearchParam& param) override @@ -426,7 +435,7 @@ class FaissGpuFlat : public FaissGpu { faiss::gpu::GpuIndexFlatConfig config; config.device = this->device_; this->index_ = std::make_shared( - &(this->gpu_resource_), dim, this->metric_type_, config); + this->gpu_resource_.get(), dim, this->metric_type_, config); } void set_search_param(const typename FaissGpu::AnnSearchParam& param) override {