Skip to content

Commit

Permalink
updated CP
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Sep 22, 2023
1 parent 5b2a7e0 commit 4139c7e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 37 deletions.
10 changes: 5 additions & 5 deletions cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ struct bitfield_view_t {
* type: void (uint8_t code, uint32_t out_ix, uint32_t j), where j = [0..pq_dim).
*/
template <uint32_t PqBits, typename Action>
__device__ void run_on_vector(std::variant<const uint8_t*, list_const_view> in_list_data,
uint32_t in_ix,
uint32_t out_ix,
uint32_t pq_dim,
Action action)
__host__ __device__ void run_on_vector(std::variant<const uint8_t*, list_const_view> in_list_data,
uint32_t in_ix,
uint32_t out_ix,
uint32_t pq_dim,
Action action)
{
using group_align = Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(in_ix);
Expand Down
82 changes: 52 additions & 30 deletions cpp/include/raft/neighbors/ivf_pq_codepacker.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,67 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/detail/ivf_pq_build.cuh>
#include <raft/neighbors/detail/div_utils.hpp>
#include <raft/neighbors/detail/ivf_pq_codepacking.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>

namespace raft::neighbors::ivf_pq::codepacker {

template <uint32_t PqBits>
void unpack_1(
const uint8_t* block, uint8_t* flat_code, uint32_t pq_dim, uint32_t offset)
{
RAFT_EXPECTS(PqBits == 8, "host codepacker supports only PqBits == 8");
using group_align = neighbors::detail::div_utils<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(offset);
const uint32_t ingroup_ix = group_align::mod(offset);
/**
* A producer for the `write_vector` reads the codes byte-by-byte. That is,
* independent of the code width (pq_bits), one code uses the whole byte, hence
* one vectors uses pq_dim bytes.
*/
struct pass_1_action {
const uint8_t* flat_code;

const uint32_t fixed_offset = group_ix * kIndexGroupSize * pq_dim + ingroup_ix * kIndexGroupVecLen;
/**
* Create a callable to be passed to `write_vector`.
*
* @param[in] flat_code flat PQ codes (one byte per code) of a single vector.
*/
__host__ __device__ inline pass_1_action(const uint8_t* flat_code) : flat_code{flat_code} {}

for (uint32_t j = 0; j < pq_dim; j += kIndexGroupVecLen) {
std::memcpy(flat_code + j, block + fixed_offset + (j / kIndexGroupVecLen) * pq_dim, kIndexGroupVecLen);
/** Read j-th component (code) of the i-th vector from the source. */
__host__ __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t
{
return flat_code[j];
}
}
};

/**
* A consumer for the `run_on_vector` that just flattens PQ codes
* one-per-byte. That is, independent of the code width (pq_bits), one code uses
* the whole byte, hence one vectors uses pq_dim bytes.
*/
struct unpack_1_action {
uint8_t* out_flat_code;

/**
* Create a callable to be passed to `run_on_vector`.
*
* @param[out] out_flat_code the destination for the read PQ codes of a single vector.
*/
__host__ __device__ inline unpack_1_action(uint8_t* out_flat_code) : out_flat_code{out_flat_code}
{
}

/** Write j-th component (code) of the i-th vector into the output array. */
__host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j)
{
out_flat_code[j] = code;
}
};

template <uint32_t PqBits>
void pack_1(
const uint8_t* flat_code,
uint8_t* block,
uint32_t pq_dim,
uint32_t offset)
void unpack_1(const uint8_t* block, uint8_t* flat_code, uint32_t pq_dim, uint32_t offset)
{
RAFT_EXPECTS(PqBits == 8, "host codepacker supports only PqBits == 8");
using group_align = neighbors::detail::div_utils<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(offset);
const uint32_t ingroup_ix = group_align::mod(offset);

const uint32_t fixed_offset = group_ix * kIndexGroupSize * pq_dim + ingroup_ix * kIndexGroupVecLen;
ivf_pq::detail::run_on_vector<PqBits>(block, offset, 0, pq_dim, unpack_1_action{flat_code});
}

for (uint32_t j = 0; j < pq_dim; j += kIndexGroupVecLen) {
size_t bytes = min(pq_dim - j, kIndexGroupVecLen);
std::memcpy(block + fixed_offset + (j / kIndexGroupVecLen) * pq_dim, flat_code + j, bytes);
}
template <uint32_t PqBits>
void pack_1(const uint8_t* flat_code, uint8_t* block, uint32_t pq_dim, uint32_t offset)
{
ivf_pq::detail::write_vector<PqBits>(block, offset, 0, pq_dim, pass_1_action{flat_code});
}
} // namespace raft::neighbors::ivf_flat::codepacker
} // namespace raft::neighbors::ivf_pq::codepacker
2 changes: 0 additions & 2 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

#include "../test_utils.cuh"
#include "ann_utils.cuh"
#include "raft/core/logger-macros.hpp"
#include <raft/core/resource/cuda_stream.hpp>

#include <raft_internal/neighbors/naive_knn.cuh>
Expand Down Expand Up @@ -387,7 +386,6 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
check_reconstruct_extend(&index, compression_ratio, label);
} break;
case 1: {
RAFT_LOG_INFO("------PACKING------");
// Dump and re-write codes for one label
check_packing(&index, label);
} break;
Expand Down

0 comments on commit 4139c7e

Please sign in to comment.