diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp index 529569865b..dc60a4782d 100644 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -44,18 +44,21 @@ void serialize(const raft::resources& res, std::ostream& os, const strided_dataset& dataset) { - serialize_scalar(res, os, dataset.n_rows()); - serialize_scalar(res, os, dataset.dim()); - serialize_scalar(res, os, dataset.stride()); + auto n_rows = dataset.n_rows(); + auto dim = dataset.dim(); + auto stride = dataset.stride(); + serialize_scalar(res, os, n_rows); + serialize_scalar(res, os, dim); + serialize_scalar(res, os, stride); // Remove padding before saving the dataset auto src = dataset.view(); - auto dst = make_host_mdarray(src.extents()); + auto dst = make_host_matrix(n_rows, dim); RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), - sizeof(DataT) * dst.extent(1), + sizeof(DataT) * dim, src.data_handle(), - sizeof(DataT) * src.stride(0), - sizeof(DataT) * dst.extent(1), - src.extent(0), + sizeof(DataT) * stride, + sizeof(DataT) * dim, + n_rows, cudaMemcpyDefault, resource::get_cuda_stream(res))); resource::sync_stream(res); @@ -144,8 +147,10 @@ void deserialize(raft::resources const& res, using out_container_policy_type = typename out_mdarray_type::container_policy_type; using out_owning_type = owning_dataset; - auto host_arrray = make_host_mdarray(out_extents); + auto host_arrray = make_host_matrix(n_rows, dim); deserialize_mdspan(res, is, host_arrray.view()); + RAFT_CUDA_TRY(cudaMemsetAsync( + out_array.data_handle(), 0, sizeof(DataT) * out_array.size(), resource::get_cuda_stream(res))); RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), sizeof(DataT) * stride, host_arrray.data_handle(),