From e42a2728cf8b6fea7e4170377fb326b7b3ee68bb Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 18 Sep 2023 21:27:53 -0700 Subject: [PATCH] Fixes for OOM during CAGRA benchmarks Running the CAGRA benchmarks and there could be OOM errors on GPU memory with large datasets. This is caused by holding multiple copies of the dataset in GPU memory. Fix by: * Free existing memory for the dataset/graph before allocating new memory during update_dataset/update_grph * On deserialize, if the serialized index doesn't contain the dataset - don't allocate GPU memory for it * Don't call update_dataset repeatedly in the benchmarking code with the same dataset --- cpp/bench/ann/src/common/benchmark.hpp | 27 ++++++++++--------- cpp/include/raft/neighbors/cagra_types.hpp | 18 ++++++++++--- .../detail/cagra/cagra_serialize.cuh | 19 ++++++++----- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 4ec977700d..79bb571d0d 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -189,6 +189,7 @@ void bench_search(::benchmark::State& state, ANN* algo; std::unique_ptr::AnnSearchParam> search_param; try { + search_param = ann::create_search_param(index.algo, sp_json); if (!current_algo || (algo = dynamic_cast*>(current_algo.get())) == nullptr) { auto ualgo = ann::create_algo( index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); @@ -196,7 +197,19 @@ void bench_search(::benchmark::State& state, algo->load(index_file); current_algo = std::move(ualgo); } - search_param = ann::create_search_param(index.algo, sp_json); + + if (search_param->needs_dataset()) { + try { + const auto algo_property = parse_algo_property(algo->get_preference(), sp_json); + algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type), + dataset->base_set_size()); + } catch (const std::exception& ex) { + state.SkipWithError("The algorithm '" + index.name + + "' requires the base set, but it's not available. " + + "Exception: " + std::string(ex.what())); + return; + } + } } catch (const std::exception& e) { return state.SkipWithError("Failed to create an algo: " + std::string(e.what())); } @@ -207,18 +220,6 @@ void bench_search(::benchmark::State& state, buf distances{algo_property.query_memory_type, k * query_set_size}; buf neighbors{algo_property.query_memory_type, k * query_set_size}; - if (search_param->needs_dataset()) { - try { - algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type), - dataset->base_set_size()); - } catch (const std::exception& ex) { - state.SkipWithError("The algorithm '" + index.name + - "' requires the base set, but it's not available. " + - "Exception: " + std::string(ex.what())); - return; - } - } - std::ptrdiff_t batch_offset = 0; std::size_t queries_processed = 0; cuda_timer gpu_timer; diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 02e3f5338e..9ef4babb72 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -165,9 +165,10 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. */ - index(raft::resources const& res) + index(raft::resources const& res, + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) : ann::index(), - metric_(raft::distance::DistanceType::L2Expanded), + metric_(metric), dataset_(make_device_matrix(res, 0, 0)), graph_(make_device_matrix(res, 0, 0)) { @@ -296,7 +297,11 @@ struct index : ann::index { raft::host_matrix_view knn_graph) { RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device"); - graph_ = make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1)); + if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) { + // clear existing memory before allocating to prevent OOM errors on large graphs + if (graph_.size()) { graph_ = make_device_matrix(res, 0, 0); } + graph_ = make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1)); + } raft::copy(graph_.data_handle(), knn_graph.data_handle(), knn_graph.size(), @@ -311,7 +316,12 @@ struct index : ann::index { mdspan, row_major, data_accessor> dataset) { size_t padded_dim = round_up_safe(dataset.extent(1) * sizeof(T), 16) / sizeof(T); - dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim); + + if ((dataset_.extent(0) != dataset.extent(0)) || (dataset_.extent(1) != padded_dim)) { + // clear existing memory before allocating to prevent OOM errors on large datasets + if (dataset_.size()) { dataset_ = make_device_matrix(res, 0, 0); } + dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim); + } if (dataset_.extent(1) == dataset.extent(1)) { raft::copy(dataset_.data_handle(), dataset.data_handle(), diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 2c9cbd2563..7ad44482e7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -125,15 +125,22 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto graph_degree = deserialize_scalar(res, is); auto metric = deserialize_scalar(res, is); - auto dataset = raft::make_host_matrix(n_rows, dim); - auto graph = raft::make_host_matrix(n_rows, graph_degree); + auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); bool has_dataset = deserialize_scalar(res, is); - if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); } - - return index( - res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); + if (has_dataset) { + auto dataset = raft::make_host_matrix(n_rows, dim); + deserialize_mdspan(res, is, dataset.view()); + return index( + res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); + } else { + // create a new index with no dataset - the user must supply via update_dataset themselves + // later (this avoids allocating GPU memory in the meantime) + index idx(res, metric); + idx.update_graph(res, raft::make_const_mdspan(graph.view())); + return idx; + } } template