From 515ac621426c8ee08efedc439fd243f064dc5bee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:16:44 +0100 Subject: [PATCH] IVF-FLAT support k > 256 (#2169) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for topk > 256 for ivf_flat (Issue [#1555](https://github.com/rapidsai/raft/issues/1555)) The PR adds a non-fused version of topk that is utilized if k > 256. FYI, @tfeher Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2169 --- .../neighbors/detail/cagra/cagra_search.cuh | 17 +- .../raft/neighbors/detail/ivf_common.cuh | 325 ++++++++++++++++++ .../raft/neighbors/detail/ivf_flat_build.cuh | 9 +- .../detail/ivf_flat_interleaved_scan-ext.cuh | 6 + .../detail/ivf_flat_interleaved_scan-inl.cuh | 132 +++++-- .../neighbors/detail/ivf_flat_search-inl.cuh | 103 ++++-- .../neighbors/detail/ivf_flat_serialize.cuh | 5 +- .../raft/neighbors/detail/ivf_pq_build.cuh | 64 +--- .../detail/ivf_pq_compute_similarity-inl.cuh | 22 +- .../detail/ivf_pq_dummy_block_sort.cuh | 39 --- .../raft/neighbors/detail/ivf_pq_search.cuh | 252 ++------------ .../neighbors/detail/ivf_pq_serialize.cuh | 4 +- .../raft/neighbors/detail/refine_device.cuh | 3 + .../raft/neighbors/ivf_flat_helpers.cuh | 35 +- cpp/include/raft/neighbors/ivf_flat_types.hpp | 89 +++-- cpp/include/raft/neighbors/ivf_pq_helpers.cuh | 6 +- ...at_interleaved_scan_float_float_int64_t.cu | 4 +- ...flat_interleaved_scan_half_half_int64_t.cu | 2 + ...interleaved_scan_int8_t_int32_t_int64_t.cu | 4 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 4 +- cpp/test/neighbors/ann_ivf_flat.cuh | 21 +- .../pylibraft/pylibraft/test/test_ivf_flat.py | 10 +- 22 files changed, 694 insertions(+), 462 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/ivf_common.cuh delete mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 40cc7c76fb..bfacceae29 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include @@ -181,13 +181,14 @@ void search_main(raft::resources const& res, // and divide the values by kDivisor. Here we restore the original scale. constexpr float kScale = spatial::knn::detail::utils::config::kDivisor / spatial::knn::detail::utils::config::kDivisor; - ivf_pq::detail::postprocess_distances(dist_out, - dist_in, - index.metric(), - distances.extent(0), - distances.extent(1), - kScale, - resource::get_cuda_stream(res)); + ivf::detail::postprocess_distances(dist_out, + dist_in, + index.metric(), + distances.extent(0), + distances.extent(1), + kScale, + true, + resource::get_cuda_stream(res)); } /** @} */ // end group cagra diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh new file mode 100644 index 0000000000..d7eb80084e --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include // matrix::detail::select::warpsort::warp_sort_distributed + +namespace raft::neighbors::ivf::detail { + +/** + * Default value returned by `search` when the `n_probes` is too small and top-k is too large. + * One may encounter it if the combined size of probed clusters is smaller than the requested + * number of results per query. + */ +template +constexpr static IdxT kOutOfBoundsRecord = std::numeric_limits::max(); + +template +struct dummy_block_sort_t { + using queue_t = + matrix::detail::select::warpsort::warp_sort_distributed; + template + __device__ dummy_block_sort_t(int k, Args...){}; +}; + +/** + * For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that + * in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total + * number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples. + */ +template +__launch_bounds__(BlockDim) RAFT_KERNEL + calc_chunk_indices_kernel(uint32_t n_probes, + const uint32_t* cluster_sizes, // [n_clusters] + const uint32_t* clusters_to_probe, // [n_queries, n_probes] + uint32_t* chunk_indices, // [n_queries, n_probes] + uint32_t* n_samples // [n_queries] + ) +{ + using block_scan = cub::BlockScan; + __shared__ typename block_scan::TempStorage shm; + + // locate the query data + clusters_to_probe += n_probes * blockIdx.x; + chunk_indices += n_probes * blockIdx.x; + + // block scan + const uint32_t n_probes_aligned = Pow2::roundUp(n_probes); + uint32_t total = 0; + for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) { + auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u; + auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u; + if (threadIdx.x == 0) { chunk += total; } + block_scan(shm).InclusiveSum(chunk, chunk, total); + __syncthreads(); + if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; } + } + // save the total size + if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; } +} + +struct calc_chunk_indices { + public: + struct configured { + void* kernel; + dim3 block_dim; + dim3 grid_dim; + uint32_t n_probes; + + inline void operator()(const uint32_t* cluster_sizes, + const uint32_t* clusters_to_probe, + uint32_t* chunk_indices, + uint32_t* n_samples, + rmm::cuda_stream_view stream) + { + void* args[] = // NOLINT + {&n_probes, &cluster_sizes, &clusters_to_probe, &chunk_indices, &n_samples}; + RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream)); + } + }; + + static inline auto configure(uint32_t n_probes, uint32_t n_queries) -> configured + { + return try_block_dim<1024>(n_probes, n_queries); + } + + private: + template + static auto try_block_dim(uint32_t n_probes, uint32_t n_queries) -> configured + { + if constexpr (BlockDim >= WarpSize * 2) { + if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); } + } + return {reinterpret_cast(calc_chunk_indices_kernel), + dim3(BlockDim, 1, 1), + dim3(n_queries, 1, 1), + n_probes}; + } +}; + +/** + * Look up the chunk id corresponding to the sample index. + * + * Each query vector was compared to all the vectors from n_probes clusters, and sample_ix is an + * ordered number of one of such vectors. This function looks up to which chunk it belongs, + * and returns the index within the chunk (which is also an index within a cluster). + * + * @param[inout] sample_ix + * input: the offset of the sample in the batch; + * output: the offset inside the chunk (probe) / selected cluster. + * @param[in] n_probes number of probes + * @param[in] chunk_indices offsets of the chunks within the batch [n_probes] + * @return chunk index (== n_probes when the input index is not in the valid range, + * which can happen if there is not enough data to output in the selected clusters). + */ +__device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT + uint32_t n_probes, + const uint32_t* chunk_indices) -> uint32_t +{ + uint32_t ix_min = 0; + uint32_t ix_max = n_probes; + do { + uint32_t i = (ix_min + ix_max) / 2; + if (chunk_indices[i] <= sample_ix) { + ix_min = i + 1; + } else { + ix_max = i; + } + } while (ix_min < ix_max); + if (ix_min > 0) { sample_ix -= chunk_indices[ix_min - 1]; } + return ix_min; +} + +template +__launch_bounds__(BlockDim) RAFT_KERNEL + postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk] + const IdxT2* neighbors_in, // [n_queries, topk] + const IdxT1* const* db_indices, // [n_clusters][..] + const uint32_t* clusters_to_probe, // [n_queries, n_probes] + const uint32_t* chunk_indices, // [n_queries, n_probes] + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk) +{ + const uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x); + const uint32_t query_ix = i / uint64_t(topk); + if (query_ix >= n_queries) { return; } + const uint32_t k = i % uint64_t(topk); + neighbors_in += query_ix * topk; + neighbors_out += query_ix * topk; + chunk_indices += query_ix * n_probes; + clusters_to_probe += query_ix * n_probes; + uint32_t data_ix = neighbors_in[k]; + const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); + const bool valid = chunk_ix < n_probes; + neighbors_out[k] = + valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; +} + +/** + * Transform found sample indices into the corresponding database indices + * (as stored in index.indices()). + * The sample indices are the record indices as they appear in the database view formed by the + * probed clusters / defined by the `chunk_indices`. + * We assume the searched sample sizes (for a single query) fit into `uint32_t`. + */ +template +void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk] + const IdxT2* neighbors_in, // [n_queries, topk] + const IdxT1* const* db_indices, // [n_clusters][..] + const uint32_t* clusters_to_probe, // [n_queries, n_probes] + const uint32_t* chunk_indices, // [n_queries, n_probes] + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk, + rmm::cuda_stream_view stream) +{ + constexpr int kPNThreads = 256; + const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); + postprocess_neighbors_kernel + <<>>(neighbors_out, + neighbors_in, + db_indices, + clusters_to_probe, + chunk_indices, + n_queries, + n_probes, + topk); +} + +/** + * Post-process the scores depending on the metric type; + * translate the element type if necessary. + */ +template +void postprocess_distances(ScoreOutT* out, // [n_queries, topk] + const ScoreInT* in, // [n_queries, topk] + distance::DistanceType metric, + uint32_t n_queries, + uint32_t topk, + float scaling_factor, + bool account_for_max_close, + rmm::cuda_stream_view stream) +{ + constexpr bool needs_cast = !std::is_same::value; + const bool needs_copy = ((void*)in) != ((void*)out); + size_t len = size_t(n_queries) * size_t(topk); + switch (metric) { + case distance::DistanceType::L2Unexpanded: + case distance::DistanceType::L2Expanded: { + if (scaling_factor != 1.0) { + linalg::unaryOp( + out, + in, + len, + raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, + raft::cast_op{}), + stream); + } else if (needs_cast || needs_copy) { + linalg::unaryOp(out, in, len, raft::cast_op{}, stream); + } + } break; + case distance::DistanceType::L2SqrtUnexpanded: + case distance::DistanceType::L2SqrtExpanded: { + if (scaling_factor != 1.0) { + linalg::unaryOp(out, + in, + len, + raft::compose_op{raft::mul_const_op{scaling_factor}, + raft::sqrt_op{}, + raft::cast_op{}}, + stream); + } else if (needs_cast) { + linalg::unaryOp( + out, in, len, raft::compose_op{raft::sqrt_op{}, raft::cast_op{}}, stream); + } else { + linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); + } + } break; + case distance::DistanceType::InnerProduct: { + float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor; + if (factor != 1.0) { + linalg::unaryOp( + out, + in, + len, + raft::compose_op(raft::mul_const_op{factor}, raft::cast_op{}), + stream); + } else if (needs_cast || needs_copy) { + linalg::unaryOp(out, in, len, raft::cast_op{}, stream); + } + } break; + default: RAFT_FAIL("Unexpected metric."); + } +} + +/** Update the state of the dependent index members. */ +template +void recompute_internal_state(const raft::resources& res, Index& index) +{ + auto stream = resource::get_cuda_stream(res); + auto tmp_res = resource::get_workspace_resource(res); + rmm::device_uvector sorted_sizes(index.n_lists(), stream, tmp_res); + + // Actualize the list pointers + auto data_ptrs = index.data_ptrs(); + auto inds_ptrs = index.inds_ptrs(); + for (uint32_t label = 0; label < index.n_lists(); label++) { + auto& list = index.lists()[label]; + const auto data_ptr = list ? list->data.data_handle() : nullptr; + const auto inds_ptr = list ? list->indices.data_handle() : nullptr; + copy(&data_ptrs(label), &data_ptr, 1, stream); + copy(&inds_ptrs(label), &inds_ptr, 1, stream); + } + + // Sort the cluster sizes in the descending order. + int begin_bit = 0; + int end_bit = sizeof(uint32_t) * 8; + size_t cub_workspace_size = 0; + cub::DeviceRadixSort::SortKeysDescending(nullptr, + cub_workspace_size, + index.list_sizes().data_handle(), + sorted_sizes.data(), + index.n_lists(), + begin_bit, + end_bit, + stream); + rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res); + cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(), + cub_workspace_size, + index.list_sizes().data_handle(), + sorted_sizes.data(), + index.n_lists(), + begin_bit, + end_bit, + stream); + // copy the results to CPU + std::vector sorted_sizes_host(index.n_lists()); + copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream); + resource::sync_stream(res); + + // accumulate the sorted cluster sizes + auto accum_sorted_sizes = index.accum_sorted_sizes(); + accum_sorted_sizes(0) = 0; + for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) { + accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label]; + } +} + +} // namespace raft::neighbors::ivf::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index fa11d9236f..55184cc615 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -75,7 +76,7 @@ auto clone(const raft::resources& res, const index& source) -> indexrecompute_internal_state(handle); + ivf::detail::recompute_internal_state(handle, *index); // Copy the old sizes, so we can start from the current state of the index; // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); @@ -355,6 +356,8 @@ inline auto build(raft::resources const& handle, RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); index index(handle, params, dim); + utils::memzero( + index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); @@ -442,7 +445,7 @@ inline void fill_refinement_index(raft::resources const& handle, ivf::resize_list(handle, lists[label], list_device_spec, n_candidates, uint32_t(0)); } // Update the pointers and the sizes - refinement_index->recompute_internal_state(handle); + ivf::detail::recompute_internal_state(handle, *refinement_index); RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 58e94ee7aa..e87c2c0d56 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -28,6 +28,8 @@ namespace raft::neighbors::ivf_flat::detail { +auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool; + template void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, const T* queries, @@ -37,6 +39,8 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, + const uint32_t max_samples, + const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, IdxT* neighbors, @@ -60,6 +64,8 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const raft::distance::DistanceType metric, \ const uint32_t n_probes, \ const uint32_t k, \ + const uint32_t max_samples, \ + const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ IdxT* neighbors, \ diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 51cd2876d8..715f046bd4 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,11 @@ using namespace raft::spatial::knn::detail; // NOLINT constexpr int kThreadsPerBlock = 128; +auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool +{ + return k <= matrix::detail::select::warpsort::kMaxCapacity; +} + /** * @brief Copy `n` elements per block from one place to another. * @@ -627,6 +633,23 @@ struct loadAndComputeDist { } }; +// switch to dummy blocksort when Capacity is 0 this explicit dummy is chosen +// to support access to warpsort constants like ::queue_t::kDummy +template +struct flat_block_sort { + using type = matrix::detail::select::warpsort:: + block_sort; +}; + +template +struct flat_block_sort<0, Ascending, T, IdxT> + : ivf::detail::dummy_block_sort_t { + using type = ivf::detail::dummy_block_sort_t; +}; + +template +using block_sort_t = typename flat_block_sort::type; + /** * Scan clusters for nearest neighbors of the query vectors. * See `ivfflat_interleaved_scan` for more information. @@ -672,12 +695,15 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, + const uint32_t max_samples, + const uint32_t* chunk_indices, const uint32_t dim, IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances) { extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; + constexpr bool kManageLocalTopK = Capacity > 0; // Using shared memory for the (part of the) query; // This allows to save on global memory bandwidth when reading index and query // data at the same time. @@ -687,8 +713,13 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) { const int query_id = blockIdx.y; query += query_id * dim; - neighbors += query_id * k * gridDim.x + blockIdx.x * k; - distances += query_id * k * gridDim.x + blockIdx.x * k; + if constexpr (kManageLocalTopK) { + neighbors += query_id * k * gridDim.x + blockIdx.x * k; + distances += query_id * k * gridDim.x + blockIdx.x * k; + } else { + distances += query_id * uint64_t(max_samples); + chunk_indices += (n_probes * query_id); + } coarse_index += query_id * n_probes; } @@ -696,14 +727,8 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - using block_sort_t = matrix::detail::select::warpsort::block_sort< - matrix::detail::select::warpsort::warp_sort_filtered, - Capacity, - Ascending, - float, - IdxT>; - block_sort_t queue(k); - + using local_topk_t = block_sort_t; + local_topk_t queue(k); { using align_warp = Pow2; const int lane_id = align_warp::mod(threadIdx.x); @@ -725,6 +750,13 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t num_groups = align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 + uint32_t sample_offset = 0; + if constexpr (!kManageLocalTopK) { + if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } + assert(list_length == chunk_indices[probe_id] - sample_offset); + assert(sample_offset + list_length <= max_samples); + } + constexpr int kUnroll = WarpSize / Veclen; constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; // Every warp reads WarpSize vectors and computes the distances to them. @@ -771,17 +803,33 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) } // Enqueue one element per thread - const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); + const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; + if constexpr (kManageLocalTopK) { + const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; + queue.add(val, idx); + } else { + if (vec_id < list_length) distances[sample_offset + vec_id] = val; + } + } + + // fill up unused slots for current query + if constexpr (!kManageLocalTopK) { + if (probe_id + 1 == n_probes) { + for (uint32_t i = threadIdx.x + sample_offset + list_length; i < max_samples; + i += blockDim.x) { + distances[i] = local_topk_t::queue_t::kDummy; + } + } } } } // finalize and store selected neighbours - __syncthreads(); - queue.done(interleaved_scan_kernel_smem); - queue.store(distances, neighbors, post_process); + if constexpr (kManageLocalTopK) { + __syncthreads(); + queue.done(interleaved_scan_kernel_smem); + queue.store(distances, neighbors, post_process); + } } /** @@ -821,6 +869,8 @@ void launch_kernel(Lambda lambda, const uint32_t queries_offset, const uint32_t n_probes, const uint32_t k, + const uint32_t max_samples, + const uint32_t* chunk_indices, IvfSampleFilterT sample_filter, IdxT* neighbors, float* distances, @@ -841,12 +891,15 @@ void launch_kernel(Lambda lambda, const int max_query_smem = 16384; int query_smem_elems = std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); - int smem_size = query_smem_elems * sizeof(T); - constexpr int kSubwarpSize = std::min(Capacity, WarpSize); - auto block_merge_mem = - raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( - kThreadsPerBlock / kSubwarpSize, k); - smem_size += std::max(smem_size, block_merge_mem); + int smem_size = query_smem_elems * sizeof(T); + + if constexpr (Capacity > 0) { + constexpr int kSubwarpSize = std::min(Capacity, WarpSize); + auto block_merge_mem = + raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( + kThreadsPerBlock / kSubwarpSize, k); + smem_size += std::max(smem_size, block_merge_mem); + } // power-of-two less than cuda limit (for better addr alignment) constexpr uint32_t kMaxGridY = 32768; @@ -879,13 +932,20 @@ void launch_kernel(Lambda lambda, queries_offset + query_offset, n_probes, k, + max_samples, + chunk_indices, index.dim(), sample_filter, neighbors, distances); queries += grid_dim_y * index.dim(); - neighbors += grid_dim_y * grid_dim_x * k; - distances += grid_dim_y * grid_dim_x * k; + if constexpr (Capacity > 0) { + neighbors += grid_dim_y * grid_dim_x * k; + distances += grid_dim_y * grid_dim_x * k; + } else { + distances += grid_dim_y * max_samples; + chunk_indices += grid_dim_y * n_probes; + } coarse_index += grid_dim_y * n_probes; } } @@ -1010,16 +1070,22 @@ struct select_interleaved_scan_kernel { * two parameters and ends with both values equal to 1. */ template - static inline void run(int capacity, int veclen, bool select_min, Args&&... args) + static inline void run(int k_max, int veclen, bool select_min, Args&&... args) { + if constexpr (Capacity > 0) { + if (k_max == 0 || k_max > Capacity) { + return select_interleaved_scan_kernel::run( + k_max, veclen, select_min, std::forward(args)...); + } + } if constexpr (Capacity > 1) { - if (capacity * 2 <= Capacity) { + if (k_max * 2 <= Capacity) { return select_interleaved_scan_kernel::run(capacity, + Veclen>::run(k_max, veclen, select_min, std::forward(args)...); @@ -1028,14 +1094,14 @@ struct select_interleaved_scan_kernel { if constexpr (Veclen > 1) { if (veclen % Veclen != 0) { return select_interleaved_scan_kernel::run( - capacity, 1, select_min, std::forward(args)...); + k_max, 1, select_min, std::forward(args)...); } } // NB: this is the limitation of the warpsort structures that use a huge number of // registers (used in the main kernel here). - RAFT_EXPECTS(capacity == Capacity, - "Capacity must be power-of-two not bigger than the maximum allowed size " - "matrix::detail::select::warpsort::kMaxCapacity (%d).", + RAFT_EXPECTS(Capacity == 0 || k_max == Capacity, + "Capacity must be either 0 or a power-of-two not bigger than the maximum " + "allowed size matrix::detail::select::warpsort::kMaxCapacity (%d).", matrix::detail::select::warpsort::kMaxCapacity); RAFT_EXPECTS( veclen == Veclen, @@ -1090,6 +1156,8 @@ void ivfflat_interleaved_scan(const index& index, const raft::distance::DistanceType metric, const uint32_t n_probes, const uint32_t k, + const uint32_t max_samples, + const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, IdxT* neighbors, @@ -1112,6 +1180,8 @@ void ivfflat_interleaved_scan(const index& index, queries_offset, n_probes, k, + max_samples, + chunk_indices, filter_adapter, neighbors, distances, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 29d521566d..0f359a0260 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -24,6 +24,7 @@ #include // raft::linalg::norm #include // raft::linalg::unary_op #include // matrix::detail::select_k +#include // raft::neighbors::detail::ivf #include // interleaved_scan #include // raft::neighbors::ivf_flat::index #include // none_ivf_sample_filter @@ -42,6 +43,7 @@ void search_impl(raft::resources const& handle, uint32_t queries_offset, uint32_t k, uint32_t n_probes, + uint32_t max_samples, bool select_min, IdxT* neighbors, AccT* distances, @@ -49,18 +51,27 @@ void search_impl(raft::resources const& handle, IvfSampleFilterT sample_filter) { auto stream = resource::get_cuda_stream(handle); + + std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes); + // The norm of query rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); // The distance value of cluster(list) and queries rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); // The topk distance value of cluster(list) and queries - rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); + rmm::device_uvector coarse_distances_dev(n_queries_probes, stream, search_mr); // The topk index of cluster(list) and queries - rmm::device_uvector coarse_indices_dev(n_queries * n_probes, stream, search_mr); + rmm::device_uvector coarse_indices_dev(n_queries_probes, stream, search_mr); + + // Optional structures if postprocessing is required // The topk distance value of candidate vectors from each cluster(list) - rmm::device_uvector refined_distances_dev(n_queries * n_probes * k, stream, search_mr); + rmm::device_uvector distances_tmp_dev(0, stream, search_mr); // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector refined_indices_dev(n_queries * n_probes * k, stream, search_mr); + rmm::device_uvector indices_tmp_dev(0, stream, search_mr); + // Number of samples for each query + rmm::device_uvector num_samples(0, stream, search_mr); + // Offsets per probe for each query + rmm::device_uvector chunk_index(0, stream, search_mr); size_t float_query_size; if constexpr (std::is_integral_v) { @@ -139,9 +150,6 @@ void search_impl(raft::resources const& handle, RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); - auto distances_dev_ptr = refined_distances_dev.data(); - auto indices_dev_ptr = refined_indices_dev.data(); - uint32_t grid_dim_x = 0; if (n_probes > 1) { // query the gridDimX size to store probes topK output @@ -154,6 +162,8 @@ void search_impl(raft::resources const& handle, index.metric(), n_probes, k, + 0, + nullptr, select_min, sample_filter, nullptr, @@ -164,9 +174,30 @@ void search_impl(raft::resources const& handle, grid_dim_x = 1; } - if (grid_dim_x == 1) { - distances_dev_ptr = distances; - indices_dev_ptr = neighbors; + auto distances_dev_ptr = distances; + auto indices_dev_ptr = neighbors; + + bool manage_local_topk = is_local_topk_feasible(k); + if (!manage_local_topk || grid_dim_x > 1) { + if (!manage_local_topk) { + num_samples.resize(n_queries, stream); + chunk_index.resize(n_queries_probes, stream); + + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( + index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); + } + + auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples); + + distances_tmp_dev.resize(target_size, stream); + if (manage_local_topk) indices_tmp_dev.resize(target_size, stream); + + distances_dev_ptr = distances_tmp_dev.data(); + indices_dev_ptr = indices_tmp_dev.data(); } ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( @@ -178,6 +209,8 @@ void search_impl(raft::resources const& handle, index.metric(), n_probes, k, + max_samples, + chunk_index.data(), select_min, sample_filter, indices_dev_ptr, @@ -186,19 +219,34 @@ void search_impl(raft::resources const& handle, stream); RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); - RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); + if (indices_dev_ptr != nullptr) { RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); } // Merge topk values from different blocks - if (grid_dim_x > 1) { + if (!manage_local_topk || grid_dim_x > 1) { matrix::detail::select_k(handle, - refined_distances_dev.data(), - refined_indices_dev.data(), + distances_tmp_dev.data(), + indices_tmp_dev.data(), n_queries, - k * grid_dim_x, + manage_local_topk ? (k * grid_dim_x) : max_samples, k, distances, neighbors, select_min); + + if (!manage_local_topk) { + // post process distances && neighbor IDs + ivf::detail::postprocess_distances( + distances, distances, index.metric(), n_queries, k, 1.0, false, stream); + ivf::detail::postprocess_neighbors(neighbors, + neighbors, + index.inds_ptrs().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + n_queries, + n_probes, + k, + stream); + } } } @@ -223,14 +271,28 @@ inline void search(raft::resources const& handle, if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } RAFT_EXPECTS(params.n_probes > 0, "n_probes (number of clusters to probe in the search) must be positive."); - auto n_probes = std::min(params.n_probes, index.n_lists()); + auto n_probes = std::min(params.n_probes, index.n_lists()); + bool manage_local_topk = is_local_topk_feasible(k); + + uint32_t max_samples = 0; + if (!manage_local_topk) { + IdxT ms = + Pow2<128 / sizeof(float)>::roundUp(std::max(index.accum_sorted_sizes()(n_probes), k)); + RAFT_EXPECTS(ms <= IdxT(std::numeric_limits::max()), + "The maximum sample size is too big."); + max_samples = ms; + } // a batch size heuristic: try to keep the workspace within the specified size - constexpr uint32_t kExpectedWsSize = 1024 * 1024 * 1024; + constexpr uint64_t kExpectedWsSize = 1024 * 1024 * 1024; + uint64_t max_ws_size = std::min(resource::get_workspace_free_bytes(handle), kExpectedWsSize); + + uint64_t ws_size_per_query = 4ull * (2 * n_probes + index.n_lists() + index.dim() + 1) + + (manage_local_topk ? ((sizeof(IdxT) + 4) * n_probes * k) + : (4ull * (max_samples + n_probes + 1))); + const uint32_t max_queries = - std::min(n_queries, - raft::div_rounding_up_safe( - kExpectedWsSize, 16ull * uint64_t{n_probes} * k + 4ull * index.dim())); + std::min(n_queries, raft::div_rounding_up_safe(max_ws_size, ws_size_per_query)); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); @@ -242,6 +304,7 @@ inline void search(raft::resources const& handle, offset_q, k, n_probes, + max_samples, raft::distance::is_min_close(index.metric()), neighbors + offset_q * k, distances + offset_q * k, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index aaf48ae830..3897b83aa6 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -153,7 +154,7 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index #include +#include #include #include #include @@ -1363,59 +1364,6 @@ void process_and_fill_codes(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } -/** Update the state of the dependent index members. */ -template -void recompute_internal_state(const raft::resources& res, index& index) -{ - auto stream = resource::get_cuda_stream(res); - auto tmp_res = resource::get_workspace_resource(res); - rmm::device_uvector sorted_sizes(index.n_lists(), stream, tmp_res); - - // Actualize the list pointers - auto data_ptrs = index.data_ptrs(); - auto inds_ptrs = index.inds_ptrs(); - for (uint32_t label = 0; label < index.n_lists(); label++) { - auto& list = index.lists()[label]; - const auto data_ptr = list ? list->data.data_handle() : nullptr; - const auto inds_ptr = list ? list->indices.data_handle() : nullptr; - copy(&data_ptrs(label), &data_ptr, 1, stream); - copy(&inds_ptrs(label), &inds_ptr, 1, stream); - } - - // Sort the cluster sizes in the descending order. - int begin_bit = 0; - int end_bit = sizeof(uint32_t) * 8; - size_t cub_workspace_size = 0; - cub::DeviceRadixSort::SortKeysDescending(nullptr, - cub_workspace_size, - index.list_sizes().data_handle(), - sorted_sizes.data(), - index.n_lists(), - begin_bit, - end_bit, - stream); - rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res); - cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(), - cub_workspace_size, - index.list_sizes().data_handle(), - sorted_sizes.data(), - index.n_lists(), - begin_bit, - end_bit, - stream); - // copy the results to CPU - std::vector sorted_sizes_host(index.n_lists()); - copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream); - resource::sync_stream(res); - - // accumulate the sorted cluster sizes - auto accum_sorted_sizes = index.accum_sorted_sizes(); - accum_sorted_sizes(0) = 0; - for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) { - accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label]; - } -} - /** * Helper function: allocate enough space in the list, compute the offset, at which to start * writing, and fill-in indices. @@ -1463,7 +1411,7 @@ void extend_list_with_codes(raft::resources const& res, // Pack the data pack_list_data(res, index, new_codes, label, offset); // Update the pointers and the sizes - recompute_internal_state(res, *index); + ivf::detail::recompute_internal_state(res, *index); } /** @@ -1482,7 +1430,7 @@ void extend_list(raft::resources const& res, // Encode the data encode_list_data(res, index, new_vectors, label, offset); // Update the pointers and the sizes - recompute_internal_state(res, *index); + ivf::detail::recompute_internal_state(res, *index); } /** @@ -1495,7 +1443,7 @@ void erase_list(raft::resources const& res, index* index, uint32_t label) uint32_t zero = 0; copy(index->list_sizes().data_handle() + label, &zero, 1, resource::get_cuda_stream(res)); index->lists()[label].reset(); - recompute_internal_state(res, *index); + ivf::detail::recompute_internal_state(res, *index); } /** Copy the state of an index into a new index, but share the list data among the two. */ @@ -1539,7 +1487,7 @@ auto clone(const raft::resources& res, const index& source) -> index target.lists() = source.lists(); // Make sure the device pointers point to the new lists - recompute_internal_state(res, target); + ivf::detail::recompute_internal_state(res, target); return target; } @@ -1688,7 +1636,7 @@ void extend(raft::resources const& handle, } // Update the pointers and the sizes - recompute_internal_state(handle, *index); + ivf::detail::recompute_internal_state(handle, *index); // Recover old cluster sizes: they are used as counters in the fill-codes kernel copy(list_sizes, orig_list_sizes.data(), n_clusters, stream); diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh index bd88c029e1..193cd31485 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,14 +18,14 @@ #include // raft::distance::DistanceType #include // matrix::detail::select::warpsort::warp_sort_distributed -#include // dummy_block_sort_t -#include // codebook_gen -#include // none_ivf_sample_filter -#include // RAFT_CUDA_TRY -#include // raft::atomicMin -#include // raft::Pow2 -#include // raft::TxN_t -#include // rmm::cuda_stream_view +#include // dummy_block_sort_t +#include // codebook_gen +#include // none_ivf_sample_filter +#include // RAFT_CUDA_TRY +#include // raft::atomicMin +#include // raft::Pow2 +#include // raft::TxN_t +#include // rmm::cuda_stream_view namespace raft::neighbors::ivf_pq::detail { @@ -72,8 +72,8 @@ struct pq_block_sort { }; template -struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { - using type = dummy_block_sort_t; +struct pq_block_sort<0, T, IdxT> : ivf::detail::dummy_block_sort_t { + using type = ivf::detail::dummy_block_sort_t; static auto mem_required(uint32_t) -> size_t { return 0; } static auto get_mem_required(uint32_t) { return mem_required; } }; diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh deleted file mode 100644 index a00b6a50ff..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // matrix::detail::select::warpsort::warp_sort_distributed - -/* - * This header file is a bit of an ugly duckling. The type dummy_block_sort is - * needed by both ivf_pq_search.cuh and ivf_pq_compute_similarity.cuh. - * - * I have decided to move it to it's own header file, which is overkill. Perhaps - * there is a nicer solution. - * - */ - -namespace raft::neighbors::ivf_pq::detail { - -template -struct dummy_block_sort_t { - using queue_t = matrix::detail::select::warpsort::warp_sort_distributed; - template - __device__ dummy_block_sort_t(int k, Args...){}; -}; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d000a1a4d3..e5294c1fa1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -20,8 +20,8 @@ #include #include +#include #include -#include #include #include #include @@ -171,217 +171,6 @@ void select_clusters(raft::resources const& handle, true); } -/** - * For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that - * in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total - * number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples. - */ -template -__launch_bounds__(BlockDim) RAFT_KERNEL - calc_chunk_indices_kernel(uint32_t n_probes, - const uint32_t* cluster_sizes, // [n_clusters] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t* n_samples // [n_queries] - ) -{ - using block_scan = cub::BlockScan; - __shared__ typename block_scan::TempStorage shm; - - // locate the query data - clusters_to_probe += n_probes * blockIdx.x; - chunk_indices += n_probes * blockIdx.x; - - // block scan - const uint32_t n_probes_aligned = Pow2::roundUp(n_probes); - uint32_t total = 0; - for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) { - auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u; - auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u; - if (threadIdx.x == 0) { chunk += total; } - block_scan(shm).InclusiveSum(chunk, chunk, total); - __syncthreads(); - if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; } - } - // save the total size - if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; } -} - -struct calc_chunk_indices { - public: - struct configured { - void* kernel; - dim3 block_dim; - dim3 grid_dim; - uint32_t n_probes; - - inline void operator()(const uint32_t* cluster_sizes, - const uint32_t* clusters_to_probe, - uint32_t* chunk_indices, - uint32_t* n_samples, - rmm::cuda_stream_view stream) - { - void* args[] = // NOLINT - {&n_probes, &cluster_sizes, &clusters_to_probe, &chunk_indices, &n_samples}; - RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream)); - } - }; - - static inline auto configure(uint32_t n_probes, uint32_t n_queries) -> configured - { - return try_block_dim<1024>(n_probes, n_queries); - } - - private: - template - static auto try_block_dim(uint32_t n_probes, uint32_t n_queries) -> configured - { - if constexpr (BlockDim >= WarpSize * 2) { - if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); } - } - return {reinterpret_cast(calc_chunk_indices_kernel), - dim3(BlockDim, 1, 1), - dim3(n_queries, 1, 1), - n_probes}; - } -}; - -/** - * Look up the chunk id corresponding to the sample index. - * - * Each query vector was compared to all the vectors from n_probes clusters, and sample_ix is an - * ordered number of one of such vectors. This function looks up to which chunk it belongs, - * and returns the index within the chunk (which is also an index within a cluster). - * - * @param[inout] sample_ix - * input: the offset of the sample in the batch; - * output: the offset inside the chunk (probe) / selected cluster. - * @param[in] n_probes number of probes - * @param[in] chunk_indices offsets of the chunks within the batch [n_probes] - * @return chunk index (== n_probes when the input index is not in the valid range, - * which can happen if there is not enough data to output in the selected clusters). - */ -__device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT - uint32_t n_probes, - const uint32_t* chunk_indices) -> uint32_t -{ - uint32_t ix_min = 0; - uint32_t ix_max = n_probes; - do { - uint32_t i = (ix_min + ix_max) / 2; - if (chunk_indices[i] <= sample_ix) { - ix_min = i + 1; - } else { - ix_max = i; - } - } while (ix_min < ix_max); - if (ix_min > 0) { sample_ix -= chunk_indices[ix_min - 1]; } - return ix_min; -} - -template -__launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] - const uint32_t* neighbors_in, // [n_queries, topk] - const IdxT* const* db_indices, // [n_clusters][..] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) -{ - const uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x); - const uint32_t query_ix = i / uint64_t(topk); - if (query_ix >= n_queries) { return; } - const uint32_t k = i % uint64_t(topk); - neighbors_in += query_ix * topk; - neighbors_out += query_ix * topk; - chunk_indices += query_ix * n_probes; - clusters_to_probe += query_ix * n_probes; - uint32_t data_ix = neighbors_in[k]; - const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); - const bool valid = chunk_ix < n_probes; - neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : ivf_pq::kOutOfBoundsRecord; -} - -/** - * Transform found sample indices into the corresponding database indices - * (as stored in index.indices()). - * The sample indices are the record indices as they appear in the database view formed by the - * probed clusters / defined by the `chunk_indices`. - * We assume the searched sample sizes (for a single query) fit into `uint32_t`. - */ -template -void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] - const uint32_t* neighbors_in, // [n_queries, topk] - const IdxT* const* db_indices, // [n_clusters][..] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk, - rmm::cuda_stream_view stream) -{ - constexpr int kPNThreads = 256; - const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel - <<>>(neighbors_out, - neighbors_in, - db_indices, - clusters_to_probe, - chunk_indices, - n_queries, - n_probes, - topk); -} - -/** - * Post-process the scores depending on the metric type; - * translate the element type if necessary. - */ -template -void postprocess_distances(float* out, // [n_queries, topk] - const ScoreT* in, // [n_queries, topk] - distance::DistanceType metric, - uint32_t n_queries, - uint32_t topk, - float scaling_factor, - rmm::cuda_stream_view stream) -{ - size_t len = size_t(n_queries) * size_t(topk); - switch (metric) { - case distance::DistanceType::L2Unexpanded: - case distance::DistanceType::L2Expanded: { - linalg::unaryOp(out, - in, - len, - raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, - raft::cast_op{}), - stream); - } break; - case distance::DistanceType::L2SqrtUnexpanded: - case distance::DistanceType::L2SqrtExpanded: { - linalg::unaryOp( - out, - in, - len, - raft::compose_op{ - raft::mul_const_op{scaling_factor}, raft::sqrt_op{}, raft::cast_op{}}, - stream); - } break; - case distance::DistanceType::InnerProduct: { - linalg::unaryOp(out, - in, - len, - raft::compose_op(raft::mul_const_op{-scaling_factor * scaling_factor}, - raft::cast_op{}), - stream); - } break; - default: RAFT_FAIL("Unexpected metric."); - } -} - /** * An approximation to the number of times each cluster appears in a batched sample. * @@ -520,11 +309,11 @@ void ivfpq_search_worker(raft::resources const& handle, neighbors_uint32 = neighbors_uint32_buf.data(); } - calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), - clusters_to_probe, - chunk_index.data(), - num_samples.data(), - stream); + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + clusters_to_probe, + chunk_index.data(), + num_samples.data(), + stream); auto coresidency = expected_probe_coresidency(index.n_lists(), n_probes, n_queries); @@ -621,9 +410,10 @@ void ivfpq_search_worker(raft::resources const& handle, if (manage_local_topk) { query_kths_buf.emplace( make_device_mdarray(handle, mr, make_extents(n_queries))); - linalg::map(handle, - query_kths_buf->view(), - raft::const_op{dummy_block_sort_t::queue_t::kDummy}); + linalg::map( + handle, + query_kths_buf->view(), + raft::const_op{ivf::detail::dummy_block_sort_t::queue_t::kDummy}); query_kths = query_kths_buf->data_handle(); } compute_similarity_run(search_instance, @@ -663,17 +453,17 @@ void ivfpq_search_worker(raft::resources const& handle, true); // Postprocessing - postprocess_distances( - distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, stream); - postprocess_neighbors(neighbors, - neighbors_uint32, - index.inds_ptrs().data_handle(), - clusters_to_probe, - chunk_index.data(), - n_queries, - n_probes, - topK, - stream); + ivf::detail::postprocess_distances( + distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, true, stream); + ivf::detail::postprocess_neighbors(neighbors, + neighbors_uint32, + index.inds_ptrs().data_handle(), + clusters_to_probe, + chunk_index.data(), + n_queries, + n_probes, + topK, + stream); } /** diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 038bf8d7cc..b23534748e 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -160,7 +160,7 @@ auto deserialize(raft::resources const& handle_, std::istream& is) -> index(indices.extent(1)); + // TODO: this restriction could be lifted with some effort RAFT_EXPECTS(k <= raft::matrix::detail::select::warpsort::kMaxCapacity, "k must be less than topk::kMaxCapacity (%d).", raft::matrix::detail::select::warpsort::kMaxCapacity); @@ -98,6 +99,8 @@ void refine_device(raft::resources const& handle, refinement_index.metric(), 1, k, + 0, + nullptr, raft::distance::is_min_close(metric), raft::neighbors::filtering::none_ivf_sample_filter(), indices.data_handle(), diff --git a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh b/cpp/include/raft/neighbors/ivf_flat_helpers.cuh index 7a05c9991c..18e03ec62a 100644 --- a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -137,9 +137,42 @@ void reset_index(const raft::resources& res, index* index) { auto stream = resource::get_cuda_stream(res); + utils::memzero( + index->accum_sorted_sizes().data_handle(), index->accum_sorted_sizes().size(), stream); utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); } + +/** + * @brief Helper exposing the re-computation of list sizes and related arrays if IVF lists have been + * modified. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * raft::resources res; + * // use default index parameters + * ivf_flat::index_params index_params; + * // initialize an empty index + * ivf_flat::index index(res, index_params, D); + * ivf_flat::helpers::reset_index(res, &index); + * // recompute the internal state of the index + * ivf_flat::helpers::recompute_internal_state(res, &index); + * @endcode + * + * @tparam T + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-FLAT index + */ +template +void recompute_internal_state(const raft::resources& res, index* index) +{ + auto& list = index->lists()[0]; + ivf::detail::recompute_internal_state(res, *index); +} + /** @} */ } // namespace raft::neighbors::ivf_flat::helpers diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 180fe2e21b..1b5bb9f10c 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,8 +30,6 @@ #include #include -#include - #include // std::max #include #include @@ -222,8 +220,31 @@ struct index : ann::index { } } + /** + * Accumulated list sizes, sorted in descending order [n_lists + 1]. + * The last value contains the total length of the index. + * The value at index zero is always zero. + * + * That is, the content of this span is as if the `list_sizes` was sorted and then accumulated. + * + * This span is used during search to estimate the maximum size of the workspace. + */ + inline auto accum_sorted_sizes() noexcept -> host_vector_view + { + return accum_sorted_sizes_.view(); + } + [[nodiscard]] inline auto accum_sorted_sizes() const noexcept + -> host_vector_view + { + return accum_sorted_sizes_.view(); + } + /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return total_size_; } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT + { + return accum_sorted_sizes()(n_lists()); + } + /** Dimensionality of the data. */ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { @@ -257,9 +278,10 @@ struct index : ann::index { list_sizes_{make_device_vector(res, n_lists)}, data_ptrs_{make_device_vector(res, n_lists)}, inds_ptrs_{make_device_vector(res, n_lists)}, - total_size_{0} + accum_sorted_sizes_{make_host_vector(n_lists + 1)} { check_consistency(); + accum_sorted_sizes_(n_lists) = 0; } /** Construct an empty index. It needs to be trained and then populated. */ @@ -298,33 +320,6 @@ struct index : ann::index { return conservative_memory_allocation_; } - /** - * Update the state of the dependent index members. - */ - void recompute_internal_state(raft::resources const& res) - { - auto stream = resource::get_cuda_stream(res); - - // Actualize the list pointers - auto this_lists = lists(); - auto this_data_ptrs = data_ptrs(); - auto this_inds_ptrs = inds_ptrs(); - for (uint32_t label = 0; label < this_lists.size(); label++) { - auto& list = this_lists[label]; - const auto data_ptr = list ? list->data.data_handle() : nullptr; - const auto inds_ptr = list ? list->indices.data_handle() : nullptr; - copy(&this_data_ptrs(label), &data_ptr, 1, stream); - copy(&this_inds_ptrs(label), &inds_ptr, 1, stream); - } - auto this_list_sizes = list_sizes().data_handle(); - total_size_ = thrust::reduce(resource::get_thrust_policy(res), - this_list_sizes, - this_list_sizes + this_lists.size(), - 0, - raft::add_op{}); - check_consistency(); - } - void allocate_center_norms(raft::resources const& res) { switch (metric_) { @@ -349,6 +344,20 @@ struct index : ann::index { return lists_; } + /** Throw an error if the index content is inconsistent. */ + void check_consistency() + { + auto n_lists = lists_.size(); + RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen"); + RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); + RAFT_EXPECTS( // + (centers_.extent(0) == list_sizes_.extent(0)) && // + (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), + "inconsistent number of lists (clusters)"); + } + private: /** * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum @@ -366,21 +375,7 @@ struct index : ann::index { // Computed members device_vector data_ptrs_; device_vector inds_ptrs_; - IdxT total_size_; - - /** Throw an error if the index content is inconsistent. */ - void check_consistency() - { - auto n_lists = lists_.size(); - RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen"); - RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS( // - (centers_.extent(0) == list_sizes_.extent(0)) && // - (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), - "inconsistent number of lists (clusters)"); - } + host_vector accum_sorted_sizes_; static auto calculate_veclen(uint32_t dim) -> uint32_t { diff --git a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh index fec31f1c61..9b97a04a83 100644 --- a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -733,7 +733,7 @@ void set_centers(raft::resources const& res, * ivf::resize_list(res, list, spec, new_size, 0); * raft::update_device(index.list_sizes(), &new_size, 1, stream); * // recompute the internal state of the index - * ivf_pq::recompute_internal_state(res, &index); + * ivf_pq::helpers::recompute_internal_state(res, &index); * @endcode * * @tparam IdxT @@ -745,7 +745,7 @@ template void recompute_internal_state(const raft::resources& res, index* index) { auto& list = index->lists()[0]; - ivf_pq::detail::recompute_internal_state(res, *index); + ivf::detail::recompute_internal_state(res, *index); } /** diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index a1d6cca7d5..def33e493e 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ const raft::distance::DistanceType metric, \ const uint32_t n_probes, \ const uint32_t k, \ + const uint32_t max_samples, \ + const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ IdxT* neighbors, \ diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu index 3c467a12d8..e96600ee02 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu @@ -31,6 +31,8 @@ const raft::distance::DistanceType metric, \ const uint32_t n_probes, \ const uint32_t k, \ + const uint32_t max_samples, \ + const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ IdxT* neighbors, \ diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 514301562d..13c9d2e283 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ const raft::distance::DistanceType metric, \ const uint32_t n_probes, \ const uint32_t k, \ + const uint32_t max_samples, \ + const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ IdxT* neighbors, \ diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 32698a8e80..51f02343fc 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ const raft::distance::DistanceType metric, \ const uint32_t n_probes, \ const uint32_t k, \ + const uint32_t max_samples, \ + const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ IdxT* neighbors, \ diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 39439d392d..26be743eec 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -354,7 +354,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ivf::resize_list(handle_, lists[label], list_device_spec, list_size, 0); } - idx.recompute_internal_state(handle_); + helpers::recompute_internal_state(handle_, &idx); using interleaved_group = Pow2; @@ -608,6 +608,23 @@ const std::vector> inputs = { {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false}, + // various combinations with k>raft::matrix::detail::select::warpsort::kMaxCapacity + {1000, 10000, 16, 1024, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 2053, 512, 50, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 2049, 2048, 70, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false}, + {10, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false}, + {10, 10000, 16, 4000, 120, 2048, raft::distance::DistanceType::L2SqrtExpanded, true}, + {20, 100000, 16, 257, 20, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 100000, 16, 259, 20, 1024, raft::distance::DistanceType::L2Expanded, true, true}, + {10000, 131072, 8, 280, 20, 1024, raft::distance::DistanceType::InnerProduct, false}, + {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2Expanded, false}, + {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2SqrtExpanded, false}, + {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::InnerProduct, false}, + {100000, 1024, 16, 300, 20, 60, raft::distance::DistanceType::L2Expanded, false}, + {100000, 1024, 16, 500, 20, 60, raft::distance::DistanceType::L2SqrtExpanded, false}, + {100000, 1024, 16, 700, 20, 60, raft::distance::DistanceType::InnerProduct, false}, + // host input data {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false, true}, {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false, true}, diff --git a/python/pylibraft/pylibraft/test/test_ivf_flat.py b/python/pylibraft/pylibraft/test/test_ivf_flat.py index 23140073f1..2e38dab7bc 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_flat.py +++ b/python/pylibraft/pylibraft/test/test_ivf_flat.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -297,6 +297,14 @@ def test_ivf_flat_params(params): "k": 129, "n_probes": 100, }, + { + "k": 257, + "n_probes": 100, + }, + { + "k": 4096, + "n_probes": 100, + }, ], ) def test_ivf_pq_search_params(params):