Skip to content

Commit

Permalink
Fix popc, data() and cuda calls as suggested in reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Oct 18, 2023
1 parent 22a9559 commit d59d08b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 29 deletions.
37 changes: 19 additions & 18 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ struct bitset_view {
/**
* @brief Get the device pointer to the bitset.
*/
inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data() -> bitset_t* { return bitset_ptr_; }
inline _RAFT_HOST_DEVICE auto data() const -> const bitset_t* { return bitset_ptr_; }
/**
* @brief Get the number of bits of the bitset representation.
*/
Expand Down Expand Up @@ -206,8 +206,8 @@ struct bitset {
/**
* @brief Get the device pointer to the bitset.
*/
inline auto data_handle() -> bitset_t* { return bitset_.data(); }
inline auto data_handle() const -> const bitset_t* { return bitset_.data(); }
inline auto data() -> bitset_t* { return bitset_.data(); }
inline auto data() const -> const bitset_t* { return bitset_.data(); }
/**
* @brief Get the number of bits of the bitset representation.
*/
Expand Down Expand Up @@ -241,10 +241,11 @@ struct bitset {
bitset_len_ = new_bitset_len;
if (old_size < new_size) {
// If the new size is larger, set the new bits to the default value
RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data() + old_size,
default_value ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res_)));

thrust::fill_n(resource::get_thrust_policy(res_),
bitset_.data() + old_size,
new_size - old_size,
default_value ? ~bitset_t{0} : bitset_t{0});
}
}

Expand Down Expand Up @@ -302,10 +303,10 @@ struct bitset {
*/
void reset(bool default_value = true)
{
RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res_)));
thrust::fill_n(resource::get_thrust_policy(res_),
bitset_.data(),
n_elements(),
default_value ? ~bitset_t{0} : bitset_t{0});
}
/**
* @brief Returns the number of bits set to true in count_gpu_scalar.
Expand All @@ -331,16 +332,16 @@ struct bitset {
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t result = 0;
if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit
if constexpr (bitset_element_size == 64) {
if (index == n_elements_ - 1)
result = index_t(raft::detail::native_popc<uint64_t>(element & last_element_mask));
result = index_t(raft::detail::popc(element & last_element_mask));
else
result = index_t(raft::detail::native_popc<uint64_t>(element));
} else {
result = index_t(raft::detail::popc(element));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == n_elements_ - 1)
result = index_t(__popc(element & last_element_mask));
result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask));
else
result = index_t(__popc(element));
result = index_t(raft::detail::popc(uint32_t{element}));
}

return result;
Expand Down
18 changes: 7 additions & 11 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
// calculate the results
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len));
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);

// calculate the reference
create_cpu_bitset(bitset_ref, mask_cpu);
Expand All @@ -139,7 +139,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream);
resource::sync_stream(res, stream);
my_bitset.set(mask_device.view());
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);

add_cpu_bitset(bitset_ref, mask_cpu);
resource::sync_stream(res, stream);
Expand All @@ -149,7 +149,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
auto bitset_count = my_bitset.count();
my_bitset.flip();
ASSERT_EQ(my_bitset.count(), spec.bitset_len - bitset_count);
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream);
flip_cpu_bitset(bitset_ref);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));
Expand All @@ -171,15 +171,11 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
auto my_bitset_3 = raft::core::bitset<bitset_t, index_t>(
res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false);
my_bitset_2 ^= my_bitset;
ASSERT_FALSE(devArrMatch(my_bitset_2.data_handle(),
my_bitset_3.data_handle(),
my_bitset.n_elements(),
raft::Compare<bitset_t>()));
ASSERT_FALSE(devArrMatch(
my_bitset_2.data(), my_bitset_3.data(), my_bitset.n_elements(), raft::Compare<bitset_t>()));
my_bitset_2 ^= my_bitset;
ASSERT_TRUE(devArrMatch(my_bitset_2.data_handle(),
my_bitset_3.data_handle(),
my_bitset.n_elements(),
raft::Compare<bitset_t>()));
ASSERT_TRUE(devArrMatch(
my_bitset_2.data(), my_bitset_3.data(), my_bitset.n_elements(), raft::Compare<bitset_t>()));
}
};

Expand Down

0 comments on commit d59d08b

Please sign in to comment.