From d5ca51e5c799f0a0ad3220539c7c7763c70bd826 Mon Sep 17 00:00:00 2001 From: Tarang Jain <40517122+tarang-jain@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:28:49 -0700 Subject: [PATCH] CosineExpanded Metric for IVF-PQ (normalize inputs) (#346) Authors: - Tarang Jain (https://github.com/tarang-jain) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Micka (https://github.com/lowener) URL: https://github.com/rapidsai/cuvs/pull/346 --- .../ivf_flat/ivf_flat_interleaved_scan.cuh | 4 +- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 34 +++++++++++++++ .../ivf_pq/ivf_pq_compute_similarity_impl.cuh | 3 ++ cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh | 42 +++++++++++++++++-- cpp/test/neighbors/ann_ivf_pq.cuh | 24 ++++++++++- .../ann_ivf_pq/test_float_int64_t.cu | 8 +++- .../ann_ivf_pq/test_int8_t_int64_t.cu | 5 ++- .../ann_ivf_pq/test_uint8_t_int64_t.cu | 8 +++- 8 files changed, 115 insertions(+), 13 deletions(-) diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index 9626b2ce5..f5a4267cd 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -1206,8 +1206,8 @@ void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... arg inner_prod_dist>( {}, raft::compose_op(raft::add_const_op{1.0f}, raft::mul_const_op{-1.0f}), - std::forward(args)...); - // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. + std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when + // adding here a new metric. default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index c65ea8108..4c9867126 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -41,6 +42,8 @@ #include #include #include +#include +#include #include #include #include @@ -1466,6 +1469,13 @@ void extend(raft::resources const& handle, std::is_same_v, "Unsupported data type"); + if (index->metric() == distance::DistanceType::CosineExpanded) { + if constexpr (std::is_same_v || std::is_same_v) + RAFT_FAIL( + "CosineExpanded distance metric is currently not supported for uint8_t and int8_t data " + "type"); + } + rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle); rmm::device_async_resource_ref large_memory = raft::resource::get_large_workspace_resource(handle); @@ -1632,6 +1642,14 @@ void extend(raft::resources const& handle, vec_batches.prefetch_next_batch(); for (const auto& vec_batch : vec_batches) { const auto& idx_batch = *idx_batches++; + if (index->metric() == CosineExpanded) { + auto vec_batch_view = raft::make_device_matrix_view( + const_cast(vec_batch.data()), vec_batch.size(), index->dim()); + raft::linalg::row_normalize(handle, + raft::make_const_mdspan(vec_batch_view), + vec_batch_view, + raft::linalg::NormType::L2Norm); + } process_and_fill_codes(handle, *index, vec_batch.data(), @@ -1683,6 +1701,13 @@ auto build(raft::resources const& handle, << (int)params.pq_dim << std::endl; RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); + if (params.metric == distance::DistanceType::CosineExpanded) { + // TODO: support int8_t and uint8_t types (https://github.com/rapidsai/cuvs/issues/389) + if constexpr (std::is_same_v || std::is_same_v) + RAFT_FAIL( + "CosineExpanded distance metric is currently not supported for uint8_t and int8_t data " + "type"); + } auto stream = raft::resource::get_cuda_stream(handle); @@ -1755,6 +1780,11 @@ auto build(raft::resources const& handle, cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = static_cast((int)index.metric()); + + if (index.metric() == distance::DistanceType::CosineExpanded) { + raft::linalg::row_normalize( + handle, trainset_const_view, trainset.view(), raft::linalg::NormType::L2Norm); + } cuvs::cluster::kmeans_balanced::fit( handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); @@ -1762,6 +1792,10 @@ auto build(raft::resources const& handle, rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); + if (index.metric() == distance::DistanceType::CosineExpanded) { + raft::linalg::row_normalize( + handle, centers_const_view, centers_view, raft::linalg::NormType::L2Norm); + } auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); cuvs::cluster::kmeans_balanced::predict(handle, diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh index 8404ca1f9..fbbdd06c2 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh @@ -369,6 +369,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim, reinterpret_cast(lut_end)[i] = query[i] - cluster_center[i]; } } break; + case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { float2 pvals; for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { @@ -408,6 +409,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim, diff -= pq_c; score += diff * diff; } break; + case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { // NB: we negate the scores as we hardcoded select-topk to always compute the minimum float q; @@ -485,6 +487,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim, reinterpret_cast(pq_thread_data), lut_scores, early_stop_limit); + if (metric == distance::DistanceType::CosineExpanded) { score = OutT(1) + score; } } if constexpr (kManageLocalTopK) { block_topk.add(score, sample_offset + i); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index e185f18dc..db8f9fbd3 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -37,6 +37,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -104,12 +107,21 @@ void select_clusters(raft::resources const& handle, This is a negative inner-product distance. We minimize it to find the similar clusters. + NB: qc_distances is NOT used further in ivfpq_search. + + Cosine distance: + `qc_distances[i, j] = - (queries[i], cluster_centers[j])` + + This is a negative inner-product distance. The queries and cluster centers are row normalized. + We minimize it to find the similar clusters. + NB: qc_distances is NOT used further in ivfpq_search. */ float norm_factor; switch (metric) { case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break; + case cuvs::distance::DistanceType::CosineExpanded: case cuvs::distance::DistanceType::InnerProduct: norm_factor = 0.0; break; default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); } @@ -133,6 +145,7 @@ void select_clusters(raft::resources const& handle, gemm_k = dim + 1; RAFT_EXPECTS(gemm_k <= dim_ext, "unexpected gemm_k or dim_ext"); } break; + case cuvs::distance::DistanceType::CosineExpanded: case cuvs::distance::DistanceType::InnerProduct: { alpha = -1.0; beta = 0.0; @@ -363,8 +376,9 @@ void ivfpq_search_worker(raft::resources const& handle, // stores basediff (query[i] - center[i]) precomp_data_count = index.rot_dim(); } break; + case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { - // stores two components (query[i] * center[i], query[i] * center[i]) + // stores two components (query[i], query[i] * center[i]) precomp_data_count = index.rot_dim() * 2; } break; default: { @@ -457,8 +471,14 @@ void ivfpq_search_worker(raft::resources const& handle, num_samples_vector); // Postprocessing - ivf::detail::postprocess_distances( - distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, true, stream); + ivf::detail::postprocess_distances(distances, + topk_dists.data(), + index.metric(), + n_queries, + topK, + scaling_factor, + index.metric() != distance::DistanceType::CosineExpanded, + stream); ivf::detail::postprocess_neighbors(neighbors, neighbors_uint32, index.inds_ptrs().data_handle(), @@ -508,6 +528,7 @@ struct ivfpq_search { { bool signed_metric = false; switch (metric) { + case cuvs::distance::DistanceType::CosineExpanded: signed_metric = true; break; case cuvs::distance::DistanceType::InnerProduct: signed_metric = true; break; default: break; } @@ -606,6 +627,12 @@ inline void search(raft::resources const& handle, static_assert(std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported element type."); + if (index.metric() == distance::DistanceType::CosineExpanded) { + if constexpr (std::is_same_v || std::is_same_v) + RAFT_FAIL( + "CosineExpanded distance metric is currently not supported for uint8_t and int8_t data " + "type"); + } raft::common::nvtx::range fun_scope( "ivf_pq::search(n_queries = %u, n_probes = %u, k = %u, dim = %zu)", n_queries, @@ -698,7 +725,14 @@ inline void search(raft::resources const& handle, rot_queries.data(), index.rot_dim(), stream); - + if (index.metric() == distance::DistanceType::CosineExpanded) { + auto rot_queries_view = raft::make_device_matrix_view( + rot_queries.data(), max_queries, index.rot_dim()); + raft::linalg::row_normalize(handle, + raft::make_const_mdspan(rot_queries_view), + rot_queries_view, + raft::linalg::NormType::L2Norm); + } for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) { uint32_t batch_size = min(max_batch_size, queries_batch - offset_b); /* The distance calculation is done in the rotated/transformed space; diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index f02568b74..fd4e330db 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -282,6 +282,8 @@ class ivf_pq_test : public ::testing::TestWithParam { uint32_t n_take, uint32_t n_skip) { + // the original data cannot be reconstructed since the dataset was normalized + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { return; } auto& rec_list = index.lists()[label]; auto dim = index.dim(); n_take = std::min(n_take, rec_list->size.load()); @@ -313,6 +315,7 @@ class ivf_pq_test : public ::testing::TestWithParam { auto old_list = index->lists()[label]; auto n_rows = old_list->size.load(); if (n_rows == 0) { return; } + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { return; } auto vectors_1 = raft::make_device_matrix(handle_, n_rows, index->dim()); auto indices = raft::make_device_vector(handle_, n_rows); @@ -374,7 +377,7 @@ class ivf_pq_test : public ::testing::TestWithParam { cuvs::Compare{})); // Pack a few vectors back to the list. - int row_offset = 9; + int row_offset = 5; int n_vec = 3; ASSERT_TRUE(row_offset + n_vec < n_rows); size_t offset = row_offset * index->pq_dim(); @@ -884,6 +887,25 @@ inline auto enum_variety_l2sqrt() -> test_cases_t }); } +inline auto enum_variety_cosine() -> test_cases_t +{ + return map(enum_variety(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + if (y.min_recall.has_value()) { + if (y.search_params.lut_dtype == CUDA_R_8U) { + // TODO: Increase this recall threshold for 8 bit lut + // (https://github.com/rapidsai/cuvs/issues/390) + y.min_recall = y.min_recall.value() * 0.70; + } else { + // In other cases it seems to perform a little bit better, still worse than L2 + y.min_recall = y.min_recall.value() * 0.94; + } + } + y.index_params.metric = distance::DistanceType::CosineExpanded; + return y; + }); +} + /** * Try different number of n_probes, some of which may trigger the non-fused version of the search * kernel. diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu index cdc6c1b7e..834fdb3d0 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu @@ -25,9 +25,13 @@ TEST_BUILD_HOST_INPUT_SEARCH(f32_f32_i64) TEST_BUILD_HOST_INPUT_OVERLAP_SEARCH(f32_f32_i64) TEST_BUILD_EXTEND_SEARCH(f32_f32_i64) TEST_BUILD_SERIALIZE_SEARCH(f32_f32_i64) -INSTANTIATE(f32_f32_i64, defaults() + small_dims() + big_dims_moderate_lut()); +INSTANTIATE(f32_f32_i64, + defaults() + small_dims() + big_dims_moderate_lut() + enum_variety_l2() + + enum_variety_l2sqrt() + enum_variety_ip() + enum_variety_cosine()); TEST_BUILD_SEARCH(f32_f32_i64_filter) -INSTANTIATE(f32_f32_i64_filter, defaults() + small_dims() + big_dims_moderate_lut()); +INSTANTIATE(f32_f32_i64_filter, + defaults() + small_dims() + big_dims_moderate_lut() + enum_variety_l2() + + enum_variety_l2sqrt() + enum_variety_ip() + enum_variety_cosine()); } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu index 80b0e2ccb..c9e5d4f01 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu @@ -25,8 +25,9 @@ TEST_BUILD_SEARCH(f32_i08_i64) TEST_BUILD_HOST_INPUT_SEARCH(f32_i08_i64) TEST_BUILD_HOST_INPUT_OVERLAP_SEARCH(f32_i08_i64) TEST_BUILD_SERIALIZE_SEARCH(f32_i08_i64) -INSTANTIATE(f32_i08_i64, defaults() + big_dims() + var_k()); +INSTANTIATE(f32_i08_i64, defaults() + big_dims() + var_k() + enum_variety_l2() + enum_variety_ip()); TEST_BUILD_SEARCH(f32_i08_i64_filter) -INSTANTIATE(f32_i08_i64_filter, defaults() + big_dims() + var_k()); +INSTANTIATE(f32_i08_i64_filter, + defaults() + big_dims() + var_k() + enum_variety_l2() + enum_variety_ip()); } // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu index 0216a1e80..6e0732227 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu @@ -25,8 +25,12 @@ TEST_BUILD_SEARCH(f32_u08_i64) TEST_BUILD_HOST_INPUT_SEARCH(f32_u08_i64) TEST_BUILD_HOST_INPUT_OVERLAP_SEARCH(f32_u08_i64) TEST_BUILD_EXTEND_SEARCH(f32_u08_i64) -INSTANTIATE(f32_u08_i64, small_dims_per_cluster() + enum_variety()); +INSTANTIATE(f32_u08_i64, + small_dims_per_cluster() + enum_variety() + enum_variety_l2() + enum_variety_l2sqrt() + + enum_variety_ip()); TEST_BUILD_SEARCH(f32_u08_i64_filter) -INSTANTIATE(f32_u08_i64_filter, small_dims_per_cluster() + enum_variety()); +INSTANTIATE(f32_u08_i64_filter, + small_dims_per_cluster() + enum_variety() + enum_variety_l2() + enum_variety_l2sqrt() + + enum_variety_ip()); } // namespace cuvs::neighbors::ivf_pq