diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 6a7847d8a0..506cbffcb9 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 8f40e6ae00..93d233152b 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle, * whether to make sure selected pairs are sorted by value * @param[in] algo * the selection algorithm to use + * @param[in] len_i + * array of size (batch_size) providing lengths for each individual row + * only radix select-k supported */ template void select_k(raft::resources const& handle, @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - true // fused_last_filter - ); - + true, // fused_last_filter + len_i); } else { bool fused_last_filter = algo == SelectAlgo::kRadix11bits; detail::select::radix::select_k(handle, @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - fused_last_filter); + fused_last_filter, + len_i); } if (sorted) { auto offsets = make_device_mdarray( diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 82983b7cd2..36a346fda3 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -557,6 +557,7 @@ RAFT_KERNEL radix_kernel(const T* in, Counter* counters, IdxT* histograms, const IdxT len, + const IdxT* len_i, const IdxT k, const bool select_min, const int pass) @@ -598,6 +599,14 @@ RAFT_KERNEL radix_kernel(const T* in, in_buf += batch_id * buf_len; in_idx_buf += batch_id * buf_len; } + + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + // "current_len > buf_len" means current pass will skip writing buffer if (pass == 0 || current_len > buf_len) { out_buf = nullptr; @@ -829,6 +838,7 @@ void radix_topk(const T* in, IdxT* out_idx, bool select_min, bool fused_last_filter, + const IdxT* len_i, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, @@ -868,6 +878,7 @@ void radix_topk(const T* in, const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; T* chunk_out = out + offset * k; IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; const T* in_buf = nullptr; const IdxT* in_idx_buf = nullptr; @@ -905,6 +916,7 @@ void radix_topk(const T* in, counters.data(), histograms.data(), len, + chunk_len_i, k, select_min, pass); @@ -1007,6 +1019,7 @@ template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, + const IdxT* len_i, const IdxT k, T* out, IdxT* out_idx, @@ -1057,6 +1070,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_idx_buf = nullptr; } + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + filter_and_histogram_for_one_block(in_buf, in_idx_buf, out_buf, @@ -1106,6 +1126,7 @@ void radix_topk_one_block(const T* in, T* out, IdxT* out_idx, bool select_min, + const IdxT* len_i, int sm_cnt, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -1121,10 +1142,12 @@ void radix_topk_one_block(const T* in, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { - int chunk_size = std::min(max_chunk_size, batch_size - offset); + int chunk_size = std::min(max_chunk_size, batch_size - offset); + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; kernel<<>>(in + offset * len, in_idx ? (in_idx + offset * len) : nullptr, len, + chunk_len_i, k, out + offset * k, out_idx + offset * k, @@ -1188,6 +1211,8 @@ void radix_topk_one_block(const T* in, * blocks is called. The later case is preferable when leading bits of input data are almost the * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. + * @param len_i + * optional array of size (batch_size) providing lengths for each individual row */ template void select_k(raft::resources const& res, @@ -1199,7 +1224,8 @@ void select_k(raft::resources const& res, T* out, IdxT* out_idx, bool select_min, - bool fused_last_filter) + bool fused_last_filter, + const IdxT* len_i) { auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); @@ -1223,13 +1249,13 @@ void select_k(raft::resources const& res, if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { impl::radix_topk(in, in_idx, @@ -1240,6 +1266,7 @@ void select_k(raft::resources const& res, out_idx, select_min, fused_last_filter, + len_i, grid_dim, sm_cnt, stream, diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh index ef7ae7c804..df0319e181 100644 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -147,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT return ix_min; } -template +template __launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] + postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -170,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); const bool valid = chunk_ix < n_probes; neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; + valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; } /** @@ -180,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL * probed clusters / defined by the `chunk_indices`. * We assume the searched sample sizes (for a single query) fit into `uint32_t`. */ -template -void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] +template +void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -193,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to { constexpr int kPNThreads = 256; const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel + postprocess_neighbors_kernel <<>>(neighbors_out, neighbors_in, db_indices, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 7c2d1d2157..140a9f17c8 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) RAFT_EXPLICIT; @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 6fc528e26b..9cd8b70148 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, const uint32_t queries_offset, @@ -700,7 +699,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t* chunk_indices, const uint32_t dim, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances) { extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; @@ -719,8 +718,8 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) distances += query_id * k * gridDim.x + blockIdx.x * k; } else { distances += query_id * uint64_t(max_samples); - chunk_indices += (n_probes * query_id); } + chunk_indices += (n_probes * query_id); coarse_index += query_id * n_probes; } @@ -728,7 +727,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - using local_topk_t = block_sort_t; + using local_topk_t = block_sort_t; local_topk_t queue(k); { using align_warp = Pow2; @@ -752,11 +751,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 uint32_t sample_offset = 0; - if constexpr (!kManageLocalTopK) { - if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } - assert(list_length == chunk_indices[probe_id] - sample_offset); - assert(sample_offset + list_length <= max_samples); - } + if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } + assert(list_length == chunk_indices[probe_id] - sample_offset); + assert(sample_offset + list_length <= max_samples); constexpr int kUnroll = WarpSize / Veclen; constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; @@ -806,8 +803,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; if constexpr (kManageLocalTopK) { - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); + queue.add(val, sample_offset + vec_id); } else { if (vec_id < list_length) distances[sample_offset + vec_id] = val; } @@ -873,7 +869,7 @@ void launch_kernel(Lambda lambda, const uint32_t max_samples, const uint32_t* chunk_indices, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) @@ -927,7 +923,6 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), queries_offset + query_offset, @@ -945,8 +940,8 @@ void launch_kernel(Lambda lambda, distances += grid_dim_y * grid_dim_x * k; } else { distances += grid_dim_y * max_samples; - chunk_indices += grid_dim_y * n_probes; } + chunk_indices += grid_dim_y * n_probes; coarse_index += grid_dim_y * n_probes; } } @@ -1161,7 +1156,7 @@ void ivfflat_interleaved_scan(const index& index, const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 98bdeda42f..441fb76b2f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -67,13 +67,16 @@ void search_impl(raft::resources const& handle, // Optional structures if postprocessing is required // The topk distance value of candidate vectors from each cluster(list) rmm::device_uvector distances_tmp_dev(0, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector indices_tmp_dev(0, stream, search_mr); // Number of samples for each query rmm::device_uvector num_samples(0, stream, search_mr); // Offsets per probe for each query rmm::device_uvector chunk_index(0, stream, search_mr); + // The topk index of candidate vectors from each cluster(list), local index offset + // also we might need additional storage for select_k + rmm::device_uvector indices_tmp_dev(0, stream, search_mr); + rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); + size_t float_query_size; if constexpr (std::is_integral_v) { float_query_size = n_queries * index.dim(); @@ -175,23 +178,29 @@ void search_impl(raft::resources const& handle, grid_dim_x = 1; } + num_samples.resize(n_queries, stream); + chunk_index.resize(n_queries_probes, stream); + + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); + auto distances_dev_ptr = distances; - auto indices_dev_ptr = neighbors; + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(neighbors); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), stream); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + + uint32_t* indices_dev_ptr = nullptr; bool manage_local_topk = is_local_topk_feasible(k); if (!manage_local_topk || grid_dim_x > 1) { - if (!manage_local_topk) { - num_samples.resize(n_queries, stream); - chunk_index.resize(n_queries_probes, stream); - - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( - index.list_sizes().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - num_samples.data(), - stream); - } - auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples); distances_tmp_dev.resize(target_size, stream); @@ -199,6 +208,8 @@ void search_impl(raft::resources const& handle, distances_dev_ptr = distances_tmp_dev.data(); indices_dev_ptr = indices_tmp_dev.data(); + } else { + indices_dev_ptr = neighbors_uint32; } ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( @@ -224,31 +235,33 @@ void search_impl(raft::resources const& handle, // Merge topk values from different blocks if (!manage_local_topk || grid_dim_x > 1) { - matrix::detail::select_k(handle, - distances_tmp_dev.data(), - indices_tmp_dev.data(), - n_queries, - manage_local_topk ? (k * grid_dim_x) : max_samples, - k, - distances, - neighbors, - select_min); - - if (!manage_local_topk) { - // post process distances && neighbor IDs - ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); - ivf::detail::postprocess_neighbors(neighbors, - neighbors, - index.inds_ptrs().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - n_queries, - n_probes, - k, - stream); - } + matrix::detail::select_k(handle, + distances_tmp_dev.data(), + indices_tmp_dev.data(), + n_queries, + manage_local_topk ? (k * grid_dim_x) : max_samples, + k, + distances, + neighbors_uint32, + select_min, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); + } + if (!manage_local_topk) { + // post process distances && neighbor IDs + ivf::detail::postprocess_distances( + distances, distances, index.metric(), n_queries, k, 1.0, false, stream); } + ivf::detail::postprocess_neighbors(neighbors, + neighbors_uint32, + index.inds_ptrs().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + n_queries, + n_probes, + k, + stream); } /** See raft::neighbors::ivf_flat::search docs */ diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d445f909e5..4c5da38092 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -447,7 +447,10 @@ void ivfpq_search_worker(raft::resources const& handle, topK, topk_dists.data(), neighbors_uint32, - true); + true, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); // Postprocessing ivf::detail::postprocess_distances( diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index e76e52657b..bdc29ca121 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -88,6 +88,27 @@ void refine_device(raft::resources const& handle, n_queries, n_candidates); uint32_t grid_dim_x = 1; + + // the neighbor ids will be computed in uint32_t as offset + rmm::device_uvector neighbors_uint32_buf(0, resource::get_cuda_stream(handle)); + // Offsets per probe for each query [n_queries] as n_probes = 1 + rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); + + // we know that each cluster has exactly n_candidates entries + thrust::fill(resource::get_thrust_policy(handle), + chunk_index.data(), + chunk_index.data() + n_queries, + uint32_t(n_candidates)); + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(indices.data_handle()); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), + resource::get_cuda_stream(handle)); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, @@ -100,13 +121,24 @@ void refine_device(raft::resources const& handle, 1, k, 0, - nullptr, + chunk_index.data(), raft::distance::is_min_close(metric), raft::neighbors::filtering::none_ivf_sample_filter(), - indices.data_handle(), + neighbors_uint32, distances.data_handle(), grid_dim_x, resource::get_cuda_stream(handle)); + + // postprocessing -- neighbors from position to actual id + ivf::detail::postprocess_neighbors(indices.data_handle(), + neighbors_uint32, + refinement_index.inds_ptrs().data_handle(), + fake_coarse_idx.data(), + chunk_index.data(), + n_queries, + 1, + k, + resource::get_cuda_stream(handle)); } } // namespace raft::neighbors::detail diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index e32b4ef6f0..bf234aacbf 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 21c954ca46..7f0511a76a 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -29,7 +29,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 7f163a0b0d..e68b1e32df 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 87b6525356..5aa40d8c9d 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index e698f811d8..9aba147edf 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index 0eee20b1fa..bc513e4aeb 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index f4e6bae21f..e46c7d46bb 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index def33e493e..5ac820e0dd 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu index e96600ee02..4d847cdeb1 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu @@ -35,7 +35,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 13c9d2e283..8a0e8f0118 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 51f02343fc..7cad992e2b 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index a111de0762..7278f71a24 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -549,6 +549,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { EXPECT_FALSE(unacceptable_node); double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -556,7 +557,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), @@ -668,6 +670,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { } double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -675,7 +678,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index a3373d79fd..3e0bead665 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -35,6 +35,7 @@ #include #include +#include namespace raft::neighbors { @@ -153,13 +154,40 @@ auto calc_recall(const std::vector& expected_idx, static_cast(match_count) / static_cast(total_count), match_count, total_count); } +/** check uniqueness of indices + */ +template +auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +{ + size_t max_count; + std::set unique_indices; + for (size_t i = 0; i < rows; ++i) { + unique_indices.clear(); + max_count = 0; + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + if (act_idx == std::numeric_limits::max()) { + max_count++; + } else if (unique_indices.find(act_idx) == unique_indices.end()) { + unique_indices.insert(act_idx); + } else { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + } + } + } + return testing::AssertionSuccess(); +} + template auto eval_recall(const std::vector& expected_idx, const std::vector& actual_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, rows, cols); @@ -176,7 +204,10 @@ auto eval_recall(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } /** Overload of calc_recall to account for distances @@ -224,7 +255,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -241,7 +273,10 @@ auto eval_neighbours(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } template