diff --git a/cpp/bench/prims/sparse/bitset_to_csr.cu b/cpp/bench/prims/sparse/bitset_to_csr.cu new file mode 100644 index 0000000000..fef2d44d3e --- /dev/null +++ b/cpp/bench/prims/sparse/bitset_to_csr.cu @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_repeat; + index_t n_cols; + float sparsity; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << " rows*cols=" << params.n_repeat << "*" << params.n_cols + << "\tsparsity=" << params.sparsity; + return os; +} + +template +struct BitsetToCsrBench : public fixture { + BitsetToCsrBench(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + bitset_d(0, stream), + nnz(0), + indptr_d(0, stream), + indices_d(0, stream), + values_d(0, stream) + { + index_t element = raft::ceildiv(1 * params.n_cols, index_t(sizeof(bitset_t) * 8)); + std::vector bitset_h(element); + nnz = create_sparse_matrix(1, params.n_cols, params.sparsity, bitset_h); + + bitset_d.resize(bitset_h.size(), stream); + indptr_d.resize(params.n_repeat + 1, stream); + indices_d.resize(nnz, stream); + values_d.resize(nnz, stream); + + update_device(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream); + + resource::sync_stream(handle); + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitset) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * (1.0f - sparsity)); + index_t res = num_ones; + + for (auto& item : bitset) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitset_t& element = bitset[index / (8 * sizeof(bitset_t))]; + index_t bit_position = index % (8 * sizeof(bitset_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto bitset = raft::core::bitset_view(bitset_d.data(), 1 * params.n_cols); + + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_repeat, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + raft::sparse::convert::bitset_to_csr(handle, bitset, csr); + + resource::sync_stream(handle); + loop_on_state(state, [this, &bitset, &csr]() { + raft::sparse::convert::bitset_to_csr(handle, bitset, csr); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + + rmm::device_uvector bitset_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + index_t nnz; +}; // struct BitsetToCsrBench + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + float sparsity; + }; + + const std::vector params_group = raft::util::itertools::product( + {index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.99f, 0.9f, 0.8f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.sparsity})); + } + return param_vec; +} + +template +const std::vector> getLargeInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + float sparsity; + }; + + const std::vector params_group = raft::util::itertools::product( + {index_t(1), index_t(100)}, {index_t(100 * 1000000)}, {0.95f, 0.99f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.sparsity})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((BitsetToCsrBench), "", getInputs()); +RAFT_BENCH_REGISTER((BitsetToCsrBench), "", getInputs()); + +RAFT_BENCH_REGISTER((BitsetToCsrBench), "", getLargeInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 081192ed44..818b572a23 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -129,6 +130,33 @@ void bitmap_to_csr(raft::resources const& handle, detail::bitmap_to_csr(handle, bitmap, csr); } +/** + * @brief Converts a bitset matrix to a Compressed Sparse Row (CSR) format matrix. + * + * The bitset format inherently supports only a single-row matrix (rows=1). If the CSR matrix + * requires multiple rows, the data from the bitset will be repeated for each row in the output. + * + * @tparam bitset_t The data type of the elements in the bitset matrix. + * @tparam index_t The data type used for indexing the elements in the matrices. + * @tparam csr_matrix_t Specifies the CSR matrix type, constrained to + * raft::device_csr_matrix. + * + * @param[in] handle The RAFT handle containing the CUDA stream for operations. + * @param[in] bitset The bitset matrix view, to be converted to CSR format. + * @param[out] csr Output parameter where the resulting CSR matrix is stored. In the + * bitset, each '1' bit corresponds to a non-zero element in the CSR matrix. + */ +template >> +void bitset_to_csr(raft::resources const& handle, + raft::core::bitset_view bitset, + csr_matrix_t& csr) +{ + detail::bitset_to_csr(handle, bitset, csr); +} + }; // end NAMESPACE convert }; // end NAMESPACE sparse }; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh new file mode 100644 index 0000000000..f4660f4ecf --- /dev/null +++ b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // detail::popc +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace sparse { +namespace convert { +namespace detail { + +template +__global__ void repeat_csr_kernel(const index_t* indptr, + const index_t* indices, + index_t* repeated_indptr, + index_t* repeated_indices, + nnz_t nnz, + index_t repeat_count) +{ + int global_id = blockIdx.x * blockDim.x + threadIdx.x; + bool guard = global_id < nnz; + index_t* repeated_indices_addr = repeated_indices + global_id; + + for (index_t i = global_id; i < repeat_count; i += gridDim.x * blockDim.x) { + repeated_indptr[i] = (i + 2) * nnz; + } + + __syncthreads(); + + int block_offset = blockIdx.x * blockDim.x; + + index_t item; + int idx = block_offset + threadIdx.x; + item = (idx < nnz) ? indices[idx] : -1; + + __syncthreads(); + + for (index_t row = 0; row < repeat_count; ++row) { + index_t start_offset = row * nnz; + if (guard) { repeated_indices_addr[start_offset] = item; } + } +} + +template +void gpu_repeat_csr(raft::resources const& handle, + const index_t* d_indptr, + const index_t* d_indices, + nnz_t nnz, + index_t repeat_count, + index_t* d_repeated_indptr, + index_t* d_repeated_indices) +{ + auto stream = resource::get_cuda_stream(handle); + index_t repeat_csr_tpb = 256; + index_t grid = (nnz + repeat_csr_tpb - 1) / (repeat_csr_tpb); + + repeat_csr_kernel<<>>( + d_indptr, d_indices, d_repeated_indptr, d_repeated_indices, nnz, repeat_count); +} + +template >> +void bitset_to_csr(raft::resources const& handle, + raft::core::bitset_view bitset, + csr_matrix_t& csr) +{ + using row_t = typename csr_matrix_t::row_type; + using nnz_t = typename csr_matrix_t::nnz_type; + + auto csr_view = csr.structure_view(); + + if (csr_view.get_n_rows() == 0 || csr_view.get_n_cols() == 0 || csr_view.get_nnz() == 0) { + return; + } + + RAFT_EXPECTS(bitset.size() == csr_view.get_n_cols(), + "Number of size in bitset must be equal to " + "number of columns in csr"); + + auto thrust_policy = resource::get_thrust_policy(handle); + auto stream = resource::get_cuda_stream(handle); + + index_t* indptr = csr_view.get_indptr().data(); + index_t* indices = csr_view.get_indices().data(); + + RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream)); + + calc_nnz_by_rows(handle, bitset.data(), row_t(1), csr_view.get_n_cols(), indptr); + thrust::exclusive_scan(thrust_policy, indptr, indptr + 2, indptr); + + index_t bitset_nnz = 0; + + if constexpr (is_device_csr_sparsity_owning_v) { + RAFT_CUDA_TRY( + cudaMemcpyAsync(&bitset_nnz, indptr + 1, sizeof(index_t), cudaMemcpyDeviceToHost, stream)); + resource::sync_stream(handle); + csr.initialize_sparsity(bitset_nnz * csr_view.get_n_rows()); + } else { + bitset_nnz = csr_view.get_nnz() / csr_view.get_n_rows(); + } + + constexpr bool check_nnz = is_device_csr_sparsity_preserving_v; + fill_indices_by_rows( + handle, bitset.data(), indptr, 1, csr_view.get_n_cols(), bitset_nnz, indices); + + if (csr_view.get_n_rows() > 1) { + gpu_repeat_csr(handle, + indptr, + indices, + bitset_nnz, + csr_view.get_n_rows() - 1, + indptr + 2, + indices + bitset_nnz); + } + + thrust::fill_n(thrust_policy, + csr.get_elements().data(), + csr_view.get_nnz(), + typename csr_matrix_t::element_type(1)); +} + +}; // end NAMESPACE detail +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 1cd49b0bbd..4ecd7a4ac8 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -17,6 +17,7 @@ #include "../test_utils.cuh" #include +#include #include #include #include @@ -461,5 +462,289 @@ INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, BitmapToCSRTestL, ::testing::ValuesIn(bitmaptocsr_inputs)); +/******************************** bitset to csr ********************************/ + +template +struct BitsetToCSRInputs { + index_t n_repeat; + index_t n_cols; + float sparsity; + bool owning; +}; + +template +class BitsetToCSRTest : public ::testing::TestWithParam> { + public: + BitsetToCSRTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + bitset_d(0, stream), + indices_d(0, stream), + indptr_d(0, stream), + values_d(0, stream), + indptr_expected_d(0, stream), + indices_expected_d(0, stream), + values_expected_d(0, stream) + { + } + + protected: + void repeat_cpu_bitset(std::vector& input, + size_t input_bits, + size_t repeat, + std::vector& output) + { + const size_t output_bits = input_bits * repeat; + const size_t output_units = (output_bits + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8); + + std::memset(output.data(), 0, output_units * sizeof(bitset_t)); + + size_t output_bit_index = 0; + + for (size_t r = 0; r < repeat; ++r) { + for (size_t i = 0; i < input_bits; ++i) { + size_t input_unit_index = i / (sizeof(bitset_t) * 8); + size_t input_bit_offset = i % (sizeof(bitset_t) * 8); + bool bit = (input[input_unit_index] >> input_bit_offset) & 1; + + size_t output_unit_index = output_bit_index / (sizeof(bitset_t) * 8); + size_t output_bit_offset = output_bit_index % (sizeof(bitset_t) * 8); + + output[output_unit_index] |= (static_cast(bit) << output_bit_offset); + + ++output_bit_index; + } + } + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitset) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitset) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitset_t& element = bitset[index / (8 * sizeof(bitset_t))]; + index_t bit_position = index % (8 * sizeof(bitset_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitset, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitset_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitset[index / (8 * sizeof(bitset_t))]; + bit_position = index % (8 * sizeof(bitset_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + bool csr_compare(const std::vector& row_ptrs1, + const std::vector& col_indices1, + const std::vector& row_ptrs2, + const std::vector& col_indices2) + { + if (row_ptrs1.size() != row_ptrs2.size()) { return false; } + + if (col_indices1.size() != col_indices2.size()) { return false; } + + if (!std::equal(row_ptrs1.begin(), row_ptrs1.end(), row_ptrs2.begin())) { return false; } + + for (size_t i = 0; i < row_ptrs1.size() - 1; ++i) { + size_t start_idx = row_ptrs1[i]; + size_t end_idx = row_ptrs1[i + 1]; + + std::vector cols1(col_indices1.begin() + start_idx, col_indices1.begin() + end_idx); + std::vector cols2(col_indices2.begin() + start_idx, col_indices2.begin() + end_idx); + + std::sort(cols1.begin(), cols1.end()); + std::sort(cols2.begin(), cols2.end()); + + if (cols1 != cols2) { return false; } + } + + return true; + } + + void SetUp() override + { + index_t element = raft::ceildiv(1 * params.n_cols, index_t(sizeof(bitset_t) * 8)); + std::vector bitset_h(element); + std::vector bitset_repeat_h(element * params.n_repeat); + + nnz = create_sparse_matrix(1, params.n_cols, params.sparsity, bitset_h); + + repeat_cpu_bitset(bitset_h, size_t(params.n_cols), size_t(params.n_repeat), bitset_repeat_h); + nnz *= params.n_repeat; + + std::vector indices_h(nnz); + std::vector indptr_h(params.n_repeat + 1); + + cpu_convert_to_csr(bitset_repeat_h, params.n_repeat, params.n_cols, indices_h, indptr_h); + + bitset_d.resize(bitset_h.size(), stream); + indptr_d.resize(params.n_repeat + 1, stream); + indices_d.resize(nnz, stream); + + indptr_expected_d.resize(params.n_repeat + 1, stream); + indices_expected_d.resize(nnz, stream); + values_expected_d.resize(nnz, stream); + + thrust::fill_n(resource::get_thrust_policy(handle), values_expected_d.data(), nnz, value_t{1}); + + values_d.resize(nnz, stream); + + update_device(indices_expected_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_expected_d.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto bitset = raft::core::bitset_view(bitset_d.data(), params.n_cols); + + if (params.owning) { + auto csr = + raft::make_device_csr_matrix(handle, params.n_repeat, params.n_cols, nnz); + auto csr_view = csr.structure_view(); + + convert::bitset_to_csr(handle, bitset, csr); + raft::copy(indptr_d.data(), csr_view.get_indptr().data(), indptr_d.size(), stream); + raft::copy(indices_d.data(), csr_view.get_indices().data(), indices_d.size(), stream); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } else { + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_repeat, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + convert::bitset_to_csr(handle, bitset, csr); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } + resource::sync_stream(handle); + + std::vector indices_h(indices_expected_d.size(), 0); + std::vector indices_expected_h(indices_expected_d.size(), 0); + update_host(indices_h.data(), indices_d.data(), indices_h.size(), stream); + update_host(indices_expected_h.data(), indices_expected_d.data(), indices_h.size(), stream); + + std::vector indptr_h(indptr_expected_d.size(), 0); + std::vector indptr_expected_h(indptr_expected_d.size(), 0); + update_host(indptr_h.data(), indptr_d.data(), indptr_h.size(), stream); + update_host(indptr_expected_h.data(), indptr_expected_d.data(), indptr_h.size(), stream); + + resource::sync_stream(handle); + + ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h)); + ASSERT_TRUE(raft::devArrMatch( + values_expected_d.data(), values_d.data(), nnz, raft::Compare(), stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + BitsetToCSRInputs params; + + rmm::device_uvector bitset_d; + + index_t nnz; + + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + rmm::device_uvector indptr_expected_d; + rmm::device_uvector indices_expected_d; + rmm::device_uvector values_expected_d; +}; + +using BitsetToCSRTestI = BitsetToCSRTest; +TEST_P(BitsetToCSRTestI, Result) { Run(); } + +using BitsetToCSRTestL = BitsetToCSRTest; +TEST_P(BitsetToCSRTestL, Result) { Run(); } + +using BitsetToCSRTestLOnLargeSize = BitsetToCSRTest; +TEST_P(BitsetToCSRTestLOnLargeSize, Result) { Run(); } + +template +const std::vector> bitsettocsr_inputs = { + {0, 0, 0.2, false}, + {10, 32, 0.4, false}, + {10, 3, 0.2, false}, + {32, 1024, 0.4, false}, + {1024, 1048576, 0.01, false}, + {1024, 1024, 0.4, false}, + {64 * 1024 + 10, 2, 0.3, false}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, false}, // No peeling-remainder + {17, 16, 0.3, false}, // Check peeling-remainder + {18, 16, 0.3, false}, // Check peeling-remainder + {32 + 9, 33, 0.2, false}, // Check peeling-remainder + {2, 33, 0.2, false}, // Check peeling-remainder + {0, 0, 0.2, true}, + {10, 32, 0.4, true}, + {10, 3, 0.2, true}, + {32, 1024, 0.4, true}, + {1024, 1048576, 0.01, true}, + {1024, 1024, 0.4, true}, + {64 * 1024 + 10, 2, 0.3, true}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, true}, // No peeling-remainder + {17, 16, 0.3, true}, // Check peeling-remainder + {18, 16, 0.3, true}, // Check peeling-remainder + {32 + 9, 33, 0.2, true}, // Check peeling-remainder + {2, 33, 0.2, true}, // Check peeling-remainder +}; + +template +const std::vector> bitsettocsr_large_inputs = { + {100, 100000000, 0.01, true}, {100, 100000000, 0.05, false}, {100, 100000000 + 17, 0.05, false}}; + +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitsetToCSRTestI, + ::testing::ValuesIn(bitsettocsr_inputs)); +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitsetToCSRTestL, + ::testing::ValuesIn(bitsettocsr_inputs)); +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitsetToCSRTestLOnLargeSize, + ::testing::ValuesIn(bitsettocsr_large_inputs)); + } // namespace sparse } // namespace raft