Skip to content

Commit

Permalink
change helpers name to contiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Nov 17, 2023
1 parent a2d4575 commit 1efd28f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 46 deletions.
48 changes: 24 additions & 24 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ void unpack_list_data(raft::resources const& res,
* into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct unpack_compressed {
struct unpack_contiguous {
uint8_t* codes;
uint32_t code_size;

Expand All @@ -680,7 +680,7 @@ struct unpack_compressed {
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline unpack_compressed(uint8_t* codes, uint32_t pq_dim)
__host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}
Expand All @@ -694,15 +694,15 @@ struct unpack_compressed {
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL unpack_compressed_list_data_kernel(
__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel(
uint8_t* out_codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> in_list_data,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
run_on_list<PqBits>(
in_list_data, offset_or_indices, n_rows, pq_dim, unpack_compressed<PqBits>(out_codes, pq_dim));
in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous<PqBits>(out_codes, pq_dim));
}

/**
Expand All @@ -714,7 +714,7 @@ __launch_bounds__(BlockSize) RAFT_KERNEL unpack_compressed_list_data_kernel(
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void unpack_compressed_list_data(
inline void unpack_contiguous_list_data(
uint8_t* codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
uint32_t n_rows,
Expand All @@ -730,11 +730,11 @@ inline void unpack_compressed_list_data(
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return unpack_compressed_list_data_kernel<kBlockSize, 4>;
case 5: return unpack_compressed_list_data_kernel<kBlockSize, 5>;
case 6: return unpack_compressed_list_data_kernel<kBlockSize, 6>;
case 7: return unpack_compressed_list_data_kernel<kBlockSize, 7>;
case 8: return unpack_compressed_list_data_kernel<kBlockSize, 8>;
case 4: return unpack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return unpack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return unpack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return unpack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return unpack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
Expand All @@ -744,14 +744,14 @@ inline void unpack_compressed_list_data(

/** Unpack the list data; see the public interface for the api and usage. */
template <typename IdxT>
void unpack_compressed_list_data(raft::resources const& res,
void unpack_contiguous_list_data(raft::resources const& res,
const index<IdxT>& index,
uint8_t* out_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
unpack_compressed_list_data(out_codes,
unpack_contiguous_list_data(out_codes,
index.lists()[label]->data.view(),
n_rows,
index.pq_dim(),
Expand Down Expand Up @@ -1002,7 +1002,7 @@ void pack_list_data(raft::resources const& res,
* the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct pack_compressed {
struct pack_contiguous {
const uint8_t* codes;
uint32_t code_size;

Expand All @@ -1011,7 +1011,7 @@ struct pack_compressed {
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline pack_compressed(const uint8_t* codes, uint32_t pq_dim)
__host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}
Expand All @@ -1025,15 +1025,15 @@ struct pack_compressed {
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL pack_compressed_list_data_kernel(
__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
write_list<PqBits, 1>(
list_data, offset_or_indices, n_rows, pq_dim, pack_compressed<PqBits>(codes, pq_dim));
list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous<PqBits>(codes, pq_dim));
}

/**
Expand All @@ -1047,7 +1047,7 @@ __launch_bounds__(BlockSize) RAFT_KERNEL pack_compressed_list_data_kernel(
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void pack_compressed_list_data(
inline void pack_contiguous_list_data(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
Expand All @@ -1063,11 +1063,11 @@ inline void pack_compressed_list_data(
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return pack_compressed_list_data_kernel<kBlockSize, 4>;
case 5: return pack_compressed_list_data_kernel<kBlockSize, 5>;
case 6: return pack_compressed_list_data_kernel<kBlockSize, 6>;
case 7: return pack_compressed_list_data_kernel<kBlockSize, 7>;
case 8: return pack_compressed_list_data_kernel<kBlockSize, 8>;
case 4: return pack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return pack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return pack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return pack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return pack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
Expand All @@ -1076,14 +1076,14 @@ inline void pack_compressed_list_data(
}

template <typename IdxT>
void pack_compressed_list_data(raft::resources const& res,
void pack_contiguous_list_data(raft::resources const& res,
index<IdxT>* index,
const uint8_t* new_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
pack_compressed_list_data(index->lists()[label]->data.view(),
pack_contiguous_list_data(index->lists()[label]->data.view(),
new_codes,
n_rows,
index->pq_dim(),
Expand Down
35 changes: 18 additions & 17 deletions cpp/include/raft/neighbors/ivf_pq_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ inline void unpack(

/**
* @brief Unpack `n_rows` consecutive records of a single list (cluster) in the compressed index
* starting at given `offset`. The output codes are not expanded to one code per byte, which means
* the output has ceildiv(pq_dim * pq_bits, 8) bytes per pq encoded vector.
* starting at given `offset`. The output codes of a single vector are contiguous, not expanded to
* one code per byte, which means the output has ceildiv(pq_dim * pq_bits, 8) bytes per PQ encoded
* vector.
*
* Usage example:
* @code{.cpp}
Expand All @@ -90,7 +91,7 @@ inline void unpack(
* res, n_rows, raft::ceildiv(index.pq_dim() * index.pq_bits(), 8));
* uint32_t offset = 0;
* // unpack n_rows elements from the list
* ivf_pq::helpers::codepacker::unpack_compressed(
* ivf_pq::helpers::codepacker::unpack_contiguous(
* res, list_data, index.pq_bits(), offset, n_rows, index.pq_dim(), codes.data_handle());
* @endcode
*
Expand All @@ -108,7 +109,7 @@ inline void unpack(
* The length `n_rows` defines how many records to unpack,
* it must be smaller than the list size.
*/
inline void unpack_compressed(
inline void unpack_contiguous(
raft::resources const& res,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
uint32_t pq_bits,
Expand All @@ -117,7 +118,7 @@ inline void unpack_compressed(
uint32_t pq_dim,
uint8_t* codes)
{
ivf_pq::detail::unpack_compressed_list_data(
ivf_pq::detail::unpack_contiguous_list_data(
codes, list_data, n_rows, pq_dim, offset, pq_bits, resource::get_cuda_stream(res));
}

Expand Down Expand Up @@ -154,8 +155,8 @@ inline void pack(
}

/**
* Write flat PQ codes into an existing list by the given offset. The input codes are compressed
* (not expanded to one code per byte).
* Write flat PQ codes into an existing list by the given offset. The input codes of a single vector
* are contiguous (not expanded to one code per byte).
*
* NB: no memory allocation happens here; the list must fit the data (offset + n_rows records).
*
Expand All @@ -169,7 +170,7 @@ inline void pack(
* ... prepare compressed vectors to pack into the list in codes ...
* // write codes into the list starting from the 42nd position. If the current size of the list
* // is greater than 42, this will overwrite the codes starting at this offset.
* ivf_pq::helpers::codepacker::pack_compressed(
* ivf_pq::helpers::codepacker::pack_contiguous(
* res, codes.data_handle(), n_rows, index.pq_dim(), index.pq_bits(), 42, list_data);
* @endcode
*
Expand All @@ -181,7 +182,7 @@ inline void pack(
* @param[in] offset how many records to skip before writing the data into the list
* @param[in] list_data block to write into
*/
inline void pack_compressed(
inline void pack_contiguous(
raft::resources const& res,
const uint8_t* codes,
uint32_t n_rows,
Expand All @@ -190,7 +191,7 @@ inline void pack_compressed(
uint32_t offset,
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data)
{
ivf_pq::detail::pack_compressed_list_data(
ivf_pq::detail::pack_contiguous_list_data(
list_data, codes, n_rows, pq_dim, offset, pq_bits, resource::get_cuda_stream(res));
}
} // namespace codepacker
Expand Down Expand Up @@ -230,7 +231,7 @@ void pack_list_data(raft::resources const& res,
}

/**
* Write flat compressed PQ codes into an existing list by the given offset. Use this when the input
* Write flat PQ codes into an existing list by the given offset. Use this when the input
* vectors are PQ encoded and not expanded to one code per byte.
*
* The list is identified by its label.
Expand All @@ -255,28 +256,28 @@ void pack_list_data(raft::resources const& res,
* // the first n_rows codes in the fourth IVF list are to be overwritten.
* uint32_t label = 3;
* // write codes into the list starting from the 0th position
* ivf_pq::helpers::pack_compressed_list_data(
* ivf_pq::helpers::pack_contiguous_list_data(
* res, &index, codes.data_handle(), n_rows, label, 0);
* @endcode
*
* @tparam IdxT
*
* @param[in] res raft resource
* @param[inout] index pointer to IVF-PQ index
* @param[in] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] codes flat contiguous PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] n_rows how many records to pack
* @param[in] label The id of the list (cluster) into which we write.
* @param[in] offset how many records to skip before writing the data into the list
*/
template <typename IdxT>
void pack_compressed_list_data(raft::resources const& res,
void pack_contiguous_list_data(raft::resources const& res,
index<IdxT>* index,
uint8_t* codes,
uint32_t n_rows,
uint32_t label,
uint32_t offset)
{
ivf_pq::detail::pack_compressed_list_data(res, index, codes, n_rows, label, offset);
ivf_pq::detail::pack_contiguous_list_data(res, index, codes, n_rows, label, offset);
}

/**
Expand Down Expand Up @@ -398,14 +399,14 @@ void unpack_list_data(raft::resources const& res,
* How many records in the list to skip.
*/
template <typename IdxT>
void unpack_compressed_list_data(raft::resources const& res,
void unpack_contiguous_list_data(raft::resources const& res,
const index<IdxT>& index,
uint8_t* out_codes,
uint32_t n_rows,
uint32_t label,
uint32_t offset)
{
return ivf_pq::detail::unpack_compressed_list_data<IdxT>(
return ivf_pq::detail::unpack_contiguous_list_data<IdxT>(
res, index, out_codes, n_rows, label, offset);
}

Expand Down
10 changes: 5 additions & 5 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,11 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {

auto codes_compressed = make_device_matrix<uint8_t>(handle_, n_rows, code_size);

ivf_pq::helpers::unpack_compressed_list_data(
ivf_pq::helpers::unpack_contiguous_list_data(
handle_, *index, codes_compressed.data_handle(), n_rows, label, 0);
ivf_pq::helpers::erase_list(handle_, index, label);
ivf_pq::detail::extend_list_prepare(handle_, index, make_const_mdspan(indices.view()), label);
ivf_pq::helpers::pack_compressed_list_data<IdxT>(
ivf_pq::helpers::pack_contiguous_list_data<IdxT>(
handle_, index, codes_compressed.data_handle(), n_rows, label, 0);
ivf_pq::helpers::recompute_internal_state(handle_, index);

Expand All @@ -424,7 +424,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
size_t offset = row_offset * code_size;
auto codes_to_pack = make_device_matrix_view<uint8_t, uint32_t>(
codes_compressed.data_handle() + offset, n_vec, index->pq_dim());
ivf_pq::helpers::pack_compressed_list_data(
ivf_pq::helpers::pack_contiguous_list_data(
handle_, index, codes_to_pack.data_handle(), n_vec, label, row_offset);
ASSERT_TRUE(devArrMatch(old_list->data.data_handle(),
new_list->data.data_handle(),
Expand All @@ -436,7 +436,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
uint32_t n_take = 4;
ASSERT_TRUE(row_offset + n_take < n_rows);
auto codes2 = raft::make_device_matrix<uint8_t>(handle_, n_take, code_size);
ivf_pq::helpers::codepacker::unpack_compressed(handle_,
ivf_pq::helpers::codepacker::unpack_contiguous(handle_,
list_data,
index->pq_bits(),
row_offset,
Expand All @@ -445,7 +445,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
codes2.data_handle());

// Write it back
ivf_pq::helpers::codepacker::pack_compressed(handle_,
ivf_pq::helpers::codepacker::pack_contiguous(handle_,
codes2.data_handle(),
n_vec,
index->pq_dim(),
Expand Down

0 comments on commit 1efd28f

Please sign in to comment.