diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index bb405088bb..c534d12c63 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include @@ -40,6 +41,8 @@ struct params { int block_size; int search_width; int max_iterations; + /** Ratio of removed indices. */ + double removed_ratio; }; template @@ -49,7 +52,8 @@ struct CagraBench : public fixture { params_(ps), queries_(make_device_matrix(handle, ps.n_queries, ps.n_dims)), dataset_(make_device_matrix(handle, ps.n_samples, ps.n_dims)), - knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)) + knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)), + removed_indices_(make_device_vector(handle, ps.removed_ratio * ps.n_samples)) { // Generate random dataset and queriees raft::random::RngState state{42}; @@ -74,8 +78,14 @@ struct CagraBench : public fixture { auto metric = raft::distance::DistanceType::L2Expanded; - index_.emplace(raft::neighbors::cagra::index( - handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); + auto index = raft::neighbors::cagra::index( + handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())); + thrust::sequence( + resource::get_thrust_policy(handle), + thrust::device_pointer_cast(removed_indices_.data_handle()), + thrust::device_pointer_cast(removed_indices_.data_handle() + removed_indices_.extent(0))); + raft::neighbors::cagra::remove(handle, index, raft::make_const_mdspan(removed_indices_.view())); + index_.emplace(std::move(index)); } void run_benchmark(::benchmark::State& state) override @@ -120,6 +130,7 @@ struct CagraBench : public fixture { state.counters["block_size"] = params_.block_size; state.counters["search_width"] = params_.search_width; state.counters["iterations"] = iterations; + state.counters["removed_ratio"] = params_.removed_ratio; } private: @@ -128,6 +139,7 @@ struct CagraBench : public fixture { raft::device_matrix queries_; raft::device_matrix dataset_; raft::device_matrix knn_graph_; + raft::device_vector removed_indices_; }; inline const std::vector generate_inputs() @@ -141,7 +153,8 @@ inline const std::vector generate_inputs() {64}, // itopk_size {0}, // block_size {1}, // search_width - {0} // max_iterations + {0}, // max_iterations + {0.0} // removed_ratio ); auto inputs2 = raft::util::itertools::product({2000000ull, 10000000ull}, // n_samples {128}, // dataset dim @@ -151,7 +164,22 @@ inline const std::vector generate_inputs() {64}, // itopk_size {64, 128, 256, 512, 1024}, // block_size {1}, // search_width - {0} // max_iterations + {0}, // max_iterations + {0.0} // removed_ratio + ); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + inputs2 = raft::util::itertools::product( + {2000000ull}, // n_samples + {128}, // dataset dim + {1000}, // n_queries + {32}, // k + {64}, // knn graph degree + {64}, // itopk_size + {128, 256}, // block_size + {2}, // search_width + {0}, // max_iterations + {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio ); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); return inputs;