diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.cuh similarity index 56% rename from cpp/include/raft/core/bitmap.hpp rename to cpp/include/raft/core/bitmap.cuh index 94c8e400bf..cdf8ddd4db 100644 --- a/cpp/include/raft/core/bitmap.hpp +++ b/cpp/include/raft/core/bitmap.cuh @@ -39,6 +39,10 @@ namespace raft::core { */ template struct bitmap_view : public bitset_view { + static constexpr index_t bitmap_element_size = sizeof(bitmap_t) * 8; + // static_assert((std::is_same::value || + // std::is_same::value), + // "The bitmap_t must be uint32_t or uint64_t."); /** * @brief Create a bitmap view from a device raw pointer. * @@ -47,7 +51,10 @@ struct bitmap_view : public bitset_view { * @param cols Number of col in the matrix. */ _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) - : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + : bitset_view(bitmap_ptr, rows * cols), + bitmap_ptr_{bitmap_ptr}, + rows_(rows), + cols_(cols) { } @@ -61,7 +68,10 @@ struct bitmap_view : public bitset_view { _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, index_t rows, index_t cols) - : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + : bitset_view(bitmap_span, rows * cols), + bitmap_ptr_{bitmap_span.data_handle()}, + rows_(rows), + cols_(cols) { } @@ -107,17 +117,75 @@ struct bitmap_view : public bitset_view { * @brief Get the total number of rows * @return index_t The total number of rows */ - inline index_t get_n_rows() const { return rows_; } + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } /** * @brief Get the total number of columns * @return index_t The total number of columns */ - inline index_t get_n_cols() const { return cols_; } + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } + + /** + * @brief Returns the number of non-zero bits in nnz_gpu_scalar. + * + * @param[in] res RAFT resources + * @param[out] nnz_gpu_scalar Device scalar to store the nnz + */ + void get_nnz(const raft::resources& res, raft::device_scalar_view nnz_gpu_scalar) + { + auto n_elements_ = raft::ceildiv(rows_ * cols_, bitmap_element_size); + auto nnz_gpu = raft::make_device_vector_view(nnz_gpu_scalar.data_handle(), 1); + auto bitmap_matrix_view = raft::make_device_matrix_view( + bitmap_ptr_, n_elements_, 1); + + bitmap_t n_last_element = ((rows_ * cols_) % bitmap_element_size); + bitmap_t last_element_mask = + n_last_element ? (bitmap_t)((bitmap_t{1} << n_last_element) - bitmap_t{1}) : ~bitmap_t{0}; + raft::linalg::coalesced_reduction( + res, + bitmap_matrix_view, + nnz_gpu, + index_t{0}, + false, + [last_element_mask, n_elements_] __device__(bitmap_t element, index_t index) { + index_t result = 0; + if constexpr (bitmap_element_size == 64) { + if (index == n_elements_ - 1) + result = index_t(raft::detail::popc(element & last_element_mask)); + else + result = index_t(raft::detail::popc(element)); + } else { // Needed because popc is not overloaded for 16 and 8 bit elements + if (index == n_elements_ - 1) + result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); + else + result = index_t(raft::detail::popc(uint32_t{element})); + } + + return result; + }); + } + + /** + * @brief Returns the number of non-zero bits. + * + * @param res RAFT resources + * @return index_t Number of non-zero bits + */ + auto get_nnz(const raft::resources& res) -> index_t + { + auto nnz_gpu_scalar = raft::make_device_scalar(res, 0.0); + get_nnz(res, nnz_gpu_scalar.view()); + index_t nnz_gpu = 0; + raft::update_host(&nnz_gpu, nnz_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + return nnz_gpu; + } private: index_t rows_; index_t cols_; + + bitmap_t* bitmap_ptr_; }; /** @} */ diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 87cef3218a..cb84e5c8dd 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -18,7 +18,7 @@ #pragma once -#include +#include #include #include #include @@ -132,12 +132,18 @@ void bitmap_to_csr(raft::resources const& handle, "Number of columns in bitmap must be equal to " "number of columns in csr"); + RAFT_EXPECTS(csr_view.get_nnz() >= bitmap.get_nnz(handle), + "Number of elements in csr must be equal or larger than " + "number of non-zero bits in bitmap"); + detail::bitmap_to_csr(handle, bitmap.data(), csr_view.get_n_rows(), csr_view.get_n_cols(), + csr_view.get_nnz(), csr_view.get_indptr().data(), - csr_view.get_indices().data()); + csr_view.get_indices().data(), + csr.get_elements().data()); } }; // end NAMESPACE convert diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index d2f4a9b59d..21b31d7717 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -118,8 +118,14 @@ void calc_nnz_by_rows(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } +/* + Execute the exclusive_scan within one warp with no inter-warp communication. + This function calculates the exclusive prefix sum of `value` across threads within the same warp. + Each thread in the warp will end up with the sum of all the values of the threads with lower IDs + in the same warp, with the first thread always getting a sum of 0. +*/ template -__device__ inline value_t warp_exclusive(value_t value) +RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value) { int lane_id = threadIdx.x & 0x1f; value_t shifted_value = __shfl_up_sync(0xffffffff, value, 1, warpSize); @@ -177,7 +183,8 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); } - index_t l_sum = g_sum + warp_exclusive(static_cast(raft::detail::popc(l_bitmap))); + index_t l_sum = + g_sum + warp_exclusive_scan(static_cast(raft::detail::popc(l_bitmap))); for (int i = 0; i < BITS_PER_BITMAP; i++) { if (l_bitmap & (ONE << i)) { @@ -218,16 +225,18 @@ void fill_indices_by_rows(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template +template void bitmap_to_csr(raft::resources const& handle, const bitmap_t* bitmap, index_t num_rows, index_t num_cols, + nnz_t nnz, index_t* indptr, - index_t* indices) + index_t* indices, + value_t* values) { const index_t total = num_rows * num_cols; - if (total == 0) { return; } + if (total == 0 || nnz == 0) { return; } auto thrust_policy = resource::get_thrust_policy(handle); auto stream = resource::get_cuda_stream(handle); @@ -237,6 +246,7 @@ void bitmap_to_csr(raft::resources const& handle, calc_nnz_by_rows(handle, bitmap, num_rows, num_cols, indptr); thrust::exclusive_scan(thrust_policy, indptr, indptr + num_rows + 1, indptr); fill_indices_by_rows(handle, bitmap, indptr, num_rows, num_cols, indices); + thrust::fill_n(thrust_policy, values, nnz, value_t{1}); } }; // end NAMESPACE detail diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 8719542c94..c9545801a4 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -16,7 +16,7 @@ #include "../test_utils.cuh" #include -#include +#include #include #include @@ -227,7 +227,7 @@ struct BitmapToCSRInputs { float sparsity; }; -template +template class BitmapToCSRTest : public ::testing::TestWithParam> { public: BitmapToCSRTest() @@ -238,7 +238,8 @@ class BitmapToCSRTest : public ::testing::TestWithParam( indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); - auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); + auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); convert::bitmap_to_csr(handle, bitmap, csr); @@ -382,6 +388,9 @@ class BitmapToCSRTest : public ::testing::TestWithParam( + values_d.data(), values_expected_d.data(), nnz, raft::Compare(), stream)); } protected: @@ -400,12 +409,13 @@ class BitmapToCSRTest : public ::testing::TestWithParam indptr_expected_d; rmm::device_uvector indices_expected_d; + rmm::device_uvector values_expected_d; }; -using BitmapToCSRTestI = BitmapToCSRTest; +using BitmapToCSRTestI = BitmapToCSRTest; TEST_P(BitmapToCSRTestI, Result) { Run(); } -using BitmapToCSRTestL = BitmapToCSRTest; +using BitmapToCSRTestL = BitmapToCSRTest; TEST_P(BitmapToCSRTestL, Result) { Run(); } template