Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-merge branch-23.12 to branch-24.02 #2014

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions cpp/include/raft/neighbors/detail/div_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef _RAFT_HAS_CUDA
#include <raft/util/pow2_utils.cuh>
#else
#include <raft/util/integer_utils.hpp>
#endif

/**
* @brief A simple wrapper for raft::Pow2 which uses Pow2 utils only when available and regular
* integer division otherwise. This is done to allow a common interface for division arithmetic for
* non CUDA headers.
*
* @tparam Value_ a compile-time value representable as a power-of-two.
*/
namespace raft::neighbors::detail {
template <auto Value_>
struct div_utils {
typedef decltype(Value_) Type;
static constexpr Type Value = Value_;

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto roundDown(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::roundDown(x);
#else
return raft::round_down_safe(x, Value_);
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto mod(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::mod(x);
#else
return x % Value_;
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto div(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::div(x);
#else
return x / Value_;
#endif
}
};
} // namespace raft::neighbors::detail
290 changes: 243 additions & 47 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,59 @@ auto calculate_offsets_and_indices(IdxT n_rows,
return max_cluster_size;
}

template <typename IdxT>
void set_centers(raft::resources const& handle, index<IdxT>* index, const float* cluster_centers)
{
auto stream = resource::get_cuda_stream(handle);
auto* device_memory = resource::get_workspace_resource(handle);

// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(),
sizeof(float) * index->dim_ext(),
cluster_centers,
sizeof(float) * index->dim(),
sizeof(float) * index->dim(),
index->n_lists(),
cudaMemcpyDefault,
stream));

rmm::device_uvector<float> center_norms(index->n_lists(), stream, device_memory);
raft::linalg::rowNorm(center_norms.data(),
cluster_centers,
index->dim(),
index->n_lists(),
raft::linalg::L2Norm,
true,
stream);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(),
sizeof(float) * index->dim_ext(),
center_norms.data(),
sizeof(float),
sizeof(float),
index->n_lists(),
cudaMemcpyDefault,
stream));

// Rotate cluster_centers
float alpha = 1.0;
float beta = 0.0;
linalg::gemm(handle,
true,
false,
index->rot_dim(),
index->n_lists(),
index->dim(),
&alpha,
index->rotation_matrix().data_handle(),
index->dim(),
cluster_centers,
index->dim(),
&beta,
index->centers_rot().data_handle(),
index->rot_dim(),
resource::get_cuda_stream(handle));
}

template <typename IdxT>
void transpose_pq_centers(const resources& handle,
index<IdxT>& index,
Expand Down Expand Up @@ -613,6 +666,100 @@ void unpack_list_data(raft::resources const& res,
resource::get_cuda_stream(res));
}

/**
* A consumer for the `run_on_vector` that just flattens PQ codes
* into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct unpack_contiguous {
uint8_t* codes;
uint32_t code_size;

/**
* Create a callable to be passed to `run_on_vector`.
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}

/** Write j-th component (code) of the i-th vector into the output array. */
__host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j)
{
bitfield_view_t<PqBits> code_view{codes + i * code_size};
code_view[j] = code;
}
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel(
uint8_t* out_codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> in_list_data,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
run_on_list<PqBits>(
in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous<PqBits>(out_codes, pq_dim));
}

/**
* Unpack flat PQ codes from an existing list by the given offset.
*
* @param[out] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] list_data the packed ivf::list data.
* @param[in] offset_or_indices how many records in the list to skip or the exact indices.
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void unpack_contiguous_list_data(
uint8_t* codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t pq_bits,
rmm::cuda_stream_view stream)
{
if (n_rows == 0) { return; }

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return unpack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return unpack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return unpack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return unpack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return unpack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
kernel<<<blocks, threads, 0, stream>>>(codes, list_data, n_rows, pq_dim, offset_or_indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

/** Unpack the list data; see the public interface for the api and usage. */
template <typename IdxT>
void unpack_contiguous_list_data(raft::resources const& res,
const index<IdxT>& index,
uint8_t* out_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
unpack_contiguous_list_data(out_codes,
index.lists()[label]->data.view(),
n_rows,
index.pq_dim(),
offset_or_indices,
index.pq_bits(),
resource::get_cuda_stream(res));
}

/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data.
*/
struct reconstruct_vectors {
Expand Down Expand Up @@ -850,6 +997,101 @@ void pack_list_data(raft::resources const& res,
resource::get_cuda_stream(res));
}

/**
* A producer for the `write_vector` reads tightly packed flat codes. That is,
* the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct pack_contiguous {
const uint8_t* codes;
uint32_t code_size;

/**
* Create a callable to be passed to `write_vector`.
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}

/** Read j-th component (code) of the i-th vector from the source. */
__host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t
{
bitfield_view_t<PqBits> code_view{const_cast<uint8_t*>(codes + i * code_size)};
return uint8_t(code_view[j]);
}
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
write_list<PqBits, 1>(
list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous<PqBits>(codes, pq_dim));
}

/**
* Write flat PQ codes into an existing list by the given offset.
*
* NB: no memory allocation happens here; the list must fit the data (offset + n_rows).
*
* @param[out] list_data the packed ivf::list data.
* @param[in] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] offset_or_indices how many records in the list to skip or the exact indices.
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void pack_contiguous_list_data(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t pq_bits,
rmm::cuda_stream_view stream)
{
if (n_rows == 0) { return; }

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return pack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return pack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return pack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return pack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return pack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
kernel<<<blocks, threads, 0, stream>>>(list_data, codes, n_rows, pq_dim, offset_or_indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename IdxT>
void pack_contiguous_list_data(raft::resources const& res,
index<IdxT>* index,
const uint8_t* new_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
pack_contiguous_list_data(index->lists()[label]->data.view(),
new_codes,
n_rows,
index->pq_dim(),
offset_or_indices,
index->pq_bits(),
resource::get_cuda_stream(res));
}

/**
*
* A producer for the `write_list` and `write_vector` that encodes level-1 input vector residuals
Expand Down Expand Up @@ -1634,60 +1876,14 @@ auto build(raft::resources const& handle,
labels_view,
utils::mapping<float>());

{
// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle(),
sizeof(float) * index.dim_ext(),
cluster_centers,
sizeof(float) * index.dim(),
sizeof(float) * index.dim(),
index.n_lists(),
cudaMemcpyDefault,
stream));

rmm::device_uvector<float> center_norms(index.n_lists(), stream, device_memory);
raft::linalg::rowNorm(center_norms.data(),
cluster_centers,
index.dim(),
index.n_lists(),
raft::linalg::L2Norm,
true,
stream);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(),
sizeof(float) * index.dim_ext(),
center_norms.data(),
sizeof(float),
sizeof(float),
index.n_lists(),
cudaMemcpyDefault,
stream));
}

// Make rotation matrix
make_rotation_matrix(handle,
params.force_random_rotation,
index.rot_dim(),
index.dim(),
index.rotation_matrix().data_handle());

// Rotate cluster_centers
float alpha = 1.0;
float beta = 0.0;
linalg::gemm(handle,
true,
false,
index.rot_dim(),
index.n_lists(),
index.dim(),
&alpha,
index.rotation_matrix().data_handle(),
index.dim(),
cluster_centers,
index.dim(),
&beta,
index.centers_rot().data_handle(),
index.rot_dim(),
stream);
set_centers(handle, &index, cluster_centers);

// Train PQ codebooks
switch (index.codebook_kind()) {
Expand Down
Loading
Loading