diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 4ec977700d..79bb571d0d 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -189,6 +189,7 @@ void bench_search(::benchmark::State& state, ANN* algo; std::unique_ptr::AnnSearchParam> search_param; try { + search_param = ann::create_search_param(index.algo, sp_json); if (!current_algo || (algo = dynamic_cast*>(current_algo.get())) == nullptr) { auto ualgo = ann::create_algo( index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); @@ -196,7 +197,19 @@ void bench_search(::benchmark::State& state, algo->load(index_file); current_algo = std::move(ualgo); } - search_param = ann::create_search_param(index.algo, sp_json); + + if (search_param->needs_dataset()) { + try { + const auto algo_property = parse_algo_property(algo->get_preference(), sp_json); + algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type), + dataset->base_set_size()); + } catch (const std::exception& ex) { + state.SkipWithError("The algorithm '" + index.name + + "' requires the base set, but it's not available. " + + "Exception: " + std::string(ex.what())); + return; + } + } } catch (const std::exception& e) { return state.SkipWithError("Failed to create an algo: " + std::string(e.what())); } @@ -207,18 +220,6 @@ void bench_search(::benchmark::State& state, buf distances{algo_property.query_memory_type, k * query_set_size}; buf neighbors{algo_property.query_memory_type, k * query_set_size}; - if (search_param->needs_dataset()) { - try { - algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type), - dataset->base_set_size()); - } catch (const std::exception& ex) { - state.SkipWithError("The algorithm '" + index.name + - "' requires the base set, but it's not available. " + - "Exception: " + std::string(ex.what())); - return; - } - } - std::ptrdiff_t batch_offset = 0; std::size_t queries_processed = 0; cuda_timer gpu_timer; diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 02e3f5338e..9ef4babb72 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -165,9 +165,10 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. */ - index(raft::resources const& res) + index(raft::resources const& res, + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) : ann::index(), - metric_(raft::distance::DistanceType::L2Expanded), + metric_(metric), dataset_(make_device_matrix(res, 0, 0)), graph_(make_device_matrix(res, 0, 0)) { @@ -296,7 +297,11 @@ struct index : ann::index { raft::host_matrix_view knn_graph) { RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device"); - graph_ = make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1)); + if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) { + // clear existing memory before allocating to prevent OOM errors on large graphs + if (graph_.size()) { graph_ = make_device_matrix(res, 0, 0); } + graph_ = make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1)); + } raft::copy(graph_.data_handle(), knn_graph.data_handle(), knn_graph.size(), @@ -311,7 +316,12 @@ struct index : ann::index { mdspan, row_major, data_accessor> dataset) { size_t padded_dim = round_up_safe(dataset.extent(1) * sizeof(T), 16) / sizeof(T); - dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim); + + if ((dataset_.extent(0) != dataset.extent(0)) || (dataset_.extent(1) != padded_dim)) { + // clear existing memory before allocating to prevent OOM errors on large datasets + if (dataset_.size()) { dataset_ = make_device_matrix(res, 0, 0); } + dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim); + } if (dataset_.extent(1) == dataset.extent(1)) { raft::copy(dataset_.data_handle(), dataset.data_handle(), diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 2c9cbd2563..7ad44482e7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -125,15 +125,22 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto graph_degree = deserialize_scalar(res, is); auto metric = deserialize_scalar(res, is); - auto dataset = raft::make_host_matrix(n_rows, dim); - auto graph = raft::make_host_matrix(n_rows, graph_degree); + auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); bool has_dataset = deserialize_scalar(res, is); - if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); } - - return index( - res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); + if (has_dataset) { + auto dataset = raft::make_host_matrix(n_rows, dim); + deserialize_mdspan(res, is, dataset.view()); + return index( + res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); + } else { + // create a new index with no dataset - the user must supply via update_dataset themselves + // later (this avoids allocating GPU memory in the meantime) + index idx(res, metric); + idx.update_graph(res, raft::make_const_mdspan(graph.view())); + return idx; + } } template