Skip to content

Commit

Permalink
Merge pull request #1847 from rapidsai/branch-23.10
Browse files Browse the repository at this point in the history
Forward-merge branch-23.10 to branch-23.12
  • Loading branch information
GPUtester authored Sep 25, 2023
2 parents 1175523 + dfde3b4 commit 3ac8835
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
19 changes: 15 additions & 4 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ struct index : ann::index {
~index() = default;

/** Construct an empty index. */
index(raft::resources const& res)
index(raft::resources const& res,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
metric_(raft::distance::DistanceType::L2Expanded),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
Expand Down Expand Up @@ -296,7 +297,11 @@ struct index : ann::index {
raft::host_matrix_view<const IdxT, int64_t, row_major> knn_graph)
{
RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device");
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) {
// clear existing memory before allocating to prevent OOM errors on large graphs
if (graph_.size()) { graph_ = make_device_matrix<IdxT, int64_t>(res, 0, 0); }
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
}
raft::copy(graph_.data_handle(),
knn_graph.data_handle(),
knn_graph.size(),
Expand All @@ -311,7 +316,13 @@ struct index : ann::index {
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
size_t padded_dim = round_up_safe<size_t>(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);

if ((dataset_.extent(0) != dataset.extent(0)) ||
(static_cast<size_t>(dataset_.extent(1)) != padded_dim)) {
// clear existing memory before allocating to prevent OOM errors on large datasets
if (dataset_.size()) { dataset_ = make_device_matrix<T, int64_t>(res, 0, 0); }
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);
}
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
Expand Down
19 changes: 13 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,22 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
auto graph_degree = deserialize_scalar<std::uint32_t>(res, is);
auto metric = deserialize_scalar<raft::distance::DistanceType>(res, is);

auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, graph.view());

bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); }

return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
if (has_dataset) {
auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
deserialize_mdspan(res, is, dataset.view());
return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
} else {
// create a new index with no dataset - the user must supply via update_dataset themselves
// later (this avoids allocating GPU memory in the meantime)
index<T, IdxT> idx(res, metric);
idx.update_graph(res, raft::make_const_mdspan(graph.view()));
return idx;
}
}

template <typename T, typename IdxT>
Expand Down

0 comments on commit 3ac8835

Please sign in to comment.