Skip to content

Commit

Permalink
Integrate vpq_dataset into cagra
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 7, 2024
1 parent aa70b61 commit 1a72020
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 180 deletions.
21 changes: 0 additions & 21 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#pragma once

#include "detail/cagra/cagra_build.cuh"
#include "detail/cagra/cagra_build_q.cuh"
#include "detail/cagra/cagra_search.cuh"
#include "detail/cagra/graph_core.cuh"

Expand Down Expand Up @@ -280,26 +279,6 @@ index<T, IdxT> build(raft::resources const& res,
return detail::build<T, IdxT, Accessor>(res, params, dataset);
}

/**
* @brief Compress a dataset for use in CAGRA-Q search in place of the original data.
*
* @tparam DatasetT a row-major mdspan or mdarray (device or host).
* @tparam MathT a type of the codebook elements and internal math ops.
* @tparam IdxT type of the indices in the source dataset
*
* @param[in] res
* @param[in] params VQ and PQ parameters for compressing the data
* @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim].
*/
template <typename DatasetT,
typename MathT = typename DatasetT::value_type,
typename IdxT = typename DatasetT::index_type>
auto compress(const raft::resources& res, const compression_params& params, const DatasetT& dataset)
-> compressed_dataset<MathT, IdxT>
{
return detail::compress<DatasetT, MathT, IdxT>(res, params, dataset);
}

/**
* @brief Search ANN using the constructed index.
*
Expand Down
161 changes: 17 additions & 144 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "ann_types.hpp"
#include "dataset.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
Expand All @@ -36,10 +37,6 @@
#include <string>
#include <type_traits>

#ifdef __cpp_lib_bitops
#include <bit>
#endif

namespace raft::neighbors::cagra {
/**
* @addtogroup cagra
Expand All @@ -66,43 +63,8 @@ struct index_params : ann::index_params {
graph_build_algo build_algo = graph_build_algo::IVF_PQ;
/** Number of Iterations to run if building with NN_DESCENT */
size_t nn_descent_niter = 20;
};

/** Parameters for CAGRA-Q compression. */
struct compression_params {
/**
* The bit length of the vector element after compression by PQ.
*
* Possible values: [4, 5, 6, 7, 8].
*
* Hint: the smaller the 'pq_bits', the smaller the index size and the better the search
* performance, but the lower the recall.
*/
uint32_t pq_bits = 8;
/**
* The dimensionality of the vector after compression by PQ.
* When zero, an optimal value is selected using a heuristic.
*
* TODO: at the moment `dim` must be a multiple `pq_dim`.
*/
uint32_t pq_dim = 0;
/**
* Vector Quantization (VQ) codebook size - number of "coarse cluster centers".
* When zero, an optimal value is selected using a heuristic.
*/
uint32_t vq_n_centers = 0;
/** The number of iterations searching for kmeans centers (both VQ & PQ phases). */
uint32_t kmeans_n_iters = 25;
/**
* The fraction of data to use during iterative kmeans building (VQ phase).
* When zero, an optimal value is selected using a heuristic.
*/
double vq_kmeans_trainset_fraction = 0;
/**
* The fraction of data to use during iterative kmeans building (PQ phase).
* When zero, an optimal value is selected using a heuristic.
*/
double pq_kmeans_trainset_fraction = 0;
/** Specify compression params if compression is desired. */
std::optional<vpq_params> compression = std::nullopt;
};

enum class search_algo {
Expand Down Expand Up @@ -187,14 +149,12 @@ struct index : ann::index {
/** Total length of the index (number of vectors). */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT
{
return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0);
auto data_rows = dataset_->n_rows();
return data_rows > 0 ? data_rows : graph_view_.extent(0);
}

/** Dimensionality of the data. */
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
{
return dataset_view_.extent(1);
}
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); }
/** Graph degree */
[[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t
{
Expand All @@ -205,7 +165,10 @@ struct index : ann::index {
[[nodiscard]] inline auto dataset() const noexcept
-> device_matrix_view<const T, int64_t, layout_stride>
{
return dataset_view_;
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
if (p != nullptr) { return p->view(); }
auto d = dataset_->dim();
return make_device_strided_matrix_view<const T, int64_t>(nullptr, 0, d, d);
}

/** neighborhood graph [size, graph-degree] */
Expand All @@ -227,7 +190,7 @@ struct index : ann::index {
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_(new neighbors::empty_dataset<int64_t>(0)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
}
Expand Down Expand Up @@ -293,12 +256,11 @@ struct index : ann::index {
mdspan<const IdxT, matrix_extent<int64_t>, row_major, graph_accessor> knn_graph)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_(construct_aligned_dataset(res, dataset, uint32_t{16})),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
update_dataset(res, dataset);
update_graph(res, knn_graph);
resource::sync_stream(res);
}
Expand All @@ -313,21 +275,14 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
if (dataset.extent(1) * sizeof(T) % 16 != 0) {
RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory");
copy_padded(res, dataset);
} else {
dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1));
}
construct_aligned_dataset(res, dataset, 16).swap(dataset_);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const&,
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, layout_stride> dataset)
{
RAFT_EXPECTS(dataset.stride(0) * sizeof(T) % 16 == 0, "Incorrect data padding.");
dataset_view_ = dataset;
construct_aligned_dataset(res, dataset, 16).swap(dataset_);
}

/**
Expand All @@ -338,8 +293,7 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device");
copy_padded(res, dataset);
construct_aligned_dataset(res, dataset, 16).swap(dataset_);
}

/**
Expand Down Expand Up @@ -376,91 +330,10 @@ struct index : ann::index {
}

private:
/** Create a device copy of the dataset, and pad it if necessary. */
template <typename data_accessor>
void copy_padded(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
detail::copy_with_padding(res, dataset_, dataset);

dataset_view_ = make_device_strided_matrix_view<const T, int64_t>(
dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1));
RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu",
static_cast<size_t>(dataset_view_.extent(0)),
static_cast<size_t>(dataset_view_.extent(1)),
static_cast<size_t>(dataset_view_.stride(0)));
}

raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix<IdxT, int64_t, row_major> graph_;
raft::device_matrix_view<const T, int64_t, layout_stride> dataset_view_;
raft::device_matrix_view<const IdxT, int64_t, row_major> graph_view_;
};

/**
* @brief CAGRA-Q compressed dataset.
*
* @tparam MathT the type of elements in the codebooks
* @tparam IdxT type of the vector indices (represent dataset.extent(0))
*
*/
template <typename MathT, typename IdxT>
struct compressed_dataset {
/** Vector Quantization codebook - "coarse cluster centers". */
device_matrix<MathT, uint32_t, row_major> vq_code_book;
/** Product Quantization codebook - "fine cluster centers". */
device_matrix<MathT, uint32_t, row_major> pq_code_book;
/** Compressed dataset. */
device_matrix<uint8_t, IdxT, row_major> dataset;

/** Total length of the index. */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset.extent(0); }
/** Row length of the encoded data in bytes. */
[[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t
{
return dataset.extent(1);
}
/** Dimensionality of the data. */
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
{
return vq_code_book.extent(1);
}
/** The number of "coarse cluster centers" */
[[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t
{
return vq_code_book.extent(0);
}
/** The bit length of an encoded vector element after compression by PQ. */
[[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t
{
auto pq_width = pq_n_centers();
#ifdef __cpp_lib_bitops
return std::countr_zero(pq_width);
#else
uint32_t pq_bits = 0;
while (pq_width > 1) {
pq_bits++;
pq_width >>= 1;
}
return pq_bits;
#endif
}
/** The dimensionality of an encoded vector after compression by PQ. */
[[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t
{
return raft::div_rounding_up_unsafe(dim(), pq_len());
}
/** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */
[[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t
{
return pq_code_book.extent(1);
}
/** The number of vectors in a PQ codebook (`1 << pq_bits`). */
[[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t
{
return pq_code_book.extent(0);
}
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
};

/** @} */
Expand Down
Loading

0 comments on commit 1a72020

Please sign in to comment.