Skip to content

Commit

Permalink
Add operators
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Oct 6, 2023
1 parent 8e7bb87 commit 82a7575
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 73 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/prims/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct bitset_bench : public fixture {
loop_on_state(state, [this]() {
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
this->res, raft::make_const_mdspan(mask.view()), params.bitset_len);
my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view());
my_bitset.test(raft::make_const_mdspan(queries.view()), outputs.view());
});
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct CagraBench : public fixture {
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
removed_indices_bitset_.set(removed_indices.view());
index_.emplace(raft::neighbors::cagra::index<T, IdxT>(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())));
}
Expand Down
143 changes: 86 additions & 57 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,11 @@ struct bitset {
bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len}
bitset_len_{bitset_len},
res_{res}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
set(res, mask_index, !default_value);
reset(default_value);
set(mask_index, !default_value);
}

/**
Expand All @@ -180,12 +178,10 @@ struct bitset {
bitset(const raft::resources& res, index_t bitset_len, bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
resource::get_cuda_stream(res)},
bitset_len_{bitset_len}
bitset_len_{bitset_len},
res_{res}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
reset(default_value);
}
// Disable copy constructor
bitset(const bitset&) = delete;
Expand Down Expand Up @@ -237,18 +233,18 @@ struct bitset {

/** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to
* the default value. */
void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true)
void resize(index_t new_bitset_len, bool default_value = true)
{
auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
bitset_.resize(new_size);
bitset_len_ = new_bitset_len;
if (old_size < new_size) {
// If the new size is larger, set the new bits to the default value
cudaMemsetAsync(bitset_.data() + old_size,
default_value ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res));
RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data() + old_size,
default_value ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res_)));
}
}

Expand All @@ -261,14 +257,13 @@ struct bitset {
* @param output List of outputs
*/
template <typename output_t = bool>
void test(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> queries,
void test(raft::device_vector_view<const index_t, index_t> queries,
raft::device_vector_view<output_t, index_t> output) const
{
RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size");
auto bitset_view = view();
raft::linalg::map(
res,
res_,
output,
[bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); },
queries);
Expand All @@ -280,12 +275,10 @@ struct bitset {
* @param mask_index indices to remove from the bitset
* @param set_value Value to set the bits to (true or false)
*/
void set(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
bool set_value = false)
void set(raft::device_vector_view<const index_t, index_t> mask_index, bool set_value = false)
{
auto this_bitset_view = view();
thrust::for_each_n(resource::get_thrust_policy(res),
thrust::for_each_n(resource::get_thrust_policy(res_),
mask_index.data_handle(),
mask_index.extent(0),
[this_bitset_view, set_value] __device__(const index_t sample_index) {
Expand All @@ -294,38 +287,34 @@ struct bitset {
}
/**
* @brief Flip all the bits in a bitset.
*
* @param res RAFT resources
*/
void flip(const raft::resources& res)
void flip()
{
auto bitset_span = this->to_mdspan();
raft::linalg::map(
res,
res_,
bitset_span,
[] __device__(bitset_t element) { return bitset_t(~element); },
raft::make_const_mdspan(bitset_span));
}
/**
* @brief Reset the bits in a bitset.
*
* @param res RAFT resources
* @param default_value Value to set the bits to (true or false)
*/
void reset(const raft::resources& res, bool default_value = true)
void reset(bool default_value = true)
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res_)));
}
/**
* @brief Returns the number of bits set to true in count_gpu_scalar.
*
* @param[in] res RAFT resources
* @param[out] count_gpu_scalar Device scalar to store the count
*/
void count(const raft::resources& res, raft::device_scalar_view<index_t> count_gpu_scalar)
void count(raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto n_elements_ = n_elements();
auto count_gpu =
Expand All @@ -337,66 +326,106 @@ struct bitset {
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res,
res_,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t res = 0;
index_t result = 0;
if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit
if (index == n_elements_ - 1)
res = index_t(raft::detail::native_popc<uint64_t>(element & last_element_mask));
result = index_t(raft::detail::native_popc<uint64_t>(element & last_element_mask));
else
res = index_t(raft::detail::native_popc<uint64_t>(element));
result = index_t(raft::detail::native_popc<uint64_t>(element));
} else {
if (index == n_elements_ - 1)
res = index_t(__popc(element & last_element_mask));
result = index_t(__popc(element & last_element_mask));
else
res = index_t(__popc(element));
result = index_t(__popc(element));
}

return res;
return result;
});
}
/**
* @brief Returns the number of bits set to true.
*
* @param[in] res RAFT resources
* @return index_t Number of bits set to true
*/
auto count(const raft::resources& res) -> index_t
auto count() -> index_t
{
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
count(res, count_gpu_scalar.view());
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res_, 0.0);
count(count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res_));
resource::sync_stream(res_);
return count_cpu;
}
/**
* @brief Checks if any of the bits are set to true in the bitset.
*
* @param res RAFT resources
*/
bool any(const raft::resources& res) { return count(res) > 0; }
bool any() { return count() > 0; }
/**
* @brief Checks if all of the bits are set to true in the bitset.
*
* @param res RAFT resources
*/
bool all(const raft::resources& res) { return count(res) == bitset_len_; }
bool all() { return count() == bitset_len_; }
/**
* @brief Checks if none of the bits are set to true in the bitset.
*
* @param res RAFT resources
*/
bool none(const raft::resources& res) { return count(res) == 0; }
bool none() { return count() == 0; }

bitset<bitset_t, index_t>& operator|=(const bitset<bitset_t, index_t>& other)
{
RAFT_EXPECTS(size() == other.size(), "Sizes must be equal");
auto this_span = to_mdspan();
auto other_span = other.to_mdspan();
raft::linalg::map(
res_,
this_span,
[] __device__(bitset_t this_element, bitset_t other_element) {
return this_element | other_element;
},
raft::make_const_mdspan(this_span),
other_span);
return *this;
}
bitset<bitset_t, index_t>& operator&=(const bitset<bitset_t, index_t>& other)
{
RAFT_EXPECTS(size() == other.size(), "Sizes must be equal");
auto this_span = to_mdspan();
auto other_span = other.to_mdspan();
raft::linalg::map(
res_,
this_span,
[] __device__(bitset_t this_element, bitset_t other_element) {
return this_element & other_element;
},
raft::make_const_mdspan(this_span),
other_span);
return *this;
}
bitset<bitset_t, index_t>& operator^=(const bitset<bitset_t, index_t>& other)
{
RAFT_EXPECTS(size() == other.size(), "Sizes must be equal");
auto this_span = to_mdspan();
auto other_span = other.to_mdspan();
raft::linalg::map(
res_,
this_span,
[] __device__(bitset_t this_element, bitset_t other_element) {
return this_element ^ other_element;
},
raft::make_const_mdspan(this_span),
other_span);
return *this;
}

private:
raft::device_uvector<bitset_t> bitset_;
index_t bitset_len_;
const raft::resources& res_;
};

/** @} */
Expand Down
43 changes: 29 additions & 14 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
// Create queries and verify the test results
raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len));
update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream);
my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view());
my_bitset.test(raft::make_const_mdspan(query_device.view()), result_device.view());
update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream);
test_cpu_bitset(bitset_ref, query_cpu, result_ref);
resource::sync_stream(res, stream);
Expand All @@ -138,33 +138,48 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len));
update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream);
resource::sync_stream(res, stream);
my_bitset.set(res, mask_device.view());
my_bitset.set(mask_device.view());
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);

add_cpu_bitset(bitset_ref, mask_cpu);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Flip the bitset and re-test
auto bitset_count = my_bitset.count(res);
my_bitset.flip(res);
ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count);
auto bitset_count = my_bitset.count();
my_bitset.flip();
ASSERT_EQ(my_bitset.count(), spec.bitset_len - bitset_count);
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
flip_cpu_bitset(bitset_ref);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

my_bitset.reset(res, false);
ASSERT_EQ(my_bitset.any(res), false);
ASSERT_EQ(my_bitset.none(res), true);
// Test count() operations
my_bitset.reset(false);
ASSERT_EQ(my_bitset.any(), false);
ASSERT_EQ(my_bitset.none(), true);
raft::linalg::range(query_device.data_handle(), query_device.size(), stream);
my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true);
bitset_count = my_bitset.count(res);
my_bitset.set(raft::make_const_mdspan(query_device.view()), true);
bitset_count = my_bitset.count();
ASSERT_EQ(bitset_count, query_device.size());
ASSERT_EQ(my_bitset.any(res), true);
ASSERT_EQ(my_bitset.none(res), false);

ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));
ASSERT_EQ(my_bitset.any(), true);
ASSERT_EQ(my_bitset.none(), false);

// Test operators
auto my_bitset_2 = raft::core::bitset<bitset_t, index_t>(
res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false);
auto my_bitset_3 = raft::core::bitset<bitset_t, index_t>(
res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false);
my_bitset_2 ^= my_bitset;
ASSERT_FALSE(devArrMatch(my_bitset_2.data_handle(),
my_bitset_3.data_handle(),
my_bitset.n_elements(),
raft::Compare<bitset_t>()));
my_bitset_2 ^= my_bitset;
ASSERT_TRUE(devArrMatch(my_bitset_2.data_handle(),
my_bitset_3.data_handle(),
my_bitset.n_elements(),
raft::Compare<bitset_t>()));
}
};

Expand Down

0 comments on commit 82a7575

Please sign in to comment.