From edba189053c1b830ca2b02c63d7a9a1c2f10010f Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 26 Sep 2024 15:50:21 -0700 Subject: [PATCH 1/2] [Feat] Relative change with `bitset` API feature #2439 in raft (#350) Authors: - rhdong (https://github.com/rhdong) Approvers: - Micka (https://github.com/lowener) URL: https://github.com/rapidsai/cuvs/pull/350 --- cpp/src/neighbors/detail/knn_brute_force.cuh | 2 +- cpp/test/neighbors/brute_force_prefiltered.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 88986af7d..cf27bcde7 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -595,7 +595,7 @@ void brute_force_search_filtered( auto filter_view = raft::make_device_vector_view(filter.data(), filter.n_elements()); IdxT size_h = n_queries * n_dataset; - auto size_view = raft::make_host_scalar_view(&size_h); + auto size_view = raft::make_host_scalar_view(&size_h); raft::popc(res, filter_view, size_view, nnz_view); raft::copy(&nnz_h, nnz.data(), 1, stream); diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu index 9304ee045..ae9111ea1 100644 --- a/cpp/test/neighbors/brute_force_prefiltered.cu +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -203,7 +203,7 @@ class PrefilteredBruteForceTest auto filter_view = raft::make_device_vector_view(filter_d.data(), filter_d.size()); index_t size_h = m * n; - auto size_view = raft::make_host_scalar_view(&size_h); + auto size_view = raft::make_host_scalar_view(&size_h); set_bitmap(src, dst, bitmap, n_edges, n, stream); From b93b8f639b2e4caa8c542b643d615035a6dee754 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Fri, 27 Sep 2024 03:04:34 +0200 Subject: [PATCH 2/2] Persistent CAGRA kernel (#215) An experimental version of the single-cta CAGRA kernel that run persistently while allowing many CPU threads submit the queries in small batches very efficiently.

CAGRA throughput @ Recall = 0.94, n_queries = 1 CAGRA throughput @ Recall = 0.94, n_queries = 10

## API In the current implementation, the public API does not change. An extra parameter `persistent` is added to the `ann::cagra::search_params` (only valid when `algo == SINGLE_CTA`). The persistent kernel is managed by a global runner object in a `shared_ptr`; the first CPU thread to call the kernel spawns the runner, subsequent calls/threads only update a global "heartbeat" atomic variable (`runner_base_t::last_touch`). When there's no heartbeat in the last few seconds (`kLiveInterval`), the runner shuts down the kernel and cleans up the associated resources. An alternative solution would be to control the kernel explicitly, in a client-server style. This would be more controllable, but would require significant re-thinking of the RAFT/cuVS API. ### Synchronization behavior and CUDA streams The kernel is managed in a dedicated thread & a non-blocking stream; it's independent of any other (i.e. calling) threads. Although we pass a CUDA stream to the search function to preserve the api, this **CUDA stream is never used**; in fact, there are no CUDA API calls happening in the calling thread. All communication between the host calling thread and GPU workers happens via atomic variables. **The search function blocks the CPU thread**, i.e. it waits till the results are back before returning. ### Exceptions and safety The kernel runner object is stored in a shared pointer. Hence, it provides all the same safety guarantees as smart pointers. For example, if a C++ exception is raised in the runner thread, the kernel is stopped during the destruction of the runner/last shared pointer. It's hard to detect if something happens to the kernel or CUDA context. If the kernel does not return the results to the calling thread within the configured kernel lifetime (`persistent_lifetime` ), the calling thread abandons the request and throws an exception. The designed behavior here is that all components can gracefully shutdown within the configured kernel lifetime independently. ## Integration notes ### lightweight_uvector RMM memory resources and device buffers are not zero-cost, even when the allocation size is zero (a common pattern for conditionally-used buffers). They do at least couple `cudaGetDevice` calls during initialization. Normally, the overhead of this is negligible. However, when the number of concurrent threads is large (hundreds of threads), any CUDA call can become a bottleneck due to a single mutex guarding a critical section somewhere in the driver. To workaround this, I introduce a `lightweight_uvector` in `/detail/cagra/search_plan.cuh` for several buffers used in CAGRA kernels. This is a stripped "multi-device-unsafe" version of `rmm::uvector`: it does not check during resize/destruction whether the current device has changed since construction. We may consider putting this in a common folder to use across other RAFT/cuVS algorithms. ### Shared resource queues / ring buffers `resource_queue_t` is an atomic counter-based ring buffer used to distribute the worker resources (CTAs) and pre-allocated job descriptors across CPU I/O threads. We may consider putting this in a common public namespace in raft if we envision more uses for it. ### Persistent runner structs `launcher_t` and `persistent_runner_base_t` look like they could be abstracted from the cagra kernel and re-used in other algos. The code in its current state, however, is not ready for this. ### Adjusted benchmarks 1. I introduced a global temporary buffer for keeping the intermediate results (e.g. neighbor candidates before refinement). This is needed to avoid unnecessary allocations alongside the persistent kernel (but also positively affects performance of the original non-persistent implementation) 2. I adjusted cuvs common benchmark utils to avoid extra d2h copies and syncs during refinement. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuvs/pull/215 --- cpp/bench/ann/src/common/ann_types.hpp | 3 + cpp/bench/ann/src/common/dataset.hpp | 56 +- cpp/bench/ann/src/common/util.hpp | 133 +- .../src/cuvs/cuvs_ann_bench_param_parser.h | 10 + cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h | 55 +- cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h | 88 +- cpp/bench/ann/src/cuvs/cuvs_wrapper.h | 1 + cpp/include/cuvs/neighbors/cagra.hpp | 24 + .../detail/cagra/compute_distance-ext.cuh | 5 +- .../detail/cagra/compute_distance.hpp | 62 +- .../cagra/compute_distance_00_generate.py | 5 +- .../cagra/compute_distance_standard-impl.cuh | 30 +- .../cagra/compute_distance_standard.hpp | 14 +- .../cagra/compute_distance_vpq-impl.cuh | 42 +- .../detail/cagra/compute_distance_vpq.hpp | 9 +- cpp/src/neighbors/detail/cagra/factory.cuh | 4 +- .../detail/cagra/search_multi_cta.cuh | 2 +- .../detail/cagra/search_multi_cta_inst.cuh | 2 +- .../cagra/search_multi_cta_kernel-inl.cuh | 4 +- .../detail/cagra/search_multi_cta_kernel.cuh | 2 +- .../detail/cagra/search_multi_kernel.cuh | 4 +- .../neighbors/detail/cagra/search_plan.cuh | 89 +- .../detail/cagra/search_single_cta.cuh | 8 +- .../detail/cagra/search_single_cta_inst.cuh | 2 +- .../cagra/search_single_cta_kernel-inl.cuh | 1220 ++++++++++++++++- .../detail/cagra/search_single_cta_kernel.cuh | 2 +- examples/cpp/CMakeLists.txt | 5 + examples/cpp/src/cagra_persistent_example.cu | 258 ++++ 28 files changed, 1896 insertions(+), 243 deletions(-) create mode 100644 examples/cpp/src/cagra_persistent_example.cu diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index 4b17885c0..c2f85e539 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -35,6 +35,7 @@ enum class Mode { enum class MemoryType { kHost, kHostMmap, + kHostPinned, kDevice, }; @@ -60,6 +61,8 @@ inline auto parse_memory_type(const std::string& memory_type) -> MemoryType return MemoryType::kHost; } else if (memory_type == "mmap") { return MemoryType::kHostMmap; + } else if (memory_type == "pinned") { + return MemoryType::kHostPinned; } else if (memory_type == "device") { return MemoryType::kDevice; } else { diff --git a/cpp/bench/ann/src/common/dataset.hpp b/cpp/bench/ann/src/common/dataset.hpp index 95f1a82a2..49020fe36 100644 --- a/cpp/bench/ann/src/common/dataset.hpp +++ b/cpp/bench/ann/src/common/dataset.hpp @@ -286,7 +286,28 @@ class dataset { { switch (memory_type) { case MemoryType::kDevice: return query_set_on_gpu(); - default: return query_set(); + case MemoryType::kHost: { + auto r = query_set(); +#ifndef BUILD_CPU_ONLY + if (query_set_pinned_) { + cudaHostUnregister(const_cast(r)); + query_set_pinned_ = false; + } +#endif + return r; + } + case MemoryType::kHostPinned: { + auto r = query_set(); +#ifndef BUILD_CPU_ONLY + if (!query_set_pinned_) { + cudaHostRegister( + const_cast(r), query_set_size() * dim() * sizeof(T), cudaHostRegisterDefault); + query_set_pinned_ = true; + } +#endif + return r; + } + default: return nullptr; } } @@ -294,7 +315,27 @@ class dataset { { switch (memory_type) { case MemoryType::kDevice: return base_set_on_gpu(); - case MemoryType::kHost: return base_set(); + case MemoryType::kHost: { + auto r = base_set(); +#ifndef BUILD_CPU_ONLY + if (base_set_pinned_) { + cudaHostUnregister(const_cast(r)); + base_set_pinned_ = false; + } +#endif + return r; + } + case MemoryType::kHostPinned: { + auto r = base_set(); +#ifndef BUILD_CPU_ONLY + if (!base_set_pinned_) { + cudaHostRegister( + const_cast(r), base_set_size() * dim() * sizeof(T), cudaHostRegisterDefault); + base_set_pinned_ = true; + } +#endif + return r; + } case MemoryType::kHostMmap: return mapped_base_set(); default: return nullptr; } @@ -315,18 +356,23 @@ class dataset { mutable T* d_query_set_ = nullptr; mutable T* mapped_base_set_ = nullptr; mutable int32_t* gt_set_ = nullptr; + + mutable bool base_set_pinned_ = false; + mutable bool query_set_pinned_ = false; }; template dataset::~dataset() { - delete[] base_set_; - delete[] query_set_; - delete[] gt_set_; #ifndef BUILD_CPU_ONLY if (d_base_set_) { cudaFree(d_base_set_); } if (d_query_set_) { cudaFree(d_query_set_); } + if (base_set_pinned_) { cudaHostUnregister(base_set_); } + if (query_set_pinned_) { cudaHostUnregister(query_set_); } #endif + delete[] base_set_; + delete[] query_set_; + delete[] gt_set_; } template diff --git a/cpp/bench/ann/src/common/util.hpp b/cpp/bench/ann/src/common/util.hpp index e01e3847b..c3db2bb4b 100644 --- a/cpp/bench/ann/src/common/util.hpp +++ b/cpp/bench/ann/src/common/util.hpp @@ -198,42 +198,71 @@ inline auto get_stream_from_global_pool() -> cudaStream_t #endif } -struct result_buffer { - explicit result_buffer(size_t size, cudaStream_t stream) : size_{size}, stream_{stream} +/** The workspace buffer for use thread-locally. */ +struct ws_buffer { + explicit ws_buffer(size_t size, cudaStream_t stream) : size_{size}, stream_{stream} {} + ws_buffer() = delete; + ws_buffer(ws_buffer&&) = delete; + auto operator=(ws_buffer&&) -> ws_buffer& = delete; + ws_buffer(const ws_buffer&) = delete; + auto operator=(const ws_buffer&) -> ws_buffer& = delete; + ~ws_buffer() noexcept { - if (size_ == 0) { return; } - data_host_ = malloc(size_); #ifndef BUILD_CPU_ONLY - cudaMallocAsync(&data_device_, size_, stream_); - cudaStreamSynchronize(stream_); -#endif - } - result_buffer() = delete; - result_buffer(result_buffer&&) = delete; - auto operator=(result_buffer&&) -> result_buffer& = delete; - result_buffer(const result_buffer&) = delete; - auto operator=(const result_buffer&) -> result_buffer& = delete; - ~result_buffer() noexcept - { - if (size_ == 0) { return; } -#ifndef BUILD_CPU_ONLY - cudaFreeAsync(data_device_, stream_); - cudaStreamSynchronize(stream_); + if (data_device_ != nullptr) { + cudaFreeAsync(data_device_, stream_); + cudaStreamSynchronize(stream_); + } + if (data_host_ != nullptr) { cudaFreeHost(data_host_); } +#else + if (data_host_ != nullptr) { free(data_host_); } #endif - free(data_host_); } [[nodiscard]] auto size() const noexcept { return size_; } - [[nodiscard]] auto data(MemoryType loc) const noexcept + [[nodiscard]] auto data(MemoryType loc) const noexcept -> void* { + if (size_ == 0) { return nullptr; } switch (loc) { - case MemoryType::kDevice: return data_device_; - default: return data_host_; +#ifndef BUILD_CPU_ONLY + case MemoryType::kDevice: { + if (data_device_ == nullptr) { + cudaMallocAsync(&data_device_, size_, stream_); + cudaStreamSynchronize(stream_); + needs_cleanup_device_ = false; + } else if (needs_cleanup_device_) { + cudaMemsetAsync(data_device_, 0, size_, stream_); + cudaStreamSynchronize(stream_); + needs_cleanup_device_ = false; + } + return data_device_; + } +#endif + default: { + if (data_host_ == nullptr) { +#ifndef BUILD_CPU_ONLY + cudaMallocHost(&data_host_, size_); +#else + data_host_ = malloc(size_); +#endif + needs_cleanup_host_ = false; + } else if (needs_cleanup_host_) { + memset(data_host_, 0, size_); + needs_cleanup_host_ = false; + } + return data_host_; + } } } void transfer_data(MemoryType dst, MemoryType src) { + // The destination is overwritten and thus does not need cleanup + if (dst == MemoryType::kDevice) { + needs_cleanup_device_ = false; + } else { + needs_cleanup_host_ = false; + } auto dst_ptr = data(dst); auto src_ptr = data(src); if (dst_ptr == src_ptr) { return; } @@ -243,15 +272,25 @@ struct result_buffer { #endif } + /** Mark the buffer for reuse - it needs to be cleared to make sure the previous results are not + * leaked to the new iteration. */ + void reuse() + { + needs_cleanup_host_ = true; + needs_cleanup_device_ = true; + } + private: size_t size_{0}; - cudaStream_t stream_ = nullptr; - void* data_host_ = nullptr; - void* data_device_ = nullptr; + cudaStream_t stream_ = nullptr; + mutable void* data_host_ = nullptr; + mutable void* data_device_ = nullptr; + mutable bool needs_cleanup_host_ = false; + mutable bool needs_cleanup_device_ = false; }; namespace detail { -inline std::vector> global_result_buffer_pool(0); +inline std::vector> global_result_buffer_pool(0); inline std::mutex grp_mutex; } // namespace detail @@ -262,24 +301,47 @@ inline std::mutex grp_mutex; * This reduces the setup overhead and number of times the context is being blocked * (this is relevant if there is a persistent kernel running across multiples benchmark cases). */ -inline auto get_result_buffer_from_global_pool(size_t size) -> result_buffer& +inline auto get_result_buffer_from_global_pool(size_t size) -> ws_buffer& { auto stream = get_stream_from_global_pool(); - auto& rb = [stream, size]() -> result_buffer& { + auto& rb = [stream, size]() -> ws_buffer& { std::lock_guard guard(detail::grp_mutex); if (static_cast(detail::global_result_buffer_pool.size()) < benchmark_n_threads) { detail::global_result_buffer_pool.resize(benchmark_n_threads); } auto& rb = detail::global_result_buffer_pool[benchmark_thread_id]; - if (!rb || rb->size() < size) { rb = std::make_unique(size, stream); } + if (!rb || rb->size() < size) { + rb = std::make_unique(size, stream); + } else { + rb->reuse(); + } return *rb; }(); + return rb; +} - memset(rb.data(MemoryType::kHost), 0, size); -#ifndef BUILD_CPU_ONLY - cudaMemsetAsync(rb.data(MemoryType::kDevice), 0, size, stream); - cudaStreamSynchronize(stream); -#endif +namespace detail { +inline std::vector> global_tmp_buffer_pool(0); +inline std::mutex gtp_mutex; +} // namespace detail + +/** + * Global temporary buffer pool for use by algorithms. + * In contrast to `get_result_buffer_from_global_pool`, the content of these buffers is never + * initialized. + */ +inline auto get_tmp_buffer_from_global_pool(size_t size) -> ws_buffer& +{ + auto stream = get_stream_from_global_pool(); + auto& rb = [stream, size]() -> ws_buffer& { + std::lock_guard guard(detail::gtp_mutex); + if (static_cast(detail::global_tmp_buffer_pool.size()) < benchmark_n_threads) { + detail::global_tmp_buffer_pool.resize(benchmark_n_threads); + } + auto& rb = detail::global_tmp_buffer_pool[benchmark_thread_id]; + if (!rb || rb->size() < size) { rb = std::make_unique(size, stream); } + return *rb; + }(); return rb; } @@ -293,6 +355,7 @@ inline void reset_global_device_resources() { #ifndef BUILD_CPU_ONLY std::lock_guard guard(detail::gsp_mutex); + detail::global_tmp_buffer_pool.resize(0); detail::global_result_buffer_pool.resize(0); detail::global_stream_pool.resize(0); #endif diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h index 67f8ed39d..22f0cab6f 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h @@ -247,6 +247,16 @@ void parse_search_param(const nlohmann::json& conf, if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); } + if (conf.contains("persistent")) { param.p.persistent = conf.at("persistent"); } + if (conf.contains("persistent_lifetime")) { + param.p.persistent_lifetime = conf.at("persistent_lifetime"); + } + if (conf.contains("persistent_device_usage")) { + param.p.persistent_device_usage = conf.at("persistent_device_usage"); + } + if (conf.contains("thread_block_size")) { + param.p.thread_block_size = conf.at("thread_block_size"); + } if (conf.contains("algo")) { if (conf.at("algo") == "single_cta") { param.p.algo = cuvs::neighbors::cagra::search_algo::SINGLE_CTA; diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h index b92785943..92274e263 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h @@ -218,27 +218,46 @@ void refine_helper(const raft::resources& res, } else { auto dataset_host = raft::make_host_matrix_view( dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - auto queries_host = raft::make_host_matrix(batch_size, dim); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); + if (raft::get_device_for_address(queries.data_handle()) >= 0) { + // Queries & results are on the device - auto stream = raft::resource::get_cuda_stream(res); - raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); + auto queries_host = raft::make_host_matrix(batch_size, dim); + auto candidates_host = raft::make_host_matrix(batch_size, k0); + auto neighbors_host = raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); - raft::resource::sync_stream(res); // wait for the queries and candidates - cuvs::neighbors::refine(res, - dataset_host, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - metric); + auto stream = raft::resource::get_cuda_stream(res); + raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); + raft::copy( + candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); + + raft::resource::sync_stream(res); // wait for the queries and candidates + cuvs::neighbors::refine(res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + metric); + + raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + + } else { + // Queries & results are on the host - no device sync / copy needed + + auto queries_host = raft::make_host_matrix_view( + queries.data_handle(), batch_size, dim); + auto candidates_host = raft::make_host_matrix_view( + candidates.data_handle(), batch_size, k0); + auto neighbors_host = + raft::make_host_matrix_view(neighbors, batch_size, k); + auto distances_host = + raft::make_host_matrix_view(distances, batch_size, k); - raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + cuvs::neighbors::refine( + res, dataset_host, queries_host, candidates_host, neighbors_host, distances_host, metric); + } } } diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h index 53db717a6..9ca41cab0 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h @@ -107,13 +107,23 @@ class cuvs_cagra : public algo, public algo_gpu { int batch_size, int k, algo_base::index_type* neighbors, - float* distances) const; + float* distances, + IdxT* neighbors_idx_t) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return handle_.get_sync_stream(); } + [[nodiscard]] auto uses_stream() const noexcept -> bool override + { + // If the algorithm uses persistent kernel, the CPU has to synchronize by the end of computing + // the result. Hence it guarantees the benchmark CUDA stream is empty by the end of the + // execution. Hence we inform the benchmark to not waste the time on recording & synchronizing + // the event. + return !search_params_.persistent; + } + // to enable dataset access from GPU memory [[nodiscard]] auto get_preference() const -> algo_property override { @@ -269,7 +279,11 @@ void cuvs_cagra::set_search_dataset(const T* dataset, size_t nrow) template void cuvs_cagra::save(const std::string& file) const { - cuvs::neighbors::cagra::serialize(handle_, file, *index_); + using ds_idx_type = decltype(index_->data().n_rows()); + bool is_vpq = + dynamic_cast*>(&index_->data()) || + dynamic_cast*>(&index_->data()); + cuvs::neighbors::cagra::serialize(handle_, file, *index_, is_vpq); } template @@ -292,19 +306,18 @@ std::unique_ptr> cuvs_cagra::copy() } template -void cuvs_cagra::search_base( - const T* queries, int batch_size, int k, algo_base::index_type* neighbors, float* distances) const +void cuvs_cagra::search_base(const T* queries, + int batch_size, + int k, + algo_base::index_type* neighbors, + float* distances, + IdxT* neighbors_idx_t) const { static_assert(std::is_integral_v); static_assert(std::is_integral_v); - IdxT* neighbors_idx_t; - std::optional> neighbors_storage{std::nullopt}; if constexpr (sizeof(IdxT) == sizeof(algo_base::index_type)) { neighbors_idx_t = reinterpret_cast(neighbors); - } else { - neighbors_storage.emplace(batch_size * k, raft::resource::get_cuda_stream(handle_)); - neighbors_idx_t = neighbors_storage->data(); } auto queries_view = @@ -317,11 +330,23 @@ void cuvs_cagra::search_base( handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); if constexpr (sizeof(IdxT) != sizeof(algo_base::index_type)) { - raft::linalg::unaryOp(neighbors, - neighbors_idx_t, - batch_size * k, - raft::cast_op(), - raft::resource::get_cuda_stream(handle_)); + if (raft::get_device_for_address(neighbors) < 0 && + raft::get_device_for_address(neighbors_idx_t) < 0) { + // Both pointers on the host, let's use host-side mapping + if (uses_stream()) { + // Need to wait for GPU to finish filling source + raft::resource::sync_stream(handle_); + } + for (int i = 0; i < batch_size * k; i++) { + neighbors[i] = algo_base::index_type(neighbors_idx_t[i]); + } + } else { + raft::linalg::unaryOp(neighbors, + neighbors_idx_t, + batch_size * k, + raft::cast_op(), + raft::resource::get_cuda_stream(handle_)); + } } } @@ -329,21 +354,42 @@ template void cuvs_cagra::search( const T* queries, int batch_size, int k, algo_base::index_type* neighbors, float* distances) const { + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + constexpr bool kNeedsIoMapping = sizeof(IdxT) != sizeof(algo_base::index_type); + auto k0 = static_cast(refine_ratio_ * k); const bool disable_refinement = k0 <= static_cast(k); const raft::resources& res = handle_; + auto mem_type = + raft::get_device_for_address(neighbors) >= 0 ? MemoryType::kDevice : MemoryType::kHostPinned; + auto& tmp_buf = get_tmp_buffer_from_global_pool( + ((disable_refinement ? 0 : (sizeof(float) + sizeof(algo_base::index_type))) + + (kNeedsIoMapping ? sizeof(IdxT) : 0)) * + batch_size * k0); + auto* candidates_ptr = reinterpret_cast(tmp_buf.data(mem_type)); + auto* candidate_dists_ptr = + reinterpret_cast(candidates_ptr + (disable_refinement ? 0 : batch_size * k0)); + auto* neighbors_idx_t = + reinterpret_cast(candidate_dists_ptr + (disable_refinement ? 0 : batch_size * k0)); if (disable_refinement) { - search_base(queries, batch_size, k, neighbors, distances); + search_base(queries, batch_size, k, neighbors, distances, neighbors_idx_t); } else { + search_base(queries, batch_size, k0, candidates_ptr, candidate_dists_ptr, neighbors_idx_t); + + if (mem_type == MemoryType::kHostPinned && uses_stream()) { + // If the algorithm uses a stream to synchronize (non-persistent kernel), but the data is in + // the pinned host memory, we need to synchronize before the refinement operation to wait for + // the data being available for the host. + raft::resource::sync_stream(res); + } + + auto candidate_ixs = + raft::make_device_matrix_view( + candidates_ptr, batch_size, k0); auto queries_v = raft::make_device_matrix_view( queries, batch_size, dimension_); - auto candidate_ixs = - raft::make_device_matrix(res, batch_size, k0); - auto candidate_dists = - raft::make_device_matrix(res, batch_size, k0); - search_base( - queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); refine_helper( res, *input_dataset_v_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); } diff --git a/cpp/bench/ann/src/cuvs/cuvs_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_wrapper.h index 0954e6051..ea052533d 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_wrapper.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 5f77eb8a3..fec95b563 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -205,6 +205,30 @@ struct search_params : cuvs::neighbors::search_params { uint32_t num_random_samplings = 1; /** Bit mask used for initial random seed node selection. */ uint64_t rand_xor_mask = 0x128394; + + /** Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.) */ + bool persistent = false; + /** Persistent kernel: time in seconds before the kernel stops if no requests received. */ + float persistent_lifetime = 2; + /** + * Set the fraction of maximum grid size used by persistent kernel. + * Value 1.0 means the kernel grid size is maximum possible for the selected device. + * The value must be greater than 0.0 and not greater than 1.0. + * + * One may need to run other kernels alongside this persistent kernel. This parameter can + * be used to reduce the grid size of the persistent kernel to leave a few SMs idle. + * Note: running any other work on GPU alongside with the persistent kernel makes the setup + * fragile. + * - Running another kernel in another thread usually works, but no progress guaranteed + * - Any CUDA allocations block the context (this issue may be obscured by using pools) + * - Memory copies to not-pinned host memory may block the context + * + * Even when we know there are no other kernels working at the same time, setting + * kDeviceUsage to 1.0 surprisingly sometimes hurts performance. Proceed with care. + * If you suspect this is an issue, you can reduce this number to ~0.9 without a significant + * impact on the throughput. + */ + float persistent_device_usage = 1.0; }; /** diff --git a/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh b/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh index 8407ef055..df447d196 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance-ext.cuh @@ -496,8 +496,7 @@ using descriptor_instances = instance_selector< template auto dataset_descriptor_init(const cagra::search_params& params, const DatasetT& dataset, - cuvs::distance::DistanceType metric, - rmm::cuda_stream_view stream) + cuvs::distance::DistanceType metric) -> dataset_descriptor_host { auto [init, priority] = @@ -505,7 +504,7 @@ auto dataset_descriptor_init(const cagra::search_params& params, if (init == nullptr || priority < 0) { RAFT_FAIL("No dataset descriptor instance compiled for this parameter combination."); } - return init(params, dataset, metric, stream); + return init(params, dataset, metric); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 4bed275ab..297eb1f55 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -34,6 +34,7 @@ #include #include #include +#include namespace cuvs::neighbors::cagra::detail { @@ -222,31 +223,61 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t { * The host struct manages the lifetime of the associated device pointer and a couple parameters * affecting the search kernel launch config. * + * [Note: lazy initialization] + * Initialization of the descriptor involves allocating device memory and calling a kernel. + * This can interfere with other workloads (such as the persistent kernel) and generally adds + * overhead. To mitigate this, we don't call any CUDA api at the construction of the descriptor + * host. Instead, we postpone the initialization till the device pointer is requested. + * */ template struct dataset_descriptor_host { - using dev_descriptor_t = dataset_descriptor_base_t; + using dev_descriptor_t = dataset_descriptor_base_t; + using dd_ptr_t = std::shared_ptr; + using init_f = + std::tuple, size_t>; uint32_t smem_ws_size_in_bytes = 0; uint32_t team_size = 0; - template - dataset_descriptor_host(const DescriptorImpl& dd_host, rmm::cuda_stream_view stream) - : dev_ptr_{[stream]() { - dev_descriptor_t* p; - RAFT_CUDA_TRY(cudaMallocAsync(&p, sizeof(DescriptorImpl), stream)); - return p; - }(), - [stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }}, + template + dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init) + : value_{std::make_tuple(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, team_size{dd_host.team_size()} { } - [[nodiscard]] auto dev_ptr() const -> const dev_descriptor_t* { return dev_ptr_.get(); } - [[nodiscard]] auto dev_ptr() -> dev_descriptor_t* { return dev_ptr_.get(); } + /** + * Return the device pointer, possibly evaluating it in the given thread. + */ + [[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) const -> const dev_descriptor_t* + { + if (std::holds_alternative(value_)) { value_ = eval(std::get(value_), stream); } + return std::get(value_).get(); + } + [[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) -> dev_descriptor_t* + { + if (std::holds_alternative(value_)) { value_ = eval(std::get(value_), stream); } + return std::get(value_).get(); + } private: - std::unique_ptr> dev_ptr_; + mutable std::variant value_; + + static auto eval(init_f init, rmm::cuda_stream_view stream) -> dd_ptr_t + { + using raft::RAFT_NAME; + auto& [fun, size] = init; + dd_ptr_t dev_ptr{ + [stream, s = size]() { + dev_descriptor_t* p; + RAFT_CUDA_TRY(cudaMallocAsync(&p, s, stream)); + return p; + }(), + [stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }}; + fun(dev_ptr.get(), stream); + return dev_ptr; + } }; /** @@ -257,11 +288,8 @@ struct dataset_descriptor_host { * */ template -using init_desc_type = - dataset_descriptor_host (*)(const cagra::search_params&, - const DatasetT&, - cuvs::distance::DistanceType, - rmm::cuda_stream_view); +using init_desc_type = dataset_descriptor_host (*)( + const cagra::search_params&, const DatasetT&, cuvs::distance::DistanceType); /** * @brief Descriptor instance specification. diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py index 52a15e2a1..f8584c62e 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py @@ -135,15 +135,14 @@ template auto dataset_descriptor_init(const cagra::search_params& params, const DatasetT& dataset, - cuvs::distance::DistanceType metric, - rmm::cuda_stream_view stream) + cuvs::distance::DistanceType metric) -> dataset_descriptor_host {{ auto [init, priority] = descriptor_instances::select(params, dataset, metric); if (init == nullptr || priority < 0) {{ RAFT_FAIL("No dataset descriptor instance compiled for this parameter combination."); }} - return init(params, dataset, metric, stream); + return init(params, dataset, metric); }} ''' f.write(template.format(includes=includes, content=contents)) diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh index b0205508a..877d83fff 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh @@ -252,28 +252,24 @@ template dataset_descriptor_host standard_descriptor_spec::init_( - const cagra::search_params& params, - const DataT* ptr, - IndexT size, - uint32_t dim, - uint32_t ld, - rmm::cuda_stream_view stream) + const cagra::search_params& params, const DataT* ptr, IndexT size, uint32_t dim, uint32_t ld) { using desc_type = standard_dataset_descriptor_t; using base_type = typename desc_type::base_type; desc_type dd_host{nullptr, nullptr, ptr, size, dim, ld}; - host_type result{dd_host, stream}; - - standard_dataset_descriptor_init_kernel - <<<1, 1, 0, stream>>>(result.dev_ptr(), ptr, size, dim, desc_type::ld(dd_host.args)); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - return result; + return host_type{dd_host, + [=](dataset_descriptor_base_t* dev_ptr, + rmm::cuda_stream_view stream) { + standard_dataset_descriptor_init_kernel + <<<1, 1, 0, stream>>>(dev_ptr, ptr, size, dim, ld); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }}; } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp index df1b77e86..fec14d713 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard.hpp @@ -45,15 +45,13 @@ struct standard_descriptor_spec : public instance_spec template static auto init(const cagra::search_params& params, const DatasetT& dataset, - cuvs::distance::DistanceType metric, - rmm::cuda_stream_view stream) -> host_type + cuvs::distance::DistanceType metric) -> host_type { return init_(params, dataset.view().data_handle(), IndexT(dataset.n_rows()), dataset.dim(), - dataset.stride(), - stream); + dataset.stride()); } template @@ -69,12 +67,8 @@ struct standard_descriptor_spec : public instance_spec } private: - static dataset_descriptor_host init_(const cagra::search_params& params, - const DataT* ptr, - IndexT size, - uint32_t dim, - uint32_t ld, - rmm::cuda_stream_view stream); + static dataset_descriptor_host init_( + const cagra::search_params& params, const DataT* ptr, IndexT size, uint32_t dim, uint32_t ld); }; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index 86c592502..6caa173f2 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -421,8 +421,7 @@ vpq_descriptor_spec<<<1, 1, 0, stream>>>(result.dev_ptr(), - encoded_dataset_ptr, - encoded_dataset_dim, - vq_code_book_ptr, - pq_code_book_ptr, - size, - dim); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - return result; + return host_type{dd_host, + [=](dataset_descriptor_base_t* dev_ptr, + rmm::cuda_stream_view stream) { + vpq_dataset_descriptor_init_kernel + <<<1, 1, 0, stream>>>(dev_ptr, + encoded_dataset_ptr, + encoded_dataset_dim, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }}; } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 378d2943e..4f7d24f17 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -57,8 +57,7 @@ struct vpq_descriptor_spec : public instance_spec { template static auto init(const cagra::search_params& params, const DatasetT& dataset, - cuvs::distance::DistanceType metric, - rmm::cuda_stream_view stream) -> host_type + cuvs::distance::DistanceType metric) -> host_type { return init_(params, dataset.data.data_handle(), @@ -66,8 +65,7 @@ struct vpq_descriptor_spec : public instance_spec { dataset.vq_code_book.data_handle(), dataset.pq_code_book.data_handle(), IndexT(dataset.n_rows()), - dataset.dim(), - stream); + dataset.dim()); } template @@ -93,8 +91,7 @@ struct vpq_descriptor_spec : public instance_spec { const CodebookT* vq_code_book_ptr, const CodebookT* pq_code_book_ptr, IndexT size, - uint32_t dim, - rmm::cuda_stream_view stream); + uint32_t dim); }; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 1c99f72f7..2f201de3b 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -168,8 +168,8 @@ auto dataset_descriptor_init_with_cache(const raft::resources& res, ->value; std::shared_ptr desc{nullptr}; if (!cache.get(key, &desc)) { - desc = std::make_shared(std::move(dataset_descriptor_init( - params, dataset, metric, raft::resource::get_cuda_stream(res)))); + desc = std::make_shared( + std::move(dataset_descriptor_init(params, dataset, metric))); cache.set(key, desc); } return *desc; diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 9bcccd9f9..0003f2495 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -209,7 +209,7 @@ struct search : public search_plan_impl( \ - const dataset_descriptor_base_t* dataset_desc, \ + const dataset_descriptor_host& dataset_desc, \ raft::device_matrix_view graph, \ IndexT* topk_indices_ptr, \ DistanceT* topk_distances_ptr, \ diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index dd74ba44b..4dfc46256 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -413,7 +413,7 @@ struct search_kernel_config { }; template -void select_and_run(const dataset_descriptor_base_t* dataset_desc, +void select_and_run(const dataset_descriptor_host& dataset_desc, raft::device_matrix_view graph, IndexT* topk_indices_ptr, // [num_queries, topk] DistanceT* topk_distances_ptr, // [num_queries, topk] @@ -455,7 +455,7 @@ void select_and_run(const dataset_descriptor_base_t* d kernel<<>>(topk_indices_ptr, topk_distances_ptr, - dataset_desc, + dataset_desc.dev_ptr(stream), queries_ptr, graph.data_handle(), graph.extent(1), diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh index 1ef35f947..1a1dcd579 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh @@ -22,7 +22,7 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search { template -void select_and_run(const dataset_descriptor_base_t* dataset_desc, +void select_and_run(const dataset_descriptor_host& dataset_desc, raft::device_matrix_view graph, IndexT* topk_indices_ptr, // [num_queries, topk] DistanceT* topk_distances_ptr, // [num_queries, topk] diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index 7b3ecabf3..0daae17b3 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -175,7 +175,7 @@ void random_pickup(const dataset_descriptor_host& data num_queries); random_pickup_kernel<<>>( - dataset_desc.dev_ptr(), + dataset_desc.dev_ptr(cuda_stream), queries_ptr, num_pickup, num_distilation, @@ -410,7 +410,7 @@ void compute_distance_to_child_nodes( parent_distance_ptr, lds, search_width, - dataset_desc.dev_ptr(), + dataset_desc.dev_ptr(cuda_stream), neighbor_graph_ptr, graph_degree, query_ptr, diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 16864ed19..6ecbbc2e8 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -32,8 +32,81 @@ #include #include +#include +#include +#include + namespace cuvs::neighbors::cagra::detail { +/** + * A lightweight version of rmm::device_uvector. + * This version avoids calling cudaSetDevice / cudaGetDevice, and therefore it is required that + * the current cuda device does not change during the lifetime of this object. This is expected + * to be useful in multi-threaded scenarios where we want to minimize overhead due to + * thread sincronization during cuda API calls. + * If the size stays at zero, this struct never calls any CUDA driver / RAFT resource functions. + */ +template +struct lightweight_uvector { + private: + using raft_res_type = const raft::resources*; + using rmm_res_type = std::tuple; + static constexpr size_t kAlign = 256; + + std::variant res_; + T* ptr_; + size_t size_; + + public: + explicit lightweight_uvector(const raft::resources& res) : res_(&res), ptr_{nullptr}, size_{0} {} + + [[nodiscard]] auto data() noexcept -> T* { return ptr_; } + [[nodiscard]] auto data() const noexcept -> const T* { return ptr_; } + [[nodiscard]] auto size() const noexcept -> size_t { return size_; } + + void resize(size_t new_size) + { + if (new_size == size_) { return; } + if (std::holds_alternative(res_)) { + auto& h = std::get(res_); + res_ = rmm_res_type{raft::resource::get_workspace_resource(*h), + raft::resource::get_cuda_stream(*h)}; + } + auto& [r, s] = std::get(res_); + T* new_ptr = nullptr; + if (new_size > 0) { + new_ptr = reinterpret_cast(r.allocate_async(new_size * sizeof(T), kAlign, s)); + } + auto copy_size = std::min(size_, new_size); + if (copy_size > 0) { + cudaMemcpyAsync(new_ptr, ptr_, copy_size * sizeof(T), cudaMemcpyDefault, s); + } + if (size_ > 0) { r.deallocate_async(ptr_, size_ * sizeof(T), kAlign, s); } + ptr_ = new_ptr; + size_ = new_size; + } + + void resize(size_t new_size, rmm::cuda_stream_view stream) + { + if (new_size == size_) { return; } + if (std::holds_alternative(res_)) { + auto& h = std::get(res_); + res_ = rmm_res_type{raft::resource::get_workspace_resource(*h), stream}; + } else { + std::get(std::get(res_)) = stream; + } + resize(new_size); + } + + ~lightweight_uvector() noexcept + { + if (size_ > 0) { + auto& [r, s] = std::get(res_); + r.deallocate_async(ptr_, size_ * sizeof(T), kAlign, s); + } + } +}; + struct search_plan_impl_base : public search_params { int64_t dim; int64_t graph_degree; @@ -75,9 +148,9 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t topk; uint32_t num_seeds; - rmm::device_uvector hashmap; - rmm::device_uvector num_executed_iterations; // device or managed? - rmm::device_uvector dev_seed; + lightweight_uvector hashmap; + lightweight_uvector num_executed_iterations; // device or managed? + lightweight_uvector dev_seed; const dataset_descriptor_host& dataset_desc; search_plan_impl(raft::resources const& res, @@ -87,16 +160,18 @@ struct search_plan_impl : public search_plan_impl_base { int64_t graph_degree, uint32_t topk) : search_plan_impl_base(params, dim, graph_degree, topk), - hashmap(0, raft::resource::get_cuda_stream(res)), - num_executed_iterations(0, raft::resource::get_cuda_stream(res)), - dev_seed(0, raft::resource::get_cuda_stream(res)), + hashmap(res), + num_executed_iterations(res), + dev_seed(res), num_seeds(0), dataset_desc(dataset_desc) { adjust_search_params(); check_params(); calc_hashmap_params(res); - num_executed_iterations.resize(max_queries, raft::resource::get_cuda_stream(res)); + if (!persistent) { // Persistent kernel does not provide this functionality + num_executed_iterations.resize(max_queries, raft::resource::get_cuda_stream(res)); + } RAFT_LOG_DEBUG("# algo = %d", static_cast(algo)); } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 4abed6760..2bed19009 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -37,8 +37,6 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp -#include - #include #include #include @@ -199,8 +197,8 @@ struct search : search_plan_impl { } RAFT_LOG_DEBUG("# smem_size: %u", smem_size); hashmap_size = 0; - if (small_hash_bitlen == 0) { - hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); + if (small_hash_bitlen == 0 && !this->persistent) { + hashmap_size = max_queries * hashmap::get_size(hash_bitlen); hashmap.resize(hashmap_size, raft::resource::get_cuda_stream(res)); } RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); @@ -218,7 +216,7 @@ struct search : search_plan_impl { SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = raft::resource::get_cuda_stream(res); - select_and_run(dataset_desc.dev_ptr(), + select_and_run(dataset_desc, graph, result_indices_ptr, result_distances_ptr, diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh index 26ca7b672..f734b0582 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh @@ -23,7 +23,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search { #define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ template void select_and_run( \ - const dataset_descriptor_base_t* dataset_desc, \ + const dataset_descriptor_host& dataset_desc, \ raft::device_matrix_view graph, \ IndexT* topk_indices_ptr, \ DistanceT* topk_distances_ptr, \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index d10313c5b..21a0f6bb2 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -39,21 +39,32 @@ #include "../ann_utils.cuh" #include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp +#include +#include #include +#include +#include + +#include +#include #include +#include #include -#include +#include #include #include +#include #include #include +#include +#include #include namespace cuvs::neighbors::cagra::detail { namespace single_cta_search { +using raft::RAFT_NAME; // TODO: this is required for RAFT_LOG_XXX messages. // #define _CLK_BREAKDOWN @@ -463,7 +474,7 @@ template -RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( +__device__ void search_core( typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, top_k] typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] const std::uint32_t top_k, @@ -485,6 +496,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( const std::uint32_t hash_bitlen, const std::uint32_t small_hash_bitlen, const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id, SAMPLE_FILTER_T sample_filter) { using LOAD_T = device::LOAD_128BIT_T; @@ -493,8 +505,6 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - const auto query_id = blockIdx.y; - #ifdef _CLK_BREAKDOWN std::uint64_t clk_init = 0; std::uint64_t clk_compute_1st_distance = 0; @@ -552,7 +562,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( if (small_hash_bitlen) { local_visited_hashmap_ptr = visited_hash_buffer; } else { - local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * blockIdx.y); } hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); __syncthreads(); @@ -796,37 +806,292 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( #endif } -template +template +RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( + typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, top_k] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] + const std::uint32_t top_k, + const DATASET_DESCRIPTOR_T* dataset_desc, + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + SAMPLE_FILTER_T sample_filter) +{ + const auto query_id = blockIdx.y; + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + dataset_desc, + queries_ptr, + knn_graph, + graph_degree, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + sample_filter); +} + +// To make sure we avoid false sharing on both CPU and GPU, we enforce cache line size to the +// maximum of the two. +// This makes sync atomic significantly faster. +constexpr size_t kCacheLineBytes = 64; + +constexpr uint32_t kMaxJobsNum = 8192; +constexpr uint32_t kMaxWorkersNum = 4096; +constexpr uint32_t kMaxWorkersPerThread = 256; +constexpr uint32_t kSoftMaxWorkersPerThread = 16; + +template +struct alignas(kCacheLineBytes) job_desc_t { + using index_type = typename DATASET_DESCRIPTOR_T::INDEX_T; + using distance_type = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using data_type = typename DATASET_DESCRIPTOR_T::DATA_T; + // The algorithm input parameters + struct value_t { + index_type* result_indices_ptr; // [num_queries, top_k] + distance_type* result_distances_ptr; // [num_queries, top_k] + const data_type* queries_ptr; // [num_queries, dataset_dim] + uint32_t top_k; + uint32_t n_queries; + }; + using blob_elem_type = uint4; + constexpr static inline size_t kBlobSize = + raft::div_rounding_up_safe(sizeof(value_t), sizeof(blob_elem_type)); + // Union facilitates loading the input by a warp in a single request + union input_t { + blob_elem_type blob[kBlobSize]; // NOLINT + value_t value; + } input; + // Last thread triggers this flag. + cuda::atomic completion_flag; +}; + +struct alignas(kCacheLineBytes) worker_handle_t { + using handle_t = uint64_t; + struct value_t { + uint32_t desc_id; + uint32_t query_id; + }; + union data_t { + handle_t handle; + value_t value; + }; + cuda::atomic data; +}; +static_assert(sizeof(worker_handle_t::value_t) == sizeof(worker_handle_t::handle_t)); +static_assert( + cuda::atomic::is_always_lock_free); + +constexpr worker_handle_t::handle_t kWaitForWork = std::numeric_limits::max(); +constexpr worker_handle_t::handle_t kNoMoreWork = kWaitForWork - 1; + +constexpr auto is_worker_busy(worker_handle_t::handle_t h) -> bool +{ + return (h != kWaitForWork) && (h != kNoMoreWork); +} + +template +RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p( + const DATASET_DESCRIPTOR_T* dataset_desc, + worker_handle_t* worker_handles, + job_desc_t* job_descriptors, + uint32_t* completion_counters, + const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + SAMPLE_FILTER_T sample_filter) +{ + using job_desc_type = job_desc_t; + __shared__ typename job_desc_type::input_t job_descriptor; + __shared__ worker_handle_t::data_t worker_data; + + auto& worker_handle = worker_handles[blockIdx.y].data; + uint32_t job_ix; + + while (true) { + // wait the writing phase + if (threadIdx.x == 0) { + worker_handle_t::data_t worker_data_local; + do { + worker_data_local = worker_handle.load(cuda::memory_order_relaxed); + } while (worker_data_local.handle == kWaitForWork); + if (worker_data_local.handle != kNoMoreWork) { + worker_handle.store({kWaitForWork}, cuda::memory_order_relaxed); + } + job_ix = worker_data_local.value.desc_id; + cuda::atomic_thread_fence(cuda::memory_order_acquire, cuda::thread_scope_system); + worker_data = worker_data_local; + } + if (threadIdx.x < raft::WarpSize) { + // Sync one warp and copy descriptor data + static_assert(job_desc_type::kBlobSize <= raft::WarpSize); + job_ix = raft::shfl(job_ix, 0); + if (threadIdx.x < job_desc_type::kBlobSize && job_ix < kMaxJobsNum) { + job_descriptor.blob[threadIdx.x] = job_descriptors[job_ix].input.blob[threadIdx.x]; + } + } + __syncthreads(); + if (worker_data.handle == kNoMoreWork) { break; } + + // reading phase + auto* result_indices_ptr = job_descriptor.value.result_indices_ptr; + auto* result_distances_ptr = job_descriptor.value.result_distances_ptr; + auto* queries_ptr = job_descriptor.value.queries_ptr; + auto top_k = job_descriptor.value.top_k; + auto n_queries = job_descriptor.value.n_queries; + auto query_id = worker_data.value.query_id; + + // work phase + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + dataset_desc, + queries_ptr, + knn_graph, + graph_degree, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + sample_filter); + + // make sure all writes are visible even for the host + // (e.g. when result buffers are in pinned memory) + cuda::atomic_thread_fence(cuda::memory_order_release, cuda::thread_scope_system); + + // arrive to mark the end of the work phase + __syncthreads(); + if (threadIdx.x == 0) { + auto completed_count = atomicInc(completion_counters + job_ix, n_queries - 1) + 1; + if (completed_count >= n_queries) { + job_descriptors[job_ix].completion_flag.store(true, cuda::memory_order_relaxed); + } + } + } +} + +template +auto dispatch_kernel = []() { + if constexpr (Persistent) { + return search_kernel_p; + } else { + return search_kernel; + } +}(); + +template struct search_kernel_config { - using kernel_t = decltype(&search_kernel<64, 64, 0, DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>); + using kernel_t = + decltype(dispatch_kernel); template static auto choose_search_kernel(unsigned itopk_size) -> kernel_t { if (itopk_size <= 64) { - return search_kernel<64, - MAX_CANDIDATES, - USE_BITONIC_SORT, - DATASET_DESCRIPTOR_T, - SAMPLE_FILTER_T>; + return dispatch_kernel; } else if (itopk_size <= 128) { - return search_kernel<128, - MAX_CANDIDATES, - USE_BITONIC_SORT, - DATASET_DESCRIPTOR_T, - SAMPLE_FILTER_T>; + return dispatch_kernel; } else if (itopk_size <= 256) { - return search_kernel<256, - MAX_CANDIDATES, - USE_BITONIC_SORT, - DATASET_DESCRIPTOR_T, - SAMPLE_FILTER_T>; + return dispatch_kernel; } else if (itopk_size <= 512) { - return search_kernel<512, - MAX_CANDIDATES, - USE_BITONIC_SORT, - DATASET_DESCRIPTOR_T, - SAMPLE_FILTER_T>; + return dispatch_kernel; } THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); } @@ -846,9 +1111,19 @@ struct search_kernel_config { // Radix-based topk is used constexpr unsigned max_candidates = 32; // to avoid build failure if (itopk_size <= 256) { - return search_kernel<256, max_candidates, 0, DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>; + return dispatch_kernel; } else if (itopk_size <= 512) { - return search_kernel<512, max_candidates, 0, DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>; + return dispatch_kernel; } } THROW("No kernel for parametels itopk_size %u, num_itopk_candidates %u", @@ -857,8 +1132,797 @@ struct search_kernel_config { } }; +/** + * @brief Resource queue + * + * @tparam T the element type + * @tparam Size the maximum capacity of the queue (power-of-two) + * @tparam Empty a special element value designating an empty queue slot. NB: storing `Empty` is UB. + * + * A shared atomic ring buffer based queue optimized for throughput when bottlenecked on `pop` + * operation. + * + * @code{.cpp} + * // allocate the queue + * resource_queue_t resource_ids; + * + * // store couple values + * resource_ids.push(42); + * resource_ids.push(7); + * + * // wait to get the value from the queue + * auto id_x = resource_ids.pop().wait(); + * + * // stand in line to get the value from the queue, but don't wait + * auto ticket_y = resource_ids.pop(); + * // do other stuff and check if the value is available + * int32_t id_y; + * while (!ticket_y.test(id_y)) { + * do_some_important_business(...); + * std::this_thread::sleep_for(std::chrono::microseconds(10); + * } + * // `id_y` is set by now and `ticket_y.wait()` won't block anymore + * assert(ticket_y.wait() == id_y); + * @endcode + */ +template ::max()> +struct alignas(kCacheLineBytes) resource_queue_t { + using value_type = T; + static constexpr uint32_t kSize = Size; + static constexpr value_type kEmpty = Empty; + static_assert(cuda::std::atomic::is_always_lock_free, + "The value type must be lock-free."); + static_assert(raft::is_a_power_of_two(kSize), "The size must be a power-of-two for efficiency."); + static constexpr uint32_t kElemsPerCacheLine = + raft::div_rounding_up_safe(kCacheLineBytes, sizeof(value_type)); + /* [Note: cache-friendly indexing] + To avoid false sharing, the queue pushes and pops values not sequentially, but with an + increment that is larger than the cache line size. + Hence we introduce the `kCounterIncrement > kCacheLineBytes`. + However, to make sure all indices are used, we choose the increment to be coprime with the + buffer size. We also require that the buffer size is a power-of-two for two reasons: + 1) Fast modulus operation - reduces to binary `and` (with `kCounterLocMask`). + 2) Easy to ensure GCD(kCounterIncrement, kSize) == 1 by construction + (see the definition below). + */ + static constexpr uint32_t kCounterIncrement = raft::bound_by_power_of_two(kElemsPerCacheLine) + 1; + static constexpr uint32_t kCounterLocMask = kSize - 1; + // These props hold by design, but we add them here as a documentation and a sanity check. + static_assert( + kCounterIncrement * sizeof(value_type) >= kCacheLineBytes, + "The counter increment should be larger than the cache line size to avoid false sharing."); + static_assert( + std::gcd(kCounterIncrement, kSize) == 1, + "The counter increment and the size must be coprime to allow using all of the queue slots."); + + static constexpr auto kMemOrder = cuda::std::memory_order_relaxed; + + explicit resource_queue_t(uint32_t capacity = Size) noexcept : capacity_{capacity} + { + head_.store(0, kMemOrder); + tail_.store(0, kMemOrder); + for (uint32_t i = 0; i < kSize; i++) { + buf_[i].store(kEmpty, kMemOrder); + } + } + + /** Nominal capacity of the queue. */ + [[nodiscard]] auto capacity() const { return capacity_; } + + /** This does not affect the queue behavior, but merely declares a nominal capacity. */ + void set_capacity(uint32_t capacity) { capacity_ = capacity; } + + /** + * A slot in the queue to take the value from. + * Once it's obtained, the corresponding value in the queue is lost for other users. + */ + struct promise_t { + explicit promise_t(cuda::std::atomic& loc) : loc_{loc}, val_{Empty} {} + ~promise_t() noexcept { wait(); } + + auto test() noexcept -> bool + { + if (val_ != Empty) { return true; } + val_ = loc_.exchange(kEmpty, kMemOrder); + return val_ != Empty; + } + + auto test(value_type& e) noexcept -> bool + { + if (test()) { + e = val_; + return true; + } + return false; + } + + auto wait() noexcept -> value_type + { + if (val_ == Empty) { + // [HOT SPOT] + // Optimize for the case of contention: expect the loc is empty. + do { + loc_.wait(kEmpty, kMemOrder); + val_ = loc_.exchange(kEmpty, kMemOrder); + } while (val_ == kEmpty); + } + return val_; + } + + private: + cuda::std::atomic& loc_; + value_type val_; + }; + + void push(value_type x) noexcept + { + auto& loc = buf_[head_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; + /* [NOT A HOT SPOT] + We expect there's always enough place in the queue to push the item, + but also we expect a few pop waiters - notify them the data is available. + */ + value_type e = kEmpty; + while (!loc.compare_exchange_weak(e, x, kMemOrder, kMemOrder)) { + e = kEmpty; + } + loc.notify_one(); + } + + auto pop() noexcept -> promise_t + { + auto& loc = buf_[tail_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; + return promise_t{loc}; + } + + private: + alignas(kCacheLineBytes) cuda::std::atomic head_{}; + alignas(kCacheLineBytes) cuda::std::atomic tail_{}; + alignas(kCacheLineBytes) std::array, kSize> buf_{}; + alignas(kCacheLineBytes) uint32_t capacity_; +}; + +/** Primitive fixed-size deque for single-threaded use. */ +template +struct local_deque_t { + explicit local_deque_t(uint32_t size) : store_(size) {} + + [[nodiscard]] auto capacity() const -> uint32_t { return store_.size(); } + [[nodiscard]] auto size() const -> uint32_t { return end_ - start_; } + + void push_back(T x) { store_[end_++ % capacity()] = x; } + + void push_front(T x) + { + if (start_ == 0) { + start_ += capacity(); + end_ += capacity(); + } + store_[--start_ % capacity()] = x; + } + + // NB: unsafe functions - do not check if the queue is full/empty. + auto pop_back() -> T { return store_[--end_ % capacity()]; } + auto pop_front() -> T { return store_[start_++ % capacity()]; } + + auto try_push_back(T x) -> bool + { + if (size() >= capacity()) { return false; } + push_back(x); + return true; + } + + auto try_push_front(T x) -> bool + { + if (size() >= capacity()) { return false; } + push_front(x); + return true; + } + + auto try_pop_back(T& x) -> bool + { + if (start_ >= end_) { return false; } + x = pop_back(); + return true; + } + + auto try_pop_front(T& x) -> bool + { + if (start_ >= end_) { return false; } + x = pop_front(); + return true; + } + + private: + std::vector store_; + uint32_t start_{0}; + uint32_t end_{0}; +}; + +struct persistent_runner_base_t { + using job_queue_type = resource_queue_t; + using worker_queue_type = resource_queue_t; + rmm::mr::pinned_host_memory_resource worker_handles_mr; + rmm::mr::pinned_host_memory_resource job_descriptor_mr; + rmm::mr::cuda_memory_resource device_mr; + cudaStream_t stream{}; + job_queue_type job_queue{}; + worker_queue_type worker_queue{}; + // This should be large enough to make the runner live through restarts of the benchmark cases. + // Otherwise, the benchmarks slowdown significantly. + std::chrono::milliseconds lifetime; + + persistent_runner_base_t(float persistent_lifetime) + : lifetime(size_t(persistent_lifetime * 1000)), job_queue(), worker_queue() + { + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } + virtual ~persistent_runner_base_t() noexcept { cudaStreamDestroy(stream); }; +}; + +struct alignas(kCacheLineBytes) launcher_t { + using job_queue_type = persistent_runner_base_t::job_queue_type; + using worker_queue_type = persistent_runner_base_t::worker_queue_type; + using pending_reads_queue_type = local_deque_t; + using completion_flag_type = cuda::atomic; + + pending_reads_queue_type pending_reads; + job_queue_type& job_ids; + worker_queue_type& idle_worker_ids; + worker_handle_t* worker_handles; + uint32_t job_id; + completion_flag_type* completion_flag; + bool all_done = false; + + /* [Note: sleeping] + When the number of threads is greater than the number of cores, the threads start to fight for + the CPU time, which reduces the throughput. + To ease the competition, we track the expected GPU latency and let a thread sleep for some + time, and only start to spin when it's about a time to get the result. + */ + static inline constexpr auto kDefaultLatency = std::chrono::nanoseconds(50000); + /* This is the base for computing maximum time a thread is allowed to sleep. */ + static inline constexpr auto kMaxExpectedLatency = + kDefaultLatency * std::max(10, kMaxJobsNum / 128); + static inline thread_local auto expected_latency = kDefaultLatency; + const std::chrono::time_point start; + std::chrono::time_point now; + const int64_t pause_factor; + int pause_count = 0; + /** + * Beyond this threshold, the launcher (calling thread) does not wait for the results anymore and + * throws an exception. + */ + std::chrono::time_point deadline; + + template + launcher_t(job_queue_type& job_ids, + worker_queue_type& idle_worker_ids, + worker_handle_t* worker_handles, + uint32_t n_queries, + std::chrono::milliseconds max_wait_time, + RecordWork record_work) + : pending_reads{std::min(n_queries, kMaxWorkersPerThread)}, + job_ids{job_ids}, + idle_worker_ids{idle_worker_ids}, + worker_handles{worker_handles}, + job_id{job_ids.pop().wait()}, + completion_flag{record_work(job_id)}, + start{std::chrono::system_clock::now()}, + pause_factor{calc_pause_factor(n_queries)}, + now{start}, + deadline{start + max_wait_time + expected_latency} + { + // Wait for the first worker and submit the query immediately. + submit_query(idle_worker_ids.pop().wait(), 0); + // Submit the rest of the queries in the batch + for (uint32_t i = 1; i < n_queries; i++) { + auto promised_worker = idle_worker_ids.pop(); + uint32_t worker_id; + while (!promised_worker.test(worker_id)) { + if (pending_reads.try_pop_front(worker_id)) { + bool returned_some = false; + for (bool keep_returning = true; keep_returning;) { + if (try_return_worker(worker_id)) { + keep_returning = pending_reads.try_pop_front(worker_id); + returned_some = true; + } else { + pending_reads.push_front(worker_id); + keep_returning = false; + } + } + if (!returned_some) { pause(); } + } else { + // Calmly wait for the promised worker instead of spinning. + worker_id = promised_worker.wait(); + break; + } + } + pause_count = 0; // reset the pause behavior + submit_query(worker_id, i); + // Try to not hold too many workers in one thread + if (i >= kSoftMaxWorkersPerThread && pending_reads.try_pop_front(worker_id)) { + if (!try_return_worker(worker_id)) { pending_reads.push_front(worker_id); } + } + } + } + + inline ~launcher_t() noexcept // NOLINT + { + // bookkeeping: update the expected latency to wait more efficiently later + constexpr size_t kWindow = 100; // moving average memory + expected_latency = std::min( + ((kWindow - 1) * expected_latency + now - start) / kWindow, kMaxExpectedLatency); + + // Try to gracefully cleanup the queue resources if the launcher is being destructed after an + // exception. + if (job_id != job_queue_type::kEmpty) { job_ids.push(job_id); } + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + idle_worker_ids.push(worker_id); + } + } + + inline void submit_query(uint32_t worker_id, uint32_t query_id) + { + worker_handles[worker_id].data.store(worker_handle_t::data_t{.value = {job_id, query_id}}, + cuda::memory_order_relaxed); + + while (!pending_reads.try_push_back(worker_id)) { + // The only reason pending_reads cannot push is that the queue is full. + // It's local, so we must pop and wait for the returned worker to finish its work. + auto pending_worker_id = pending_reads.pop_front(); + while (!try_return_worker(pending_worker_id)) { + pause(); + } + } + pause_count = 0; // reset the pause behavior + } + + /** Check if the worker has finished the work; if so, return it to the shared pool. */ + inline auto try_return_worker(uint32_t worker_id) -> bool + { + // Use the cached `all_done` - makes sense when called from the `wait()` routine. + if (all_done || + !is_worker_busy(worker_handles[worker_id].data.load(cuda::memory_order_relaxed).handle)) { + idle_worker_ids.push(worker_id); + return true; + } else { + return false; + } + } + + /** Check if all workers finished their work. */ + inline auto is_all_done() + { + // Cache the result of the check to avoid doing unnecessary atomic loads. + if (all_done) { return true; } + all_done = completion_flag->load(cuda::memory_order_relaxed); + return all_done; + } + + /** The launcher shouldn't attempt to wait past the returned time. */ + [[nodiscard]] inline auto sleep_limit() const + { + constexpr auto kMinWakeTime = std::chrono::nanoseconds(10000); + constexpr double kSleepLimit = 0.6; + return start + expected_latency * kSleepLimit - kMinWakeTime; + } + + /** + * When the latency is much larger than expected, it's a sign that there is a thread contention. + * Then we switch to sleeping instead of waiting to give the cpu cycles to other threads. + */ + [[nodiscard]] inline auto overtime_threshold() const + { + constexpr auto kOvertimeFactor = 3; + return start + expected_latency * kOvertimeFactor; + } + + /** + * Calculate the fraction of time can be spent sleeping in a single call to `pause()`. + * Naturally it depends on the number of queries in a batch and the number of parallel workers. + */ + [[nodiscard]] inline auto calc_pause_factor(uint32_t n_queries) const -> uint32_t + { + constexpr uint32_t kMultiplier = 10; + return kMultiplier * raft::div_rounding_up_safe(n_queries, idle_worker_ids.capacity()); + } + + /** Wait a little bit (called in a loop). */ + inline void pause() + { + // Don't sleep this many times hoping for smoother run + constexpr auto kSpinLimit = 3; + // It doesn't make much sense to slee less than this + constexpr auto kPauseTimeMin = std::chrono::nanoseconds(1000); + // Bound sleeping time + constexpr auto kPauseTimeMax = std::chrono::nanoseconds(50000); + if (pause_count++ < kSpinLimit) { + std::this_thread::yield(); + return; + } + now = std::chrono::system_clock::now(); + auto pause_time_base = std::max(now - start, expected_latency); + auto pause_time = std::clamp(pause_time_base / pause_factor, kPauseTimeMin, kPauseTimeMax); + if (now + pause_time < sleep_limit()) { + // It's too early: sleep for a bit + std::this_thread::sleep_for(pause_time); + } else if (now <= overtime_threshold()) { + // It's about time to check the results, don't sleep + std::this_thread::yield(); + } else if (now <= deadline) { + // Too late; perhaps the system is too busy - sleep again + std::this_thread::sleep_for(pause_time); + } else { + // Missed the deadline: throw an exception + throw raft::exception( + "The calling thread didn't receive the results from the persistent CAGRA kernel within the " + "expected kernel lifetime. Here are possible reasons of this failure:\n" + " (1) `persistent_lifetime` search parameter is too small - increase it;\n" + " (2) there is other work being executed on the same device and the kernel failed to " + "progress - decreasing `persistent_device_usage` may help (but not guaranteed);\n" + " (3) there is a bug in the implementation - please report it to cuVS team."); + } + } + + /** Wait for all work to finish and don't forget to return the workers to the shared pool. */ + inline void wait() + { + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + while (!try_return_worker(worker_id)) { + if (!is_all_done()) { pause(); } + } + } + pause_count = 0; // reset the pause behavior + // terminal state, should be engaged only after the `pending_reads` is empty + // and `queries_submitted == n_queries` + now = std::chrono::system_clock::now(); + while (!is_all_done()) { + auto till_time = sleep_limit(); + if (now < till_time) { + std::this_thread::sleep_until(till_time); + now = std::chrono::system_clock::now(); + } else { + pause(); + } + } + + // Return the job descriptor + job_ids.push(job_id); + job_id = job_queue_type::kEmpty; + } +}; + +template +struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_base_t { + using descriptor_base_type = dataset_descriptor_base_t; + using index_type = IndexT; + using distance_type = DistanceT; + using data_type = DataT; + using kernel_config_type = search_kernel_config; + using kernel_type = typename kernel_config_type::kernel_t; + using job_desc_type = job_desc_t; + kernel_type kernel; + uint32_t block_size; + dataset_descriptor_host dd_host; + rmm::device_uvector worker_handles; + rmm::device_uvector job_descriptors; + rmm::device_uvector completion_counters; + rmm::device_uvector hashmap; + std::atomic> last_touch; + uint64_t param_hash; + + /** + * Calculate the hash of the parameters to detect if they've changed across the calls. + * NB: this must have the same argument types as the constructor. + */ + static inline auto calculate_parameter_hash( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage) -> uint64_t + { + return uint64_t(graph.data_handle()) ^ dataset_desc.get().team_size ^ num_itopk_candidates ^ + block_size ^ smem_size ^ hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ + rand_xor_mask ^ num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ + uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000); + } + + persistent_runner_t( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage) + : persistent_runner_base_t{persistent_lifetime}, + kernel{kernel_config_type::choose_itopk_and_mx_candidates( + itopk_size, num_itopk_candidates, block_size)}, + block_size{block_size}, + worker_handles(0, stream, worker_handles_mr), + job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), + completion_counters(kMaxJobsNum, stream, device_mr), + hashmap(0, stream, device_mr), + dd_host{dataset_desc.get()}, + param_hash(calculate_parameter_hash(dd_host, + graph, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + num_random_samplings, + rand_xor_mask, + num_seeds, + itopk_size, + search_width, + min_iterations, + max_iterations, + sample_filter, + persistent_lifetime, + persistent_device_usage)) + { + // initialize the dataset/distance descriptor + auto* dd_dev_ptr = dd_host.dev_ptr(stream); + + // set kernel attributes same as in normal kernel + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // set kernel launch parameters + dim3 gs = calc_coop_grid_size(block_size, smem_size, persistent_device_usage); + dim3 bs(block_size, 1, 1); + RAFT_LOG_DEBUG( + "Launching persistent kernel with %u threads, %u block %u smem", bs.x, gs.y, smem_size); + + // initialize the job queue + auto* completion_counters_ptr = completion_counters.data(); + auto* job_descriptors_ptr = job_descriptors.data(); + for (uint32_t i = 0; i < kMaxJobsNum; i++) { + auto& jd = job_descriptors_ptr[i].input.value; + jd.result_indices_ptr = nullptr; + jd.result_distances_ptr = nullptr; + jd.queries_ptr = nullptr; + jd.top_k = 0; + jd.n_queries = 0; + job_descriptors_ptr[i].completion_flag.store(false); + job_queue.push(i); + } + + // initialize the worker queue + worker_queue.set_capacity(gs.y); + worker_handles.resize(gs.y, stream); + auto* worker_handles_ptr = worker_handles.data(); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (uint32_t i = 0; i < gs.y; i++) { + worker_handles_ptr[i].data.store({kWaitForWork}); + worker_queue.push(i); + } + + index_type* hashmap_ptr = nullptr; + if (small_hash_bitlen == 0) { + hashmap.resize(gs.y * hashmap::get_size(hash_bitlen), stream); + hashmap_ptr = hashmap.data(); + } + + // launch the kernel + auto* graph_ptr = graph.data_handle(); + uint32_t graph_degree = graph.extent(1); + uint32_t* num_executed_iterations = nullptr; // optional arg [num_queries] + const index_type* dev_seed_ptr = nullptr; // optional arg [num_queries, num_seeds] + + void* args[] = // NOLINT + {&dd_dev_ptr, + &worker_handles_ptr, + &job_descriptors_ptr, + &completion_counters_ptr, + &graph_ptr, // [dataset_size, graph_degree] + &graph_degree, + &num_random_samplings, + &rand_xor_mask, + &dev_seed_ptr, + &num_seeds, + &hashmap_ptr, // visited_hashmap_ptr: [num_queries, 1 << hash_bitlen] + &itopk_size, + &search_width, + &min_iterations, + &max_iterations, + &num_executed_iterations, + &hash_bitlen, + &small_hash_bitlen, + &small_hash_reset_interval, + &sample_filter}; + cuda::atomic_thread_fence(cuda::memory_order_seq_cst, cuda::thread_scope_system); + RAFT_CUDA_TRY(cudaLaunchCooperativeKernel>( + kernel, gs, bs, args, smem_size, stream)); + RAFT_LOG_INFO( + "Initialized the kernel %p in stream %zd; job_queue size = %u; worker_queue size = %u", + reinterpret_cast(kernel), + int64_t((cudaStream_t)stream), + job_queue.capacity(), + worker_queue.capacity()); + last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); + } + + ~persistent_runner_t() noexcept override + { + auto whs = worker_handles.data(); + for (auto i = worker_handles.size(); i > 0; i--) { + whs[worker_queue.pop().wait()].data.store({kNoMoreWork}, cuda::memory_order_relaxed); + } + RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream)); + RAFT_LOG_INFO("Destroyed the persistent runner."); + } + + void launch(index_type* result_indices_ptr, // [num_queries, top_k] + distance_type* result_distances_ptr, // [num_queries, top_k] + const data_type* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + uint32_t top_k) + { + // submit all queries + launcher_t launcher{job_queue, + worker_queue, + worker_handles.data(), + num_queries, + this->lifetime, + [=](uint32_t job_ix) { + auto& jd = job_descriptors.data()[job_ix].input.value; + auto* cflag = &job_descriptors.data()[job_ix].completion_flag; + jd.result_indices_ptr = result_indices_ptr; + jd.result_distances_ptr = result_distances_ptr; + jd.queries_ptr = queries_ptr; + jd.top_k = top_k; + jd.n_queries = num_queries; + cflag->store(false, cuda::memory_order_relaxed); + cuda::atomic_thread_fence(cuda::memory_order_release, + cuda::thread_scope_system); + return cflag; + }}; + + // Update the state of the keep-alive atomic in the meanwhile + auto prev_touch = last_touch.load(std::memory_order_relaxed); + if (prev_touch + lifetime / 10 < launcher.now) { + // to avoid congestion at this atomic, we only update it if a significant fraction of the live + // interval has passed. + last_touch.store(launcher.now, std::memory_order_relaxed); + } + // wait for the results to arrive + launcher.wait(); + } + + auto calc_coop_grid_size(uint32_t block_size, uint32_t smem_size, float persistent_device_usage) + -> dim3 + { + // determine the grid size + int ctas_per_sm = 1; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, block_size, smem_size); + int num_sm = raft::getMultiProcessorCount(); + auto n_blocks = static_cast(persistent_device_usage * (ctas_per_sm * num_sm)); + if (n_blocks > kMaxWorkersNum) { + RAFT_LOG_WARN("Limiting the grid size limit due to the size of the queue: %u -> %u", + n_blocks, + kMaxWorkersNum); + n_blocks = kMaxWorkersNum; + } + + return {1, n_blocks, 1}; + } +}; + +struct alignas(kCacheLineBytes) persistent_state { + std::shared_ptr runner{nullptr}; + std::mutex lock; +}; + +inline persistent_state persistent{}; + +template +auto create_runner(Args... args) -> std::shared_ptr // it's ok.. pass everything by values +{ + std::lock_guard guard(persistent.lock); + // Check if the runner has already been created + std::shared_ptr runner_outer = std::dynamic_pointer_cast(persistent.runner); + if (runner_outer) { + if (runner_outer->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner_outer; + } else { + runner_outer.reset(); + } + } + // Runner has not yet been created (or it's incompatible): + // create it in another thread and only then release the lock. + // Free the resources (if any) in advance + persistent.runner.reset(); + + cuda::std::atomic_flag ready{}; + ready.clear(cuda::std::memory_order_relaxed); + std::thread( + [&runner_outer, &ready](Args... thread_args) { // pass everything by values + // create the runner (the lock is acquired in the parent thread). + runner_outer = std::make_shared(thread_args...); + auto lifetime = runner_outer->lifetime; + persistent.runner = std::static_pointer_cast(runner_outer); + std::weak_ptr runner_weak = runner_outer; + ready.test_and_set(cuda::std::memory_order_release); + ready.notify_one(); + // NB: runner_outer is passed by reference and may be dead by this time. + + while (true) { + std::this_thread::sleep_for(lifetime); + auto runner = runner_weak.lock(); // runner_weak is local - thread-safe + if (!runner) { + return; // dead already + } + if (runner->last_touch.load(std::memory_order_relaxed) + lifetime < + std::chrono::system_clock::now()) { + std::lock_guard guard(persistent.lock); + if (runner == persistent.runner) { persistent.runner.reset(); } + return; + } + } + }, + args...) + .detach(); + ready.wait(false, cuda::std::memory_order_acquire); + return runner_outer; +} + +template +auto get_runner(Args... args) -> std::shared_ptr +{ + // Using a thread-local weak pointer allows us to avoid using locks/atomics, + // since the control block of weak/shared pointers is thread-safe. + static thread_local std::weak_ptr weak; + auto runner = weak.lock(); + if (runner) { + if (runner->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner; + } else { + weak.reset(); + runner.reset(); + } + } + // Thread-local variable expected_latency makes sense only for a current RunnerT configuration. + // If `weak` is not alive, it's a hint the configuration has changed and we should reset our + // estimate of the expected launch latency. + launcher_t::expected_latency = launcher_t::kDefaultLatency; + runner = create_runner(args...); + weak = runner; + return runner; +} + template -void select_and_run(const dataset_descriptor_base_t* dataset_desc, +void select_and_run(const dataset_descriptor_host& dataset_desc, raft::device_matrix_view graph, IndexT* topk_indices_ptr, // [num_queries, topk] DistanceT* topk_distances_ptr, // [num_queries, topk] @@ -879,40 +1943,66 @@ void select_and_run(const dataset_descriptor_base_t* d SampleFilterT sample_filter, cudaStream_t stream) { - auto kernel = - search_kernel_config, - SampleFilterT>::choose_itopk_and_mx_candidates(ps.itopk_size, - num_itopk_candidates, - block_size); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - dim3 thread_dims(block_size, 1, 1); - dim3 block_dims(1, num_queries, 1); - RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - topk, - dataset_desc, - queries_ptr, - graph.data_handle(), - graph.extent(1), - ps.num_random_samplings, - ps.rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - ps.itopk_size, - ps.search_width, - ps.min_iterations, - ps.max_iterations, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - sample_filter); - // RAFT_CUDA_TRY(cudaPeekAtLastError()); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + if (ps.persistent) { + using runner_type = persistent_runner_t; + + get_runner(/* +Note, we're passing the descriptor by reference here, and this reference is going to be passed to a +new spawned thread, which is dangerous. However, the descriptor is copied in that thread before the +control is returned in this thread (in persistent_runner_t constructor), so we're safe. +*/ + std::cref(dataset_desc), + graph, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + ps.num_random_samplings, + ps.rand_xor_mask, + num_seeds, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + sample_filter, + ps.persistent_lifetime, + ps.persistent_device_usage) + ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); + } else { + using descriptor_base_type = dataset_descriptor_base_t; + auto kernel = search_kernel_config:: + choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size); + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + dim3 thread_dims(block_size, 1, 1); + dim3 block_dims(1, num_queries, 1); + RAFT_LOG_DEBUG( + "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); + kernel<<>>(topk_indices_ptr, + topk_distances_ptr, + topk, + dataset_desc.dev_ptr(stream), + queries_ptr, + graph.data_handle(), + graph.extent(1), + ps.num_random_samplings, + ps.rand_xor_mask, + dev_seed_ptr, + num_seeds, + hashmap_ptr, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + sample_filter); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } } } // namespace single_cta_search } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh index 7b7f44db7..4d8b72b41 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh @@ -22,7 +22,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search { template -void select_and_run(const dataset_descriptor_base_t* dataset_desc, +void select_and_run(const dataset_descriptor_host& dataset_desc, raft::device_matrix_view graph, IndexT* topk_indices_ptr, // [num_queries, topk] DistanceT* topk_distances_ptr, // [num_queries, topk] diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index d744a8178..15f494d3d 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -27,6 +27,7 @@ include(rapids-find) rapids_cuda_init_architectures(test_cuvs) project(test_cuvs LANGUAGES CXX CUDA) +find_package(Threads) # ------------- configure cuvs -----------------# @@ -36,11 +37,15 @@ include(../cmake/thirdparty/get_cuvs.cmake) # -------------- compile tasks ----------------- # add_executable(CAGRA_EXAMPLE src/cagra_example.cu) +add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu) add_executable(IVF_FLAT_EXAMPLE src/ivf_flat_example.cu) add_executable(IVF_PQ_EXAMPLE src/ivf_pq_example.cu) # `$` is a generator expression that ensures that targets are # installed in a conda environment, if one exists target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $) +target_link_libraries( + CAGRA_PERSISTENT_EXAMPLE PRIVATE cuvs::cuvs $ Threads::Threads +) target_link_libraries(IVF_PQ_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries(IVF_FLAT_EXAMPLE PRIVATE cuvs::cuvs $) diff --git a/examples/cpp/src/cagra_persistent_example.cu b/examples/cpp/src/cagra_persistent_example.cu new file mode 100644 index 000000000..9258a7311 --- /dev/null +++ b/examples/cpp/src/cagra_persistent_example.cu @@ -0,0 +1,258 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.cuh" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// A helper to split the dataset into chunks +template +auto slice_matrix(DeviceMatrixOrView source, + typename DeviceMatrixOrView::index_type offset_rows, + typename DeviceMatrixOrView::index_type count_rows) { + auto n_cols = source.extent(1); + return raft::make_device_matrix_view< + typename DeviceMatrixOrView::element_type, + typename DeviceMatrixOrView::index_type>( + source.data_handle() + offset_rows * n_cols, count_rows, n_cols); +} + +// A helper to measure the execution time of a function +template +void time_it(std::string label, F f, Args &&...xs) { + auto start = std::chrono::system_clock::now(); + f(std::forward(xs)...); + auto end = std::chrono::system_clock::now(); + auto t = std::chrono::duration_cast(end - start); + auto t_ms = double(t.count()) / 1000.0; + std::cout << "[" << label << "] execution time: " << t_ms << " ms" + << std::endl; +} + +void cagra_build_search_variants( + raft::device_resources const &res, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) { + using namespace cuvs::neighbors; + + // Number of neighbors to search + int64_t topk = 100; + // We split the queries set into three subsets for our experiment, one for a + // sanity check and two for measuring the performance. + int64_t n_queries_a = queries.extent(0) / 2; + int64_t n_queries_b = queries.extent(0) - n_queries_a; + + auto queries_a = slice_matrix(queries, 0, n_queries_a); + auto queries_b = slice_matrix(queries, n_queries_a, n_queries_b); + + // create output arrays + auto neighbors = + raft::make_device_matrix(res, queries.extent(0), topk); + auto distances = + raft::make_device_matrix(res, queries.extent(0), topk); + // slice them same as queries + auto neighbors_a = slice_matrix(neighbors, 0, n_queries_a); + auto distances_a = slice_matrix(distances, 0, n_queries_a); + auto neighbors_b = slice_matrix(neighbors, n_queries_a, n_queries_b); + auto distances_b = slice_matrix(distances, n_queries_a, n_queries_b); + + // use default index parameters + cagra::index_params index_params; + + std::cout << "Building CAGRA index (search graph)" << std::endl; + auto index = cagra::build(res, index_params, dataset); + + std::cout << "CAGRA index has " << index.size() << " vectors" << std::endl; + std::cout << "CAGRA graph has degree " << index.graph_degree() + << ", graph size [" << index.graph().extent(0) << ", " + << index.graph().extent(1) << "]" << std::endl; + + // use default search parameters + cagra::search_params search_params; + // get a decent recall by increasing the internal topk list + search_params.itopk_size = 512; + + // Another copy of search parameters to enable persistent kernel + cagra::search_params search_params_persistent = search_params; + search_params_persistent.persistent = true; + // Persistent kernel only support single-cta search algorithm for now. + search_params_persistent.algo = cagra::search_algo::SINGLE_CTA; + // Slightly reduce the kernel grid size to make this example program work + // smooth on workstations, which use the same GPU for other tasks (e.g. + // rendering GUI). + search_params_persistent.persistent_device_usage = 0.95; + + /* + Define the big-batch setting as a baseline for measuring the throughput. + + Note, this lambda can be used by the standard and the persistent + implementation interchangeably: the index stays the same, only search + parameters need some adjustment. + */ + auto search_batch = + [&res, &index](bool needs_sync, const cagra::search_params &ps, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) { + cagra::search(res, ps, index, queries, neighbors, distances); + /* + To make a fair comparison, standard implementation needs to synchronize + with the device to make sure the kernel has finished the work. + Persistent kernel does not make any use of CUDA streams and blocks till + the results are available. Hence, synchronizing with the stream is a + waste of time in this case. + */ + if (needs_sync) { + raft::resource::sync_stream(res); + } + }; + + /* + Define the asynchronous small-batch search setting. + The same lambda is used for both the standard and the persistent + implementations. + + There are a few things to remember about this example though: + 1. The standard kernel is launched in the given stream (behind the `res`); + The persistent kernel is launched implicitly; the public api call does + not touch the stream and blocks till the results are returned. (Hence the + optional sync at the end of the lambda.) + 2. When launched asynchronously, the standard kernel should actually have a + separate raft::resource per-thread to achieve best performance. However, + this requires extra management of the resource/stream pools, we don't + include that for simplicity. + The persistent implementation does not require any special care; you can + safely pass a single raft::resources to all threads. + 3. This example relies on the compiler implementation to launch the async + jobs in separate threads. This is not guaranteed, however. + In the real world, we'd advise to use a custom thread pool for managing + the requests. + 4. Although the API defines the arguments as device-side mdspans, we advise + to use the host-side buffers accessible from the device, such as + allocated by cudaHostAlloc/cudaHostRegister (or any host memory if + HMM/ATS is enabled). + This way, you can save some GPU resources by not manually copying the + data in cuda streams. + */ + auto search_async = + [&res, &index](bool needs_sync, const cagra::search_params &ps, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) { + auto work_size = queries.extent(0); + using index_type = typeof(work_size); + // Limit the maximum number of concurrent jobs + constexpr index_type kMaxJobs = 1000; + std::array, kMaxJobs> futures; + for (index_type i = 0; i < work_size + kMaxJobs; i++) { + // wait for previous job in the same slot to finish + if (i >= kMaxJobs) { + futures[i % kMaxJobs].wait(); + } + // submit a new job + if (i < work_size) { + futures[i % kMaxJobs] = std::async(std::launch::async, [&]() { + cagra::search(res, ps, index, slice_matrix(queries, i, 1), + slice_matrix(neighbors, i, 1), + slice_matrix(distances, i, 1)); + }); + } + } + /* See the remark for search_batch */ + if (needs_sync) { + raft::resource::sync_stream(res); + } + }; + + // Launch the baseline search: check the big-batch performance + time_it("standard/batch A", search_batch, true, search_params, queries_a, + neighbors_a, distances_a); + time_it("standard/batch B", search_batch, true, search_params, queries_b, + neighbors_b, distances_b); + + // Try to handle the same amount of work in the async setting using the + // standard implementation. + // (Warning: suboptimal - it uses a single stream for all async jobs) + time_it("standard/async A", search_async, true, search_params, queries_a, + neighbors_a, distances_a); + time_it("standard/async B", search_async, true, search_params, queries_b, + neighbors_b, distances_b); + + // Do the same using persistent kernel. + time_it("persistent/async A", search_async, false, search_params_persistent, + queries_a, neighbors_a, distances_a); + time_it("persistent/async B", search_async, false, search_params_persistent, + queries_b, neighbors_b, distances_b); + /* +Here's an example output, which shows the wall time of processing the same +amount of data in a single batch vs in async mode (1 query per job): +``` +CAGRA index has 1000000 vectors +CAGRA graph has degree 64, graph size [1000000, 64] +[standard/batch A] execution time: 854.645 ms +[standard/batch B] execution time: 698.58 ms +[standard/async A] execution time: 19190.6 ms +[standard/async B] execution time: 18292 ms +[I] [15:56:49.756754] Initialized the kernel 0x7ea4e55a5350 in stream + 139227270582864; job_queue size = 8192; worker_queue size = 155 +[persistent/async A] execution time: 1285.65 ms +[persistent/async B] execution time: 1316.97 ms +[I] [15:56:55.756952] Destroyed the persistent runner. +``` +Note, while the persistent kernel provides minimal latency for each search +request, the wall time to process all the queries in async mode (1 query per +job) is up to 2x slower than the standard kernel with the huge batch +size (100K queries). One reason for this is the non-optimal CTA size: CAGRA +kernels are automatically tuned for latency and so use large CTA sizes when the +batch size is small. Try explicitly setting the search parameter +`thread_block_size` to a small value, such as `64` or `128` if this is an issue +for you. This increases the latency of individual jobs though. + */ +} + +int main() { + raft::device_resources res; + + // Set pool memory resource with 1 GiB initial pool size. All allocations use + // the same pool. + rmm::mr::pool_memory_resource pool_mr( + rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(&pool_mr); + + // Create input arrays. + int64_t n_samples = 1000000; + int64_t n_dim = 128; + int64_t n_queries = 100000; + auto dataset = + raft::make_device_matrix(res, n_samples, n_dim); + auto queries = + raft::make_device_matrix(res, n_queries, n_dim); + generate_dataset(res, dataset.view(), queries.view()); + + // run the interesting part of the program + cagra_build_search_variants(res, raft::make_const_mdspan(dataset.view()), + raft::make_const_mdspan(queries.view())); +}