Skip to content

Commit

Permalink
Normalize dataset vectors in the CAGRA InnerProduct tests (#2287)
Browse files Browse the repository at this point in the history
This PR updates the CAGRA test to normalize the dataset and query vectors in the CAGRA test when the metric is InnerProduct. If we don't normalize them, large L2 norm dataset vectors tend to be included in the search result across all queries. This means that only a part of the graph nodes may be traversed in the search process, leading to test incompleteness.

Authors:
  - tsuki (https://github.com/enp1s0)

Approvers:
  - Tarang Jain (https://github.com/tarang-jain)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #2287
  • Loading branch information
enp1s0 authored May 7, 2024
1 parent ef28628 commit 97e38eb
Showing 1 changed file with 67 additions and 20 deletions.
87 changes: 67 additions & 20 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/normalize.cuh>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/itertools.hpp>

#include <raft_internal/neighbors/naive_knn.cuh>
Expand Down Expand Up @@ -200,6 +202,67 @@ void GenerateRoundingErrorFreeDataset(
GenerateRoundingErrorFreeDataset_kernel<T>
<<<grid_size, block_size, 0, cuda_stream>>>(ptr, size, resolution);
}

template <class DataT>
void InitDataset(const raft::resources& handle,
DataT* const datatset_ptr,
std::uint32_t size,
std::uint32_t dim,
raft::distance::DistanceType metric,
raft::random::RngState& r)
{
if constexpr (std::is_same_v<DataT, float> || std::is_same_v<DataT, half>) {
GenerateRoundingErrorFreeDataset(handle, datatset_ptr, size, dim, r, true);

if (metric == raft::distance::InnerProduct) {
auto dataset_view = raft::make_device_matrix_view(datatset_ptr, size, dim);
raft::linalg::row_normalize(
handle, raft::make_const_mdspan(dataset_view), dataset_view, raft::linalg::L2Norm);
}
} else if constexpr (std::is_same_v<DataT, std::uint8_t> || std::is_same_v<DataT, std::int8_t>) {
if constexpr (std::is_same_v<DataT, std::int8_t>) {
raft::random::uniformInt(handle, r, datatset_ptr, size * dim, DataT(-10), DataT(10));
} else {
raft::random::uniformInt(handle, r, datatset_ptr, size * dim, DataT(1), DataT(20));
}

if (metric == raft::distance::InnerProduct) {
// TODO (enp1s0): Change this once row_normalize supports (u)int8 matrices.
// https://github.com/rapidsai/raft/issues/2291

using ComputeT = float;
auto dataset_view = raft::make_device_matrix_view(datatset_ptr, size, dim);
auto dev_row_norm = raft::make_device_vector<ComputeT>(handle, size);
const auto normalized_norm =
(std::is_same_v<DataT, std::uint8_t> ? 40 : 20) * std::sqrt(static_cast<ComputeT>(dim));

raft::linalg::reduce(dev_row_norm.data_handle(),
datatset_ptr,
dim,
size,
0.f,
true,
true,
resource::get_cuda_stream(handle),
false,
raft::sq_op(),
raft::add_op(),
raft::sqrt_op());
raft::linalg::matrix_vector_op(
handle,
raft::make_const_mdspan(dataset_view),
raft::make_const_mdspan(dev_row_norm.view()),
dataset_view,
raft::linalg::Apply::ALONG_COLUMNS,
[normalized_norm] __device__(DataT elm, ComputeT norm) {
const ComputeT v = elm / norm * normalized_norm;
const ComputeT max_v_range = std::numeric_limits<DataT>::max();
const ComputeT min_v_range = std::numeric_limits<DataT>::min();
return static_cast<DataT>(std::min(max_v_range, std::max(min_v_range, v)));
});
}
}
}
} // namespace

struct AnnCagraInputs {
Expand Down Expand Up @@ -360,16 +423,8 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
search_queries.resize(ps.n_queries * ps.dim, stream_);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same_v<DataT, float> || std::is_same_v<DataT, half>) {
GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, true);
GenerateRoundingErrorFreeDataset(
handle_, search_queries.data(), ps.n_queries, ps.dim, r, true);
} else {
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20));
}
InitDataset(handle_, database.data(), ps.n_rows, ps.dim, ps.metric, r);
InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r);
resource::sync_stream(handle_);
}

Expand Down Expand Up @@ -744,16 +799,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
search_queries.resize(ps.n_queries * ps.dim, stream_);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same_v<DataT, float> || std::is_same_v<DataT, half>) {
GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, true);
GenerateRoundingErrorFreeDataset(
handle_, search_queries.data(), ps.n_queries, ps.dim, r, true);
} else {
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20));
}
InitDataset(handle_, database.data(), ps.n_rows, ps.dim, ps.metric, r);
InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r);
resource::sync_stream(handle_);
}

Expand Down

0 comments on commit 97e38eb

Please sign in to comment.