diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 715a94403f..cc787d3e57 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -26,11 +26,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -200,6 +202,67 @@ void GenerateRoundingErrorFreeDataset( GenerateRoundingErrorFreeDataset_kernel <<>>(ptr, size, resolution); } + +template +void InitDataset(const raft::resources& handle, + DataT* const datatset_ptr, + std::uint32_t size, + std::uint32_t dim, + raft::distance::DistanceType metric, + raft::random::RngState& r) +{ + if constexpr (std::is_same_v || std::is_same_v) { + GenerateRoundingErrorFreeDataset(handle, datatset_ptr, size, dim, r, true); + + if (metric == raft::distance::InnerProduct) { + auto dataset_view = raft::make_device_matrix_view(datatset_ptr, size, dim); + raft::linalg::row_normalize( + handle, raft::make_const_mdspan(dataset_view), dataset_view, raft::linalg::L2Norm); + } + } else if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v) { + raft::random::uniformInt(handle, r, datatset_ptr, size * dim, DataT(-10), DataT(10)); + } else { + raft::random::uniformInt(handle, r, datatset_ptr, size * dim, DataT(1), DataT(20)); + } + + if (metric == raft::distance::InnerProduct) { + // TODO (enp1s0): Change this once row_normalize supports (u)int8 matrices. + // https://github.com/rapidsai/raft/issues/2291 + + using ComputeT = float; + auto dataset_view = raft::make_device_matrix_view(datatset_ptr, size, dim); + auto dev_row_norm = raft::make_device_vector(handle, size); + const auto normalized_norm = + (std::is_same_v ? 40 : 20) * std::sqrt(static_cast(dim)); + + raft::linalg::reduce(dev_row_norm.data_handle(), + datatset_ptr, + dim, + size, + 0.f, + true, + true, + resource::get_cuda_stream(handle), + false, + raft::sq_op(), + raft::add_op(), + raft::sqrt_op()); + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(dataset_view), + raft::make_const_mdspan(dev_row_norm.view()), + dataset_view, + raft::linalg::Apply::ALONG_COLUMNS, + [normalized_norm] __device__(DataT elm, ComputeT norm) { + const ComputeT v = elm / norm * normalized_norm; + const ComputeT max_v_range = std::numeric_limits::max(); + const ComputeT min_v_range = std::numeric_limits::min(); + return static_cast(std::min(max_v_range, std::max(min_v_range, v))); + }); + } + } +} } // namespace struct AnnCagraInputs { @@ -360,16 +423,8 @@ class AnnCagraTest : public ::testing::TestWithParam { database.resize(((size_t)ps.n_rows) * ps.dim, stream_); search_queries.resize(ps.n_queries * ps.dim, stream_); raft::random::RngState r(1234ULL); - if constexpr (std::is_same_v || std::is_same_v) { - GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, true); - GenerateRoundingErrorFreeDataset( - handle_, search_queries.data(), ps.n_queries, ps.dim, r, true); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); - raft::random::uniformInt( - handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20)); - } + InitDataset(handle_, database.data(), ps.n_rows, ps.dim, ps.metric, r); + InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r); resource::sync_stream(handle_); } @@ -744,16 +799,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { database.resize(((size_t)ps.n_rows) * ps.dim, stream_); search_queries.resize(ps.n_queries * ps.dim, stream_); raft::random::RngState r(1234ULL); - if constexpr (std::is_same_v || std::is_same_v) { - GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, true); - GenerateRoundingErrorFreeDataset( - handle_, search_queries.data(), ps.n_queries, ps.dim, r, true); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); - raft::random::uniformInt( - handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20)); - } + InitDataset(handle_, database.data(), ps.n_rows, ps.dim, ps.metric, r); + InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r); resource::sync_stream(handle_); }