Skip to content

Commit

Permalink
CosineExpanded Metric for IVF-PQ (normalize inputs) (#346)
Browse files Browse the repository at this point in the history
Authors:
  - Tarang Jain (https://github.com/tarang-jain)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Micka (https://github.com/lowener)

URL: #346
  • Loading branch information
tarang-jain authored Oct 3, 2024
1 parent cc86ffc commit d5ca51e
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 13 deletions.
4 changes: 2 additions & 2 deletions cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1206,8 +1206,8 @@ void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... arg
inner_prod_dist<Veclen, T, AccT>>(
{},
raft::compose_op(raft::add_const_op<float>{1.0f}, raft::mul_const_op<float>{-1.0f}),
std::forward<Args>(args)...);
// NB: update the description of `knn::ivf_flat::build` when adding here a new metric.
std::forward<Args>(args)...); // NB: update the description of `knn::ivf_flat::build` when
// adding here a new metric.
default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric));
}
}
Expand Down
34 changes: 34 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <raft/core/device_mdarray.hpp>
#include <raft/core/logger-ext.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cuda_stream_pool.hpp>
Expand All @@ -41,6 +42,8 @@
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/norm_types.hpp>
#include <raft/linalg/normalize.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/matrix/linewise_op.cuh>
Expand Down Expand Up @@ -1466,6 +1469,13 @@ void extend(raft::resources const& handle,
std::is_same_v<T, int8_t>,
"Unsupported data type");

if (index->metric() == distance::DistanceType::CosineExpanded) {
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>)
RAFT_FAIL(
"CosineExpanded distance metric is currently not supported for uint8_t and int8_t data "
"type");
}

rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle);
rmm::device_async_resource_ref large_memory =
raft::resource::get_large_workspace_resource(handle);
Expand Down Expand Up @@ -1632,6 +1642,14 @@ void extend(raft::resources const& handle,
vec_batches.prefetch_next_batch();
for (const auto& vec_batch : vec_batches) {
const auto& idx_batch = *idx_batches++;
if (index->metric() == CosineExpanded) {
auto vec_batch_view = raft::make_device_matrix_view<T, internal_extents_t>(
const_cast<T*>(vec_batch.data()), vec_batch.size(), index->dim());
raft::linalg::row_normalize(handle,
raft::make_const_mdspan(vec_batch_view),
vec_batch_view,
raft::linalg::NormType::L2Norm);
}
process_and_fill_codes(handle,
*index,
vec_batch.data(),
Expand Down Expand Up @@ -1683,6 +1701,13 @@ auto build(raft::resources const& handle,
<< (int)params.pq_dim << std::endl;
RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");
RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists");
if (params.metric == distance::DistanceType::CosineExpanded) {
// TODO: support int8_t and uint8_t types (https://github.com/rapidsai/cuvs/issues/389)
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>)
RAFT_FAIL(
"CosineExpanded distance metric is currently not supported for uint8_t and int8_t data "
"type");
}

auto stream = raft::resource::get_cuda_stream(handle);

Expand Down Expand Up @@ -1755,13 +1780,22 @@ auto build(raft::resources const& handle,
cuvs::cluster::kmeans::balanced_params kmeans_params;
kmeans_params.n_iters = params.kmeans_n_iters;
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>((int)index.metric());

if (index.metric() == distance::DistanceType::CosineExpanded) {
raft::linalg::row_normalize(
handle, trainset_const_view, trainset.view(), raft::linalg::NormType::L2Norm);
}
cuvs::cluster::kmeans_balanced::fit(
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});

// Trainset labels are needed for training PQ codebooks
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, big_memory_resource);
auto centers_const_view = raft::make_device_matrix_view<const float, internal_extents_t>(
cluster_centers, index.n_lists(), index.dim());
if (index.metric() == distance::DistanceType::CosineExpanded) {
raft::linalg::row_normalize(
handle, centers_const_view, centers_view, raft::linalg::NormType::L2Norm);
}
auto labels_view =
raft::make_device_vector_view<uint32_t, internal_extents_t>(labels.data(), n_rows_train);
cuvs::cluster::kmeans_balanced::predict(handle,
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim,
reinterpret_cast<float*>(lut_end)[i] = query[i] - cluster_center[i];
}
} break;
case distance::DistanceType::CosineExpanded:
case distance::DistanceType::InnerProduct: {
float2 pvals;
for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) {
Expand Down Expand Up @@ -408,6 +409,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim,
diff -= pq_c;
score += diff * diff;
} break;
case distance::DistanceType::CosineExpanded:
case distance::DistanceType::InnerProduct: {
// NB: we negate the scores as we hardcoded select-topk to always compute the minimum
float q;
Expand Down Expand Up @@ -485,6 +487,7 @@ RAFT_KERNEL compute_similarity_kernel(uint32_t dim,
reinterpret_cast<const vec_t::io_t*>(pq_thread_data),
lut_scores,
early_stop_limit);
if (metric == distance::DistanceType::CosineExpanded) { score = OutT(1) + score; }
}
if constexpr (kManageLocalTopK) {
block_topk.add(score, sample_offset + i);
Expand Down
42 changes: 38 additions & 4 deletions cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
#include <raft/core/resources.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm_types.hpp>
#include <raft/linalg/normalize.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/util/cache.hpp>
Expand Down Expand Up @@ -104,12 +107,21 @@ void select_clusters(raft::resources const& handle,
This is a negative inner-product distance. We minimize it to find the similar clusters.
NB: qc_distances is NOT used further in ivfpq_search.
Cosine distance:
`qc_distances[i, j] = - (queries[i], cluster_centers[j])`
This is a negative inner-product distance. The queries and cluster centers are row normalized.
We minimize it to find the similar clusters.
NB: qc_distances is NOT used further in ivfpq_search.
*/
float norm_factor;
switch (metric) {
case cuvs::distance::DistanceType::L2SqrtExpanded:
case cuvs::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break;
case cuvs::distance::DistanceType::CosineExpanded:
case cuvs::distance::DistanceType::InnerProduct: norm_factor = 0.0; break;
default: RAFT_FAIL("Unsupported distance type %d.", int(metric));
}
Expand All @@ -133,6 +145,7 @@ void select_clusters(raft::resources const& handle,
gemm_k = dim + 1;
RAFT_EXPECTS(gemm_k <= dim_ext, "unexpected gemm_k or dim_ext");
} break;
case cuvs::distance::DistanceType::CosineExpanded:
case cuvs::distance::DistanceType::InnerProduct: {
alpha = -1.0;
beta = 0.0;
Expand Down Expand Up @@ -363,8 +376,9 @@ void ivfpq_search_worker(raft::resources const& handle,
// stores basediff (query[i] - center[i])
precomp_data_count = index.rot_dim();
} break;
case distance::DistanceType::CosineExpanded:
case distance::DistanceType::InnerProduct: {
// stores two components (query[i] * center[i], query[i] * center[i])
// stores two components (query[i], query[i] * center[i])
precomp_data_count = index.rot_dim() * 2;
} break;
default: {
Expand Down Expand Up @@ -457,8 +471,14 @@ void ivfpq_search_worker(raft::resources const& handle,
num_samples_vector);

// Postprocessing
ivf::detail::postprocess_distances(
distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, true, stream);
ivf::detail::postprocess_distances(distances,
topk_dists.data(),
index.metric(),
n_queries,
topK,
scaling_factor,
index.metric() != distance::DistanceType::CosineExpanded,
stream);
ivf::detail::postprocess_neighbors(neighbors,
neighbors_uint32,
index.inds_ptrs().data_handle(),
Expand Down Expand Up @@ -508,6 +528,7 @@ struct ivfpq_search {
{
bool signed_metric = false;
switch (metric) {
case cuvs::distance::DistanceType::CosineExpanded: signed_metric = true; break;
case cuvs::distance::DistanceType::InnerProduct: signed_metric = true; break;
default: break;
}
Expand Down Expand Up @@ -606,6 +627,12 @@ inline void search(raft::resources const& handle,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, uint8_t> ||
std::is_same_v<T, int8_t>,
"Unsupported element type.");
if (index.metric() == distance::DistanceType::CosineExpanded) {
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>)
RAFT_FAIL(
"CosineExpanded distance metric is currently not supported for uint8_t and int8_t data "
"type");
}
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope(
"ivf_pq::search(n_queries = %u, n_probes = %u, k = %u, dim = %zu)",
n_queries,
Expand Down Expand Up @@ -698,7 +725,14 @@ inline void search(raft::resources const& handle,
rot_queries.data(),
index.rot_dim(),
stream);

if (index.metric() == distance::DistanceType::CosineExpanded) {
auto rot_queries_view = raft::make_device_matrix_view<float, uint32_t>(
rot_queries.data(), max_queries, index.rot_dim());
raft::linalg::row_normalize(handle,
raft::make_const_mdspan(rot_queries_view),
rot_queries_view,
raft::linalg::NormType::L2Norm);
}
for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) {
uint32_t batch_size = min(max_batch_size, queries_batch - offset_b);
/* The distance calculation is done in the rotated/transformed space;
Expand Down
24 changes: 23 additions & 1 deletion cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
uint32_t n_take,
uint32_t n_skip)
{
// the original data cannot be reconstructed since the dataset was normalized
if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { return; }
auto& rec_list = index.lists()[label];
auto dim = index.dim();
n_take = std::min<uint32_t>(n_take, rec_list->size.load());
Expand Down Expand Up @@ -313,6 +315,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
auto old_list = index->lists()[label];
auto n_rows = old_list->size.load();
if (n_rows == 0) { return; }
if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { return; }

auto vectors_1 = raft::make_device_matrix<EvalT>(handle_, n_rows, index->dim());
auto indices = raft::make_device_vector<IdxT>(handle_, n_rows);
Expand Down Expand Up @@ -374,7 +377,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
cuvs::Compare<uint8_t>{}));

// Pack a few vectors back to the list.
int row_offset = 9;
int row_offset = 5;
int n_vec = 3;
ASSERT_TRUE(row_offset + n_vec < n_rows);
size_t offset = row_offset * index->pq_dim();
Expand Down Expand Up @@ -884,6 +887,25 @@ inline auto enum_variety_l2sqrt() -> test_cases_t
});
}

inline auto enum_variety_cosine() -> test_cases_t
{
return map<ivf_pq_inputs>(enum_variety(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
if (y.min_recall.has_value()) {
if (y.search_params.lut_dtype == CUDA_R_8U) {
// TODO: Increase this recall threshold for 8 bit lut
// (https://github.com/rapidsai/cuvs/issues/390)
y.min_recall = y.min_recall.value() * 0.70;
} else {
// In other cases it seems to perform a little bit better, still worse than L2
y.min_recall = y.min_recall.value() * 0.94;
}
}
y.index_params.metric = distance::DistanceType::CosineExpanded;
return y;
});
}

/**
* Try different number of n_probes, some of which may trigger the non-fused version of the search
* kernel.
Expand Down
8 changes: 6 additions & 2 deletions cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ TEST_BUILD_HOST_INPUT_SEARCH(f32_f32_i64)
TEST_BUILD_HOST_INPUT_OVERLAP_SEARCH(f32_f32_i64)
TEST_BUILD_EXTEND_SEARCH(f32_f32_i64)
TEST_BUILD_SERIALIZE_SEARCH(f32_f32_i64)
INSTANTIATE(f32_f32_i64, defaults() + small_dims() + big_dims_moderate_lut());
INSTANTIATE(f32_f32_i64,
defaults() + small_dims() + big_dims_moderate_lut() + enum_variety_l2() +
enum_variety_l2sqrt() + enum_variety_ip() + enum_variety_cosine());

TEST_BUILD_SEARCH(f32_f32_i64_filter)
INSTANTIATE(f32_f32_i64_filter, defaults() + small_dims() + big_dims_moderate_lut());
INSTANTIATE(f32_f32_i64_filter,
defaults() + small_dims() + big_dims_moderate_lut() + enum_variety_l2() +
enum_variety_l2sqrt() + enum_variety_ip() + enum_variety_cosine());

} // namespace cuvs::neighbors::ivf_pq
5 changes: 3 additions & 2 deletions cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ TEST_BUILD_SEARCH(f32_i08_i64)
TEST_BUILD_HOST_INPUT_SEARCH(f32_i08_i64)
TEST_BUILD_HOST_INPUT_OVERLAP_SEARCH(f32_i08_i64)
TEST_BUILD_SERIALIZE_SEARCH(f32_i08_i64)
INSTANTIATE(f32_i08_i64, defaults() + big_dims() + var_k());
INSTANTIATE(f32_i08_i64, defaults() + big_dims() + var_k() + enum_variety_l2() + enum_variety_ip());

TEST_BUILD_SEARCH(f32_i08_i64_filter)
INSTANTIATE(f32_i08_i64_filter, defaults() + big_dims() + var_k());
INSTANTIATE(f32_i08_i64_filter,
defaults() + big_dims() + var_k() + enum_variety_l2() + enum_variety_ip());
} // namespace cuvs::neighbors::ivf_pq
8 changes: 6 additions & 2 deletions cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ TEST_BUILD_SEARCH(f32_u08_i64)
TEST_BUILD_HOST_INPUT_SEARCH(f32_u08_i64)
TEST_BUILD_HOST_INPUT_OVERLAP_SEARCH(f32_u08_i64)
TEST_BUILD_EXTEND_SEARCH(f32_u08_i64)
INSTANTIATE(f32_u08_i64, small_dims_per_cluster() + enum_variety());
INSTANTIATE(f32_u08_i64,
small_dims_per_cluster() + enum_variety() + enum_variety_l2() + enum_variety_l2sqrt() +
enum_variety_ip());

TEST_BUILD_SEARCH(f32_u08_i64_filter)
INSTANTIATE(f32_u08_i64_filter, small_dims_per_cluster() + enum_variety());
INSTANTIATE(f32_u08_i64_filter,
small_dims_per_cluster() + enum_variety() + enum_variety_l2() + enum_variety_l2sqrt() +
enum_variety_ip());
} // namespace cuvs::neighbors::ivf_pq

0 comments on commit d5ca51e

Please sign in to comment.