Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Add support for bitmap_view & the API of bitmap_to_csr #2109

Merged
merged 30 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
165b278
[FEA] Add support for bitmap_view & the API of `bitmap_to_csr`
rhdong Jan 21, 2024
62d9e34
Merge branch 'branch-24.02' into rhdong/bitmap
rhdong Jan 23, 2024
7dea38d
fix doc build CI error 35fe8a
rhdong Jan 23, 2024
1b6e784
Merge branch 'rhdong/bitmap' of https://github.com/rhdong/raft into r…
rhdong Jan 23, 2024
67f0650
try to fix the CI failure
rhdong Jan 24, 2024
1e5294c
Merge branch 'branch-24.02' into rhdong/bitmap
rhdong Jan 24, 2024
70cce47
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Jan 25, 2024
f8960d9
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Feb 23, 2024
20f76af
fix a ut& benchmark error
rhdong Feb 21, 2024
7dc7cf8
Improve performance & eliminate the temp buffer.
rhdong Feb 23, 2024
cab8691
fix : compatible with devices with compute capability < 8.0
rhdong Feb 24, 2024
3edeafd
Merge remote-tracking branch 'origin/branch-24.04' into rhdong/bitmap…
rhdong Mar 5, 2024
f015ed5
Optimize based on review comments
rhdong Mar 5, 2024
308eb8b
Ensure the `bitmap_t` to be {uint64_t, uint32_t}
rhdong Mar 5, 2024
0d5ec74
Merge remote-tracking branch 'origin/branch-24.04' into rhdong/bitmap…
rhdong Mar 6, 2024
5cf5a8b
Merge remote-tracking branch 'origin/branch-24.04' into rhdong/bitmap…
rhdong Mar 6, 2024
42d5355
Optimize based on review comments-3rd round
rhdong Mar 7, 2024
1aff844
[Fix] `std::vector` compilation error of in the `nvtx.hpp`
rhdong Mar 6, 2024
59533e0
Remove `#undef NDEBUG`
rhdong Mar 7, 2024
311f2d1
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Mar 7, 2024
7d960bd
Optimize based on review comments-4rd round
rhdong Mar 8, 2024
8a3b759
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Mar 11, 2024
7dea754
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Mar 11, 2024
08397c6
API changes to Owning/perserving
rhdong Mar 14, 2024
b2ad06e
d1
rhdong Mar 15, 2024
09e64ab
Add the UT cases for owning scenario
rhdong Mar 15, 2024
53202c6
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Mar 15, 2024
8838a36
fix benchmark compile error
rhdong Mar 16, 2024
1c91c74
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Mar 19, 2024
01dd903
Merge branch 'branch-24.04' into rhdong/bitmap
rhdong Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions cpp/include/raft/sparse/convert/csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,21 @@ void adj_to_csr(raft::resources const& handle,
/**
* @brief Converts a bitmap matrix into unsorted CSR format matrix.
*
* @tparam bitmap_t Underlying type of the bitmap.
* @tparam index_t Indexing type used.
* @tparam value_t Data type of CSR
* @tparam nnz_t Type of CSR
* @tparam bitmap_t Underlying type of the bitmap.
* @tparam index_t Indexing type used.
* @tparam csr_matrix_t Reference Type of CSR Matrix, raft::device_csr_matrix
*
* @param[in] handle RAFT handle
* @param[in] bitmap input raft::bitmap_view
* @param[out] csr output raft::device_csr_matrix_view
* @param[in] handle RAFT handle
* @param[in] bitmap input raft::bitmap_view
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In and out are already specified on the param so we shouldn't need to duplicate those in the description. Can you please expand the descriptions, though? Instead of just "handle" and "raft::bitmap_view", it would be helpful to explain their purpose.

Also, please don't include "reference type" in the description. The caller doesn't have to care that it's a reference type. That's automatic when they pass in the argument.

Copy link
Member

@cjnolet cjnolet Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still be accepting device_structure_owning_csr_matrix here unless we intend to explicitly support both. If a user passes in a structure-preserving csr matrix, they will end up getting an error when you try to initialize the sparsity.

EDIT: Nevermind. I see where you are supporting both below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense.

* @param[out] csr output raft::device_csr_matrix
*/
template <typename bitmap_t, typename index_t, typename value_t, typename nnz_t>
template <typename bitmap_t,
typename index_t,
typename csr_matrix_t,
typename = std::enable_if_t<raft::is_device_csr_matrix_v<csr_matrix_t>>>
void bitmap_to_csr(raft::resources const& handle,
raft::core::bitmap_view<bitmap_t, index_t> bitmap,
raft::device_csr_matrix_view<value_t, index_t, index_t, nnz_t> csr)
csr_matrix_t& csr)
{
detail::bitmap_to_csr(handle, bitmap, csr);
}
Expand Down
49 changes: 34 additions & 15 deletions cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value)
// Threads per block in fill_indices_by_rows_kernel.
static const constexpr int fill_indices_by_rows_tpb = 32;

template <typename bitmap_t, typename index_t, typename nnz_t>
template <typename bitmap_t, typename index_t, typename nnz_t, bool check_nnz>
RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb)
fill_indices_by_rows_kernel(const bitmap_t* bitmap,
const index_t* indptr,
Expand All @@ -163,7 +163,9 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb)
// Ensure the HBM allocated for CSR values is sufficient to handle all non-zero bitmap bits.
// An assert will trigger if the allocated HBM is insufficient when `NDEBUG` isn't defined.
// Note: Assertion is active only if `NDEBUG` is undefined.
if (lane_id == 0) { assert(nnz < indptr[num_rows]); }
if constexpr (check_nnz) {
if (lane_id == 0) { assert(nnz < indptr[num_rows]); }
}

#pragma unroll
for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) {
Expand Down Expand Up @@ -204,7 +206,7 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb)
}
}

template <typename bitmap_t, typename index_t, typename nnz_t>
template <typename bitmap_t, typename index_t, typename nnz_t, bool check_nnz = false>
void fill_indices_by_rows(raft::resources const& handle,
const bitmap_t* bitmap,
const index_t* indptr,
Expand All @@ -223,23 +225,26 @@ void fill_indices_by_rows(raft::resources const& handle,
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&blocks_per_sm,
fill_indices_by_rows_kernel<bitmap_t, index_t, nnz_t>,
fill_indices_by_rows_kernel<bitmap_t, index_t, nnz_t, check_nnz>,
fill_indices_by_rows_tpb,
0);

index_t max_active_blocks = sm_count * blocks_per_sm;
auto grid = std::min(max_active_blocks, num_rows);
auto block = fill_indices_by_rows_tpb;

fill_indices_by_rows_kernel<bitmap_t, index_t, nnz_t>
fill_indices_by_rows_kernel<bitmap_t, index_t, nnz_t, check_nnz>
<<<grid, block, 0, stream>>>(bitmap, indptr, num_rows, num_cols, nnz, bitmap_num, indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename bitmap_t, typename index_t, typename nnz_t, typename value_t>
template <typename bitmap_t,
typename index_t,
typename csr_matrix_t,
typename = std::enable_if_t<raft::is_device_csr_matrix_v<csr_matrix_t>>>
void bitmap_to_csr(raft::resources const& handle,
raft::core::bitmap_view<bitmap_t, index_t> bitmap,
raft::device_csr_matrix_view<value_t, index_t, index_t, nnz_t> csr)
csr_matrix_t& csr)
{
auto csr_view = csr.structure_view();

Expand All @@ -265,14 +270,28 @@ void bitmap_to_csr(raft::resources const& handle,

calc_nnz_by_rows(handle, bitmap.data(), csr_view.get_n_rows(), csr_view.get_n_cols(), indptr);
thrust::exclusive_scan(thrust_policy, indptr, indptr + csr_view.get_n_rows() + 1, indptr);
fill_indices_by_rows(handle,
bitmap.data(),
indptr,
csr_view.get_n_rows(),
csr_view.get_n_cols(),
csr_view.get_nnz(),
indices);
thrust::fill_n(thrust_policy, csr.get_elements().data(), csr_view.get_nnz(), value_t{1});

if constexpr (is_device_csr_sparsity_owning_v<csr_matrix_t>) {
index_t nnz = 0;
RAFT_CUDA_TRY(cudaMemcpyAsync(
&nnz, indptr + csr_view.get_n_rows(), sizeof(index_t), cudaMemcpyDeviceToHost, stream));
resource::sync_stream(handle);
csr.initialize_sparsity(nnz);
}
constexpr bool check_nnz = is_device_csr_sparsity_preserving_v<csr_matrix_t>;
fill_indices_by_rows<bitmap_t, index_t, typename csr_matrix_t::nnz_type, check_nnz>(
handle,
bitmap.data(),
indptr,
csr_view.get_n_rows(),
csr_view.get_n_cols(),
csr_view.get_nnz(),
indices);

thrust::fill_n(thrust_policy,
csr.get_elements().data(),
csr_view.get_nnz(),
typename csr_matrix_t::element_t(1));
}

}; // end NAMESPACE detail
Expand Down
7 changes: 3 additions & 4 deletions cpp/test/sparse/convert_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,9 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_

auto csr_view = raft::make_device_compressed_structure_view<index_t, index_t, index_t>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have to know the sparsity of the CSR when we call bitmap_to_csr. To provide a good user experienc, we should accept a structure-owning CSR so we can count the number of nonzeros in the bitmap view for the user and then initialize the csr w/ the sparsity. Please see the device_sparsity_owning_csr_matrix_view and the initialize_sparsity method.

indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz);
auto csr = raft::make_device_csr_matrix_view<value_t, index_t>(values_d.data(), csr_view);
auto csr = raft::make_device_csr_matrix<value_t, index_t>(handle, csr_view);

convert::bitmap_to_csr<bitmap_t, index_t>(handle, bitmap, csr);
convert::bitmap_to_csr(handle, bitmap, csr);

resource::sync_stream(handle);

Expand All @@ -389,9 +389,8 @@ class BitmapToCSRTest : public ::testing::TestWithParam<BitmapToCSRInputs<index_
resource::sync_stream(handle);

ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h));

ASSERT_TRUE(raft::devArrMatch<value_t>(
values_d.data(), values_expected_d.data(), nnz, raft::Compare<value_t>(), stream));
csr.get_elements().data(), values_expected_d.data(), nnz, raft::Compare<value_t>(), stream));
}

protected:
Expand Down
Loading