Skip to content

Commit

Permalink
Fixes for OOM during CAGRA benchmarks (#1832)
Browse files Browse the repository at this point in the history
Running the CAGRA benchmarks and there could be OOM errors on GPU memory with large datasets. This is caused by holding multiple copies of the dataset in GPU memory. Fix by:

* Free existing memory for the dataset/graph before allocating new memory during update_dataset/update_grph
* On deserialize, if the serialized index doesn't contain the dataset - don't allocate GPU memory for it
* Don't call update_dataset repeatedly in the benchmarking code with the same dataset

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1832
  • Loading branch information
benfred authored Sep 25, 2023
1 parent ba923cc commit dfde3b4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
19 changes: 15 additions & 4 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ struct index : ann::index {
~index() = default;

/** Construct an empty index. */
index(raft::resources const& res)
index(raft::resources const& res,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
metric_(raft::distance::DistanceType::L2Expanded),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
Expand Down Expand Up @@ -296,7 +297,11 @@ struct index : ann::index {
raft::host_matrix_view<const IdxT, int64_t, row_major> knn_graph)
{
RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device");
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) {
// clear existing memory before allocating to prevent OOM errors on large graphs
if (graph_.size()) { graph_ = make_device_matrix<IdxT, int64_t>(res, 0, 0); }
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
}
raft::copy(graph_.data_handle(),
knn_graph.data_handle(),
knn_graph.size(),
Expand All @@ -311,7 +316,13 @@ struct index : ann::index {
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
size_t padded_dim = round_up_safe<size_t>(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);

if ((dataset_.extent(0) != dataset.extent(0)) ||
(static_cast<size_t>(dataset_.extent(1)) != padded_dim)) {
// clear existing memory before allocating to prevent OOM errors on large datasets
if (dataset_.size()) { dataset_ = make_device_matrix<T, int64_t>(res, 0, 0); }
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);
}
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
Expand Down
19 changes: 13 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,22 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
auto graph_degree = deserialize_scalar<std::uint32_t>(res, is);
auto metric = deserialize_scalar<raft::distance::DistanceType>(res, is);

auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, graph.view());

bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); }

return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
if (has_dataset) {
auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
deserialize_mdspan(res, is, dataset.view());
return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
} else {
// create a new index with no dataset - the user must supply via update_dataset themselves
// later (this avoids allocating GPU memory in the meantime)
index<T, IdxT> idx(res, metric);
idx.update_graph(res, raft::make_const_mdspan(graph.view()));
return idx;
}
}

template <typename T, typename IdxT>
Expand Down

0 comments on commit dfde3b4

Please sign in to comment.