Skip to content

Commit

Permalink
Fixing workaround for cuSPARSE bug with correct copy dimensions (#2185)
Browse files Browse the repository at this point in the history
This PR tries to fix the workaround that introduce a temporary copy of the Z matrix for SpMM. 
Note that the correct sizing for the copy is crucial - this is covered by the SPARSE_DIST_TEST.

However, I did NOT test whether this fixes the cuSparse bug. CC @cjnolet @trxcllnt

Authors:
  - Malte Förster (https://github.com/mfoerste4)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #2185
  • Loading branch information
mfoerste4 authored Feb 21, 2024
1 parent 573a034 commit c95bf6a
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions cpp/include/raft/sparse/linalg/spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,19 @@ void spmm(raft::resources const& handle,
{
bool is_row_major = detail::is_row_major(y, z);

// WARNING: The following copy is working around a bug in cusparse which causes an alignment issue
// and incorrect results. This bug is fixed in CUDA 12.5+ so this workaround shouldn't be removed
// until that version is supported.
auto size = is_row_major ? (z.extent(0) - 1) * z.stride(0) + z.extent(1)
: (z.extent(1) - 1) * z.stride(1) + z.extent(0);
rmm::device_uvector<ValueType> z_tmp(size, raft::resource::get_cuda_stream(handle));
raft::copy(z_tmp.data(), z.data_handle(), z_tmp.size(), raft::resource::get_cuda_stream(handle));

auto z_tmp_view =
is_row_major ? raft::make_device_strided_matrix_view<ValueType, IndexType, layout_c_contiguous>(
z.data_handle(), z.extent(0), z.extent(1), z.stride(0))
z_tmp.data(), z.extent(0), z.extent(1), z.stride(0))
: raft::make_device_strided_matrix_view<ValueType, IndexType, layout_f_contiguous>(
z.data_handle(), z.extent(0), z.extent(1), z.stride(1));
z_tmp.data(), z.extent(0), z.extent(1), z.stride(1));

auto descr_x = detail::create_descriptor(x);
auto descr_y = detail::create_descriptor(y);
Expand All @@ -74,10 +82,7 @@ void spmm(raft::resources const& handle,

// WARNING: Do not remove the following copy unless you can, with certainty, say that
// the underlying cuSPARSE issue affecting CUDA 12.2+ has been resolved.
raft::copy(z.data_handle(),
z_tmp_view.data_handle(),
z_tmp_view.size(),
raft::resource::get_cuda_stream(handle));
raft::copy(z.data_handle(), z_tmp.data(), z_tmp.size(), raft::resource::get_cuda_stream(handle));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y));
RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z));
Expand Down

0 comments on commit c95bf6a

Please sign in to comment.