Skip to content

Commit

Permalink
Add dataset compression as an optional step during build
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 7, 2024
1 parent 1a72020 commit 99fa02f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 16 deletions.
16 changes: 12 additions & 4 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ struct index : ann::index {
mdspan<const IdxT, matrix_extent<int64_t>, row_major, graph_accessor> knn_graph)
: ann::index(),
metric_(metric),
dataset_(construct_aligned_dataset(res, dataset, uint32_t{16})),
dataset_(upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16))),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
Expand All @@ -275,14 +275,14 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
construct_aligned_dataset(res, dataset, 16).swap(dataset_);
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, layout_stride> dataset)
{
construct_aligned_dataset(res, dataset, 16).swap(dataset_);
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
}

/**
Expand All @@ -293,7 +293,15 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
construct_aligned_dataset(res, dataset, 16).swap(dataset_);
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
}

/** Replace the dataset with a new dataset. */
template <typename DatasetT>
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
upcast_dataset_ptr(std::make_unique<DatasetT>(std::move(dataset))).swap(dataset_);
}

/**
Expand Down
42 changes: 30 additions & 12 deletions cpp/include/raft/neighbors/dataset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,32 @@ namespace raft::neighbors {
/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */
template <typename IdxT>
struct dataset {
using index_type = IdxT;
/** Size of the dataset. */
[[nodiscard]] virtual auto n_rows() const noexcept -> IdxT;
[[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0;
/** Dimensionality of the dataset. */
[[nodiscard]] virtual auto dim() const noexcept -> uint32_t;
[[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0;
/** Whether the object owns the data. */
[[nodiscard]] virtual auto is_owning() const noexcept -> bool;
virtual ~dataset() noexcept = default;
[[nodiscard]] virtual auto is_owning() const noexcept -> bool = 0;
virtual ~dataset() noexcept = default;
};

template <typename IdxT>
struct empty_dataset : public dataset<IdxT> {
using index_type = IdxT;
uint32_t suggested_dim;
explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(0) {}
[[nodiscard]] auto n_rows() const noexcept -> IdxT final { return 0; }
[[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; }
[[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; }
[[nodiscard]] auto is_owning() const noexcept -> bool final { return true; }
};

template <typename DataT, typename IdxT>
struct strided_dataset : public dataset<IdxT> {
using view_type = device_matrix_view<const DataT, IdxT, layout_stride>;
[[nodiscard]] auto n_rows() const noexcept -> IdxT final { return view().extent(0); }
using index_type = IdxT;
using value_type = DataT;
using view_type = device_matrix_view<const value_type, index_type, layout_stride>;
[[nodiscard]] auto n_rows() const noexcept -> index_type final { return view().extent(0); }
[[nodiscard]] auto dim() const noexcept -> uint32_t final
{
return static_cast<uint32_t>(view().extent(1));
Expand All @@ -71,7 +75,9 @@ struct strided_dataset : public dataset<IdxT> {

template <typename DataT, typename IdxT>
struct non_owning_dataset : public strided_dataset<DataT, IdxT> {
using typename strided_dataset<DataT, IdxT>::view_type;
using index_type = IdxT;
using value_type = DataT;
using typename strided_dataset<value_type, index_type>::view_type;
view_type data;
explicit non_owning_dataset(view_type data) noexcept : data(data) {}
[[nodiscard]] auto is_owning() const noexcept -> bool final { return false; }
Expand All @@ -80,8 +86,11 @@ struct non_owning_dataset : public strided_dataset<DataT, IdxT> {

template <typename DataT, typename IdxT, typename LayoutPolicy, typename ContainerPolicy>
struct owning_dataset : public strided_dataset<DataT, IdxT> {
using typename strided_dataset<DataT, IdxT>::view_type;
using storage_type = mdarray<DataT, matrix_extent<IdxT>, LayoutPolicy, ContainerPolicy>;
using index_type = IdxT;
using value_type = DataT;
using typename strided_dataset<value_type, index_type>::view_type;
using storage_type =
mdarray<value_type, matrix_extent<index_type>, LayoutPolicy, ContainerPolicy>;
using mapping_type = typename view_type::mapping_type;
storage_type data;
mapping_type view_mapping;
Expand Down Expand Up @@ -153,17 +162,26 @@ auto construct_strided_dataset(const raft::resources& res,

template <typename SrcT>
auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes)
-> std::unique_ptr<dataset<typename SrcT::index_type>>
-> std::unique_ptr<strided_dataset<typename SrcT::value_type, typename SrcT::index_type>>
{
using value_type = typename SrcT::value_type;
using index_type = typename SrcT::index_type;
using out_type = dataset<index_type>;
using out_type = strided_dataset<value_type, index_type>;
constexpr size_t kSize = sizeof(value_type);
uint32_t required_stride =
raft::round_up_safe<size_t>(src.extent(1) * kSize, align_bytes) / kSize;
return std::unique_ptr<out_type>{construct_strided_dataset(res, src, required_stride).release()};
}

template <typename DatasetT>
auto upcast_dataset_ptr(std::unique_ptr<DatasetT>&& src)
-> std::unique_ptr<dataset<typename DatasetT::index_type>>
{
using out_type = dataset<typename DatasetT::index_type>;
static_assert(std::is_base_of_v<out_type, DatasetT>, "The source must be a child of `dataset`");
return std::unique_ptr<out_type>{src.release()};
}

/** Parameters for VPQ compression. */
struct vpq_params {
/**
Expand Down
11 changes: 11 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "../../cagra_types.hpp"
#include "../../vpq_dataset.cuh"
#include "graph_core.cuh"

#include <raft/core/device_mdarray.hpp>
Expand Down Expand Up @@ -344,6 +345,16 @@ index<T, IdxT> build(
RAFT_LOG_INFO("Graph optimized, creating index");
// Construct an index from dataset and optimized knn graph.
if (construct_index_with_dataset) {
if (params.compression.has_value()) {
index<T, IdxT> idx(res, params.metric);
idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view()));
idx.update_dataset(
res,
// TODO: ATM, only float math type is supported in kmeans training.
// Later, we can do runtime dispatching of the math type.
neighbors::vpq_build<decltype(dataset), float, int64_t>(res, *params.compression, dataset));
return idx;
}
return index<T, IdxT>(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view()));
} else {
// We just add the graph. User is expected to update dataset separately. This branch is used
Expand Down

0 comments on commit 99fa02f

Please sign in to comment.