From 0e498a85c32ec275875a973b2b6d579c7bc302b8 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 23 Apr 2024 16:41:26 -0700 Subject: [PATCH] change helper to constructor --- cpp/include/raft/neighbors/cagra.cuh | 3 +-- cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh | 5 +---- cpp/include/raft/neighbors/ivf_pq_types.hpp | 8 +++----- cpp/test/neighbors/ann_cagra.cuh | 3 +-- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 8fbd8e6ee6..7c1a1f6582 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -55,8 +55,7 @@ namespace raft::neighbors::cagra { * @code{.cpp} * using namespace raft::neighbors; * // use default index parameters - * ivf_pq::index_params build_params; - * build_params.initialize_from_dataset(dataset); + * ivf_pq::index_params build_params(dataset); * ivf_pq::search_params search_params; * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index fad720a03b..fcecd4fa3c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -59,10 +59,7 @@ void build_knn_graph(raft::resources const& res, size_t(dataset.extent(1)), node_degree); - if (!build_params) { - build_params = ivf_pq::index_params{}; - build_params.value().initialize_from_dataset(dataset); - } + if (!build_params) { build_params = ivf_pq::index_params(dataset); } // Make model name const std::string model_name = [&]() { diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 5a3e6caa24..4effc373f2 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -106,12 +106,11 @@ struct index_params : ann::index_params { bool conservative_memory_allocation = false; /** - * Helper that sets values according to the extents of the dataset mdspan. + * Constructor that sets values according to the extents of the dataset mdspan. */ template - void initialize_from_dataset( - mdspan, row_major, Accessor> dataset, - raft::distance::DistanceType metric = raft::distance::L2Expanded) + explicit index_params(mdspan, row_major, Accessor> dataset, + raft::distance::DistanceType metric = raft::distance::L2Expanded) { n_lists = dataset.extent(0) < 4 * 2500 ? 4 : static_cast(std::sqrt(dataset.extent(0))); @@ -161,7 +160,6 @@ struct search_params : ann::search_params { double preferred_shmem_carveout = 1.0; }; -static_assert(std::is_aggregate_v); static_assert(std::is_aggregate_v); /** Size of the interleaved group. */ diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 499abd7e26..313a88553b 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -417,8 +417,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam { raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); if (ps.build_algo == graph_build_algo::IVF_PQ) { - auto build_params = ivf_pq::index_params{}; - build_params.initialize_from_dataset(database_view, ps.metric); + auto build_params = ivf_pq::index_params(database_view, ps.metric); if (ps.host_dataset) { cagra::build_knn_graph( handle_, database_host_view, knn_graph.view(), 2, build_params);