Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Helpers and CodePacker for IVF-PQ #1826

Merged
merged 89 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
cc9cbd3
Unpack list data kernel
tarang-jain Jul 1, 2023
28484ef
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 1, 2023
e39ee56
update packing and unpacking functions
tarang-jain Jul 5, 2023
68bf927
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 5, 2023
78d6380
Update codepacker
tarang-jain Jul 14, 2023
49a8834
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 14, 2023
897338e
refactor codepacker (does not build)
tarang-jain Jul 17, 2023
c1d80f5
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 17, 2023
2a2ee51
Undo deletions
tarang-jain Jul 17, 2023
834dd2c
undo yaml changes
tarang-jain Jul 17, 2023
6013429
style
tarang-jain Jul 17, 2023
ab6345a
Update tests, correct make_list_extents
tarang-jain Jul 18, 2023
ed80d1a
More changes
tarang-jain Jul 19, 2023
cdff9e1
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 19, 2023
7412272
debugging
tarang-jain Jul 20, 2023
700ea82
Working build
tarang-jain Jul 21, 2023
27451c6
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 21, 2023
9d742ef
rename codepacking api
tarang-jain Jul 21, 2023
d1ef8a1
Updated gtest
tarang-jain Jul 27, 2023
e187147
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 27, 2023
4f233a6
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 27, 2023
4ee99e3
updates
tarang-jain Jul 27, 2023
22f4f80
update testing
tarang-jain Jul 28, 2023
9f4e22c
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 28, 2023
c95d1e0
updates
tarang-jain Jul 28, 2023
da78c66
Update testing, pow2
tarang-jain Jul 31, 2023
5cc6dc9
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 31, 2023
15db0c6
remove unneccessary changes
tarang-jain Jul 31, 2023
154dc6d
Delete log.txt
tarang-jain Jul 31, 2023
47d6421
updates
tarang-jain Jul 31, 2023
0f1d106
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Jul 31, 2023
e2e1308
ore cleanup
tarang-jain Jul 31, 2023
3f470c8
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 31, 2023
41a49b2
style
tarang-jain Jul 31, 2023
1d2a5b0
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Aug 9, 2023
8ce8115
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Aug 23, 2023
171215b
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Sep 11, 2023
135d973
Initial commit
tarang-jain Sep 13, 2023
d7a9b4e
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Sep 13, 2023
5738cca
im
tarang-jain Sep 21, 2023
62b39cf
host pq codepacker
tarang-jain Sep 22, 2023
8702b92
refactored codepacker
tarang-jain Sep 22, 2023
5b2a7e0
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Sep 22, 2023
4139c7e
updated CP
tarang-jain Sep 22, 2023
e846352
undo some diffs
tarang-jain Sep 22, 2023
2ab3da2
undo some diffs
tarang-jain Sep 22, 2023
eb493a7
undo some diffs
tarang-jain Sep 22, 2023
28b7125
update docs
tarang-jain Sep 22, 2023
4b3b3bb
Merge branch 'branch-23.10' into faiss-ivf
tarang-jain Sep 25, 2023
3da5265
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Oct 5, 2023
d546d89
initial efforts for compress/decompress codepacker
tarang-jain Oct 6, 2023
b6e3de9
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Oct 6, 2023
4b94c45
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Oct 11, 2023
ec11fd8
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Oct 12, 2023
8a41330
Update codepacker and helpers
tarang-jain Oct 17, 2023
86f1aa4
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Oct 17, 2023
0baee4a
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Oct 17, 2023
9d66a8f
more helpers and debugging
tarang-jain Oct 26, 2023
3be7afd
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Oct 26, 2023
fd01442
Update tests
tarang-jain Oct 26, 2023
1b4fd0e
action struct correction
tarang-jain Nov 2, 2023
7d760e9
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 2, 2023
aaff0bf
testing
tarang-jain Nov 3, 2023
c4bc220
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 3, 2023
6a5443a
remove unneeded funcs
tarang-jain Nov 3, 2023
bca8f40
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Nov 7, 2023
8edc7a1
Add helper for extracting cluster centers
tarang-jain Nov 7, 2023
93eebab
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 7, 2023
140701e
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Nov 7, 2023
0b88ca4
Update docs
tarang-jain Nov 9, 2023
d67fe8d
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 9, 2023
a68d7a7
Add test
tarang-jain Nov 9, 2023
41ac27f
correction
tarang-jain Nov 9, 2023
5073ea3
Update docs
tarang-jain Nov 16, 2023
889bbdd
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 16, 2023
3dbf3a7
more updates to docs
tarang-jain Nov 16, 2023
30bdee5
style
tarang-jain Nov 16, 2023
55fa0ef
more docs
tarang-jain Nov 16, 2023
8eb07f8
undo small docstring change
tarang-jain Nov 16, 2023
f8956d5
style
tarang-jain Nov 16, 2023
228e997
more doc updates
tarang-jain Nov 16, 2023
bdd75cf
small doc fix
tarang-jain Nov 16, 2023
6adcb98
resource docs
tarang-jain Nov 16, 2023
1893963
Update docs for ivf_flat::helpers::reset_index
tarang-jain Nov 16, 2023
91e17c2
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 16, 2023
a2d4575
update reset_index
tarang-jain Nov 16, 2023
1efd28f
change helpers name to contiguous
tarang-jain Nov 17, 2023
9841e6c
move get_list_size to index struct
tarang-jain Nov 17, 2023
3f8baaa
change test name
tarang-jain Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions cpp/include/raft/neighbors/detail/div_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef _RAFT_HAS_CUDA
#include <raft/util/pow2_utils.cuh>
#else
#include <raft/util/integer_utils.hpp>
#endif

/**
* @brief A simple wrapper for raft::Pow2 which uses Pow2 utils only when available and regular
* integer division otherwise. This is done to allow a common interface for division arithmetic for
* non CUDA headers.
*
* @tparam Value_ a compile-time value representable as a power-of-two.
*/
namespace raft::neighbors::detail {
template <auto Value_>
struct div_utils {
typedef decltype(Value_) Type;
static constexpr Type Value = Value_;

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto roundDown(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::roundDown(x);
#else
return raft::round_down_safe(x, Value_);
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto mod(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::mod(x);
#else
return x % Value_;
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto div(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::div(x);
#else
return x / Value_;
#endif
}
};
} // namespace raft::neighbors::detail
290 changes: 243 additions & 47 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,59 @@ auto calculate_offsets_and_indices(IdxT n_rows,
return max_cluster_size;
}

template <typename IdxT>
void set_centers(raft::resources const& handle, index<IdxT>* index, const float* cluster_centers)
{
auto stream = resource::get_cuda_stream(handle);
auto* device_memory = resource::get_workspace_resource(handle);

// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(),
sizeof(float) * index->dim_ext(),
cluster_centers,
sizeof(float) * index->dim(),
sizeof(float) * index->dim(),
index->n_lists(),
cudaMemcpyDefault,
stream));

rmm::device_uvector<float> center_norms(index->n_lists(), stream, device_memory);
raft::linalg::rowNorm(center_norms.data(),
cluster_centers,
index->dim(),
index->n_lists(),
raft::linalg::L2Norm,
true,
stream);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(),
sizeof(float) * index->dim_ext(),
center_norms.data(),
sizeof(float),
sizeof(float),
index->n_lists(),
cudaMemcpyDefault,
stream));

// Rotate cluster_centers
float alpha = 1.0;
float beta = 0.0;
linalg::gemm(handle,
true,
false,
index->rot_dim(),
index->n_lists(),
index->dim(),
&alpha,
index->rotation_matrix().data_handle(),
index->dim(),
cluster_centers,
index->dim(),
&beta,
index->centers_rot().data_handle(),
index->rot_dim(),
resource::get_cuda_stream(handle));
}

template <typename IdxT>
void transpose_pq_centers(const resources& handle,
index<IdxT>& index,
Expand Down Expand Up @@ -613,6 +666,100 @@ void unpack_list_data(raft::resources const& res,
resource::get_cuda_stream(res));
}

/**
* A consumer for the `run_on_vector` that just flattens PQ codes
* into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct unpack_contiguous {
uint8_t* codes;
uint32_t code_size;

/**
* Create a callable to be passed to `run_on_vector`.
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}

/** Write j-th component (code) of the i-th vector into the output array. */
__host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j)
{
bitfield_view_t<PqBits> code_view{codes + i * code_size};
code_view[j] = code;
}
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel(
uint8_t* out_codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> in_list_data,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
run_on_list<PqBits>(
in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous<PqBits>(out_codes, pq_dim));
}

/**
* Unpack flat PQ codes from an existing list by the given offset.
*
* @param[out] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] list_data the packed ivf::list data.
* @param[in] offset_or_indices how many records in the list to skip or the exact indices.
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void unpack_contiguous_list_data(
uint8_t* codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t pq_bits,
rmm::cuda_stream_view stream)
{
if (n_rows == 0) { return; }

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return unpack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return unpack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return unpack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return unpack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return unpack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
kernel<<<blocks, threads, 0, stream>>>(codes, list_data, n_rows, pq_dim, offset_or_indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

/** Unpack the list data; see the public interface for the api and usage. */
template <typename IdxT>
void unpack_contiguous_list_data(raft::resources const& res,
const index<IdxT>& index,
uint8_t* out_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
unpack_contiguous_list_data(out_codes,
index.lists()[label]->data.view(),
n_rows,
index.pq_dim(),
offset_or_indices,
index.pq_bits(),
resource::get_cuda_stream(res));
}

/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data.
*/
struct reconstruct_vectors {
Expand Down Expand Up @@ -850,6 +997,101 @@ void pack_list_data(raft::resources const& res,
resource::get_cuda_stream(res));
}

/**
* A producer for the `write_vector` reads tightly packed flat codes. That is,
* the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct pack_contiguous {
const uint8_t* codes;
uint32_t code_size;

/**
* Create a callable to be passed to `write_vector`.
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}

/** Read j-th component (code) of the i-th vector from the source. */
__host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t
{
bitfield_view_t<PqBits> code_view{const_cast<uint8_t*>(codes + i * code_size)};
return uint8_t(code_view[j]);
}
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
write_list<PqBits, 1>(
list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous<PqBits>(codes, pq_dim));
}

/**
* Write flat PQ codes into an existing list by the given offset.
*
* NB: no memory allocation happens here; the list must fit the data (offset + n_rows).
*
* @param[out] list_data the packed ivf::list data.
* @param[in] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] offset_or_indices how many records in the list to skip or the exact indices.
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void pack_contiguous_list_data(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t pq_bits,
rmm::cuda_stream_view stream)
{
if (n_rows == 0) { return; }

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return pack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return pack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return pack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return pack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return pack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
kernel<<<blocks, threads, 0, stream>>>(list_data, codes, n_rows, pq_dim, offset_or_indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename IdxT>
void pack_contiguous_list_data(raft::resources const& res,
index<IdxT>* index,
const uint8_t* new_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
pack_contiguous_list_data(index->lists()[label]->data.view(),
new_codes,
n_rows,
index->pq_dim(),
offset_or_indices,
index->pq_bits(),
resource::get_cuda_stream(res));
}

/**
*
* A producer for the `write_list` and `write_vector` that encodes level-1 input vector residuals
Expand Down Expand Up @@ -1634,60 +1876,14 @@ auto build(raft::resources const& handle,
labels_view,
utils::mapping<float>());

{
// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle(),
sizeof(float) * index.dim_ext(),
cluster_centers,
sizeof(float) * index.dim(),
sizeof(float) * index.dim(),
index.n_lists(),
cudaMemcpyDefault,
stream));

rmm::device_uvector<float> center_norms(index.n_lists(), stream, device_memory);
raft::linalg::rowNorm(center_norms.data(),
cluster_centers,
index.dim(),
index.n_lists(),
raft::linalg::L2Norm,
true,
stream);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(),
sizeof(float) * index.dim_ext(),
center_norms.data(),
sizeof(float),
sizeof(float),
index.n_lists(),
cudaMemcpyDefault,
stream));
}

// Make rotation matrix
make_rotation_matrix(handle,
params.force_random_rotation,
index.rot_dim(),
index.dim(),
index.rotation_matrix().data_handle());

// Rotate cluster_centers
float alpha = 1.0;
float beta = 0.0;
linalg::gemm(handle,
true,
false,
index.rot_dim(),
index.n_lists(),
index.dim(),
&alpha,
index.rotation_matrix().data_handle(),
index.dim(),
cluster_centers,
index.dim(),
&beta,
index.centers_rot().data_handle(),
index.rot_dim(),
stream);
set_centers(handle, &index, cluster_centers);

// Train PQ codebooks
switch (index.codebook_kind()) {
Expand Down
Loading
Loading