diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index dd661c71ac..2eb1e58f1c 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -60,11 +60,19 @@ void spmm(raft::resources const& handle, { bool is_row_major = detail::is_row_major(y, z); + // WARNING: The following copy is working around a bug in cusparse which causes an alignment issue + // and incorrect results. This bug is fixed in CUDA 12.5+ so this workaround shouldn't be removed + // until that version is supported. + auto size = is_row_major ? (z.extent(0) - 1) * z.stride(0) + z.extent(1) + : (z.extent(1) - 1) * z.stride(1) + z.extent(0); + rmm::device_uvector z_tmp(size, raft::resource::get_cuda_stream(handle)); + raft::copy(z_tmp.data(), z.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle)); + auto z_tmp_view = is_row_major ? raft::make_device_strided_matrix_view( - z.data_handle(), z.extent(0), z.extent(1), z.stride(0)) + z_tmp.data(), z.extent(0), z.extent(1), z.stride(0)) : raft::make_device_strided_matrix_view( - z.data_handle(), z.extent(0), z.extent(1), z.stride(1)); + z_tmp.data(), z.extent(0), z.extent(1), z.stride(1)); auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); @@ -74,10 +82,7 @@ void spmm(raft::resources const& handle, // WARNING: Do not remove the following copy unless you can, with certainty, say that // the underlying cuSPARSE issue affecting CUDA 12.2+ has been resolved. - raft::copy(z.data_handle(), - z_tmp_view.data_handle(), - z_tmp_view.size(), - raft::resource::get_cuda_stream(handle)); + raft::copy(z.data_handle(), z_tmp.data(), z_tmp.size(), raft::resource::get_cuda_stream(handle)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z));