Skip to content

Commit

Permalink
Optimize based on review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Mar 5, 2024
1 parent 3edeafd commit f015ed5
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ namespace raft::core {
*/
template <typename bitmap_t = uint32_t, typename index_t = uint32_t>
struct bitmap_view : public bitset_view<bitmap_t, index_t> {
static constexpr index_t bitmap_element_size = sizeof(bitmap_t) * 8;
// static_assert((std::is_same<bitmap_t, uint32_t>::value ||
// std::is_same<bitmap_t, uint64_t>::value),
// "The bitmap_t must be uint32_t or uint64_t.");
/**
* @brief Create a bitmap view from a device raw pointer.
*
Expand All @@ -47,7 +51,10 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
* @param cols Number of col in the matrix.
*/
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols), rows_(rows), cols_(cols)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols),
bitmap_ptr_{bitmap_ptr},
rows_(rows),
cols_(cols)
{
}

Expand All @@ -61,7 +68,10 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span,
index_t rows,
index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols), rows_(rows), cols_(cols)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols),
bitmap_ptr_{bitmap_span.data_handle()},
rows_(rows),
cols_(cols)
{
}

Expand Down Expand Up @@ -107,17 +117,75 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
* @brief Get the total number of rows
* @return index_t The total number of rows
*/
inline index_t get_n_rows() const { return rows_; }
inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; }

/**
* @brief Get the total number of columns
* @return index_t The total number of columns
*/
inline index_t get_n_cols() const { return cols_; }
inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; }

/**
* @brief Returns the number of non-zero bits in nnz_gpu_scalar.
*
* @param[in] res RAFT resources
* @param[out] nnz_gpu_scalar Device scalar to store the nnz
*/
void get_nnz(const raft::resources& res, raft::device_scalar_view<index_t> nnz_gpu_scalar)
{
auto n_elements_ = raft::ceildiv(rows_ * cols_, bitmap_element_size);
auto nnz_gpu = raft::make_device_vector_view<index_t, index_t>(nnz_gpu_scalar.data_handle(), 1);
auto bitmap_matrix_view = raft::make_device_matrix_view<const bitmap_t, index_t, col_major>(
bitmap_ptr_, n_elements_, 1);

bitmap_t n_last_element = ((rows_ * cols_) % bitmap_element_size);
bitmap_t last_element_mask =
n_last_element ? (bitmap_t)((bitmap_t{1} << n_last_element) - bitmap_t{1}) : ~bitmap_t{0};
raft::linalg::coalesced_reduction(
res,
bitmap_matrix_view,
nnz_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitmap_t element, index_t index) {
index_t result = 0;
if constexpr (bitmap_element_size == 64) {
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(element & last_element_mask));
else
result = index_t(raft::detail::popc(element));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask));
else
result = index_t(raft::detail::popc(uint32_t{element}));
}

return result;
});
}

/**
* @brief Returns the number of non-zero bits.
*
* @param res RAFT resources
* @return index_t Number of non-zero bits
*/
auto get_nnz(const raft::resources& res) -> index_t
{
auto nnz_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
get_nnz(res, nnz_gpu_scalar.view());
index_t nnz_gpu = 0;
raft::update_host(&nnz_gpu, nnz_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
return nnz_gpu;
}

private:
index_t rows_;
index_t cols_;

bitmap_t* bitmap_ptr_;
};

/** @} */
Expand Down
10 changes: 8 additions & 2 deletions cpp/include/raft/sparse/convert/csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#pragma once

#include <raft/core/bitmap.hpp>
#include <raft/core/bitmap.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/sparse/convert/detail/adj_to_csr.cuh>
#include <raft/sparse/convert/detail/bitmap_to_csr.cuh>
Expand Down Expand Up @@ -132,12 +132,18 @@ void bitmap_to_csr(raft::resources const& handle,
"Number of columns in bitmap must be equal to "
"number of columns in csr");

RAFT_EXPECTS(csr_view.get_nnz() >= bitmap.get_nnz(handle),
"Number of elements in csr must be equal or larger than "
"number of non-zero bits in bitmap");

detail::bitmap_to_csr(handle,
bitmap.data(),
csr_view.get_n_rows(),
csr_view.get_n_cols(),
csr_view.get_nnz(),
csr_view.get_indptr().data(),
csr_view.get_indices().data());
csr_view.get_indices().data(),
csr.get_elements().data());
}

}; // end NAMESPACE convert
Expand Down
20 changes: 15 additions & 5 deletions cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,14 @@ void calc_nnz_by_rows(raft::resources const& handle,
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

/*
Execute the exclusive_scan within one warp with no inter-warp communication.
This function calculates the exclusive prefix sum of `value` across threads within the same warp.
Each thread in the warp will end up with the sum of all the values of the threads with lower IDs
in the same warp, with the first thread always getting a sum of 0.
*/
template <typename value_t>
__device__ inline value_t warp_exclusive(value_t value)
RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value)
{
int lane_id = threadIdx.x & 0x1f;
value_t shifted_value = __shfl_up_sync(0xffffffff, value, 1, warpSize);
Expand Down Expand Up @@ -177,7 +183,8 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb)
l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit);
}

index_t l_sum = g_sum + warp_exclusive(static_cast<index_t>(raft::detail::popc(l_bitmap)));
index_t l_sum =
g_sum + warp_exclusive_scan(static_cast<index_t>(raft::detail::popc(l_bitmap)));

for (int i = 0; i < BITS_PER_BITMAP; i++) {
if (l_bitmap & (ONE << i)) {
Expand Down Expand Up @@ -218,16 +225,18 @@ void fill_indices_by_rows(raft::resources const& handle,
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename bitmap_t, typename index_t>
template <typename bitmap_t, typename index_t, typename nnz_t, typename value_t>
void bitmap_to_csr(raft::resources const& handle,
const bitmap_t* bitmap,
index_t num_rows,
index_t num_cols,
nnz_t nnz,
index_t* indptr,
index_t* indices)
index_t* indices,
value_t* values)
{
const index_t total = num_rows * num_cols;
if (total == 0) { return; }
if (total == 0 || nnz == 0) { return; }

auto thrust_policy = resource::get_thrust_policy(handle);
auto stream = resource::get_cuda_stream(handle);
Expand All @@ -237,6 +246,7 @@ void bitmap_to_csr(raft::resources const& handle,
calc_nnz_by_rows(handle, bitmap, num_rows, num_cols, indptr);
thrust::exclusive_scan(thrust_policy, indptr, indptr + num_rows + 1, indptr);
fill_indices_by_rows(handle, bitmap, indptr, num_rows, num_cols, indices);
thrust::fill_n(thrust_policy, values, nnz, value_t{1});
}

}; // end NAMESPACE detail
Expand Down
22 changes: 16 additions & 6 deletions cpp/test/sparse/convert_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "../test_utils.cuh"
#include <gtest/gtest.h>
#include <raft/core/bitmap.hpp>
#include <raft/core/bitmap.cuh>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/util/cuda_utils.cuh>

Expand Down Expand Up @@ -227,7 +227,7 @@ struct BitmapToCSRInputs {
float sparsity;
};

template <typename bitmap_t, typename index_t>
template <typename bitmap_t, typename index_t, typename value_t>
class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_t>> {
public:
BitmapToCSRTest()
Expand All @@ -238,7 +238,8 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
indptr_d(0, stream),
values_d(0, stream),
indptr_expected_d(0, stream),
indices_expected_d(0, stream)
indices_expected_d(0, stream),
values_expected_d(0, stream)
{
}

Expand Down Expand Up @@ -341,8 +342,13 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
bitmap_d.resize(bitmap_h.size(), stream);
indptr_d.resize(params.n_rows + 1, stream);
indices_d.resize(nnz, stream);

indptr_expected_d.resize(params.n_rows + 1, stream);
indices_expected_d.resize(nnz, stream);
values_expected_d.resize(nnz, stream);

thrust::fill_n(resource::get_thrust_policy(handle), values_expected_d.data(), nnz, value_t{1});

values_d.resize(nnz, stream);

update_device(indices_expected_d.data(), indices_h.data(), indices_h.size(), stream);
Expand All @@ -359,7 +365,7 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_

auto csr_view = raft::make_device_compressed_structure_view<index_t, index_t, index_t>(
indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz);
auto csr = raft::make_device_csr_matrix_view<float, index_t>(values_d.data(), csr_view);
auto csr = raft::make_device_csr_matrix_view<value_t, index_t>(values_d.data(), csr_view);

convert::bitmap_to_csr<bitmap_t, index_t>(handle, bitmap, csr);

Expand All @@ -382,6 +388,9 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
resource::sync_stream(handle);

ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h));

ASSERT_TRUE(raft::devArrMatch<value_t>(
values_d.data(), values_expected_d.data(), nnz, raft::Compare<value_t>(), stream));
}

protected:
Expand All @@ -400,12 +409,13 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_

rmm::device_uvector<index_t> indptr_expected_d;
rmm::device_uvector<index_t> indices_expected_d;
rmm::device_uvector<float> values_expected_d;
};

using BitmapToCSRTestI = BitmapToCSRTest<uint32_t, int>;
using BitmapToCSRTestI = BitmapToCSRTest<uint32_t, int, float>;
TEST_P(BitmapToCSRTestI, Result) { Run(); }

using BitmapToCSRTestL = BitmapToCSRTest<uint32_t, int64_t>;
using BitmapToCSRTestL = BitmapToCSRTest<uint32_t, int64_t, float>;
TEST_P(BitmapToCSRTestL, Result) { Run(); }

template <typename index_t>
Expand Down

0 comments on commit f015ed5

Please sign in to comment.