diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp index b010063dee..70c502da6a 100644 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ b/cpp/bench/ann/src/common/ann_types.hpp @@ -33,6 +33,7 @@ enum Objective { enum class MemoryType { Host, HostMmap, + HostPinned, Device, }; @@ -58,6 +59,8 @@ inline auto parse_memory_type(const std::string& memory_type) -> MemoryType return MemoryType::Host; } else if (memory_type == "mmap") { return MemoryType::HostMmap; + } else if (memory_type == "pinned") { + return MemoryType::HostPinned; } else if (memory_type == "device") { return MemoryType::Device; } else { @@ -73,7 +76,7 @@ struct AlgoProperty { class AnnBase { public: - using index_type = size_t; + using index_type = uint32_t; inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {} virtual ~AnnBase() noexcept = default; diff --git a/cpp/bench/ann/src/common/dataset.hpp b/cpp/bench/ann/src/common/dataset.hpp index 8fcff77d3c..7e8e7ba8f8 100644 --- a/cpp/bench/ann/src/common/dataset.hpp +++ b/cpp/bench/ann/src/common/dataset.hpp @@ -283,7 +283,28 @@ class Dataset { { switch (memory_type) { case MemoryType::Device: return query_set_on_gpu(); - default: return query_set(); + case MemoryType::Host: { + auto r = query_set(); +#ifndef BUILD_CPU_ONLY + if (query_set_pinned_) { + cudaHostUnregister(const_cast(r)); + query_set_pinned_ = false; + } +#endif + return r; + } + case MemoryType::HostPinned: { + auto r = query_set(); +#ifndef BUILD_CPU_ONLY + if (!query_set_pinned_) { + cudaHostRegister( + const_cast(r), query_set_size() * dim() * sizeof(T), cudaHostRegisterDefault); + query_set_pinned_ = true; + } +#endif + return r; + } + default: return nullptr; } } @@ -291,7 +312,27 @@ class Dataset { { switch (memory_type) { case MemoryType::Device: return base_set_on_gpu(); - case MemoryType::Host: return base_set(); + case MemoryType::Host: { + auto r = base_set(); +#ifndef BUILD_CPU_ONLY + if (base_set_pinned_) { + cudaHostUnregister(const_cast(r)); + base_set_pinned_ = false; + } +#endif + return r; + } + case MemoryType::HostPinned: { + auto r = base_set(); +#ifndef BUILD_CPU_ONLY + if (!base_set_pinned_) { + cudaHostRegister( + const_cast(r), base_set_size() * dim() * sizeof(T), cudaHostRegisterDefault); + base_set_pinned_ = true; + } +#endif + return r; + } case MemoryType::HostMmap: return mapped_base_set(); default: return nullptr; } @@ -312,18 +353,23 @@ class Dataset { mutable T* d_query_set_ = nullptr; mutable T* mapped_base_set_ = nullptr; mutable int32_t* gt_set_ = nullptr; + + mutable bool base_set_pinned_ = false; + mutable bool query_set_pinned_ = false; }; template Dataset::~Dataset() { - delete[] base_set_; - delete[] query_set_; - delete[] gt_set_; #ifndef BUILD_CPU_ONLY if (d_base_set_) { cudaFree(d_base_set_); } if (d_query_set_) { cudaFree(d_query_set_); } + if (base_set_pinned_) { cudaHostUnregister(base_set_); } + if (query_set_pinned_) { cudaHostUnregister(query_set_); } #endif + delete[] base_set_; + delete[] query_set_; + delete[] gt_set_; } template diff --git a/cpp/bench/ann/src/common/util.hpp b/cpp/bench/ann/src/common/util.hpp index 96185c79eb..c481f589bd 100644 --- a/cpp/bench/ann/src/common/util.hpp +++ b/cpp/bench/ann/src/common/util.hpp @@ -197,10 +197,12 @@ struct result_buffer { explicit result_buffer(size_t size, cudaStream_t stream) : size_{size}, stream_{stream} { if (size_ == 0) { return; } - data_host_ = malloc(size_); #ifndef BUILD_CPU_ONLY cudaMallocAsync(&data_device_, size_, stream_); + cudaMallocHost(&data_host_, size_); cudaStreamSynchronize(stream_); +#else + data_host_ = malloc(size_); #endif } result_buffer() = delete; @@ -213,9 +215,11 @@ struct result_buffer { if (size_ == 0) { return; } #ifndef BUILD_CPU_ONLY cudaFreeAsync(data_device_, stream_); + cudaFreeHost(data_host_); cudaStreamSynchronize(stream_); -#endif +#else free(data_host_); +#endif } [[nodiscard]] auto size() const noexcept { return size_; } @@ -278,6 +282,31 @@ inline auto get_result_buffer_from_global_pool(size_t size) -> result_buffer& return rb; } +namespace detail { +inline std::vector> global_tmp_buffer_pool(0); +inline std::mutex gtp_mutex; +} // namespace detail + +/** + * Global temporary buffer pool for use by algorithms. + * In contrast to `get_result_buffer_from_global_pool`, the content of these buffers is never + * initialized. + */ +inline auto get_tmp_buffer_from_global_pool(size_t size) -> result_buffer& +{ + auto stream = get_stream_from_global_pool(); + auto& rb = [stream, size]() -> result_buffer& { + std::lock_guard guard(detail::gtp_mutex); + if (static_cast(detail::global_tmp_buffer_pool.size()) < benchmark_n_threads) { + detail::global_tmp_buffer_pool.resize(benchmark_n_threads); + } + auto& rb = detail::global_tmp_buffer_pool[benchmark_thread_id]; + if (!rb || rb->size() < size) { rb = std::make_unique(size, stream); } + return *rb; + }(); + return rb; +} + /** * Delete all streams and memory allocations in the global pool. * It's called at the end of the `main` function - before global/static variables and cuda context diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 48bf1d70d8..289e7a293f 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -249,6 +249,10 @@ void parse_search_param(const nlohmann::json& conf, if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); } + if (conf.contains("persistent")) { param.p.persistent = conf.at("persistent"); } + if (conf.contains("thread_block_size")) { + param.p.thread_block_size = conf.at("thread_block_size"); + } if (conf.contains("algo")) { if (conf.at("algo") == "single_cta") { param.p.algo = raft::neighbors::experimental::cagra::search_algo::SINGLE_CTA; diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index 9b086fdb23..f754faa17b 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -228,27 +228,47 @@ void refine_helper(const raft::resources& res, } else { auto dataset_host = raft::make_host_matrix_view( dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - auto queries_host = raft::make_host_matrix(batch_size, dim); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - auto stream = resource::get_cuda_stream(res); - raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); + if (raft::get_device_for_address(queries.data_handle()) >= 0) { + // Queries & results are on the device - raft::resource::sync_stream(res); // wait for the queries and candidates - raft::neighbors::refine(res, - dataset_host, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - metric); + auto queries_host = raft::make_host_matrix(batch_size, dim); + auto candidates_host = raft::make_host_matrix(batch_size, k0); + auto neighbors_host = raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); + + auto stream = resource::get_cuda_stream(res); + raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); + raft::copy( + candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); + + raft::resource::sync_stream(res); // wait for the queries and candidates + raft::neighbors::refine(res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + metric); + + raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + + } else { + // Queries & results are on the host - no device sync / copy needed + + auto queries_host = raft::make_host_matrix_view( + queries.data_handle(), batch_size, dim); + auto candidates_host = raft::make_host_matrix_view( + candidates.data_handle(), batch_size, k0); + auto neighbors_host = + raft::make_host_matrix_view(neighbors, batch_size, k); + auto distances_host = + raft::make_host_matrix_view(distances, batch_size, k); - raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + raft::neighbors::refine( + res, dataset_host, queries_host, candidates_host, neighbors_host, distances_host, metric); + } } } diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index b03f875a8e..e5d0622184 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -117,6 +117,15 @@ class RaftCagra : public ANN, public AnnGPU { return handle_.get_sync_stream(); } + [[nodiscard]] auto uses_stream() const noexcept -> bool override + { + // If the algorithm uses persistent kernel, the CPU has to synchronize by the end of computing + // the result. Hence it guarantees the benchmark CUDA stream is empty by the end of the + // execution. Hence we inform the benchmark to not waste the time on recording & synchronizing + // the event. + return !search_params_.persistent; + } + // to enable dataset access from GPU memory AlgoProperty get_preference() const override { @@ -326,14 +335,33 @@ void RaftCagra::search( } else { auto queries_v = raft::make_device_matrix_view(queries, batch_size, dimension_); - auto candidate_ixs = - raft::make_device_matrix(res, batch_size, k0); - auto candidate_dists = - raft::make_device_matrix(res, batch_size, k0); - search_base( - queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); - refine_helper( - res, *input_dataset_v_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); + + auto& tmp_buf = get_tmp_buffer_from_global_pool((sizeof(float) + sizeof(AnnBase::index_type)) * + batch_size * k0); + auto mem_type = + raft::get_device_for_address(neighbors) >= 0 ? MemoryType::Device : MemoryType::HostPinned; + + auto candidate_ixs = raft::make_device_matrix_view( + reinterpret_cast(tmp_buf.data(mem_type)), batch_size, k0); + auto candidate_dists = reinterpret_cast(candidate_ixs.data_handle() + batch_size * k0); + + search_base(queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists); + + if (mem_type == MemoryType::HostPinned && uses_stream()) { + // If the algorithm uses a stream to synchronize (non-persistent kernel), but the data is in + // the pinned host memory, we need top synchronize before the refinement operation to wait for + // the data being available for the host. + raft::resource::sync_stream(res); + } + + refine_helper(res, + *input_dataset_v_, + queries_v, + raft::make_const_mdspan(candidate_ixs), + k, + neighbors, + distances, + index_->metric()); } } } // namespace raft::bench::ann diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 97c9c0d098..3ea75313c5 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -124,6 +124,8 @@ struct search_params : ann::search_params { uint32_t num_random_samplings = 1; /** Bit mask used for initial random seed node selection. */ uint64_t rand_xor_mask = 0x128394; + /** Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.) */ + bool persistent = false; }; static_assert(std::is_aggregate_v); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index b35d96e9f5..a9dc894587 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -29,14 +29,85 @@ #include #include +#include +#include +#include + namespace raft::neighbors::cagra::detail { +/** + * A lightweight version of rmm::device_uvector. + * This version ignores the current device on allocations, thus avoids calling + * cudaSetDevice/cudaGetDevice. + * If the size stays at zero, this struct never calls any CUDA driver / RAFT resource functions. + */ +template +struct lightweight_uvector { + private: + using raft_res_type = const raft::resources*; + using rmm_res_type = std::tuple; + static constexpr size_t kAlign = 256; + + std::variant res_; + T* ptr_; + size_t size_; + + public: + explicit lightweight_uvector(const raft::resources& res) : res_(&res), ptr_{nullptr}, size_{0} {} + + [[nodiscard]] auto data() noexcept -> T* { return ptr_; } + [[nodiscard]] auto data() const noexcept -> const T* { return ptr_; } + [[nodiscard]] auto size() const noexcept -> size_t { return size_; } + + void resize(size_t new_size) + { + if (new_size == size_) { return; } + if (std::holds_alternative(res_)) { + auto& h = std::get(res_); + res_ = rmm_res_type{resource::get_workspace_resource(*h), resource::get_cuda_stream(*h)}; + } + auto& [r, s] = std::get(res_); + T* new_ptr = nullptr; + if (new_size > 0) { + new_ptr = reinterpret_cast(r.allocate_async(new_size * sizeof(T), kAlign, s)); + } + auto copy_size = std::min(size_, new_size); + if (copy_size > 0) { + cudaMemcpyAsync(new_ptr, ptr_, copy_size * sizeof(T), cudaMemcpyDefault, s); + } + if (size_ > 0) { r.deallocate_async(ptr_, size_ * sizeof(T), kAlign, s); } + ptr_ = new_ptr; + size_ = new_size; + } + + void resize(size_t new_size, rmm::cuda_stream_view stream) + { + if (new_size == size_) { return; } + if (std::holds_alternative(res_)) { + auto& h = std::get(res_); + res_ = rmm_res_type{resource::get_workspace_resource(*h), stream}; + } else { + std::get(std::get(res_)) = stream; + } + resize(new_size); + } + + ~lightweight_uvector() noexcept + { + if (size_ > 0) { + auto& [r, s] = std::get(res_); + r.deallocate_async(ptr_, size_ * sizeof(T), kAlign, s); + } + } +}; + struct search_plan_impl_base : public search_params { int64_t dataset_block_dim; int64_t dim; int64_t graph_degree; uint32_t topk; raft::distance::DistanceType metric; + search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, @@ -95,9 +166,9 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t topk; uint32_t num_seeds; - rmm::device_uvector hashmap; - rmm::device_uvector num_executed_iterations; // device or managed? - rmm::device_uvector dev_seed; + lightweight_uvector hashmap; + lightweight_uvector num_executed_iterations; // device or managed? + lightweight_uvector dev_seed; search_plan_impl(raft::resources const& res, search_params params, @@ -106,16 +177,18 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t topk, raft::distance::DistanceType metric) : search_plan_impl_base(params, dim, graph_degree, topk, metric), - hashmap(0, resource::get_cuda_stream(res)), - num_executed_iterations(0, resource::get_cuda_stream(res)), - dev_seed(0, resource::get_cuda_stream(res)), + hashmap(res), + num_executed_iterations(res), + dev_seed(res), num_seeds(0) { adjust_search_params(); check_params(); calc_hashmap_params(res); set_dataset_block_and_team_size(dim); - num_executed_iterations.resize(max_queries, resource::get_cuda_stream(res)); + if (!persistent) { // Persistent kernel does not provide this functionality + num_executed_iterations.resize(max_queries, resource::get_cuda_stream(res)); + } RAFT_LOG_DEBUG("# algo = %d", static_cast(algo)); } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 0771652787..442296aa40 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -201,8 +201,8 @@ struct search : search_plan_impl { } RAFT_LOG_DEBUG("# smem_size: %u", smem_size); hashmap_size = 0; - if (small_hash_bitlen == 0) { - hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); + if (small_hash_bitlen == 0 && !this->persistent) { + hashmap_size = max_queries * hashmap::get_size(hash_bitlen); hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); } RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); @@ -221,6 +221,12 @@ struct search : search_plan_impl { SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); + + // Set the 'persistent' flag as the first bit of rand_xor_mask to avoid changing the signature + // of the select_and_run for now. + constexpr uint64_t kPMask = 0x8000000000000000LL; + auto rand_xor_mask_augmented = + this->persistent ? (rand_xor_mask | kPMask) : (rand_xor_mask & ~kPMask); select_and_run( dataset_desc, graph, @@ -239,7 +245,7 @@ struct search : search_plan_impl { small_hash_bitlen, small_hash_reset_interval, num_random_samplings, - rand_xor_mask, + rand_xor_mask_augmented, num_seeds, itopk_size, search_width, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 232dcb782a..fed7ca0fb4 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -33,17 +33,28 @@ #include #include #include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp +#include +#include #include +#include +#include +#include + +#include +#include #include +#include #include +#include #include #include #include +#include #include #include +#include #include namespace raft::neighbors::cagra::detail { @@ -457,7 +468,7 @@ template -__launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( +__device__ void search_core( typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, top_k] typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] const std::uint32_t top_k, @@ -479,6 +490,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( const std::uint32_t hash_bitlen, const std::uint32_t small_hash_bitlen, const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id, SAMPLE_FILTER_T sample_filter, raft::distance::DistanceType metric) { @@ -489,8 +501,6 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; using QUERY_T = typename DATASET_DESCRIPTOR_T::QUERY_T; - const auto query_id = blockIdx.y; - #ifdef _CLK_BREAKDOWN std::uint64_t clk_init = 0; std::uint64_t clk_compute_1st_distance = 0; @@ -557,7 +567,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( if (small_hash_bitlen) { local_visited_hashmap_ptr = visited_hash_buffer; } else { - local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); + local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * blockIdx.y); } hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); __syncthreads(); @@ -812,53 +822,330 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( } template +__launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( + typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, top_k] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] + const std::uint32_t top_k, + DATASET_DESCRIPTOR_T dataset_desc, + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric) +{ + const auto query_id = blockIdx.y; + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + dataset_desc, + queries_ptr, + knn_graph, + graph_degree, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + sample_filter, + metric); +} + +// To make sure we avoid false sharing on both CPU and GPU, we enforce cache line size to the +// maximum of the two. +// This makes sync atomic significantly faster. +constexpr size_t kCacheLineBytes = 64; + +constexpr uint32_t kMaxJobsNum = 2048; +constexpr uint32_t kMaxWorkersNum = 2048; +constexpr uint32_t kMaxWorkersPerThread = 256; +constexpr uint32_t kSoftMaxWorkersPerThread = 16; + +template +struct alignas(kCacheLineBytes) job_desc_t { + using index_type = typename DATASET_DESCRIPTOR_T::INDEX_T; + using distance_type = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using data_type = typename DATASET_DESCRIPTOR_T::DATA_T; + // The algorithm input parameters + struct value_t { + index_type* result_indices_ptr; // [num_queries, top_k] + distance_type* result_distances_ptr; // [num_queries, top_k] + const data_type* queries_ptr; // [num_queries, dataset_dim] + uint32_t top_k; + uint32_t n_queries; + }; + using blob_elem_type = uint4; + constexpr static inline size_t kBlobSize = + raft::div_rounding_up_safe(sizeof(value_t), sizeof(blob_elem_type)); + // Union facilitates loading the input by a warp in a single request + union input_t { + blob_elem_type blob[kBlobSize]; // NOLINT + value_t value; + } input; + // Last thread triggers this flag. + cuda::atomic completion_flag; +}; + +struct alignas(kCacheLineBytes) worker_handle_t { + using handle_t = uint64_t; + struct value_t { + uint32_t desc_id; + uint32_t query_id; + }; + union data_t { + handle_t handle; + value_t value; + }; + cuda::atomic data; +}; +static_assert(sizeof(worker_handle_t::value_t) == sizeof(worker_handle_t::handle_t)); +static_assert( + cuda::atomic::is_always_lock_free); + +constexpr worker_handle_t::handle_t kWaitForWork = std::numeric_limits::max(); +constexpr worker_handle_t::handle_t kNoMoreWork = kWaitForWork - 1; + +constexpr auto is_worker_busy(worker_handle_t::handle_t h) -> bool +{ + return (h != kWaitForWork) && (h != kNoMoreWork); +} + +template +__launch_bounds__(1024, 1) RAFT_KERNEL search_kernel_p( + DATASET_DESCRIPTOR_T dataset_desc, + worker_handle_t* worker_handles, + job_desc_t* job_descriptors, + uint32_t* completion_counters, + const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric) +{ + using job_desc_type = job_desc_t; + __shared__ typename job_desc_type::input_t job_descriptor; + __shared__ worker_handle_t::data_t worker_data; + + auto& worker_handle = worker_handles[blockIdx.y].data; + uint32_t job_ix; + + while (true) { + // wait the writing phase + if (threadIdx.x == 0) { + worker_handle_t::data_t worker_data_local; + do { + worker_data_local = worker_handle.load(cuda::memory_order_relaxed); + } while (worker_data_local.handle == kWaitForWork); + if (worker_data_local.handle != kNoMoreWork) { + worker_handle.store({kWaitForWork}, cuda::memory_order_relaxed); + } + job_ix = worker_data_local.value.desc_id; + cuda::atomic_thread_fence(cuda::memory_order_acquire, cuda::thread_scope_system); + worker_data = worker_data_local; + } + if (threadIdx.x < WarpSize) { + // Sync one warp and copy descriptor data + static_assert(job_desc_type::kBlobSize <= WarpSize); + job_ix = raft::shfl(job_ix, 0); + if (threadIdx.x < job_desc_type::kBlobSize && job_ix < kMaxJobsNum) { + job_descriptor.blob[threadIdx.x] = job_descriptors[job_ix].input.blob[threadIdx.x]; + } + } + __syncthreads(); + if (worker_data.handle == kNoMoreWork) { break; } + + // reading phase + auto* result_indices_ptr = job_descriptor.value.result_indices_ptr; + auto* result_distances_ptr = job_descriptor.value.result_distances_ptr; + auto* queries_ptr = job_descriptor.value.queries_ptr; + auto top_k = job_descriptor.value.top_k; + auto n_queries = job_descriptor.value.n_queries; + auto query_id = worker_data.value.query_id; + + // work phase + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + dataset_desc, + queries_ptr, + knn_graph, + graph_degree, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + sample_filter, + metric); + + // make sure all writes are visible even for the host + // (e.g. when result buffers are in pinned memory) + cuda::atomic_thread_fence(cuda::memory_order_release, cuda::thread_scope_system); + + // arrive to mark the end of the work phase + __syncthreads(); + if (threadIdx.x == 0) { + auto completed_count = atomicInc(completion_counters + job_ix, n_queries - 1) + 1; + if (completed_count >= n_queries) { + // we may need a memory fence here: + // - device - if the queries are accessed by the device + // - system - e.g. if we put them into managed/pinned memory. + job_descriptors[job_ix].completion_flag.store(true, cuda::memory_order_relaxed); + } + } + } +} + +template +auto dispatch_kernel = []() { + if constexpr (Persistent) { + return search_kernel_p; + } else { + return search_kernel; + } +}(); + +template struct search_kernel_config { - using kernel_t = decltype(&search_kernel); + using kernel_t = decltype(dispatch_kernel); template static auto choose_search_kernel(unsigned itopk_size) -> kernel_t { if (itopk_size <= 64) { - return search_kernel; + return dispatch_kernel; } else if (itopk_size <= 128) { - return search_kernel; + return dispatch_kernel; } else if (itopk_size <= 256) { - return search_kernel; + return dispatch_kernel; } else if (itopk_size <= 512) { - return search_kernel; + return dispatch_kernel; } THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); } @@ -878,21 +1165,23 @@ struct search_kernel_config { // Radix-based topk is used constexpr unsigned max_candidates = 32; // to avoid build failure if (itopk_size <= 256) { - return search_kernel; + return dispatch_kernel; } else if (itopk_size <= 512) { - return search_kernel; + return dispatch_kernel; } } THROW("No kernel for parametels itopk_size %u, num_itopk_candidates %u", @@ -901,6 +1190,708 @@ struct search_kernel_config { } }; +/** + * @brief Resource queue + * + * A shared atomic ring buffer based queue optimized for throughput when bottlenecked on `pop` + * operation. + */ +template ::max()> +struct alignas(kCacheLineBytes) resource_queue_t { + using value_type = T; + static constexpr uint32_t kSize = Size; + static constexpr value_type kEmpty = Empty; + static_assert(cuda::std::atomic::is_always_lock_free, + "The value type must be lock-free."); + static_assert(raft::is_a_power_of_two(kSize), "The size must be a power-of-two for efficiency."); + static constexpr uint32_t kElemsPerCacheLine = + raft::div_rounding_up_safe(kCacheLineBytes, sizeof(value_type)); + static constexpr uint32_t kCounterIncrement = raft::bound_by_power_of_two(kElemsPerCacheLine) + 1; + static constexpr uint32_t kCounterLocMask = kSize - 1; + // These props hold by design, but we add them here as a documentation and a sanity check. + static_assert( + kCounterIncrement * sizeof(value_type) >= kCacheLineBytes, + "The counter increment should be larger than the cache line size to avoid false sharing."); + static_assert( + std::gcd(kCounterIncrement, kSize) == 1, + "The counter increment and the size must be coprime to allow using all of the queue slots."); + + static constexpr auto kMemOrder = cuda::std::memory_order_relaxed; + + explicit resource_queue_t(uint32_t capacity = Size) noexcept : capacity_{capacity} + { + head_.store(0, kMemOrder); + tail_.store(0, kMemOrder); + for (uint32_t i = 0; i < kSize; i++) { + buf_[i].store(kEmpty, kMemOrder); + } + } + + /** Nominal capacity of the queue. */ + [[nodiscard]] auto capacity() const { return capacity_; } + + /** This does not affect the queue behavior, but merely declares a nominal capacity. */ + void set_capacity(uint32_t capacity) { capacity_ = capacity; } + + /** + * A slot in the queue to take the value from. + * Once it's obtained, the corresponding value in the queue is lost for other users. + */ + struct promise_t { + explicit promise_t(cuda::std::atomic& loc) : loc_{loc}, val_{Empty} {} + ~promise_t() noexcept { wait(); } + + auto test() noexcept -> bool + { + if (val_ != Empty) { return true; } + val_ = loc_.exchange(kEmpty, kMemOrder); + return val_ != Empty; + } + + auto test(value_type& e) noexcept -> bool + { + if (test()) { + e = val_; + return true; + } + return false; + } + + auto wait() noexcept -> value_type + { + if (val_ == Empty) { + // [HOT SPOT] + // Optimize for the case of contention: expect the loc is empty. + do { + loc_.wait(kEmpty, kMemOrder); + val_ = loc_.exchange(kEmpty, kMemOrder); + } while (val_ == kEmpty); + } + return val_; + } + + private: + cuda::std::atomic& loc_; + value_type val_; + }; + + void push(value_type x) noexcept + { + auto& loc = buf_[head_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; + /* [NOT A HOT SPOT] + We expect there's always enough place in the queue to push the item, + but also we expect a few pop waiters - notify them the data is available. + */ + value_type e = kEmpty; + while (!loc.compare_exchange_weak(e, x, kMemOrder, kMemOrder)) { + e = kEmpty; + } + loc.notify_one(); + } + + auto pop() noexcept -> promise_t + { + auto& loc = buf_[tail_.fetch_add(kCounterIncrement, kMemOrder) & kCounterLocMask]; + return promise_t{loc}; + } + + private: + alignas(kCacheLineBytes) cuda::std::atomic head_{}; + alignas(kCacheLineBytes) cuda::std::atomic tail_{}; + alignas(kCacheLineBytes) std::array, kSize> buf_{}; + alignas(kCacheLineBytes) uint32_t capacity_; +}; + +/** Primitive fixed-size deque for single-threaded use. */ +template +struct local_deque_t { + explicit local_deque_t(uint32_t size) : store_(size) {} + + [[nodiscard]] auto capacity() const -> uint32_t { return store_.size(); } + [[nodiscard]] auto size() const -> uint32_t { return end_ - start_; } + + void push_back(T x) { store_[end_++ % capacity()] = x; } + + void push_front(T x) + { + if (start_ == 0) { + start_ += capacity(); + end_ += capacity(); + } + store_[--start_ % capacity()] = x; + } + + // NB: non-blocking, unsafe functions + auto pop_back() -> T { return store_[--end_ % capacity()]; } + auto pop_front() -> T { return store_[start_++ % capacity()]; } + + auto try_push_back(T x) -> bool + { + if (size() >= capacity()) { return false; } + push_back(x); + return true; + } + + auto try_push_front(T x) -> bool + { + if (size() >= capacity()) { return false; } + push_front(x); + return true; + } + + auto try_pop_back(T& x) -> bool + { + if (start_ >= end_) { return false; } + x = pop_back(); + return true; + } + + auto try_pop_front(T& x) -> bool + { + if (start_ >= end_) { return false; } + x = pop_front(); + return true; + } + + private: + std::vector store_; + uint32_t start_{0}; + uint32_t end_{0}; +}; + +struct persistent_runner_base_t { + using job_queue_type = resource_queue_t; + using worker_queue_type = resource_queue_t; + rmm::mr::pinned_host_memory_resource worker_handles_mr; + rmm::mr::pinned_host_memory_resource job_descriptor_mr; + rmm::mr::cuda_memory_resource device_mr; + cudaStream_t stream{}; + job_queue_type job_queue{}; + worker_queue_type worker_queue{}; + persistent_runner_base_t() : job_queue(), worker_queue() + { + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } + virtual ~persistent_runner_base_t() noexcept { cudaStreamDestroy(stream); }; +}; + +struct alignas(kCacheLineBytes) launcher_t { + using job_queue_type = persistent_runner_base_t::job_queue_type; + using worker_queue_type = persistent_runner_base_t::worker_queue_type; + using pending_reads_queue_type = local_deque_t; + using completion_flag_type = cuda::atomic; + + pending_reads_queue_type pending_reads; + job_queue_type& job_ids; + worker_queue_type& idle_worker_ids; + worker_handle_t* worker_handles; + uint32_t job_id; + completion_flag_type* completion_flag; + bool all_done = false; + + /* [Note: sleeping] + When the number of threads is greater than the number of cores, the threads start to fight for + the CPU time, which reduces the throughput. + To ease the competition, we track the expected GPU latency and let a thread sleep for some + time, and only start to spin when it's about a time to get the result. + */ + static inline constexpr auto kDefaultLatency = std::chrono::nanoseconds(50000); + static inline thread_local auto expected_latency = kDefaultLatency; + const std::chrono::time_point start; + std::chrono::time_point now; + const int64_t pause_factor; + int pause_count = 0; + + template + launcher_t(job_queue_type& job_ids, + worker_queue_type& idle_worker_ids, + worker_handle_t* worker_handles, + uint32_t n_queries, + RecordWork record_work) + : pending_reads{std::min(n_queries, kMaxWorkersPerThread)}, + job_ids{job_ids}, + idle_worker_ids{idle_worker_ids}, + worker_handles{worker_handles}, + job_id{job_ids.pop().wait()}, + completion_flag{record_work(job_id)}, + start{std::chrono::system_clock::now()}, + pause_factor{calc_pause_factor(n_queries)}, + now{start} + { + // Wait for the first worker and submit the query immediately. + submit_query(idle_worker_ids.pop().wait(), 0); + // Submit the rest of the queries in the batch + for (uint32_t i = 1; i < n_queries; i++) { + auto promised_worker = idle_worker_ids.pop(); + uint32_t worker_id; + while (!promised_worker.test(worker_id)) { + if (pending_reads.try_pop_front(worker_id)) { + if (!try_return_worker(worker_id)) { pending_reads.push_front(worker_id); } + } else { + pause(); + } + } + submit_query(worker_id, i); + // Try to not hold too many workers in one thread + if (i >= kSoftMaxWorkersPerThread && pending_reads.try_pop_front(worker_id)) { + if (!try_return_worker(worker_id)) { pending_reads.push_front(worker_id); } + } + } + } + + inline ~launcher_t() noexcept // NOLINT + { + // bookkeeping: update the expected latency to wait more efficiently later + constexpr size_t kWindow = 100; // moving average memory + expected_latency = ((kWindow - 1) * expected_latency + now - start) / kWindow; + } + + inline void submit_query(uint32_t worker_id, uint32_t query_id) + { + worker_handles[worker_id].data.store(worker_handle_t::data_t{.value = {job_id, query_id}}, + cuda::memory_order_relaxed); + + while (!pending_reads.try_push_back(worker_id)) { + // The only reason pending_reads cannot push is that the queue is full. + // It's local, so we must pop and wait for the returned worker to finish its work. + auto pending_worker_id = pending_reads.pop_front(); + while (!try_return_worker(pending_worker_id)) { + pause(); + } + } + } + + /** Check if the worker has finished the work; if so, return it to the shared pool. */ + inline auto try_return_worker(uint32_t worker_id) -> bool + { + // Use the cached `all_done` - makes sense when called from the `wait()` routine. + if (all_done || + !is_worker_busy(worker_handles[worker_id].data.load(cuda::memory_order_relaxed).handle)) { + idle_worker_ids.push(worker_id); + return true; + } else { + return false; + } + } + + /** Check if all workers finished their work. */ + inline auto is_all_done() + { + // Cache the result of the check to avoid doing unnecessary atomic loads. + if (all_done) { return true; } + all_done = completion_flag->load(cuda::memory_order_relaxed); + return all_done; + } + + /** The launcher shouldn't attempt to wait past the returned time. */ + [[nodiscard]] inline auto sleep_limit() const + { + constexpr auto kMinWakeTime = std::chrono::nanoseconds(10000); + constexpr double kSleepLimit = 0.6; + return start + expected_latency * kSleepLimit - kMinWakeTime; + } + + /** + * When the latency is much larger than expected, it's a sign that there is a thread contention. + * Then we switch to sleeping instead of waiting to give the cpu cycles to other threads. + */ + [[nodiscard]] inline auto overtime_threshold() const + { + constexpr auto kOvertimeFactor = 3; + return start + expected_latency * kOvertimeFactor; + } + + /** + * Calculate the fraction of time can be spent sleeping in a single call to `pause()`. + * Naturally it depends on the number of queries in a batch and the number of parallel workers. + */ + [[nodiscard]] inline auto calc_pause_factor(uint32_t n_queries) const -> uint32_t + { + constexpr uint32_t kMultiplier = 10; + return kMultiplier * raft::div_rounding_up_safe(n_queries, idle_worker_ids.capacity()); + } + + /** Wait a little bit (called in a loop). */ + inline void pause() + { + // Don't sleep this many times hoping for smoother run + constexpr auto kSpinLimit = 3; + // It doesn't make much sense to slee less than this + constexpr auto kPauseTimeMin = std::chrono::nanoseconds(1000); + // Bound sleeping time + constexpr auto kPauseTimeMax = std::chrono::nanoseconds(10000000); + if (pause_count++ < kSpinLimit) { + std::this_thread::yield(); + return; + } + now = std::chrono::system_clock::now(); + auto pause_time_base = std::max(now - start, expected_latency); + auto pause_time = std::clamp(pause_time_base / pause_factor, kPauseTimeMin, kPauseTimeMax); + if (now + pause_time < sleep_limit() || now > overtime_threshold()) { + std::this_thread::sleep_for(pause_time); + } else { + std::this_thread::yield(); + } + } + + /** Wait for all work to finish and don't forget to return the workers to the shared pool. */ + inline void wait() + { + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + while (!try_return_worker(worker_id)) { + if (!is_all_done()) { pause(); } + } + } + // terminal state, should be engaged only after the `pending_reads` is empty + // and `queries_submitted == n_queries` + now = std::chrono::system_clock::now(); + while (!is_all_done()) { + auto till_time = sleep_limit(); + if (now < till_time) { + std::this_thread::sleep_until(till_time); + now = std::chrono::system_clock::now(); + } else { + pause(); + } + } + + // Return the job descriptor + job_ids.push(job_id); + } +}; + +template +struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_base_t { + using index_type = typename DATASET_DESCRIPTOR_T::INDEX_T; + using distance_type = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using data_type = typename DATASET_DESCRIPTOR_T::DATA_T; + using kernel_config_type = + search_kernel_config; + using kernel_type = typename kernel_config_type::kernel_t; + using job_desc_type = job_desc_t; + kernel_type kernel; + uint32_t block_size; + rmm::device_uvector worker_handles; + rmm::device_uvector job_descriptors; + rmm::device_uvector completion_counters; + rmm::device_uvector hashmap; + std::atomic> last_touch; + uint64_t param_hash; + + // This should be large enough to make the runner live through restarts of the benchmark cases. + // Otherwise, the benchmarks slowdown significantly. + constexpr static auto kLiveInterval = std::chrono::milliseconds(2000); + + /** + * Calculate the hash of the parameters to detect if they've changed across the calls. + * NB: this must have the same argument types as the constructor. + */ + static inline auto calculate_parameter_hash( + DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view graph, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric) -> uint64_t + { + return uint64_t(graph.data_handle()) ^ num_itopk_candidates ^ block_size ^ smem_size ^ + hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ rand_xor_mask ^ + num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ metric; + } + + persistent_runner_t(DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view graph, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric) + : persistent_runner_base_t{}, + kernel{kernel_config_type::choose_itopk_and_mx_candidates( + itopk_size, num_itopk_candidates, block_size)}, + block_size{block_size}, + worker_handles(0, stream, worker_handles_mr), + job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), + completion_counters(kMaxJobsNum, stream, device_mr), + hashmap(0, stream, device_mr), + param_hash(calculate_parameter_hash(dataset_desc, + graph, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + num_random_samplings, + rand_xor_mask, + num_seeds, + itopk_size, + search_width, + min_iterations, + max_iterations, + sample_filter, + metric)) + { + // set kernel attributes same as in normal kernel + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // set kernel launch parameters + dim3 gs = calc_coop_grid_size(block_size, smem_size); + dim3 bs(block_size, 1, 1); + RAFT_LOG_DEBUG( + "Launching persistent kernel with %u threads, %u block %u smem", bs.x, gs.y, smem_size); + + // initialize the job queue + auto* completion_counters_ptr = completion_counters.data(); + auto* job_descriptors_ptr = job_descriptors.data(); + for (uint32_t i = 0; i < kMaxJobsNum; i++) { + auto& jd = job_descriptors_ptr[i].input.value; + jd.result_indices_ptr = nullptr; + jd.result_distances_ptr = nullptr; + jd.queries_ptr = nullptr; + jd.top_k = 0; + jd.n_queries = 0; + job_descriptors_ptr[i].completion_flag.store(false); + job_queue.push(i); + } + + // initialize the worker queue + worker_queue.set_capacity(gs.y); + worker_handles.resize(gs.y, stream); + auto* worker_handles_ptr = worker_handles.data(); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (uint32_t i = 0; i < gs.y; i++) { + worker_handles_ptr[i].data.store({kWaitForWork}); + worker_queue.push(i); + } + + index_type* hashmap_ptr = nullptr; + if (small_hash_bitlen == 0) { + hashmap.resize(gs.y * hashmap::get_size(hash_bitlen), stream); + hashmap_ptr = hashmap.data(); + } + + // launch the kernel + auto* graph_ptr = graph.data_handle(); + uint32_t graph_degree = graph.extent(1); + uint32_t* num_executed_iterations = nullptr; // optional arg [num_queries] + const index_type* dev_seed_ptr = nullptr; // optional arg [num_queries, num_seeds] + + void* args[] = // NOLINT + {&dataset_desc, + &worker_handles_ptr, + &job_descriptors_ptr, + &completion_counters_ptr, + &graph_ptr, // [dataset_size, graph_degree] + &graph_degree, + &num_random_samplings, + &rand_xor_mask, + &dev_seed_ptr, + &num_seeds, + &hashmap_ptr, // visited_hashmap_ptr: [num_queries, 1 << hash_bitlen] + &itopk_size, + &search_width, + &min_iterations, + &max_iterations, + &num_executed_iterations, + &hash_bitlen, + &small_hash_bitlen, + &small_hash_reset_interval, + &sample_filter, + &metric}; + RAFT_CUDA_TRY(cudaLaunchCooperativeKernel>( + kernel, gs, bs, args, smem_size, stream)); + RAFT_LOG_INFO( + "Initialized the kernel %p in stream %zd; job_queue size = %u; worker_queue size = %u", + reinterpret_cast(kernel), + int64_t((cudaStream_t)stream), + job_queue.capacity(), + worker_queue.capacity()); + last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); + } + + ~persistent_runner_t() noexcept override + { + auto whs = worker_handles.data(); + for (auto i = worker_handles.size(); i > 0; i--) { + whs[worker_queue.pop().wait()].data.store({kNoMoreWork}, cuda::memory_order_relaxed); + } + RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream)); + RAFT_LOG_INFO("Destroyed the persistent runner."); + } + + void launch(index_type* result_indices_ptr, // [num_queries, top_k] + distance_type* result_distances_ptr, // [num_queries, top_k] + const data_type* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + uint32_t top_k) + { + // submit all queries + launcher_t launcher{ + job_queue, worker_queue, worker_handles.data(), num_queries, [=](uint32_t job_ix) { + auto& jd = job_descriptors.data()[job_ix].input.value; + auto cflag = &job_descriptors.data()[job_ix].completion_flag; + jd.result_indices_ptr = result_indices_ptr; + jd.result_distances_ptr = result_distances_ptr; + jd.queries_ptr = queries_ptr; + jd.top_k = top_k; + jd.n_queries = num_queries; + cflag->store(false, cuda::memory_order_relaxed); + cuda::atomic_thread_fence(cuda::memory_order_release, cuda::thread_scope_system); + return cflag; + }}; + + // Update the state of the keep-alive atomic in the meanwhile + auto prev_touch = last_touch.load(std::memory_order_relaxed); + if (prev_touch + kLiveInterval / 10 < launcher.now) { + // to avoid congestion at this atomic, we only update it if a significant fraction of the live + // interval has passed. + last_touch.store(launcher.now, std::memory_order_relaxed); + } + // wait for the results to arrive + launcher.wait(); + } + + auto calc_coop_grid_size(uint32_t block_size, uint32_t smem_size) -> dim3 + { + // We may need to run other kernels alongside this persistent kernel. + // So we can leave a few SMs idle. + // Note: running any other work on GPU alongside with the persistent kernel make the setup + // fragile. + // - Running another kernel in another thread usually works, but no progress guaranteed + // - Any CUDA allocations block the context (this issue may be obscured by using pools) + // - Memory copies to not-pinned host memory may block the context + // + // Even when we know there are no other kernels working at the same time, setting + // kDeviceUsage to 1.0 surprisingly sometimes hurts performance. Proceed with care. + // If you suspect this is an issue, you can reduce this number to ~0.9 without a significant + // impact on the throughput. + constexpr double kDeviceUsage = 1.0; + + // determine the grid size + int ctas_per_sm = 1; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, block_size, smem_size); + int num_sm = getMultiProcessorCount(); + auto n_blocks = static_cast(kDeviceUsage * (ctas_per_sm * num_sm)); + if (n_blocks > kMaxWorkersNum) { + RAFT_LOG_WARN("Limiting the grid size limit due to the size of the queue: %u -> %u", + n_blocks, + kMaxWorkersNum); + n_blocks = kMaxWorkersNum; + } + + return {1, n_blocks, 1}; + } +}; + +struct alignas(kCacheLineBytes) persistent_state { + std::shared_ptr runner{nullptr}; + std::mutex lock; +}; + +inline persistent_state persistent{}; + +template +auto create_runner(Args... args) -> std::shared_ptr // it's ok.. pass everything by values +{ + std::lock_guard guard(persistent.lock); + // Check if the runner has already been created + std::shared_ptr runner_outer = std::dynamic_pointer_cast(persistent.runner); + if (runner_outer) { + if (runner_outer->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner_outer; + } else { + runner_outer.reset(); + } + } + // Runner has not yet been created (or it's incompatible): + // create it in another thread and only then release the lock. + // Free the resources (if any) in advance + persistent.runner.reset(); + + cuda::std::atomic_flag ready{}; + ready.clear(cuda::std::memory_order_relaxed); + std::thread( + [&runner_outer, &ready](Args... thread_args) { // pass everything by values + // create the runner (the lock is acquired in the parent thread). + runner_outer = std::make_shared(thread_args...); + persistent.runner = std::static_pointer_cast(runner_outer); + std::weak_ptr runner_weak = runner_outer; + ready.test_and_set(cuda::std::memory_order_release); + ready.notify_one(); + // NB: runner_outer is passed by reference and may be dead by this time. + + while (true) { + std::this_thread::sleep_for(RunnerT::kLiveInterval); + auto runner = runner_weak.lock(); // runner_weak is local - thread-safe + if (!runner) { + return; // dead already + } + if (runner->last_touch.load(std::memory_order_relaxed) + RunnerT::kLiveInterval < + std::chrono::system_clock::now()) { + std::lock_guard guard(persistent.lock); + if (runner == persistent.runner) { persistent.runner.reset(); } + return; + } + } + }, + args...) + .detach(); + ready.wait(false, cuda::std::memory_order_acquire); + return runner_outer; +} + +template +auto get_runner(Args... args) -> std::shared_ptr +{ + // Using a thread-local weak pointer allows us to avoid using locks/atomics, + // since the control block of weak/shared pointers is thread-safe. + static thread_local std::weak_ptr weak; + auto runner = weak.lock(); + if (runner) { + if (runner->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner; + } else { + weak.reset(); + runner.reset(); + } + } + // Thread-local variable expected_latency makes sense only for a current RunnerT configuration. + // If `weak` is not alive, it's a hint the configuration has changed and we should reset our + // estimate of the expected launch latency. + launcher_t::expected_latency = launcher_t::kDefaultLatency; + runner = create_runner(args...); + weak = runner; + return runner; +} + template :: - choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); - RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte)); - dim3 thread_dims(block_size, 1, 1); - dim3 block_dims(1, num_queries, 1); - RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - topk, - dataset_desc, - queries_ptr, - graph.data_handle(), - graph.extent(1), - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - itopk_size, - search_width, - min_iterations, - max_iterations, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - sample_filter, - metric); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + // hack: pass the 'is_persistent' flag in the highest bit of the `rand_xor_mask` + // to avoid changing the signature of `select_and_run` and updating all its + // instantiations... + uint64_t pmask = 0x8000000000000000LL; + bool is_persistent = rand_xor_mask & pmask; + rand_xor_mask &= ~pmask; + if (is_persistent) { + using runner_type = + persistent_runner_t; + get_runner(dataset_desc, + graph, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + num_random_samplings, + rand_xor_mask, + num_seeds, + itopk_size, + search_width, + min_iterations, + max_iterations, + sample_filter, + metric) + ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); + } else { + auto kernel = + search_kernel_config::choose_itopk_and_mx_candidates(itopk_size, + num_itopk_candidates, + block_size); + RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte)); + dim3 thread_dims(block_size, 1, 1); + dim3 block_dims(1, num_queries, 1); + RAFT_LOG_DEBUG( + "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); + kernel<<>>(topk_indices_ptr, + topk_distances_ptr, + topk, + dataset_desc, + queries_ptr, + graph.data_handle(), + graph.extent(1), + num_random_samplings, + rand_xor_mask, + dev_seed_ptr, + num_seeds, + hashmap_ptr, + itopk_size, + search_width, + min_iterations, + max_iterations, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + sample_filter, + metric); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } } } // namespace single_cta_search } // namespace raft::neighbors::cagra::detail