diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh index 622b911a8f..13da363fd5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh @@ -103,11 +103,11 @@ struct bitfield_view_t { * type: void (uint8_t code, uint32_t out_ix, uint32_t j), where j = [0..pq_dim). */ template -__device__ void run_on_vector(std::variant in_list_data, - uint32_t in_ix, - uint32_t out_ix, - uint32_t pq_dim, - Action action) +__host__ __device__ void run_on_vector(std::variant in_list_data, + uint32_t in_ix, + uint32_t out_ix, + uint32_t pq_dim, + Action action) { using group_align = Pow2; const uint32_t group_ix = group_align::div(in_ix); diff --git a/cpp/include/raft/neighbors/ivf_pq_codepacker.cuh b/cpp/include/raft/neighbors/ivf_pq_codepacker.cuh index 9562a60552..116ced0067 100644 --- a/cpp/include/raft/neighbors/ivf_pq_codepacker.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_codepacker.cuh @@ -20,45 +20,67 @@ #include #include #include -#include -#include #include +#include +#include namespace raft::neighbors::ivf_pq::codepacker { -template -void unpack_1( - const uint8_t* block, uint8_t* flat_code, uint32_t pq_dim, uint32_t offset) -{ - RAFT_EXPECTS(PqBits == 8, "host codepacker supports only PqBits == 8"); - using group_align = neighbors::detail::div_utils; - const uint32_t group_ix = group_align::div(offset); - const uint32_t ingroup_ix = group_align::mod(offset); +/** + * A producer for the `write_vector` reads the codes byte-by-byte. That is, + * independent of the code width (pq_bits), one code uses the whole byte, hence + * one vectors uses pq_dim bytes. + */ +struct pass_1_action { + const uint8_t* flat_code; - const uint32_t fixed_offset = group_ix * kIndexGroupSize * pq_dim + ingroup_ix * kIndexGroupVecLen; + /** + * Create a callable to be passed to `write_vector`. + * + * @param[in] flat_code flat PQ codes (one byte per code) of a single vector. + */ + __host__ __device__ inline pass_1_action(const uint8_t* flat_code) : flat_code{flat_code} {} - for (uint32_t j = 0; j < pq_dim; j += kIndexGroupVecLen) { - std::memcpy(flat_code + j, block + fixed_offset + (j / kIndexGroupVecLen) * pq_dim, kIndexGroupVecLen); + /** Read j-th component (code) of the i-th vector from the source. */ + __host__ __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t + { + return flat_code[j]; } -} +}; + +/** + * A consumer for the `run_on_vector` that just flattens PQ codes + * one-per-byte. That is, independent of the code width (pq_bits), one code uses + * the whole byte, hence one vectors uses pq_dim bytes. + */ +struct unpack_1_action { + uint8_t* out_flat_code; + + /** + * Create a callable to be passed to `run_on_vector`. + * + * @param[out] out_flat_code the destination for the read PQ codes of a single vector. + */ + __host__ __device__ inline unpack_1_action(uint8_t* out_flat_code) : out_flat_code{out_flat_code} + { + } + + /** Write j-th component (code) of the i-th vector into the output array. */ + __host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) + { + out_flat_code[j] = code; + } +}; template -void pack_1( - const uint8_t* flat_code, - uint8_t* block, - uint32_t pq_dim, - uint32_t offset) +void unpack_1(const uint8_t* block, uint8_t* flat_code, uint32_t pq_dim, uint32_t offset) { - RAFT_EXPECTS(PqBits == 8, "host codepacker supports only PqBits == 8"); - using group_align = neighbors::detail::div_utils; - const uint32_t group_ix = group_align::div(offset); - const uint32_t ingroup_ix = group_align::mod(offset); - - const uint32_t fixed_offset = group_ix * kIndexGroupSize * pq_dim + ingroup_ix * kIndexGroupVecLen; + ivf_pq::detail::run_on_vector(block, offset, 0, pq_dim, unpack_1_action{flat_code}); +} - for (uint32_t j = 0; j < pq_dim; j += kIndexGroupVecLen) { - size_t bytes = min(pq_dim - j, kIndexGroupVecLen); - std::memcpy(block + fixed_offset + (j / kIndexGroupVecLen) * pq_dim, flat_code + j, bytes); - } +template +void pack_1(const uint8_t* flat_code, uint8_t* block, uint32_t pq_dim, uint32_t offset) +{ + ivf_pq::detail::write_vector(block, offset, 0, pq_dim, pass_1_action{flat_code}); } -} // namespace raft::neighbors::ivf_flat::codepacker \ No newline at end of file +} // namespace raft::neighbors::ivf_pq::codepacker \ No newline at end of file diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index be04d0a686..e03d09ae50 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -17,7 +17,6 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" -#include "raft/core/logger-macros.hpp" #include #include @@ -387,7 +386,6 @@ class ivf_pq_test : public ::testing::TestWithParam { check_reconstruct_extend(&index, compression_ratio, label); } break; case 1: { - RAFT_LOG_INFO("------PACKING------"); // Dump and re-write codes for one label check_packing(&index, label); } break;