diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index dbe788b15c..245ecc76e3 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -143,6 +143,7 @@ RAFT_KERNEL __launch_bounds__(bitmap_to_csr_tpb) index_t* indptr, size_t num_rows, size_t num_cols, + nnz_t nnz, index_t* indices, nnz_t* sub_col_nnz, index_t bits_per_sub_col) @@ -158,14 +159,15 @@ RAFT_KERNEL __launch_bounds__(bitmap_to_csr_tpb) const auto tid = threadIdx.x; const auto row = blockIdx.x; + const auto num_sub_cols = gridDim.y; + const auto sub_col = blockIdx.y; + // Ensure the HBM allocated for CSR values is sufficient to handle all non-zero bitmap bits. // An assert will trigger if the allocated HBM is insufficient when `NDEBUG` isn't defined. // Note: Assertion is active only if `NDEBUG` is undefined. if constexpr (check_nnz) { - if (tid == 0) { assert(nnz < indptr[num_rows]); } + if (tid == 0) { assert(nnz < sub_col_nnz[num_rows * num_sub_cols]); } } - const auto num_sub_cols = gridDim.y; - const auto sub_col = blockIdx.y; size_t s_bit = size_t(row) * num_cols + sub_col * bits_per_sub_col; size_t e_bit = min(s_bit + bits_per_sub_col, size_t(num_cols) * (row + 1)); @@ -227,8 +229,8 @@ RAFT_KERNEL __launch_bounds__(bitmap_to_csr_tpb) int l_bits = raft::detail::popc(l_bitmap); int l_sum_32b = 0; - BlockScan(scan_storage).ExclusiveSum(l_bits, l_sum_32b); - l_sum = l_sum_32b + g_sum; + BlockScan(scan_storage).InclusiveSum(l_bits, l_sum_32b); + l_sum = l_sum_32b + g_sum - l_bits; #pragma unroll for (int i = 0; i < BITS_PER_BITMAP; i++) { @@ -240,7 +242,7 @@ RAFT_KERNEL __launch_bounds__(bitmap_to_csr_tpb) l_sum += guard[i]; } - if (threadIdx.x == (bitmap_to_csr_tpb - 1)) { g_sum += (l_sum_32b + l_bits); } + if (threadIdx.x == (bitmap_to_csr_tpb - 1)) { g_sum += (l_sum_32b); } g_bits += BITS_PER_BITMAP * blockDim.x; } } @@ -251,6 +253,7 @@ void fill_indices_by_rows(raft::resources const& handle, index_t* indptr, index_t num_rows, index_t num_cols, + nnz_t nnz, index_t* indices, nnz_t* sub_col_nnz, index_t bits_per_sub_col, @@ -264,7 +267,7 @@ void fill_indices_by_rows(raft::resources const& handle, auto block = bitmap_to_csr_tpb; fill_indices_by_rows_kernel<<>>( - bitmap, indptr, num_rows, num_cols, indices, sub_col_nnz, bits_per_sub_col); + bitmap, indptr, num_rows, num_cols, nnz, indices, sub_col_nnz, bits_per_sub_col); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -338,6 +341,7 @@ void bitmap_to_csr(raft::resources const& handle, indptr, csr_view.get_n_rows(), csr_view.get_n_cols(), + csr_view.get_nnz(), indices, sub_nnz.data(), bits_per_sub_col,