Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 22, 2024
1 parent 485c7e7 commit 00584a2
Showing 1 changed file with 51 additions and 13 deletions.
64 changes: 51 additions & 13 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#include <cstddef>
#include <iostream>
#include <optional>
#include <string>
#include <vector>

Expand Down Expand Up @@ -144,6 +145,7 @@ struct AnnCagraInputs {
bool include_serialized_dataset;
// std::optional<double>
double min_recall; // = std::nullopt;
std::optional<vpq_params> compression = std::nullopt;
};

inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p)
Expand All @@ -154,7 +156,13 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p)
<< ", k=" << p.k << ", " << algo.at((int)p.algo) << ", max_queries=" << p.max_queries
<< ", itopk_size=" << p.itopk_size << ", search_width=" << p.search_width
<< ", metric=" << static_cast<int>(p.metric) << (p.host_dataset ? ", host" : ", device")
<< ", build_algo=" << build_algo.at((int)p.build_algo) << '}' << std::endl;
<< ", build_algo=" << build_algo.at((int)p.build_algo);
if (p.compression.has_value()) {
auto vpq = p.compression.value();
os << ", pq_bits=" << vpq.pq_bits << ", pq_dim=" << vpq.pq_dim
<< ", vq_n_centers=" << vpq.vq_n_centers;
}
os << '}' << std::endl;
return os;
}

Expand Down Expand Up @@ -204,7 +212,8 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
cagra::index_params index_params;
index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is
// not used for knn_graph building.
index_params.build_algo = ps.build_algo;
index_params.build_algo = ps.build_algo;
index_params.compression = ps.compression;
cagra::search_params search_params;
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
Expand Down Expand Up @@ -261,17 +270,20 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
ps.k,
0.003,
min_recall));
EXPECT_TRUE(eval_distances(handle_,
database.data(),
search_queries.data(),
indices_dev.data(),
distances_dev.data(),
ps.n_rows,
ps.dim,
ps.n_queries,
ps.k,
ps.metric,
1.0e-4));
if (!ps.compression.has_value()) {
// Don't evaluate distances for CAGRA-Q for now as the error can be somewhat large
EXPECT_TRUE(eval_distances(handle_,
database.data(),
search_queries.data(),
indices_dev.data(),
distances_dev.data(),
ps.n_rows,
ps.dim,
ps.n_queries,
ps.k,
ps.metric,
1.0e-4));
}
}
}

Expand Down Expand Up @@ -394,6 +406,32 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

// a few PQ configurations
inputs2 = raft::util::itertools::product<AnnCagraInputs>(
{100},
{10000},
{64, 128, 192, 256, 512, 1024}, // dim
{16}, // k
{graph_build_algo::IVF_PQ},
{search_algo::AUTO},
{10},
{0},
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{true},
{0.6}); // don't demand high recall without refinement
for (uint32_t pq_len : {2}) { // for now, only pq_len = 2 is supported, more options coming soon
for (uint32_t vq_n_centers : {100, 1000}) {
for (auto input : inputs2) {
input.compression.emplace(
vpq_params{.pq_dim = input.dim / pq_len, .vq_n_centers = vq_n_centers});
inputs.push_back(input);
}
}
}

return inputs;
}

Expand Down

0 comments on commit 00584a2

Please sign in to comment.