From 165b27866f0ede81ee2a0c07c7f89cccb0a2fe9e Mon Sep 17 00:00:00 2001 From: rhdong Date: Sun, 21 Jan 2024 13:48:15 -0800 Subject: [PATCH 01/16] [FEA] Add support for bitmap_view & the API of `bitmap_to_csr` - This PR is one part of the Feature of [FEA] Pre-filtered brute-force KNN #1969 Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) --- cpp/bench/prims/CMakeLists.txt | 9 +- cpp/bench/prims/sparse/bitmap_to_csr.cu | 155 +++++++++++++ cpp/include/raft/core/bitmap.hpp | 112 +++++++++ cpp/include/raft/sparse/convert/csr.cuh | 31 ++- .../sparse/convert/detail/bitmap_to_csr.cuh | 151 +++++++++++++ cpp/test/sparse/convert_csr.cu | 213 +++++++++++++++++- docs/source/cpp_api/core_bitmap.rst | 15 ++ 7 files changed, 683 insertions(+), 3 deletions(-) create mode 100644 cpp/bench/prims/sparse/bitmap_to_csr.cu create mode 100644 cpp/include/raft/core/bitmap.hpp create mode 100644 cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh create mode 100644 docs/source/cpp_api/core_bitmap.rst diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 3a2431cd34..d7cb8dbd59 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -139,7 +139,14 @@ if(BUILD_PRIMS_BENCH) bench/prims/random/rng.cu bench/prims/main.cpp ) - ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) + ConfigureBench( + NAME + SPARSE_BENCH + PATH + bench/prims/sparse/bitmap_to_csr.cu + bench/prims/sparse/convert_csr.cu + bench/prims/main.cpp + ) ConfigureBench( NAME diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu new file mode 100644 index 0000000000..6bab064062 --- /dev/null +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_rows; + index_t n_cols; + float sparsity; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << " rows*cols=" << params.n_rows << "*" << params.n_cols << "\tsparsity=" << params.sparsity; + return os; +} + +template +struct BitmapToCsrTest : public fixture { + BitmapToCsrTest(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + bitmap_d(0, stream), + nnz(0), + indptr_d(0, stream), + indices_d(0, stream), + values_d(0, stream) + { + index_t element = raft::ceildiv(params.n_rows * params.n_cols, index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, bitmap_h); + + bitmap_d.resize(bitmap_h.size(), stream); + indptr_d.resize(params.n_rows + 1, stream); + indices_d.resize(nnz, stream); + values_d.resize(nnz, stream); + + update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + + resource::sync_stream(handle); + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) >> bit_position); + num_ones--; + } + } + return res; + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto bitmap = + raft::core::bitmap_view(bitmap_d.data(), params.n_rows, params.n_cols); + + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); + + raft::sparse::convert::bitmap_to_csr(handle, bitmap, csr); + + resource::sync_stream(handle); + loop_on_state(state, [this, &bitmap, &csr]() { + raft::sparse::convert::bitmap_to_csr(handle, bitmap, csr); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + + rmm::device_uvector bitmap_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + index_t nnz; +}; // struct BitmapToCsrTest + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + float sparsity; + }; + + const std::vector params_group = raft::util::itertools::product( + {index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.01f, 0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.sparsity})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((BitmapToCsrTest), "", getInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp new file mode 100644 index 0000000000..15aa7e2b40 --- /dev/null +++ b/cpp/include/raft/core/bitmap.hpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::core { +/** + * @defgroup bitmap Bitmap + * @{ + */ +/** + * @brief View of a RAFT Bitmap. + * + * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view + * with row major order. This class provides functionality for handling a matrix where each element + * is represented as a bit in a bitmap. + * + * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitmap_view : public bitset_view { + /** + * @brief Create a bitmap view from a device raw pointer. + * + * @param bitset_ptr Device raw pointer + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) + : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + { + } + + /** + * @brief Create a bitmap view from a device vector view of the bitset. + * + * @param bitmap_span Device vector view of the bitmap + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t rows, + index_t cols) + : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + { + } + + private: + // Hide the constructors of bitset_view. + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) + : bitset_view(bitmap_ptr, bitmap_len) + { + } + + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t bitmap_len) + : bitset_view(bitmap_span, bitmap_len) + { + } + + public: + /** + * @brief Device function to test if a given row and col are set in the bitmap. + * + * @param row Row index of the bit to test + * @param col Col index of the bit to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_DEVICE auto test(const index_t row, const index_t col) const -> bool + { + return test(row * cols_ + col); + } + + /** + * @brief Device function to set a given row and col to set_value in the bitset. + * + * @param row Row index of the bit to set + * @param col Col index of the bit to set + * @param new_value Value to set the bit to (true or false) + */ + inline _RAFT_DEVICE void set(const index_t row, const index_t col, bool new_value) const + { + set(row * cols_ + col, &new_value); + } + + private: + index_t rows_; + index_t cols_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 999e64cb0b..02bb46222b 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,10 @@ #pragma once +#include +#include #include +#include #include #include @@ -102,6 +105,32 @@ void adj_to_csr(raft::resources const& handle, detail::adj_to_csr(handle, adj, row_ind, num_rows, num_cols, tmp, out_col_ind); } +/** + * @brief Converts a bitmap matrix into unsorted CSR format matrix. + * + * @tparam bitmap_t Underlying type of the bitmap. + * @tparam index_t Indexing type used. + * @tparam value_t Data type of CSR + * @tparam nnz_t Type of CSR + * + * @param[in] handle RAFT handle + * @param[in] bitmap input raft::bitmap_view + * @param[inout] csr output raft::device_csr_matrix_view + */ +template +void bitmap_to_csr(raft::resources const& handle, + raft::core::bitmap_view bitmap, + raft::device_csr_matrix_view csr) +{ + auto csr_view = csr.structure_view(); + detail::bitmap_to_csr(handle, + bitmap.data(), + csr_view.get_n_rows(), + csr_view.get_n_cols(), + csr_view.get_indptr().data(), + csr_view.get_indices().data()); +} + }; // end NAMESPACE convert }; // end NAMESPACE sparse }; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh new file mode 100644 index 0000000000..f906564ca9 --- /dev/null +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace convert { +namespace detail { + +// Threads per block in bitmap_to_any_kernel. +static const constexpr int bitmap_to_any_tpb = 512; + +template +RAFT_KERNEL __launch_bounds__(bitmap_to_any_tpb) + bitmap_to_any_kernel(const bitmap_t* bitmap, const index_t num_bits, any_t* index_array) +{ + index_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; + for (index_t idx = thread_idx; idx < num_bits; idx += blockDim.x * gridDim.x) { + bitmap_t element = bitmap[idx / (8 * sizeof(bitmap_t))]; + index_t bit_position = idx % (8 * sizeof(bitmap_t)); + index_array[idx] = static_cast((element >> bit_position) & 1); + } +} + +template +void bitmap_to_any(raft::resources const& handle, + const bitmap_t* bitmap, + const index_t num_bits, + any_t* any_array) +{ + auto stream = resource::get_cuda_stream(handle); + + int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, bitmap_to_any_kernel, bitmap_to_any_tpb, 0); + + auto grid = sm_count * blocks_per_sm; + auto block = bitmap_to_any_tpb; + + bitmap_to_any_kernel + <<>>(bitmap, num_bits, any_array); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// Threads per block in init_row_indicator_kernel. +static const constexpr int init_row_indicator_tpb = 512; + +template +RAFT_KERNEL __launch_bounds__(init_row_indicator_tpb) + init_row_indicator_kernel(index_t num_cols, index_t total, index_t* row_indicator) +{ + index_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; + for (index_t idx = thread_idx; idx < total; idx += blockDim.x * gridDim.x) { + row_indicator[idx] = idx / num_cols; + } +} + +template +void init_row_indicator(raft::resources const& handle, + index_t num_cols, + index_t total, + index_t* row_indicator) +{ + auto stream = resource::get_cuda_stream(handle); + + int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, init_row_indicator_kernel, init_row_indicator_tpb, 0); + + auto grid = sm_count * blocks_per_sm; + auto block = init_row_indicator_tpb; + + init_row_indicator_kernel<<>>(num_cols, total, row_indicator); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void bitmap_to_csr(raft::resources const& handle, + const bitmap_t* bitmap, + index_t num_rows, + index_t num_cols, + index_t* indptr, + index_t* indices) +{ + const index_t total = num_rows * num_cols; + if (total == 0) { return; } + + auto thrust_policy = resource::get_thrust_policy(handle); + auto stream = resource::get_cuda_stream(handle); + + rmm::device_uvector bool_matrix(total, resource::get_cuda_stream(handle)); + rmm::device_uvector int_matrix(total, resource::get_cuda_stream(handle)); + rmm::device_uvector row_indicator(total, resource::get_cuda_stream(handle)); + + bitmap_to_any(handle, bitmap, total, bool_matrix.data()); + bitmap_to_any(handle, bitmap, total, int_matrix.data()); + + init_row_indicator(handle, num_cols, total, row_indicator.data()); + + thrust::reduce_by_key(thrust_policy, + row_indicator.data(), + row_indicator.data() + total, + int_matrix.data(), + thrust::make_discard_iterator(), + indptr + 1, + thrust::equal_to(), + thrust::plus()); + // compute indptr + thrust::inclusive_scan(thrust_policy, indptr, indptr + num_rows + 1, indptr); + + // compute indices + adj_to_csr(handle, bool_matrix.data(), indptr, num_rows, num_cols, row_indicator.data(), indices); +} + +}; // end NAMESPACE detail +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 8be107ef3e..90b062281b 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "../test_utils.cuh" #include +#include #include #include @@ -217,5 +218,215 @@ INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, CSRAdjGraphTestL, ::testing::ValuesIn(csradjgraph_inputs_l)); +/******************************** bitmap to csr ********************************/ + +template +struct BitmapToCSRInputs { + index_t n_rows; + index_t n_cols; + float sparsity; +}; + +template +class BitmapToCSRTest : public ::testing::TestWithParam> { + public: + BitmapToCSRTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + bitmap_d(0, stream), + indices_d(0, stream), + indptr_d(0, stream), + values_d(0, stream), + indptr_expected_d(0, stream), + indices_expected_d(0, stream) + { + } + + protected: + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) >> bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitmap, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitmap_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitmap[index / (8 * sizeof(bitmap_t))]; + bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + bool csr_compare(const std::vector& row_ptrs1, + const std::vector& col_indices1, + const std::vector& row_ptrs2, + const std::vector& col_indices2) + { + if (row_ptrs1.size() != row_ptrs2.size()) { return false; } + + if (col_indices1.size() != col_indices2.size()) { return false; } + + if (!std::equal(row_ptrs1.begin(), row_ptrs1.end(), row_ptrs2.begin())) { return false; } + + for (size_t i = 0; i < row_ptrs1.size() - 1; ++i) { + size_t start_idx = row_ptrs1[i]; + size_t end_idx = row_ptrs1[i + 1]; + + std::vector cols1(col_indices1.begin() + start_idx, col_indices1.begin() + end_idx); + std::vector cols2(col_indices2.begin() + start_idx, col_indices2.begin() + end_idx); + + std::sort(cols1.begin(), cols1.end()); + std::sort(cols2.begin(), cols2.end()); + + if (cols1 != cols2) { return false; } + } + + return true; + } + + void SetUp() override + { + index_t element = raft::ceildiv(params.n_rows * params.n_cols, index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, bitmap_h); + + std::vector indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + cpu_convert_to_csr(bitmap_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + bitmap_d.resize(bitmap_h.size(), stream); + indptr_d.resize(params.n_rows + 1, stream); + indices_d.resize(nnz, stream); + indptr_expected_d.resize(params.n_rows + 1, stream); + indices_expected_d.resize(nnz, stream); + values_d.resize(nnz, stream); + + update_device(indices_expected_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_expected_d.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto bitmap = + raft::core::bitmap_view(bitmap_d.data(), params.n_rows, params.n_cols); + + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); + + convert::bitmap_to_csr(handle, bitmap, csr); + + resource::sync_stream(handle); + + ASSERT_EQ(csr_view.get_indptr().size(), indptr_expected_d.size()); + ASSERT_EQ(csr_view.get_indices().size(), indices_expected_d.size()); + ASSERT_EQ(csr_view.get_nnz(), nnz); + + std::vector indices_h(indices_expected_d.size(), 0); + std::vector indices_expected_h(indices_expected_d.size(), 0); + update_host(indices_h.data(), csr_view.get_indices().data(), indices_h.size(), stream); + update_host(indices_expected_h.data(), indices_expected_d.data(), indices_h.size(), stream); + + std::vector indptr_h(indptr_expected_d.size(), 0); + std::vector indptr_expected_h(indptr_expected_d.size(), 0); + update_host(indptr_h.data(), csr_view.get_indptr().data(), indptr_h.size(), stream); + update_host(indptr_expected_h.data(), indptr_expected_d.data(), indptr_h.size(), stream); + + resource::sync_stream(handle); + + ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + BitmapToCSRInputs params; + + rmm::device_uvector bitmap_d; + + index_t nnz; + + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + rmm::device_uvector indptr_expected_d; + rmm::device_uvector indices_expected_d; +}; + +using BitmapToCSRTestI = BitmapToCSRTest; +TEST_P(BitmapToCSRTestI, Result) { Run(); } + +using BitmapToCSRTestL = BitmapToCSRTest; +TEST_P(BitmapToCSRTestL, Result) { Run(); } + +template +const std::vector> bitmaptocsr_inputs = { + {0, 0, 0.2}, + {10, 32, 0.2}, + {32, 1024, 0.4}, + {1024, 1024, 0.4}, + {64 * 1024 + 10, 2, 0.3}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3}, // No peeling-remainder + {17, 16, 0.3}, // Check peeling-remainder + {18, 16, 0.3}, // Check peeling-remainder + {32 + 9, 33, 0.2}, // Check peeling-remainder +}; + +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitmapToCSRTestI, + ::testing::ValuesIn(bitmaptocsr_inputs)); +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitmapToCSRTestL, + ::testing::ValuesIn(bitmaptocsr_inputs)); + } // namespace sparse } // namespace raft diff --git a/docs/source/cpp_api/core_bitmap.rst b/docs/source/cpp_api/core_bitmap.rst new file mode 100644 index 0000000000..6c1dc607bf --- /dev/null +++ b/docs/source/cpp_api/core_bitmap.rst @@ -0,0 +1,15 @@ +Bitmap +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygengroup:: bitmap + :project: RAFT + :members: + :content-only: \ No newline at end of file From 7dea38d682999cddd7a38eca118afbeeca61d751 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 22 Jan 2024 20:13:32 -0800 Subject: [PATCH 02/16] fix doc build CI error 35fe8a --- cpp/include/raft/core/bitmap.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp index 15aa7e2b40..6d64cb0b38 100644 --- a/cpp/include/raft/core/bitmap.hpp +++ b/cpp/include/raft/core/bitmap.hpp @@ -42,7 +42,7 @@ struct bitmap_view : public bitset_view { /** * @brief Create a bitmap view from a device raw pointer. * - * @param bitset_ptr Device raw pointer + * @param bitmap_ptr Device raw pointer * @param rows Number of row in the matrix. * @param cols Number of col in the matrix. */ From 67f06505c6e512b8622d64abd2aef4802794011f Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 23 Jan 2024 18:32:03 -0800 Subject: [PATCH 03/16] try to fix the CI failure --- docs/source/cpp_api/core.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 39e57fd69a..4122a18506 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -21,4 +21,5 @@ expose in public APIs. core_interruptible.rst core_operators.rst core_math.rst - core_bitset.rst \ No newline at end of file + core_bitset.rst + core_bitmap.rst \ No newline at end of file From 20f76af5d656d5440dc7e213ee02b773b7ddd96e Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 20 Feb 2024 22:48:18 -0800 Subject: [PATCH 04/16] fix a ut& benchmark error --- cpp/bench/prims/sparse/bitmap_to_csr.cu | 2 +- cpp/test/sparse/convert_csr.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu index 6bab064062..9881e797bf 100644 --- a/cpp/bench/prims/sparse/bitmap_to_csr.cu +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -89,7 +89,7 @@ struct BitmapToCsrTest : public fixture { index_t bit_position = index % (8 * sizeof(bitmap_t)); if (((element >> bit_position) & 1) == 0) { - element |= (static_cast(1) >> bit_position); + element |= (static_cast(1) << bit_position); num_ones--; } } diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 90b062281b..4db988be94 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -264,7 +264,7 @@ class BitmapToCSRTest : public ::testing::TestWithParam> bit_position) & 1) == 0) { - element |= (static_cast(1) >> bit_position); + element |= (static_cast(1) << bit_position); num_ones--; } } From 7dc7cf866a41c413424b453ede2a9de1e231a148 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 22 Feb 2024 20:49:40 -0800 Subject: [PATCH 05/16] Improve performance & eliminate the temp buffer. --- cpp/bench/prims/sparse/bitmap_to_csr.cu | 8 +- cpp/include/raft/core/bitmap.hpp | 12 + cpp/include/raft/sparse/convert/csr.cuh | 11 +- .../sparse/convert/detail/bitmap_to_csr.cuh | 208 +++++++++++++----- cpp/test/sparse/convert_csr.cu | 5 +- 5 files changed, 178 insertions(+), 66 deletions(-) diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu index 9881e797bf..37ab184db9 100644 --- a/cpp/bench/prims/sparse/bitmap_to_csr.cu +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -43,8 +43,8 @@ inline auto operator<<(std::ostream& os, const bench_param& params) -> } template -struct BitmapToCsrTest : public fixture { - BitmapToCsrTest(const bench_param& p) +struct BitmapToCsrBench : public fixture { + BitmapToCsrBench(const bench_param& p) : fixture(true), params(p), handle(stream), @@ -128,7 +128,7 @@ struct BitmapToCsrTest : public fixture { rmm::device_uvector values_d; index_t nnz; -}; // struct BitmapToCsrTest +}; // struct BitmapToCsrBench template const std::vector> getInputs() @@ -150,6 +150,6 @@ const std::vector> getInputs() return param_vec; } -RAFT_BENCH_REGISTER((BitmapToCsrTest), "", getInputs()); +RAFT_BENCH_REGISTER((BitmapToCsrBench), "", getInputs()); } // namespace raft::bench::sparse diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.hpp index 6d64cb0b38..94c8e400bf 100644 --- a/cpp/include/raft/core/bitmap.hpp +++ b/cpp/include/raft/core/bitmap.hpp @@ -103,6 +103,18 @@ struct bitmap_view : public bitset_view { set(row * cols_ + col, &new_value); } + /** + * @brief Get the total number of rows + * @return index_t The total number of rows + */ + inline index_t get_n_rows() const { return rows_; } + + /** + * @brief Get the total number of columns + * @return index_t The total number of columns + */ + inline index_t get_n_cols() const { return cols_; } + private: index_t rows_; index_t cols_; diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 02bb46222b..87cef3218a 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -115,7 +115,7 @@ void adj_to_csr(raft::resources const& handle, * * @param[in] handle RAFT handle * @param[in] bitmap input raft::bitmap_view - * @param[inout] csr output raft::device_csr_matrix_view + * @param[out] csr output raft::device_csr_matrix_view */ template void bitmap_to_csr(raft::resources const& handle, @@ -123,6 +123,15 @@ void bitmap_to_csr(raft::resources const& handle, raft::device_csr_matrix_view csr) { auto csr_view = csr.structure_view(); + + RAFT_EXPECTS(bitmap.get_n_rows() == csr_view.get_n_rows(), + "Number of rows in bitmap must be equal to " + "number of rows in csr"); + + RAFT_EXPECTS(bitmap.get_n_cols() == csr_view.get_n_cols(), + "Number of columns in bitmap must be equal to " + "number of columns in csr"); + detail::bitmap_to_csr(handle, bitmap.data(), csr_view.get_n_rows(), diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index f906564ca9..e05ea026fe 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -18,6 +18,7 @@ #include +#include // detail::popc #include #include #include @@ -36,74 +37,178 @@ namespace sparse { namespace convert { namespace detail { -// Threads per block in bitmap_to_any_kernel. -static const constexpr int bitmap_to_any_tpb = 512; +// Threads per block in calc_nnz_by_rows_kernel. +static const constexpr int calc_nnz_by_rows_tpb = 32; -template -RAFT_KERNEL __launch_bounds__(bitmap_to_any_tpb) - bitmap_to_any_kernel(const bitmap_t* bitmap, const index_t num_bits, any_t* index_array) +template +RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(const bitmap_t* bitmap, + index_t num_rows, + index_t num_cols, + index_t bitmap_num, + nnz_t* nnz_per_row) { - index_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; - for (index_t idx = thread_idx; idx < num_bits; idx += blockDim.x * gridDim.x) { - bitmap_t element = bitmap[idx / (8 * sizeof(bitmap_t))]; - index_t bit_position = idx % (8 * sizeof(bitmap_t)); - index_array[idx] = static_cast((element >> bit_position) & 1); + constexpr bitmap_t FULL_MASK = ~bitmap_t(0u); + constexpr bitmap_t ONE = bitmap_t(1u); + constexpr index_t BITS_PER_BITMAP = sizeof(bitmap_t) * 8; + + int lane_id = threadIdx.x & 0x1f; + + for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { + index_t offset = 0; + index_t s_bit = row * num_cols; + index_t e_bit = s_bit + num_cols; + auto l_sum = 0; + + while (offset < num_cols) { + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + bitmap_t l_bitmap = bitmap_t(0); + + if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } + + if (s_bit > bitmap_idx * BITS_PER_BITMAP) { + l_bitmap >>= (s_bit - bitmap_idx * BITS_PER_BITMAP); + l_bitmap <<= (s_bit - bitmap_idx * BITS_PER_BITMAP); + } + + if ((bitmap_idx + 1) * BITS_PER_BITMAP > e_bit) { + l_bitmap <<= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + } + + l_sum += static_cast(raft::detail::popc(l_bitmap)); + offset += BITS_PER_BITMAP * warpSize; + } + + l_sum = __reduce_add_sync(0xffffffff, l_sum); + + if (lane_id == 0) { *(nnz_per_row + row) += static_cast(l_sum); } } } -template -void bitmap_to_any(raft::resources const& handle, - const bitmap_t* bitmap, - const index_t num_bits, - any_t* any_array) +template +void calc_nnz_by_rows(raft::resources const& handle, + const bitmap_t* bitmap, + index_t num_rows, + index_t num_cols, + nnz_t* nnz_per_row) { - auto stream = resource::get_cuda_stream(handle); + auto stream = resource::get_cuda_stream(handle); + const index_t total = num_rows * num_cols; + const index_t bitmap_num = raft::ceildiv(total, index_t(sizeof(bitmap_t) * 8)); int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, bitmap_to_any_kernel, bitmap_to_any_tpb, 0); + &blocks_per_sm, calc_nnz_by_rows_kernel, calc_nnz_by_rows_tpb, 0); - auto grid = sm_count * blocks_per_sm; - auto block = bitmap_to_any_tpb; + index_t max_active_blocks = sm_count * blocks_per_sm; + auto grid = std::min(max_active_blocks, raft::ceildiv(bitmap_num, index_t(calc_nnz_by_rows_tpb))); + auto block = calc_nnz_by_rows_tpb; - bitmap_to_any_kernel - <<>>(bitmap, num_bits, any_array); + calc_nnz_by_rows_kernel + <<>>(bitmap, num_rows, num_cols, bitmap_num, nnz_per_row); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -// Threads per block in init_row_indicator_kernel. -static const constexpr int init_row_indicator_tpb = 512; +template +__device__ inline value_t warp_exclusive(value_t value) +{ + int lane_id = threadIdx.x & 0x1f; + value_t shifted_value = __shfl_up_sync(0xffffffff, value, 1, warpSize); + if (lane_id == 0) shifted_value = 0; -template -RAFT_KERNEL __launch_bounds__(init_row_indicator_tpb) - init_row_indicator_kernel(index_t num_cols, index_t total, index_t* row_indicator) + value_t sum = shifted_value; + + for (int i = 1; i < warpSize; i *= 2) { + value_t n = __shfl_up_sync(0xffffffff, sum, i, warpSize); + if (lane_id >= i) { sum += n; } + } + return sum; +} + +// Threads per block in fill_indices_by_rows_kernel. +static const constexpr int fill_indices_by_rows_tpb = 32; + +template +RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) + fill_indices_by_rows_kernel(const bitmap_t* bitmap, + const index_t* indptr, + index_t num_rows, + index_t num_cols, + index_t bitmap_num, + index_t* indices) { - index_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; - for (index_t idx = thread_idx; idx < total; idx += blockDim.x * gridDim.x) { - row_indicator[idx] = idx / num_cols; + constexpr bitmap_t FULL_MASK = ~bitmap_t(0u); + constexpr bitmap_t ONE = bitmap_t(1u); + constexpr index_t BITS_PER_BITMAP = sizeof(bitmap_t) * 8; + + int lane_id = threadIdx.x & 0x1f; + + for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { + index_t offset = 0; + index_t g_sum = 0; + index_t s_bit = row * num_cols; + index_t e_bit = s_bit + num_cols; + index_t indptr_row = indptr[row]; + + while (offset < num_cols) { + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + bitmap_t l_bitmap = bitmap_t(0); + index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); + + if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } + + if (s_bit > bitmap_idx * BITS_PER_BITMAP) { + l_bitmap >>= (s_bit - bitmap_idx * BITS_PER_BITMAP); + l_bitmap <<= (s_bit - bitmap_idx * BITS_PER_BITMAP); + } + + if ((bitmap_idx + 1) * BITS_PER_BITMAP > e_bit) { + l_bitmap <<= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + } + + index_t l_sum = g_sum + warp_exclusive(static_cast(raft::detail::popc(l_bitmap))); + + for (int i = 0; i < BITS_PER_BITMAP; i++) { + if (l_bitmap & (ONE << i)) { + indices[indptr_row + l_sum] = l_offset + i; + l_sum++; + } + } + offset += BITS_PER_BITMAP * warpSize; + g_sum = __shfl_sync(0xffffffff, l_sum, warpSize - 1); + } } } -template -void init_row_indicator(raft::resources const& handle, - index_t num_cols, - index_t total, - index_t* row_indicator) +template +void fill_indices_by_rows(raft::resources const& handle, + const bitmap_t* bitmap, + const index_t* indptr, + index_t num_rows, + index_t num_cols, + index_t* indices) { - auto stream = resource::get_cuda_stream(handle); + auto stream = resource::get_cuda_stream(handle); + const index_t total = num_rows * num_cols; + const index_t bitmap_num = raft::ceildiv(total, index_t(sizeof(bitmap_t) * 8)); int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, init_row_indicator_kernel, init_row_indicator_tpb, 0); + &blocks_per_sm, fill_indices_by_rows_kernel, fill_indices_by_rows_tpb, 0); - auto grid = sm_count * blocks_per_sm; - auto block = init_row_indicator_tpb; + index_t max_active_blocks = sm_count * blocks_per_sm; + auto grid = std::min(max_active_blocks, num_rows); + auto block = fill_indices_by_rows_tpb; - init_row_indicator_kernel<<>>(num_cols, total, row_indicator); + fill_indices_by_rows_kernel + <<>>(bitmap, indptr, num_rows, num_cols, bitmap_num, indices); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -121,28 +226,11 @@ void bitmap_to_csr(raft::resources const& handle, auto thrust_policy = resource::get_thrust_policy(handle); auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector bool_matrix(total, resource::get_cuda_stream(handle)); - rmm::device_uvector int_matrix(total, resource::get_cuda_stream(handle)); - rmm::device_uvector row_indicator(total, resource::get_cuda_stream(handle)); - - bitmap_to_any(handle, bitmap, total, bool_matrix.data()); - bitmap_to_any(handle, bitmap, total, int_matrix.data()); - - init_row_indicator(handle, num_cols, total, row_indicator.data()); - - thrust::reduce_by_key(thrust_policy, - row_indicator.data(), - row_indicator.data() + total, - int_matrix.data(), - thrust::make_discard_iterator(), - indptr + 1, - thrust::equal_to(), - thrust::plus()); - // compute indptr - thrust::inclusive_scan(thrust_policy, indptr, indptr + num_rows + 1, indptr); + RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (num_rows + 1) * sizeof(index_t), stream)); - // compute indices - adj_to_csr(handle, bool_matrix.data(), indptr, num_rows, num_cols, row_indicator.data(), indices); + calc_nnz_by_rows(handle, bitmap, num_rows, num_cols, indptr); + thrust::exclusive_scan(thrust_policy, indptr, indptr + num_rows + 1, indptr); + fill_indices_by_rows(handle, bitmap, indptr, num_rows, num_cols, indices); } }; // end NAMESPACE detail diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 4db988be94..8719542c94 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -411,14 +411,17 @@ TEST_P(BitmapToCSRTestL, Result) { Run(); } template const std::vector> bitmaptocsr_inputs = { {0, 0, 0.2}, - {10, 32, 0.2}, + {10, 32, 0.4}, + {10, 3, 0.2}, {32, 1024, 0.4}, + {1024, 1048576, 0.01}, {1024, 1024, 0.4}, {64 * 1024 + 10, 2, 0.3}, // 64K + 10 is slightly over maximum of blockDim.y {16, 16, 0.3}, // No peeling-remainder {17, 16, 0.3}, // Check peeling-remainder {18, 16, 0.3}, // Check peeling-remainder {32 + 9, 33, 0.2}, // Check peeling-remainder + {2, 33, 0.2}, // Check peeling-remainder }; INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, From cab8691f39e98601c315f3cc7dfd5f9c5bab7322 Mon Sep 17 00:00:00 2001 From: rhdong Date: Fri, 23 Feb 2024 16:35:19 -0800 Subject: [PATCH 06/16] fix : compatible with devices with compute capability < 8.0 --- .../raft/sparse/convert/detail/bitmap_to_csr.cuh | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index e05ea026fe..d2f4a9b59d 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include // detail::popc #include @@ -32,6 +33,8 @@ #include #include +namespace cg = cooperative_groups; + namespace raft { namespace sparse { namespace convert { @@ -51,13 +54,16 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons constexpr bitmap_t ONE = bitmap_t(1u); constexpr index_t BITS_PER_BITMAP = sizeof(bitmap_t) * 8; + auto block = cg::this_thread_block(); + auto tile = cg::tiled_partition<32>(block); + int lane_id = threadIdx.x & 0x1f; for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { index_t offset = 0; index_t s_bit = row * num_cols; index_t e_bit = s_bit + num_cols; - auto l_sum = 0; + index_t l_sum = 0; while (offset < num_cols) { index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; @@ -79,7 +85,7 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons offset += BITS_PER_BITMAP * warpSize; } - l_sum = __reduce_add_sync(0xffffffff, l_sum); + l_sum = cg::reduce(tile, l_sum, cg::plus()); if (lane_id == 0) { *(nnz_per_row + row) += static_cast(l_sum); } } @@ -146,14 +152,15 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) int lane_id = threadIdx.x & 0x1f; +#pragma unroll for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { - index_t offset = 0; index_t g_sum = 0; index_t s_bit = row * num_cols; index_t e_bit = s_bit + num_cols; index_t indptr_row = indptr[row]; - while (offset < num_cols) { +#pragma unroll + for (index_t offset = 0; offset < num_cols; offset += BITS_PER_BITMAP * warpSize) { index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; bitmap_t l_bitmap = bitmap_t(0); index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); @@ -178,7 +185,6 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) l_sum++; } } - offset += BITS_PER_BITMAP * warpSize; g_sum = __shfl_sync(0xffffffff, l_sum, warpSize - 1); } } From f015ed5e3e97be4d6e6b80659a2900b73d7aca7a Mon Sep 17 00:00:00 2001 From: hrong Date: Tue, 5 Mar 2024 14:06:19 -0800 Subject: [PATCH 07/16] Optimize based on review comments --- .../raft/core/{bitmap.hpp => bitmap.cuh} | 76 ++++++++++++++++++- cpp/include/raft/sparse/convert/csr.cuh | 10 ++- .../sparse/convert/detail/bitmap_to_csr.cuh | 20 +++-- cpp/test/sparse/convert_csr.cu | 22 ++++-- 4 files changed, 111 insertions(+), 17 deletions(-) rename cpp/include/raft/core/{bitmap.hpp => bitmap.cuh} (56%) diff --git a/cpp/include/raft/core/bitmap.hpp b/cpp/include/raft/core/bitmap.cuh similarity index 56% rename from cpp/include/raft/core/bitmap.hpp rename to cpp/include/raft/core/bitmap.cuh index 94c8e400bf..cdf8ddd4db 100644 --- a/cpp/include/raft/core/bitmap.hpp +++ b/cpp/include/raft/core/bitmap.cuh @@ -39,6 +39,10 @@ namespace raft::core { */ template struct bitmap_view : public bitset_view { + static constexpr index_t bitmap_element_size = sizeof(bitmap_t) * 8; + // static_assert((std::is_same::value || + // std::is_same::value), + // "The bitmap_t must be uint32_t or uint64_t."); /** * @brief Create a bitmap view from a device raw pointer. * @@ -47,7 +51,10 @@ struct bitmap_view : public bitset_view { * @param cols Number of col in the matrix. */ _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) - : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + : bitset_view(bitmap_ptr, rows * cols), + bitmap_ptr_{bitmap_ptr}, + rows_(rows), + cols_(cols) { } @@ -61,7 +68,10 @@ struct bitmap_view : public bitset_view { _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, index_t rows, index_t cols) - : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + : bitset_view(bitmap_span, rows * cols), + bitmap_ptr_{bitmap_span.data_handle()}, + rows_(rows), + cols_(cols) { } @@ -107,17 +117,75 @@ struct bitmap_view : public bitset_view { * @brief Get the total number of rows * @return index_t The total number of rows */ - inline index_t get_n_rows() const { return rows_; } + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } /** * @brief Get the total number of columns * @return index_t The total number of columns */ - inline index_t get_n_cols() const { return cols_; } + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } + + /** + * @brief Returns the number of non-zero bits in nnz_gpu_scalar. + * + * @param[in] res RAFT resources + * @param[out] nnz_gpu_scalar Device scalar to store the nnz + */ + void get_nnz(const raft::resources& res, raft::device_scalar_view nnz_gpu_scalar) + { + auto n_elements_ = raft::ceildiv(rows_ * cols_, bitmap_element_size); + auto nnz_gpu = raft::make_device_vector_view(nnz_gpu_scalar.data_handle(), 1); + auto bitmap_matrix_view = raft::make_device_matrix_view( + bitmap_ptr_, n_elements_, 1); + + bitmap_t n_last_element = ((rows_ * cols_) % bitmap_element_size); + bitmap_t last_element_mask = + n_last_element ? (bitmap_t)((bitmap_t{1} << n_last_element) - bitmap_t{1}) : ~bitmap_t{0}; + raft::linalg::coalesced_reduction( + res, + bitmap_matrix_view, + nnz_gpu, + index_t{0}, + false, + [last_element_mask, n_elements_] __device__(bitmap_t element, index_t index) { + index_t result = 0; + if constexpr (bitmap_element_size == 64) { + if (index == n_elements_ - 1) + result = index_t(raft::detail::popc(element & last_element_mask)); + else + result = index_t(raft::detail::popc(element)); + } else { // Needed because popc is not overloaded for 16 and 8 bit elements + if (index == n_elements_ - 1) + result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); + else + result = index_t(raft::detail::popc(uint32_t{element})); + } + + return result; + }); + } + + /** + * @brief Returns the number of non-zero bits. + * + * @param res RAFT resources + * @return index_t Number of non-zero bits + */ + auto get_nnz(const raft::resources& res) -> index_t + { + auto nnz_gpu_scalar = raft::make_device_scalar(res, 0.0); + get_nnz(res, nnz_gpu_scalar.view()); + index_t nnz_gpu = 0; + raft::update_host(&nnz_gpu, nnz_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + return nnz_gpu; + } private: index_t rows_; index_t cols_; + + bitmap_t* bitmap_ptr_; }; /** @} */ diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 87cef3218a..cb84e5c8dd 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -18,7 +18,7 @@ #pragma once -#include +#include #include #include #include @@ -132,12 +132,18 @@ void bitmap_to_csr(raft::resources const& handle, "Number of columns in bitmap must be equal to " "number of columns in csr"); + RAFT_EXPECTS(csr_view.get_nnz() >= bitmap.get_nnz(handle), + "Number of elements in csr must be equal or larger than " + "number of non-zero bits in bitmap"); + detail::bitmap_to_csr(handle, bitmap.data(), csr_view.get_n_rows(), csr_view.get_n_cols(), + csr_view.get_nnz(), csr_view.get_indptr().data(), - csr_view.get_indices().data()); + csr_view.get_indices().data(), + csr.get_elements().data()); } }; // end NAMESPACE convert diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index d2f4a9b59d..21b31d7717 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -118,8 +118,14 @@ void calc_nnz_by_rows(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } +/* + Execute the exclusive_scan within one warp with no inter-warp communication. + This function calculates the exclusive prefix sum of `value` across threads within the same warp. + Each thread in the warp will end up with the sum of all the values of the threads with lower IDs + in the same warp, with the first thread always getting a sum of 0. +*/ template -__device__ inline value_t warp_exclusive(value_t value) +RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value) { int lane_id = threadIdx.x & 0x1f; value_t shifted_value = __shfl_up_sync(0xffffffff, value, 1, warpSize); @@ -177,7 +183,8 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); } - index_t l_sum = g_sum + warp_exclusive(static_cast(raft::detail::popc(l_bitmap))); + index_t l_sum = + g_sum + warp_exclusive_scan(static_cast(raft::detail::popc(l_bitmap))); for (int i = 0; i < BITS_PER_BITMAP; i++) { if (l_bitmap & (ONE << i)) { @@ -218,16 +225,18 @@ void fill_indices_by_rows(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template +template void bitmap_to_csr(raft::resources const& handle, const bitmap_t* bitmap, index_t num_rows, index_t num_cols, + nnz_t nnz, index_t* indptr, - index_t* indices) + index_t* indices, + value_t* values) { const index_t total = num_rows * num_cols; - if (total == 0) { return; } + if (total == 0 || nnz == 0) { return; } auto thrust_policy = resource::get_thrust_policy(handle); auto stream = resource::get_cuda_stream(handle); @@ -237,6 +246,7 @@ void bitmap_to_csr(raft::resources const& handle, calc_nnz_by_rows(handle, bitmap, num_rows, num_cols, indptr); thrust::exclusive_scan(thrust_policy, indptr, indptr + num_rows + 1, indptr); fill_indices_by_rows(handle, bitmap, indptr, num_rows, num_cols, indices); + thrust::fill_n(thrust_policy, values, nnz, value_t{1}); } }; // end NAMESPACE detail diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 8719542c94..c9545801a4 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -16,7 +16,7 @@ #include "../test_utils.cuh" #include -#include +#include #include #include @@ -227,7 +227,7 @@ struct BitmapToCSRInputs { float sparsity; }; -template +template class BitmapToCSRTest : public ::testing::TestWithParam> { public: BitmapToCSRTest() @@ -238,7 +238,8 @@ class BitmapToCSRTest : public ::testing::TestWithParam( indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); - auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); + auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); convert::bitmap_to_csr(handle, bitmap, csr); @@ -382,6 +388,9 @@ class BitmapToCSRTest : public ::testing::TestWithParam( + values_d.data(), values_expected_d.data(), nnz, raft::Compare(), stream)); } protected: @@ -400,12 +409,13 @@ class BitmapToCSRTest : public ::testing::TestWithParam indptr_expected_d; rmm::device_uvector indices_expected_d; + rmm::device_uvector values_expected_d; }; -using BitmapToCSRTestI = BitmapToCSRTest; +using BitmapToCSRTestI = BitmapToCSRTest; TEST_P(BitmapToCSRTestI, Result) { Run(); } -using BitmapToCSRTestL = BitmapToCSRTest; +using BitmapToCSRTestL = BitmapToCSRTest; TEST_P(BitmapToCSRTestL, Result) { Run(); } template From 308eb8ba363393ec2b89870afa2bbeade0bdb7ca Mon Sep 17 00:00:00 2001 From: hrong Date: Tue, 5 Mar 2024 14:14:53 -0800 Subject: [PATCH 08/16] Ensure the `bitmap_t` to be {uint64_t, uint32_t} - Add test case for `uint64_t` bitmap_t --- cpp/bench/prims/sparse/bitmap_to_csr.cu | 1 + cpp/include/raft/core/bitmap.cuh | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu index 37ab184db9..1736bf941e 100644 --- a/cpp/bench/prims/sparse/bitmap_to_csr.cu +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -151,5 +151,6 @@ const std::vector> getInputs() } RAFT_BENCH_REGISTER((BitmapToCsrBench), "", getInputs()); +RAFT_BENCH_REGISTER((BitmapToCsrBench), "", getInputs()); } // namespace raft::bench::sparse diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index cdf8ddd4db..fff86b0c05 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -40,9 +40,9 @@ namespace raft::core { template struct bitmap_view : public bitset_view { static constexpr index_t bitmap_element_size = sizeof(bitmap_t) * 8; - // static_assert((std::is_same::value || - // std::is_same::value), - // "The bitmap_t must be uint32_t or uint64_t."); + static_assert((std::is_same::value || + std::is_same::value), + "The bitmap_t must be uint32_t or uint64_t."); /** * @brief Create a bitmap view from a device raw pointer. * From 42d53559e52328aebfaa7267b230b590dfb3eb56 Mon Sep 17 00:00:00 2001 From: hrong Date: Wed, 6 Mar 2024 21:04:44 -0800 Subject: [PATCH 09/16] Optimize based on review comments-3rd round --- cpp/bench/prims/sparse/bitmap_to_csr.cu | 8 +- cpp/include/raft/core/bitmap.cuh | 62 +-------------- cpp/include/raft/sparse/convert/csr.cuh | 23 +----- .../sparse/convert/detail/bitmap_to_csr.cuh | 75 +++++++++++++------ 4 files changed, 59 insertions(+), 109 deletions(-) diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu index 1736bf941e..b9c7638fcc 100644 --- a/cpp/bench/prims/sparse/bitmap_to_csr.cu +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -14,14 +14,14 @@ * limitations under the License. */ #include -#include -#include #include -#include - #include #include +#include +#include + +#include #include #include diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index fff86b0c05..f43879c8ba 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -117,75 +117,17 @@ struct bitmap_view : public bitset_view { * @brief Get the total number of rows * @return index_t The total number of rows */ - inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } + RAFT_INLINE_FUNCTION index_t get_n_rows() const { return rows_; } /** * @brief Get the total number of columns * @return index_t The total number of columns */ - inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } - - /** - * @brief Returns the number of non-zero bits in nnz_gpu_scalar. - * - * @param[in] res RAFT resources - * @param[out] nnz_gpu_scalar Device scalar to store the nnz - */ - void get_nnz(const raft::resources& res, raft::device_scalar_view nnz_gpu_scalar) - { - auto n_elements_ = raft::ceildiv(rows_ * cols_, bitmap_element_size); - auto nnz_gpu = raft::make_device_vector_view(nnz_gpu_scalar.data_handle(), 1); - auto bitmap_matrix_view = raft::make_device_matrix_view( - bitmap_ptr_, n_elements_, 1); - - bitmap_t n_last_element = ((rows_ * cols_) % bitmap_element_size); - bitmap_t last_element_mask = - n_last_element ? (bitmap_t)((bitmap_t{1} << n_last_element) - bitmap_t{1}) : ~bitmap_t{0}; - raft::linalg::coalesced_reduction( - res, - bitmap_matrix_view, - nnz_gpu, - index_t{0}, - false, - [last_element_mask, n_elements_] __device__(bitmap_t element, index_t index) { - index_t result = 0; - if constexpr (bitmap_element_size == 64) { - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(element & last_element_mask)); - else - result = index_t(raft::detail::popc(element)); - } else { // Needed because popc is not overloaded for 16 and 8 bit elements - if (index == n_elements_ - 1) - result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); - else - result = index_t(raft::detail::popc(uint32_t{element})); - } - - return result; - }); - } - - /** - * @brief Returns the number of non-zero bits. - * - * @param res RAFT resources - * @return index_t Number of non-zero bits - */ - auto get_nnz(const raft::resources& res) -> index_t - { - auto nnz_gpu_scalar = raft::make_device_scalar(res, 0.0); - get_nnz(res, nnz_gpu_scalar.view()); - index_t nnz_gpu = 0; - raft::update_host(&nnz_gpu, nnz_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); - return nnz_gpu; - } + RAFT_INLINE_FUNCTION index_t get_n_cols() const { return cols_; } private: index_t rows_; index_t cols_; - - bitmap_t* bitmap_ptr_; }; /** @} */ diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index cb84e5c8dd..424e6eee12 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -122,28 +122,7 @@ void bitmap_to_csr(raft::resources const& handle, raft::core::bitmap_view bitmap, raft::device_csr_matrix_view csr) { - auto csr_view = csr.structure_view(); - - RAFT_EXPECTS(bitmap.get_n_rows() == csr_view.get_n_rows(), - "Number of rows in bitmap must be equal to " - "number of rows in csr"); - - RAFT_EXPECTS(bitmap.get_n_cols() == csr_view.get_n_cols(), - "Number of columns in bitmap must be equal to " - "number of columns in csr"); - - RAFT_EXPECTS(csr_view.get_nnz() >= bitmap.get_nnz(handle), - "Number of elements in csr must be equal or larger than " - "number of non-zero bits in bitmap"); - - detail::bitmap_to_csr(handle, - bitmap.data(), - csr_view.get_n_rows(), - csr_view.get_n_cols(), - csr_view.get_nnz(), - csr_view.get_indptr().data(), - csr_view.get_indices().data(), - csr.get_elements().data()); + detail::bitmap_to_csr(handle, bitmap, csr); } }; // end NAMESPACE convert diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index 21b31d7717..bc07d853da 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -16,9 +16,6 @@ #pragma once -#include -#include - #include // detail::popc #include #include @@ -27,12 +24,17 @@ #include +#include +#include #include #include #include #include #include +#undef NDEBUG +#include + namespace cg = cooperative_groups; namespace raft { @@ -143,12 +145,13 @@ RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value) // Threads per block in fill_indices_by_rows_kernel. static const constexpr int fill_indices_by_rows_tpb = 32; -template +template RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) fill_indices_by_rows_kernel(const bitmap_t* bitmap, const index_t* indptr, index_t num_rows, index_t num_cols, + nnz_t nnz, index_t bitmap_num, index_t* indices) { @@ -158,6 +161,13 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) int lane_id = threadIdx.x & 0x1f; + // Ensure the HBM allocated for CSR values is sufficient to handle all non-zero bitmap bits. + // An assert will trigger if the allocated HBM is insufficient. + if (nnz < indptr[num_rows]) { + int csr_nnz_is_too_small = 0; + assert(csr_nnz_is_too_small); + } + #pragma unroll for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { index_t g_sum = 0; @@ -197,12 +207,13 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) } } -template +template void fill_indices_by_rows(raft::resources const& handle, const bitmap_t* bitmap, const index_t* indptr, index_t num_rows, index_t num_cols, + nnz_t nnz, index_t* indices) { auto stream = resource::get_cuda_stream(handle); @@ -214,39 +225,57 @@ void fill_indices_by_rows(raft::resources const& handle, cudaGetDevice(&dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, fill_indices_by_rows_kernel, fill_indices_by_rows_tpb, 0); + &blocks_per_sm, + fill_indices_by_rows_kernel, + fill_indices_by_rows_tpb, + 0); index_t max_active_blocks = sm_count * blocks_per_sm; auto grid = std::min(max_active_blocks, num_rows); auto block = fill_indices_by_rows_tpb; - fill_indices_by_rows_kernel - <<>>(bitmap, indptr, num_rows, num_cols, bitmap_num, indices); + fill_indices_by_rows_kernel + <<>>(bitmap, indptr, num_rows, num_cols, nnz, bitmap_num, indices); RAFT_CUDA_TRY(cudaPeekAtLastError()); } template void bitmap_to_csr(raft::resources const& handle, - const bitmap_t* bitmap, - index_t num_rows, - index_t num_cols, - nnz_t nnz, - index_t* indptr, - index_t* indices, - value_t* values) + raft::core::bitmap_view bitmap, + raft::device_csr_matrix_view csr) { - const index_t total = num_rows * num_cols; - if (total == 0 || nnz == 0) { return; } + auto csr_view = csr.structure_view(); + + if (csr_view.get_n_rows() == 0 || csr_view.get_n_cols() == 0 || csr_view.get_nnz() == 0) { + return; + } + + RAFT_EXPECTS(bitmap.get_n_rows() == csr_view.get_n_rows(), + "Number of rows in bitmap must be equal to " + "number of rows in csr"); + + RAFT_EXPECTS(bitmap.get_n_cols() == csr_view.get_n_cols(), + "Number of columns in bitmap must be equal to " + "number of columns in csr"); auto thrust_policy = resource::get_thrust_policy(handle); auto stream = resource::get_cuda_stream(handle); - RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (num_rows + 1) * sizeof(index_t), stream)); - - calc_nnz_by_rows(handle, bitmap, num_rows, num_cols, indptr); - thrust::exclusive_scan(thrust_policy, indptr, indptr + num_rows + 1, indptr); - fill_indices_by_rows(handle, bitmap, indptr, num_rows, num_cols, indices); - thrust::fill_n(thrust_policy, values, nnz, value_t{1}); + index_t* indptr = csr_view.get_indptr().data(); + index_t* indices = csr_view.get_indices().data(); + + RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream)); + + calc_nnz_by_rows(handle, bitmap.data(), csr_view.get_n_rows(), csr_view.get_n_cols(), indptr); + thrust::exclusive_scan(thrust_policy, indptr, indptr + csr_view.get_n_rows() + 1, indptr); + fill_indices_by_rows(handle, + bitmap.data(), + indptr, + csr_view.get_n_rows(), + csr_view.get_n_cols(), + csr_view.get_nnz(), + indices); + thrust::fill_n(thrust_policy, csr.get_elements().data(), csr_view.get_nnz(), value_t{1}); } }; // end NAMESPACE detail From 1aff844a186b14df29858663d0e8b5b6ab36edcf Mon Sep 17 00:00:00 2001 From: hrong Date: Wed, 6 Mar 2024 14:34:33 -0800 Subject: [PATCH 10/16] [Fix] `std::vector` compilation error of in the `nvtx.hpp` --- cpp/include/raft/core/detail/nvtx.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/core/detail/nvtx.hpp b/cpp/include/raft/core/detail/nvtx.hpp index 8afd1f16c6..82db75de84 100644 --- a/cpp/include/raft/core/detail/nvtx.hpp +++ b/cpp/include/raft/core/detail/nvtx.hpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace raft::common::nvtx::detail { From 59533e0b49337762c68a8857a9b5d69590633de5 Mon Sep 17 00:00:00 2001 From: hrong Date: Wed, 6 Mar 2024 21:32:18 -0800 Subject: [PATCH 11/16] Remove `#undef NDEBUG` --- cpp/include/raft/core/bitmap.cuh | 11 ++--------- .../raft/sparse/convert/detail/bitmap_to_csr.cuh | 9 +++------ 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index f43879c8ba..da009ef4cc 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -39,7 +39,6 @@ namespace raft::core { */ template struct bitmap_view : public bitset_view { - static constexpr index_t bitmap_element_size = sizeof(bitmap_t) * 8; static_assert((std::is_same::value || std::is_same::value), "The bitmap_t must be uint32_t or uint64_t."); @@ -51,10 +50,7 @@ struct bitmap_view : public bitset_view { * @param cols Number of col in the matrix. */ _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) - : bitset_view(bitmap_ptr, rows * cols), - bitmap_ptr_{bitmap_ptr}, - rows_(rows), - cols_(cols) + : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) { } @@ -68,10 +64,7 @@ struct bitmap_view : public bitset_view { _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, index_t rows, index_t cols) - : bitset_view(bitmap_span, rows * cols), - bitmap_ptr_{bitmap_span.data_handle()}, - rows_(rows), - cols_(cols) + : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) { } diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index bc07d853da..e706e11b32 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -32,7 +32,6 @@ #include #include -#undef NDEBUG #include namespace cg = cooperative_groups; @@ -162,11 +161,9 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) int lane_id = threadIdx.x & 0x1f; // Ensure the HBM allocated for CSR values is sufficient to handle all non-zero bitmap bits. - // An assert will trigger if the allocated HBM is insufficient. - if (nnz < indptr[num_rows]) { - int csr_nnz_is_too_small = 0; - assert(csr_nnz_is_too_small); - } + // An assert will trigger if the allocated HBM is insufficient when `NDEBUG` isn't defined. + // Note: Assertion is active only if `NDEBUG` is undefined. + if (lane_id == 0) { assert(nnz < indptr[num_rows]); } #pragma unroll for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { From 7d960bd97ca378562cbc7094e243507828cf10c8 Mon Sep 17 00:00:00 2001 From: hrong Date: Fri, 8 Mar 2024 09:06:42 -0800 Subject: [PATCH 12/16] Optimize based on review comments-4rd round --- cpp/include/raft/core/bitmap.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh index da009ef4cc..829c84ed25 100644 --- a/cpp/include/raft/core/bitmap.cuh +++ b/cpp/include/raft/core/bitmap.cuh @@ -110,13 +110,13 @@ struct bitmap_view : public bitset_view { * @brief Get the total number of rows * @return index_t The total number of rows */ - RAFT_INLINE_FUNCTION index_t get_n_rows() const { return rows_; } + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } /** * @brief Get the total number of columns * @return index_t The total number of columns */ - RAFT_INLINE_FUNCTION index_t get_n_cols() const { return cols_; } + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } private: index_t rows_; From 08397c6c99e879ef317ef2826bef5c8ac26dc3ea Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 14 Mar 2024 16:31:01 -0700 Subject: [PATCH 13/16] API changes to Owning/perserving --- cpp/include/raft/sparse/convert/csr.cuh | 20 ++++---- .../sparse/convert/detail/bitmap_to_csr.cuh | 49 +++++++++++++------ cpp/test/sparse/convert_csr.cu | 7 ++- 3 files changed, 48 insertions(+), 28 deletions(-) diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 424e6eee12..51ad47ee0d 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -108,19 +108,21 @@ void adj_to_csr(raft::resources const& handle, /** * @brief Converts a bitmap matrix into unsorted CSR format matrix. * - * @tparam bitmap_t Underlying type of the bitmap. - * @tparam index_t Indexing type used. - * @tparam value_t Data type of CSR - * @tparam nnz_t Type of CSR + * @tparam bitmap_t Underlying type of the bitmap. + * @tparam index_t Indexing type used. + * @tparam csr_matrix_t Reference Type of CSR Matrix, raft::device_csr_matrix * - * @param[in] handle RAFT handle - * @param[in] bitmap input raft::bitmap_view - * @param[out] csr output raft::device_csr_matrix_view + * @param[in] handle RAFT handle + * @param[in] bitmap input raft::bitmap_view + * @param[out] csr output raft::device_csr_matrix */ -template +template >> void bitmap_to_csr(raft::resources const& handle, raft::core::bitmap_view bitmap, - raft::device_csr_matrix_view csr) + csr_matrix_t& csr) { detail::bitmap_to_csr(handle, bitmap, csr); } diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index e706e11b32..c0d79d9310 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -144,7 +144,7 @@ RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value) // Threads per block in fill_indices_by_rows_kernel. static const constexpr int fill_indices_by_rows_tpb = 32; -template +template RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) fill_indices_by_rows_kernel(const bitmap_t* bitmap, const index_t* indptr, @@ -163,7 +163,9 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) // Ensure the HBM allocated for CSR values is sufficient to handle all non-zero bitmap bits. // An assert will trigger if the allocated HBM is insufficient when `NDEBUG` isn't defined. // Note: Assertion is active only if `NDEBUG` is undefined. - if (lane_id == 0) { assert(nnz < indptr[num_rows]); } + if constexpr (check_nnz) { + if (lane_id == 0) { assert(nnz < indptr[num_rows]); } + } #pragma unroll for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { @@ -204,7 +206,7 @@ RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) } } -template +template void fill_indices_by_rows(raft::resources const& handle, const bitmap_t* bitmap, const index_t* indptr, @@ -223,7 +225,7 @@ void fill_indices_by_rows(raft::resources const& handle, cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaOccupancyMaxActiveBlocksPerMultiprocessor( &blocks_per_sm, - fill_indices_by_rows_kernel, + fill_indices_by_rows_kernel, fill_indices_by_rows_tpb, 0); @@ -231,15 +233,18 @@ void fill_indices_by_rows(raft::resources const& handle, auto grid = std::min(max_active_blocks, num_rows); auto block = fill_indices_by_rows_tpb; - fill_indices_by_rows_kernel + fill_indices_by_rows_kernel <<>>(bitmap, indptr, num_rows, num_cols, nnz, bitmap_num, indices); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template +template >> void bitmap_to_csr(raft::resources const& handle, raft::core::bitmap_view bitmap, - raft::device_csr_matrix_view csr) + csr_matrix_t& csr) { auto csr_view = csr.structure_view(); @@ -265,14 +270,28 @@ void bitmap_to_csr(raft::resources const& handle, calc_nnz_by_rows(handle, bitmap.data(), csr_view.get_n_rows(), csr_view.get_n_cols(), indptr); thrust::exclusive_scan(thrust_policy, indptr, indptr + csr_view.get_n_rows() + 1, indptr); - fill_indices_by_rows(handle, - bitmap.data(), - indptr, - csr_view.get_n_rows(), - csr_view.get_n_cols(), - csr_view.get_nnz(), - indices); - thrust::fill_n(thrust_policy, csr.get_elements().data(), csr_view.get_nnz(), value_t{1}); + + if constexpr (is_device_csr_sparsity_owning_v) { + index_t nnz = 0; + RAFT_CUDA_TRY(cudaMemcpyAsync( + &nnz, indptr + csr_view.get_n_rows(), sizeof(index_t), cudaMemcpyDeviceToHost, stream)); + resource::sync_stream(handle); + csr.initialize_sparsity(nnz); + } + constexpr bool check_nnz = is_device_csr_sparsity_preserving_v; + fill_indices_by_rows( + handle, + bitmap.data(), + indptr, + csr_view.get_n_rows(), + csr_view.get_n_cols(), + csr_view.get_nnz(), + indices); + + thrust::fill_n(thrust_policy, + csr.get_elements().data(), + csr_view.get_nnz(), + typename csr_matrix_t::element_t(1)); } }; // end NAMESPACE detail diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 5ac8ea600e..429684850f 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -366,9 +366,9 @@ class BitmapToCSRTest : public ::testing::TestWithParam( indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); - auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); + auto csr = raft::make_device_csr_matrix(handle, csr_view); - convert::bitmap_to_csr(handle, bitmap, csr); + convert::bitmap_to_csr(handle, bitmap, csr); resource::sync_stream(handle); @@ -389,9 +389,8 @@ class BitmapToCSRTest : public ::testing::TestWithParam( - values_d.data(), values_expected_d.data(), nnz, raft::Compare(), stream)); + csr.get_elements().data(), values_expected_d.data(), nnz, raft::Compare(), stream)); } protected: From b2ad06e81cad98db25f0f5486bd5c97da000795a Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 14 Mar 2024 17:28:01 -0700 Subject: [PATCH 14/16] d1 --- cpp/include/raft/sparse/convert/csr.cuh | 13 +++++++------ .../raft/sparse/convert/detail/bitmap_to_csr.cuh | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 51ad47ee0d..f4c513fc90 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -108,13 +108,14 @@ void adj_to_csr(raft::resources const& handle, /** * @brief Converts a bitmap matrix into unsorted CSR format matrix. * - * @tparam bitmap_t Underlying type of the bitmap. - * @tparam index_t Indexing type used. - * @tparam csr_matrix_t Reference Type of CSR Matrix, raft::device_csr_matrix + * @tparam bitmap_t Underlying type of the bitmap. + * @tparam index_t Indexing type used. + * @tparam csr_matrix_t Type of CSR Matrix, must be raft::device_csr_matrix * - * @param[in] handle RAFT handle - * @param[in] bitmap input raft::bitmap_view - * @param[out] csr output raft::device_csr_matrix + * @param[in] handle RAFT handle, containing the CUDA stream on which to schedule work + * @param[in] bitmap Bitmap view, need to be converted to CSR Matrix. + * @param[out] csr A CSR Matrix, containing the result of converting. Each of 1 in + * bitmap will be a non-zero element of csr. */ template Date: Fri, 15 Mar 2024 12:13:46 -0700 Subject: [PATCH 15/16] Add the UT cases for owning scenario --- cpp/include/raft/sparse/convert/csr.cuh | 17 +++--- cpp/test/sparse/convert_csr.cu | 70 ++++++++++++++++--------- 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index f4c513fc90..081192ed44 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -106,16 +106,17 @@ void adj_to_csr(raft::resources const& handle, } /** - * @brief Converts a bitmap matrix into unsorted CSR format matrix. + * @brief Converts a bitmap matrix to a Compressed Sparse Row (CSR) format matrix. * - * @tparam bitmap_t Underlying type of the bitmap. - * @tparam index_t Indexing type used. - * @tparam csr_matrix_t Type of CSR Matrix, must be raft::device_csr_matrix + * @tparam bitmap_t The data type of the elements in the bitmap matrix. + * @tparam index_t The data type used for indexing the elements in the matrices. + * @tparam csr_matrix_t Specifies the CSR matrix type, constrained to + * raft::device_csr_matrix. * - * @param[in] handle RAFT handle, containing the CUDA stream on which to schedule work - * @param[in] bitmap Bitmap view, need to be converted to CSR Matrix. - * @param[out] csr A CSR Matrix, containing the result of converting. Each of 1 in - * bitmap will be a non-zero element of csr. + * @param[in] handle The RAFT handle containing the CUDA stream for operations. + * @param[in] bitmap The bitmap matrix view, to be converted to CSR format. + * @param[out] csr Output parameter where the resulting CSR matrix is stored. In the + * bitmap, each '1' bit corresponds to a non-zero element in the CSR matrix. */ template @@ -364,33 +365,40 @@ class BitmapToCSRTest : public ::testing::TestWithParam(bitmap_d.data(), params.n_rows, params.n_cols); - auto csr_view = raft::make_device_compressed_structure_view( - indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); - auto csr = raft::make_device_csr_matrix(handle, csr_view); - - convert::bitmap_to_csr(handle, bitmap, csr); - + if (params.owning) { + auto csr = + raft::make_device_csr_matrix(handle, params.n_rows, params.n_cols, nnz); + auto csr_view = csr.structure_view(); + + convert::bitmap_to_csr(handle, bitmap, csr); + raft::copy(indptr_d.data(), csr_view.get_indptr().data(), indptr_d.size(), stream); + raft::copy(indices_d.data(), csr_view.get_indices().data(), indices_d.size(), stream); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } else { + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + convert::bitmap_to_csr(handle, bitmap, csr); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } resource::sync_stream(handle); - ASSERT_EQ(csr_view.get_indptr().size(), indptr_expected_d.size()); - ASSERT_EQ(csr_view.get_indices().size(), indices_expected_d.size()); - ASSERT_EQ(csr_view.get_nnz(), nnz); - std::vector indices_h(indices_expected_d.size(), 0); std::vector indices_expected_h(indices_expected_d.size(), 0); - update_host(indices_h.data(), csr_view.get_indices().data(), indices_h.size(), stream); + update_host(indices_h.data(), indices_d.data(), indices_h.size(), stream); update_host(indices_expected_h.data(), indices_expected_d.data(), indices_h.size(), stream); std::vector indptr_h(indptr_expected_d.size(), 0); std::vector indptr_expected_h(indptr_expected_d.size(), 0); - update_host(indptr_h.data(), csr_view.get_indptr().data(), indptr_h.size(), stream); + update_host(indptr_h.data(), indptr_d.data(), indptr_h.size(), stream); update_host(indptr_expected_h.data(), indptr_expected_d.data(), indptr_h.size(), stream); resource::sync_stream(handle); ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h)); ASSERT_TRUE(raft::devArrMatch( - csr.get_elements().data(), values_expected_d.data(), nnz, raft::Compare(), stream)); + values_expected_d.data(), values_d.data(), nnz, raft::Compare(), stream)); } protected: @@ -420,18 +428,30 @@ TEST_P(BitmapToCSRTestL, Result) { Run(); } template const std::vector> bitmaptocsr_inputs = { - {0, 0, 0.2}, - {10, 32, 0.4}, - {10, 3, 0.2}, - {32, 1024, 0.4}, - {1024, 1048576, 0.01}, - {1024, 1024, 0.4}, - {64 * 1024 + 10, 2, 0.3}, // 64K + 10 is slightly over maximum of blockDim.y - {16, 16, 0.3}, // No peeling-remainder - {17, 16, 0.3}, // Check peeling-remainder - {18, 16, 0.3}, // Check peeling-remainder - {32 + 9, 33, 0.2}, // Check peeling-remainder - {2, 33, 0.2}, // Check peeling-remainder + {0, 0, 0.2, false}, + {10, 32, 0.4, false}, + {10, 3, 0.2, false}, + {32, 1024, 0.4, false}, + {1024, 1048576, 0.01, false}, + {1024, 1024, 0.4, false}, + {64 * 1024 + 10, 2, 0.3, false}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, false}, // No peeling-remainder + {17, 16, 0.3, false}, // Check peeling-remainder + {18, 16, 0.3, false}, // Check peeling-remainder + {32 + 9, 33, 0.2, false}, // Check peeling-remainder + {2, 33, 0.2, false}, // Check peeling-remainder + {0, 0, 0.2, true}, + {10, 32, 0.4, true}, + {10, 3, 0.2, true}, + {32, 1024, 0.4, true}, + {1024, 1048576, 0.01, true}, + {1024, 1024, 0.4, true}, + {64 * 1024 + 10, 2, 0.3, true}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, true}, // No peeling-remainder + {17, 16, 0.3, true}, // Check peeling-remainder + {18, 16, 0.3, true}, // Check peeling-remainder + {32 + 9, 33, 0.2, true}, // Check peeling-remainder + {2, 33, 0.2, true}, // Check peeling-remainder }; INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, From 8838a36d32bb285f969839b5a29c43a34c129220 Mon Sep 17 00:00:00 2001 From: rhdong Date: Fri, 15 Mar 2024 20:30:31 -0700 Subject: [PATCH 16/16] fix benchmark compile error --- cpp/bench/prims/sparse/bitmap_to_csr.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu index b9c7638fcc..ed53df3265 100644 --- a/cpp/bench/prims/sparse/bitmap_to_csr.cu +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -107,7 +107,7 @@ struct BitmapToCsrBench : public fixture { auto csr_view = raft::make_device_compressed_structure_view( indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); - auto csr = raft::make_device_csr_matrix_view(values_d.data(), csr_view); + auto csr = raft::make_device_csr_matrix(handle, csr_view); raft::sparse::convert::bitmap_to_csr(handle, bitmap, csr);