Skip to content

Commit

Permalink
Fixes for OOM during CAGRA benchmarks
Browse files Browse the repository at this point in the history
Running the CAGRA benchmarks and there could be OOM errors on GPU memory with large datasets.
This is caused by holding multiple copies of the dataset in GPU memory. Fix by:

* Free existing memory for the dataset/graph before allocating new memory during update_dataset/update_grph
* On deserialize, if the serialized index doesn't contain the dataset - don't allocate GPU memory for it
* Don't call update_dataset repeatedly in the benchmarking code with the same dataset
  • Loading branch information
benfred committed Sep 19, 2023
1 parent b9cf917 commit e42a272
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
27 changes: 14 additions & 13 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,27 @@ void bench_search(::benchmark::State& state,
ANN<T>* algo;
std::unique_ptr<typename ANN<T>::AnnSearchParam> search_param;
try {
search_param = ann::create_search_param<T>(index.algo, sp_json);
if (!current_algo || (algo = dynamic_cast<ANN<T>*>(current_algo.get())) == nullptr) {
auto ualgo = ann::create_algo<T>(
index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list);
algo = ualgo.get();
algo->load(index_file);
current_algo = std::move(ualgo);
}
search_param = ann::create_search_param<T>(index.algo, sp_json);

if (search_param->needs_dataset()) {
try {
const auto algo_property = parse_algo_property(algo->get_preference(), sp_json);
algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type),
dataset->base_set_size());
} catch (const std::exception& ex) {
state.SkipWithError("The algorithm '" + index.name +
"' requires the base set, but it's not available. " +
"Exception: " + std::string(ex.what()));
return;
}
}
} catch (const std::exception& e) {
return state.SkipWithError("Failed to create an algo: " + std::string(e.what()));
}
Expand All @@ -207,18 +220,6 @@ void bench_search(::benchmark::State& state,
buf<float> distances{algo_property.query_memory_type, k * query_set_size};
buf<std::size_t> neighbors{algo_property.query_memory_type, k * query_set_size};

if (search_param->needs_dataset()) {
try {
algo->set_search_dataset(dataset->base_set(algo_property.dataset_memory_type),
dataset->base_set_size());
} catch (const std::exception& ex) {
state.SkipWithError("The algorithm '" + index.name +
"' requires the base set, but it's not available. " +
"Exception: " + std::string(ex.what()));
return;
}
}

std::ptrdiff_t batch_offset = 0;
std::size_t queries_processed = 0;
cuda_timer gpu_timer;
Expand Down
18 changes: 14 additions & 4 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ struct index : ann::index {
~index() = default;

/** Construct an empty index. */
index(raft::resources const& res)
index(raft::resources const& res,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
metric_(raft::distance::DistanceType::L2Expanded),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
Expand Down Expand Up @@ -296,7 +297,11 @@ struct index : ann::index {
raft::host_matrix_view<const IdxT, int64_t, row_major> knn_graph)
{
RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device");
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) {
// clear existing memory before allocating to prevent OOM errors on large graphs
if (graph_.size()) { graph_ = make_device_matrix<IdxT, int64_t>(res, 0, 0); }
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
}
raft::copy(graph_.data_handle(),
knn_graph.data_handle(),
knn_graph.size(),
Expand All @@ -311,7 +316,12 @@ struct index : ann::index {
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
size_t padded_dim = round_up_safe<size_t>(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);

if ((dataset_.extent(0) != dataset.extent(0)) || (dataset_.extent(1) != padded_dim)) {
// clear existing memory before allocating to prevent OOM errors on large datasets
if (dataset_.size()) { dataset_ = make_device_matrix<T, int64_t>(res, 0, 0); }
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);
}
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
Expand Down
19 changes: 13 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,22 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
auto graph_degree = deserialize_scalar<std::uint32_t>(res, is);
auto metric = deserialize_scalar<raft::distance::DistanceType>(res, is);

auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, graph.view());

bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); }

return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
if (has_dataset) {
auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
deserialize_mdspan(res, is, dataset.view());
return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
} else {
// create a new index with no dataset - the user must supply via update_dataset themselves
// later (this avoids allocating GPU memory in the meantime)
index<T, IdxT> idx(res, metric);
idx.update_graph(res, raft::make_const_mdspan(graph.view()));
return idx;
}
}

template <typename T, typename IdxT>
Expand Down

0 comments on commit e42a272

Please sign in to comment.