diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index ad16f3e11d..9c37ee146d 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -19,11 +19,14 @@ #include "../nn_descent_types.hpp" #include +#include #include #include +#include #include #include #include +#include #include #include #include // raft::util::arch::SM_* @@ -1365,12 +1368,22 @@ void GNND::build(Data_t* data, static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); if (return_distances) { - for (size_t i = 0; i < (size_t)nrow_; i++) { - raft::copy(output_distances + i * build_config_.output_graph_degree, - graph_.h_dists.data_handle() + i * build_config_.node_degree, - build_config_.output_graph_degree, - raft::resource::get_cuda_stream(res)); - } + auto graph_d_dists = raft::make_device_matrix( + res, nrow_, build_config_.node_degree); + raft::copy(graph_d_dists.data_handle(), + graph_.h_dists.data_handle(), + nrow_ * build_config_.node_degree, + raft::resource::get_cuda_stream(res)); + + auto output_dist_view = raft::make_device_matrix_view( + output_distances, nrow_, build_config_.output_graph_degree); + + raft::matrix::slice_coordinates coords{static_cast(0), + static_cast(0), + static_cast(nrow_), + static_cast(build_config_.output_graph_degree)}; + raft::matrix::slice( + res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords); } Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle();