Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Support bitset_to_csr #2523

Open
wants to merge 8 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 77 additions & 34 deletions cpp/bench/prims/linalg/masked_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/linalg/masked_matmul.hpp>
#include <raft/sparse/linalg/masked_matmul.cuh>
#include <raft/util/itertools.hpp>

#include <cusparse_v2.h>
Expand All @@ -49,11 +49,14 @@ inline auto operator<<(std::ostream& os, const MaskedMatmulBenchParams<value_t>&
{
os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n
<< "\tsparsity=" << params.sparsity;
if (params.sparsity == 1.0) { os << "<-inner product for comparison"; }
if (params.sparsity == 0.0) { os << "<-inner product for comparison"; }
return os;
}

template <typename value_t, typename index_t = int64_t, typename bitmap_t = uint32_t>
template <typename value_t,
bool bitmap_or_bitset = true,
typename index_t = int64_t,
typename bits_t = uint32_t>
struct MaskedMatmulBench : public fixture {
MaskedMatmulBench(const MaskedMatmulBenchParams<value_t>& p)
: fixture(true),
Expand All @@ -64,15 +67,15 @@ struct MaskedMatmulBench : public fixture {
c_indptr_d(0, stream),
c_indices_d(0, stream),
c_data_d(0, stream),
bitmap_d(0, stream),
bits_d(0, stream),
c_dense_data_d(0, stream)
{
index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bitmap_t) * 8));
std::vector<bitmap_t> bitmap_h(element);
index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bits_t) * 8));
std::vector<bits_t> bits_h(element);

a_data_d.resize(params.m * params.k, stream);
b_data_d.resize(params.k * params.n, stream);
bitmap_d.resize(element, stream);
bits_d.resize(element, stream);

raft::random::RngState rng(2024ULL);
raft::random::uniform(
Expand All @@ -82,7 +85,13 @@ struct MaskedMatmulBench : public fixture {

std::vector<bool> c_dense_data_h(params.m * params.n);

c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h);
if constexpr (bitmap_or_bitset) {
c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bits_h);
} else {
c_true_nnz = create_sparse_matrix(1, params.n, params.sparsity, bits_h);
repeat_cpu_bitset_inplace(bits_h, params.n, params.m - 1);
c_true_nnz *= params.m;
}

std::vector<value_t> values(c_true_nnz);
std::vector<index_t> indices(c_true_nnz);
Expand All @@ -93,24 +102,49 @@ struct MaskedMatmulBench : public fixture {
c_indices_d.resize(c_true_nnz, stream);
c_dense_data_d.resize(params.m * params.n, stream);

cpu_convert_to_csr(bitmap_h, params.m, params.n, indices, indptr);
cpu_convert_to_csr(bits_h, params.m, params.n, indices, indptr);
RAFT_EXPECTS(c_true_nnz == c_indices_d.size(),
"Something wrong. The c_true_nnz != c_indices_d.size()!");

update_device(c_data_d.data(), values.data(), c_true_nnz, stream);
update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream);
update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream);
update_device(bitmap_d.data(), bitmap_h.data(), element, stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to support both bitmap and bitset inputs but it appears we're removing the bitmap support.

Copy link
Member Author

@rhdong rhdong Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bits is only a naming; the code is needed to be compatible with bitset and bitmap, so I need to change bitmap to bits, as the compatible control point is here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not easy to understand and it’s not obvious which one is selected when true or false (future eyes are going to be confused too) . Let’s create an enum for this that we can share across benchmarks and tests. It’ll make this more straightforward for future eyes too.

update_device(bits_d.data(), bits_h.data(), element, stream);
}

void repeat_cpu_bitset_inplace(std::vector<bits_t>& inout, size_t input_bits, size_t repeat)
{
size_t output_bit_index = input_bits;

for (size_t r = 0; r < repeat; ++r) {
for (size_t i = 0; i < input_bits; ++i) {
size_t input_unit_index = i / (sizeof(bits_t) * 8);
size_t input_bit_offset = i % (sizeof(bits_t) * 8);
bool bit = (inout[input_unit_index] >> input_bit_offset) & 1;

size_t output_unit_index = output_bit_index / (sizeof(bits_t) * 8);
size_t output_bit_offset = output_bit_index % (sizeof(bits_t) * 8);

inout[output_unit_index] |= (static_cast<bits_t>(bit) << output_bit_offset);

++output_bit_index;
}
}
}

index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bitmap_t>& bitmap)
index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bits_t>& bits)
{
index_t total = static_cast<index_t>(m * n);
index_t num_ones = static_cast<index_t>((total * 1.0f) * sparsity);
index_t num_ones = static_cast<index_t>((total * 1.0f) * (1.0f - sparsity));
index_t res = num_ones;

for (auto& item : bitmap) {
item = static_cast<bitmap_t>(0);
if (sparsity == 0.0f) {
std::fill(bits.begin(), bits.end(), 0xffffffff);
return num_ones;
}

for (auto& item : bits) {
item = static_cast<bits_t>(0);
}

std::random_device rd;
Expand All @@ -120,8 +154,8 @@ struct MaskedMatmulBench : public fixture {
while (num_ones > 0) {
index_t index = dis(gen);

bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))];
index_t bit_position = index % (8 * sizeof(bitmap_t));
bits_t& element = bits[index / (8 * sizeof(bits_t))];
index_t bit_position = index % (8 * sizeof(bits_t));

if (((element >> bit_position) & 1) == 0) {
element |= (static_cast<index_t>(1) << bit_position);
Expand All @@ -131,7 +165,7 @@ struct MaskedMatmulBench : public fixture {
return res;
}

void cpu_convert_to_csr(std::vector<bitmap_t>& bitmap,
void cpu_convert_to_csr(std::vector<bits_t>& bits,
index_t rows,
index_t cols,
std::vector<index_t>& indices,
Expand All @@ -142,14 +176,14 @@ struct MaskedMatmulBench : public fixture {
indptr[offset_indptr++] = 0;

index_t index = 0;
bitmap_t element = 0;
bits_t element = 0;
index_t bit_position = 0;

for (index_t i = 0; i < rows; ++i) {
for (index_t j = 0; j < cols; ++j) {
index = i * cols + j;
element = bitmap[index / (8 * sizeof(bitmap_t))];
bit_position = index % (8 * sizeof(bitmap_t));
element = bits[index / (8 * sizeof(bits_t))];
bit_position = index % (8 * sizeof(bits_t));

if (((element >> bit_position) & 1)) {
indices[offset_values] = static_cast<index_t>(j);
Expand Down Expand Up @@ -181,13 +215,17 @@ struct MaskedMatmulBench : public fixture {
params.n,
static_cast<index_t>(c_indices_d.size()));

auto mask =
raft::core::bitmap_view<const bitmap_t, index_t>(bitmap_d.data(), params.m, params.n);

auto c = raft::make_device_csr_matrix_view<value_t>(c_data_d.data(), c_structure);

if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
if (params.sparsity > 0.0) {
if constexpr (bitmap_or_bitset) {
auto mask =
raft::core::bitmap_view<const bits_t, index_t>(bits_d.data(), params.m, params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
auto mask = raft::core::bitset_view<const bits_t, index_t>(bits_d.data(), params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
}
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
Expand All @@ -201,12 +239,16 @@ struct MaskedMatmulBench : public fixture {
}
resource::sync_stream(handle);

raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
resource::sync_stream(handle);

loop_on_state(state, [this, &a, &b, &mask, &c]() {
if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
loop_on_state(state, [this, &a, &b, &c]() {
if (params.sparsity > 0.0) {
if constexpr (bitmap_or_bitset) {
auto mask =
raft::core::bitmap_view<const bits_t, index_t>(bits_d.data(), params.m, params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
auto mask = raft::core::bitset_view<const bits_t, index_t>(bits_d.data(), params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
}
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
Expand All @@ -228,7 +270,7 @@ struct MaskedMatmulBench : public fixture {

rmm::device_uvector<value_t> a_data_d;
rmm::device_uvector<value_t> b_data_d;
rmm::device_uvector<bitmap_t> bitmap_d;
rmm::device_uvector<bits_t> bits_d;

rmm::device_uvector<value_t> c_dense_data_d;

Expand All @@ -253,7 +295,7 @@ static std::vector<MaskedMatmulBenchParams<value_t>> getInputs()
raft::util::itertools::product<TestParams>({size_t(10), size_t(1024)},
{size_t(128), size_t(1024)},
{size_t(1024 * 1024)},
{0.01f, 0.1f, 0.2f, 0.5f, 1.0f});
{0.99f, 0.9f, 0.8f, 0.5f, 0.0f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
Expand All @@ -263,6 +305,7 @@ static std::vector<MaskedMatmulBenchParams<value_t>> getInputs()
return param_vec;
}

RAFT_BENCH_REGISTER((MaskedMatmulBench<float>), "", getInputs<float>());
RAFT_BENCH_REGISTER((MaskedMatmulBench<float, true>), "", getInputs<float>());
RAFT_BENCH_REGISTER((MaskedMatmulBench<float, false>), "", getInputs<float>());

} // namespace raft::bench::linalg
Loading
Loading