Skip to content

Commit

Permalink
[Feat] Support bitset_to_csr
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Dec 9, 2024
1 parent fc7818f commit 06f6f29
Show file tree
Hide file tree
Showing 4 changed files with 649 additions and 0 deletions.
178 changes: 178 additions & 0 deletions cpp/bench/prims/sparse/bitset_to_csr.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <common/benchmark.hpp>

#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/convert/csr.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

#include <sstream>
#include <vector>

namespace raft::bench::sparse {

template <typename index_t>
struct bench_param {
index_t n_repeat;
index_t n_cols;
float sparsity;
};

template <typename index_t>
inline auto operator<<(std::ostream& os, const bench_param<index_t>& params) -> std::ostream&
{
os << " rows*cols=" << params.n_repeat << "*" << params.n_cols
<< "\tsparsity=" << params.sparsity;
return os;
}

template <typename bitset_t, typename index_t, typename value_t = float>
struct BitsetToCsrBench : public fixture {
BitsetToCsrBench(const bench_param<index_t>& p)
: fixture(true),
params(p),
handle(stream),
bitset_d(0, stream),
nnz(0),
indptr_d(0, stream),
indices_d(0, stream),
values_d(0, stream)
{
index_t element = raft::ceildiv(1 * params.n_cols, index_t(sizeof(bitset_t) * 8));
std::vector<bitset_t> bitset_h(element);
nnz = create_sparse_matrix(1, params.n_cols, params.sparsity, bitset_h);

bitset_d.resize(bitset_h.size(), stream);
indptr_d.resize(params.n_repeat + 1, stream);
indices_d.resize(nnz, stream);
values_d.resize(nnz, stream);

update_device(bitset_d.data(), bitset_h.data(), bitset_h.size(), stream);

resource::sync_stream(handle);
}

index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bitset_t>& bitset)
{
index_t total = static_cast<index_t>(m * n);
index_t num_ones = static_cast<index_t>((total * 1.0f) * (1.0f - sparsity));
index_t res = num_ones;

for (auto& item : bitset) {
item = static_cast<bitset_t>(0);
}

std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<index_t> dis(0, total - 1);

while (num_ones > 0) {
index_t index = dis(gen);

bitset_t& element = bitset[index / (8 * sizeof(bitset_t))];
index_t bit_position = index % (8 * sizeof(bitset_t));

if (((element >> bit_position) & 1) == 0) {
element |= (static_cast<index_t>(1) << bit_position);
num_ones--;
}
}
return res;
}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

auto bitset = raft::core::bitset_view<bitset_t, index_t>(bitset_d.data(), 1 * params.n_cols);

auto csr_view = raft::make_device_compressed_structure_view<index_t, index_t, index_t>(
indptr_d.data(), indices_d.data(), params.n_repeat, params.n_cols, nnz);
auto csr = raft::make_device_csr_matrix<value_t, index_t>(handle, csr_view);

raft::sparse::convert::bitset_to_csr<bitset_t, index_t>(handle, bitset, csr);

resource::sync_stream(handle);
loop_on_state(state, [this, &bitset, &csr]() {
raft::sparse::convert::bitset_to_csr<bitset_t, index_t>(handle, bitset, csr);
});
}

protected:
const raft::device_resources handle;

bench_param<index_t> params;

rmm::device_uvector<bitset_t> bitset_d;
rmm::device_uvector<index_t> indptr_d;
rmm::device_uvector<index_t> indices_d;
rmm::device_uvector<value_t> values_d;

index_t nnz;
}; // struct BitsetToCsrBench

template <typename index_t>
const std::vector<bench_param<index_t>> getInputs()
{
std::vector<bench_param<index_t>> param_vec;
struct TestParams {
index_t m;
index_t n;
float sparsity;
};

const std::vector<TestParams> params_group = raft::util::itertools::product<TestParams>(
{index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.99f, 0.9f, 0.8f, 0.5f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
param_vec.push_back(bench_param<index_t>({params.m, params.n, params.sparsity}));
}
return param_vec;
}

template <typename index_t = int64_t>
const std::vector<bench_param<index_t>> getLargeInputs()
{
std::vector<bench_param<index_t>> param_vec;
struct TestParams {
index_t m;
index_t n;
float sparsity;
};

const std::vector<TestParams> params_group = raft::util::itertools::product<TestParams>(
{index_t(1), index_t(100)}, {index_t(100 * 1000000)}, {0.95f, 0.99f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
param_vec.push_back(bench_param<index_t>({params.m, params.n, params.sparsity}));
}
return param_vec;
}

RAFT_BENCH_REGISTER((BitsetToCsrBench<uint32_t, int, float>), "", getInputs<int>());
RAFT_BENCH_REGISTER((BitsetToCsrBench<uint64_t, int, double>), "", getInputs<int>());

RAFT_BENCH_REGISTER((BitsetToCsrBench<uint32_t, int64_t, float>), "", getLargeInputs<int64_t>());

} // namespace raft::bench::sparse
28 changes: 28 additions & 0 deletions cpp/include/raft/sparse/convert/csr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/device_csr_matrix.hpp>
#include <raft/sparse/convert/detail/adj_to_csr.cuh>
#include <raft/sparse/convert/detail/bitmap_to_csr.cuh>
#include <raft/sparse/convert/detail/bitset_to_csr.cuh>
#include <raft/sparse/convert/detail/csr.cuh>
#include <raft/sparse/csr.hpp>

Expand Down Expand Up @@ -129,6 +130,33 @@ void bitmap_to_csr(raft::resources const& handle,
detail::bitmap_to_csr(handle, bitmap, csr);
}

/**
* @brief Converts a bitset matrix to a Compressed Sparse Row (CSR) format matrix.
*
* The bitset format inherently supports only a single-row matrix (rows=1). If the CSR matrix
* requires multiple rows, the data from the bitset will be repeated for each row in the output.
*
* @tparam bitset_t The data type of the elements in the bitset matrix.
* @tparam index_t The data type used for indexing the elements in the matrices.
* @tparam csr_matrix_t Specifies the CSR matrix type, constrained to
* raft::device_csr_matrix.
*
* @param[in] handle The RAFT handle containing the CUDA stream for operations.
* @param[in] bitset The bitset matrix view, to be converted to CSR format.
* @param[out] csr Output parameter where the resulting CSR matrix is stored. In the
* bitset, each '1' bit corresponds to a non-zero element in the CSR matrix.
*/
template <typename bitset_t,
typename index_t,
typename csr_matrix_t,
typename = std::enable_if_t<raft::is_device_csr_matrix_v<csr_matrix_t>>>
void bitset_to_csr(raft::resources const& handle,
raft::core::bitset_view<bitset_t, index_t> bitset,
csr_matrix_t& csr)
{
detail::bitset_to_csr(handle, bitset, csr);
}

}; // end NAMESPACE convert
}; // end NAMESPACE sparse
}; // end NAMESPACE raft
Expand Down
158 changes: 158 additions & 0 deletions cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/detail/mdspan_util.cuh> // detail::popc
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/convert/detail/adj_to_csr.cuh>
#include <raft/sparse/convert/detail/bitmap_to_csr.cuh>

#include <rmm/device_uvector.hpp>

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>

#include <assert.h>

namespace raft {
namespace sparse {
namespace convert {
namespace detail {

template <typename index_t, typename nnz_t>
__global__ void repeat_csr_kernel(const index_t* indptr,
const index_t* indices,
index_t* repeated_indptr,
index_t* repeated_indices,
nnz_t nnz,
index_t repeat_count)
{
int global_id = blockIdx.x * blockDim.x + threadIdx.x;
bool guard = global_id < nnz;
index_t* repeated_indices_addr = repeated_indices + global_id;

for (index_t i = global_id; i < repeat_count; i += gridDim.x * blockDim.x) {
repeated_indptr[i] = (i + 2) * nnz;
}

__syncthreads();

int block_offset = blockIdx.x * blockDim.x;

index_t item;
int idx = block_offset + threadIdx.x;
item = (idx < nnz) ? indices[idx] : -1;

__syncthreads();

for (index_t row = 0; row < repeat_count; ++row) {
index_t start_offset = row * nnz;
if (guard) { repeated_indices_addr[start_offset] = item; }
}
}

template <typename index_t, typename nnz_t>
void gpu_repeat_csr(raft::resources const& handle,
const index_t* d_indptr,
const index_t* d_indices,
nnz_t nnz,
index_t repeat_count,
index_t* d_repeated_indptr,
index_t* d_repeated_indices)
{
auto stream = resource::get_cuda_stream(handle);
index_t repeat_csr_tpb = 256;
index_t grid = (nnz + repeat_csr_tpb - 1) / (repeat_csr_tpb);

repeat_csr_kernel<<<grid, repeat_csr_tpb, 0, stream>>>(
d_indptr, d_indices, d_repeated_indptr, d_repeated_indices, nnz, repeat_count);
}

template <typename bitset_t,
typename index_t,
typename csr_matrix_t,
typename = std::enable_if_t<raft::is_device_csr_matrix_v<csr_matrix_t>>>
void bitset_to_csr(raft::resources const& handle,
raft::core::bitset_view<bitset_t, index_t> bitset,
csr_matrix_t& csr)
{
using row_t = typename csr_matrix_t::row_type;
using nnz_t = typename csr_matrix_t::nnz_type;

auto csr_view = csr.structure_view();

if (csr_view.get_n_rows() == 0 || csr_view.get_n_cols() == 0 || csr_view.get_nnz() == 0) {
return;
}

RAFT_EXPECTS(bitset.size() == csr_view.get_n_cols(),
"Number of size in bitset must be equal to "
"number of columns in csr");

auto thrust_policy = resource::get_thrust_policy(handle);
auto stream = resource::get_cuda_stream(handle);

index_t* indptr = csr_view.get_indptr().data();
index_t* indices = csr_view.get_indices().data();

RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream));

calc_nnz_by_rows(handle, bitset.data(), row_t(1), csr_view.get_n_cols(), indptr);
thrust::exclusive_scan(thrust_policy, indptr, indptr + 2, indptr);

index_t bitset_nnz = 0;

if constexpr (is_device_csr_sparsity_owning_v<csr_matrix_t>) {
RAFT_CUDA_TRY(
cudaMemcpyAsync(&bitset_nnz, indptr + 1, sizeof(index_t), cudaMemcpyDeviceToHost, stream));
resource::sync_stream(handle);
csr.initialize_sparsity(bitset_nnz * csr_view.get_n_rows());
} else {
bitset_nnz = csr_view.get_nnz() / csr_view.get_n_rows();
}

constexpr bool check_nnz = is_device_csr_sparsity_preserving_v<csr_matrix_t>;
fill_indices_by_rows<bitset_t, index_t, nnz_t, check_nnz>(
handle, bitset.data(), indptr, 1, csr_view.get_n_cols(), bitset_nnz, indices);

if (csr_view.get_n_rows() > 1) {
gpu_repeat_csr<index_t, nnz_t>(handle,
indptr,
indices,
bitset_nnz,
csr_view.get_n_rows() - 1,
indptr + 2,
indices + bitset_nnz);
}

thrust::fill_n(thrust_policy,
csr.get_elements().data(),
csr_view.get_nnz(),
typename csr_matrix_t::element_type(1));
}

}; // end NAMESPACE detail
}; // end NAMESPACE convert
}; // end NAMESPACE sparse
}; // end NAMESPACE raft
Loading

0 comments on commit 06f6f29

Please sign in to comment.