Skip to content

Commit

Permalink
Add benchmark for cagra remove function
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Sep 26, 2023
1 parent 39fc683 commit 121603e
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/neighbors/cagra.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>
#include <thrust/sequence.h>

#include <optional>

Expand All @@ -40,6 +41,8 @@ struct params {
int block_size;
int search_width;
int max_iterations;
/** Ratio of removed indices. */
double removed_ratio;
};

template <typename T, typename IdxT>
Expand All @@ -49,7 +52,8 @@ struct CagraBench : public fixture {
params_(ps),
queries_(make_device_matrix<T, int64_t>(handle, ps.n_queries, ps.n_dims)),
dataset_(make_device_matrix<T, int64_t>(handle, ps.n_samples, ps.n_dims)),
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree))
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree)),
removed_indices_(make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples))
{
// Generate random dataset and queriees
raft::random::RngState state{42};
Expand All @@ -74,8 +78,14 @@ struct CagraBench : public fixture {

auto metric = raft::distance::DistanceType::L2Expanded;

index_.emplace(raft::neighbors::cagra::index<T, IdxT>(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())));
auto index = raft::neighbors::cagra::index<T, IdxT>(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()));
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices_.data_handle()),
thrust::device_pointer_cast(removed_indices_.data_handle() + removed_indices_.extent(0)));
raft::neighbors::cagra::remove(handle, index, raft::make_const_mdspan(removed_indices_.view()));
index_.emplace(std::move(index));
}

void run_benchmark(::benchmark::State& state) override
Expand Down Expand Up @@ -120,6 +130,7 @@ struct CagraBench : public fixture {
state.counters["block_size"] = params_.block_size;
state.counters["search_width"] = params_.search_width;
state.counters["iterations"] = iterations;
state.counters["removed_ratio"] = params_.removed_ratio;
}

private:
Expand All @@ -128,6 +139,7 @@ struct CagraBench : public fixture {
raft::device_matrix<T, int64_t, row_major> queries_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix<IdxT, int64_t, row_major> knn_graph_;
raft::device_vector<IdxT, int64_t> removed_indices_;
};

inline const std::vector<params> generate_inputs()
Expand All @@ -141,7 +153,8 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{0}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
auto inputs2 = raft::util::itertools::product<params>({2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
Expand All @@ -151,7 +164,22 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{64, 128, 256, 512, 1024}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

inputs2 = raft::util::itertools::product<params>(
{2000000ull}, // n_samples
{128}, // dataset dim
{1000}, // n_queries
{32}, // k
{64}, // knn graph degree
{64}, // itopk_size
{128, 256}, // block_size
{2}, // search_width
{0}, // max_iterations
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
return inputs;
Expand Down

0 comments on commit 121603e

Please sign in to comment.