Skip to content

Commit

Permalink
[Opt] Optimizing the performance of bitmap_to_csr
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Dec 4, 2024
1 parent 0e6d35f commit d9a42d2
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 156 deletions.
26 changes: 24 additions & 2 deletions cpp/bench/prims/sparse/bitmap_to_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct BitmapToCsrBench : public fixture {
index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bitmap_t>& bitmap)
{
index_t total = static_cast<index_t>(m * n);
index_t num_ones = static_cast<index_t>((total * 1.0f) * sparsity);
index_t num_ones = static_cast<index_t>((total * 1.0f) * (1.0f - sparsity));
index_t res = num_ones;

for (auto& item : bitmap) {
Expand Down Expand Up @@ -141,7 +141,27 @@ const std::vector<bench_param<index_t>> getInputs()
};

const std::vector<TestParams> params_group = raft::util::itertools::product<TestParams>(
{index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.01f, 0.1f, 0.2f, 0.5f});
{index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.99f, 0.9f, 0.8f, 0.5f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
param_vec.push_back(bench_param<index_t>({params.m, params.n, params.sparsity}));
}
return param_vec;
}

template <typename index_t = int64_t>
const std::vector<bench_param<index_t>> getLargeInputs()
{
std::vector<bench_param<index_t>> param_vec;
struct TestParams {
index_t m;
index_t n;
float sparsity;
};

const std::vector<TestParams> params_group = raft::util::itertools::product<TestParams>(
{index_t(1), index_t(100)}, {index_t(100 * 1000000)}, {0.95f, 0.99f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
Expand All @@ -153,4 +173,6 @@ const std::vector<bench_param<index_t>> getInputs()
RAFT_BENCH_REGISTER((BitmapToCsrBench<uint32_t, int, float>), "", getInputs<int>());
RAFT_BENCH_REGISTER((BitmapToCsrBench<uint64_t, int, double>), "", getInputs<int>());

RAFT_BENCH_REGISTER((BitmapToCsrBench<uint32_t, int64_t, float>), "", getLargeInputs<int64_t>());

} // namespace raft::bench::sparse
Loading

0 comments on commit d9a42d2

Please sign in to comment.