Skip to content

Commit

Permalink
Add support for different data type of bitset (#2535)
Browse files Browse the repository at this point in the history
This PR is useful for Milvus.
Previously the `bitset_view` object only supported the data type used to create the bitset. With the proposed modifications, a `bitset_view` object can support any data type used to create the bitset by specifying the `original_nbits` parameter in the class constructor.

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - rhdong (https://github.com/rhdong)

URL: #2535
  • Loading branch information
lowener authored Jan 13, 2025
1 parent 1b62c41 commit 5c826d7
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 21 deletions.
24 changes: 20 additions & 4 deletions cpp/include/raft/core/bitmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
* @param bitmap_ptr Device raw pointer
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
* @param original_nbits Original number of bits used when the bitmap was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitmap was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols), rows_(rows), cols_(cols)
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr,
index_t rows,
index_t cols,
index_t original_nbits = 0)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols, original_nbits),
rows_(rows),
cols_(cols)
{
}

Expand All @@ -65,11 +74,18 @@ struct bitmap_view : public bitset_view<bitmap_t, index_t> {
* @param bitmap_span Device vector view of the bitmap
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
* @param original_nbits Original number of bits used when the bitmap was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitmap was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span,
index_t rows,
index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols), rows_(rows), cols_(cols)
index_t cols,
index_t original_nbits = 0)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols, original_nbits),
rows_(rows),
cols_(cols)
{
}

Expand Down
53 changes: 45 additions & 8 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,41 @@

namespace raft::core {

template <typename index_t>
_RAFT_HOST_DEVICE void inline compute_original_nbits_position(const index_t original_nbits,
const index_t nbits,
const index_t sample_index,
index_t& new_bit_index,
index_t& new_bit_offset)
{
const index_t original_bit_index = sample_index / original_nbits;
const index_t original_bit_offset = sample_index % original_nbits;
new_bit_index = original_bit_index * original_nbits / nbits;
new_bit_offset = 0;
if (original_nbits > nbits) {
new_bit_index += original_bit_offset / nbits;
new_bit_offset = original_bit_offset % nbits;
} else {
index_t ratio = nbits / original_nbits;
new_bit_offset += (original_bit_index % ratio) * original_nbits;
new_bit_offset += original_bit_offset % nbits;
}
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE inline bool bitset_view<bitset_t, index_t>::test(const index_t sample_index) const
{
const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size];
const index_t bit_index = sample_index % bitset_element_size;
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
const index_t nbits = sizeof(bitset_t) * 8;
index_t bit_index = 0;
index_t bit_offset = 0;
if (original_nbits_ == 0 || nbits == original_nbits_) {
bit_index = sample_index / bitset_element_size;
bit_offset = sample_index % bitset_element_size;
} else {
compute_original_nbits_position(original_nbits_, nbits, sample_index, bit_index, bit_offset);
}
const bitset_t bit_element = bitset_ptr_[bit_index];
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_offset)) != 0;
return is_bit_set;
}

Expand All @@ -51,14 +80,22 @@ template <typename bitset_t, typename index_t>
_RAFT_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index,
bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
const index_t nbits = sizeof(bitset_t) * 8;
index_t bit_index = 0;
index_t bit_offset = 0;

if (original_nbits_ == 0 || nbits == original_nbits_) {
bit_index = sample_index / bitset_element_size;
bit_offset = sample_index % bitset_element_size;
} else {
compute_original_nbits_position(original_nbits_, nbits, sample_index, bit_index, bit_offset);
}
const bitset_t bitmask = bitset_t{1} << bit_offset;
if (set_value) {
atomicOr(bitset_ptr_ + bit_element, bitmask);
atomicOr(bitset_ptr_ + bit_index, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr_ + bit_element, bitmask2);
atomicAnd(bitset_ptr_ + bit_index, bitmask2);
}
}

Expand Down
34 changes: 30 additions & 4 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,38 @@ template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset_view {
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
/**
* @brief Create a bitset view from a device pointer to the bitset.
*
* @param bitset_ptr Device pointer to the bitset
* @param bitset_len Number of bits in the bitset
* @param original_nbits Original number of bits used when the bitset was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitset was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr,
index_t bitset_len,
index_t original_nbits = 0)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}, original_nbits_{original_nbits}
{
}
/**
* @brief Create a bitset view from a device vector view of the bitset.
*
* @param bitset_span Device vector view of the bitset
* @param bitset_len Number of bits in the bitset
* @param original_nbits Original number of bits used when the bitset was created, to handle
* potential mismatches of data types. This is useful for using ANN indexes when a bitset was
* originally created with a different data type than the ones currently supported in cuVS ANN
* indexes.
*/
_RAFT_HOST_DEVICE bitset_view(raft::device_vector_view<bitset_t, index_t> bitset_span,
index_t bitset_len)
: bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len}
index_t bitset_len,
index_t original_nbits = 0)
: bitset_ptr_{bitset_span.data_handle()},
bitset_len_{bitset_len},
original_nbits_{original_nbits}
{
}
/**
Expand Down Expand Up @@ -180,9 +199,16 @@ struct bitset_view {
return (bitset_len + bits_per_element - 1) / bits_per_element;
}

/**
* @brief Get the original number of bits of the bitset.
*/
auto get_original_nbits() const -> index_t { return original_nbits_; }
void set_original_nbits(index_t original_nbits) { original_nbits_ = original_nbits; }

private:
bitset_t* bitset_ptr_;
index_t bitset_len_;
index_t original_nbits_;
};

/**
Expand Down
98 changes: 93 additions & 5 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <gtest/gtest.h>

#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <numeric>

namespace raft::core {
Expand Down Expand Up @@ -73,6 +75,40 @@ void test_cpu_bitset(const std::vector<bitset_t>& bitset,
}
}

template <typename bitset_t, typename index_t>
void test_cpu_bitset_nbits(const bitset_t* bitset,
const std::vector<index_t>& queries,
std::vector<uint8_t>& result,
unsigned original_nbits_)
{
constexpr size_t nbits = sizeof(bitset_t) * 8;
if (original_nbits_ == nbits) {
for (size_t i = 0; i < queries.size(); i++) {
result[i] =
uint8_t((bitset[queries[i] / nbits] & (bitset_t{1} << (queries[i] % nbits))) != 0);
}
}
for (size_t i = 0; i < queries.size(); i++) {
const index_t sample_index = queries[i];
const index_t original_bit_index = sample_index / original_nbits_;
const index_t original_bit_offset = sample_index % original_nbits_;
index_t new_bit_index = original_bit_index * original_nbits_ / nbits;
index_t new_bit_offset = 0;
if (original_nbits_ > nbits) {
new_bit_index += original_bit_offset / nbits;
new_bit_offset = original_bit_offset % nbits;
} else {
index_t ratio = nbits / original_nbits_;
new_bit_offset += (original_bit_index % ratio) * original_nbits_;
new_bit_offset += original_bit_offset % nbits;
}
const bitset_t bit_element = bitset[new_bit_index];
const bool is_bit_set = (bit_element & (bitset_t{1} << new_bit_offset)) != 0;

result[i] = uint8_t(is_bit_set);
}
}

template <typename bitset_t>
void flip_cpu_bitset(std::vector<bitset_t>& bitset)
{
Expand Down Expand Up @@ -168,11 +204,12 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

auto query_device = raft::make_device_vector<index_t, index_t>(res, spec.query_len);
auto result_device = raft::make_device_vector<uint8_t, index_t>(res, spec.query_len);
auto query_cpu = std::vector<index_t>(spec.query_len);
auto result_cpu = std::vector<uint8_t>(spec.query_len);
auto result_ref = std::vector<uint8_t>(spec.query_len);
auto query_device = raft::make_device_vector<index_t, index_t>(res, spec.query_len);
auto result_device = raft::make_device_vector<uint8_t, index_t>(res, spec.query_len);
auto query_cpu = std::vector<index_t>(spec.query_len);
auto result_cpu = std::vector<uint8_t>(spec.query_len);
auto result_ref_nbits = std::vector<uint8_t>(spec.query_len);
auto result_ref = std::vector<uint8_t>(spec.query_len);

// Create queries and verify the test results
raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len));
Expand All @@ -194,6 +231,57 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Reinterpret the bitset as uint8_t, uint32 then uint64_t
{
// Test CPU logic
test_cpu_bitset(bitset_ref, query_cpu, result_ref);
uint8_t* bitset_cpu_uint8 = (uint8_t*)std::malloc(sizeof(bitset_t) * bitset_ref.size());
std::memcpy(bitset_cpu_uint8, bitset_ref.data(), sizeof(bitset_t) * bitset_ref.size());
test_cpu_bitset_nbits(bitset_cpu_uint8, query_cpu, result_ref_nbits, sizeof(bitset_t) * 8);
ASSERT_TRUE(hostVecMatch(result_ref, result_ref_nbits, raft::Compare<uint8_t>()));
std::free(bitset_cpu_uint8);

// Test GPU uint8_t, uint32_t, uint64_t
auto my_bitset_view_uint8_t = raft::core::bitset_view<uint8_t, uint32_t>(
reinterpret_cast<uint8_t*>(my_bitset.data()), my_bitset.size(), sizeof(bitset_t) * 8);
raft::linalg::map(
res,
result_device.view(),
[my_bitset_view_uint8_t] __device__(index_t query) {
return my_bitset_view_uint8_t.test(query);
},
raft::make_const_mdspan(query_device.view()));
update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(result_ref, result_cpu, Compare<uint8_t>()));

auto my_bitset_view_uint32_t = raft::core::bitset_view<uint32_t, uint32_t>(
reinterpret_cast<uint32_t*>(my_bitset.data()), my_bitset.size(), sizeof(bitset_t) * 8);
raft::linalg::map(
res,
result_device.view(),
[my_bitset_view_uint32_t] __device__(index_t query) {
return my_bitset_view_uint32_t.test(query);
},
raft::make_const_mdspan(query_device.view()));
update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(result_ref, result_cpu, Compare<uint8_t>()));

auto my_bitset_view_uint64_t = raft::core::bitset_view<uint64_t, uint32_t>(
reinterpret_cast<uint64_t*>(my_bitset.data()), my_bitset.size(), sizeof(bitset_t) * 8);
raft::linalg::map(
res,
result_device.view(),
[my_bitset_view_uint64_t] __device__(index_t query) {
return my_bitset_view_uint64_t.test(query);
},
raft::make_const_mdspan(query_device.view()));
update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(result_ref, result_cpu, Compare<uint8_t>()));
}

// test sparsity, repeat and eval_n_elements
{
auto my_bitset_view = my_bitset.view();
Expand Down

0 comments on commit 5c826d7

Please sign in to comment.