diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 7eb5e21f53..34e79987ae 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -256,7 +256,7 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(construct_aligned_dataset(res, dataset, uint32_t{16})), + dataset_(upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16))), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), @@ -275,14 +275,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - construct_aligned_dataset(res, dataset, 16).swap(dataset_); + upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - construct_aligned_dataset(res, dataset, 16).swap(dataset_); + upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); } /** @@ -293,7 +293,15 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - construct_aligned_dataset(res, dataset, 16).swap(dataset_); + upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + } + + /** Replace the dataset with a new dataset. */ + template + auto update_dataset(raft::resources const& res, DatasetT&& dataset) + -> std::enable_if_t, DatasetT>> + { + upcast_dataset_ptr(std::make_unique(std::move(dataset))).swap(dataset_); } /** diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index cc655438bc..8586757679 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -34,28 +34,32 @@ namespace raft::neighbors { /** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */ template struct dataset { + using index_type = IdxT; /** Size of the dataset. */ - [[nodiscard]] virtual auto n_rows() const noexcept -> IdxT; + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; /** Dimensionality of the dataset. */ - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t; + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; /** Whether the object owns the data. */ - [[nodiscard]] virtual auto is_owning() const noexcept -> bool; - virtual ~dataset() noexcept = default; + [[nodiscard]] virtual auto is_owning() const noexcept -> bool = 0; + virtual ~dataset() noexcept = default; }; template struct empty_dataset : public dataset { + using index_type = IdxT; uint32_t suggested_dim; explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(0) {} - [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return 0; } + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; } [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } }; template struct strided_dataset : public dataset { - using view_type = device_matrix_view; - [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return view().extent(0); } + using index_type = IdxT; + using value_type = DataT; + using view_type = device_matrix_view; + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return view().extent(0); } [[nodiscard]] auto dim() const noexcept -> uint32_t final { return static_cast(view().extent(1)); @@ -71,7 +75,9 @@ struct strided_dataset : public dataset { template struct non_owning_dataset : public strided_dataset { - using typename strided_dataset::view_type; + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; view_type data; explicit non_owning_dataset(view_type data) noexcept : data(data) {} [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } @@ -80,8 +86,11 @@ struct non_owning_dataset : public strided_dataset { template struct owning_dataset : public strided_dataset { - using typename strided_dataset::view_type; - using storage_type = mdarray, LayoutPolicy, ContainerPolicy>; + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + using storage_type = + mdarray, LayoutPolicy, ContainerPolicy>; using mapping_type = typename view_type::mapping_type; storage_type data; mapping_type view_mapping; @@ -153,17 +162,26 @@ auto construct_strided_dataset(const raft::resources& res, template auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes) - -> std::unique_ptr> + -> std::unique_ptr> { using value_type = typename SrcT::value_type; using index_type = typename SrcT::index_type; - using out_type = dataset; + using out_type = strided_dataset; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = raft::round_up_safe(src.extent(1) * kSize, align_bytes) / kSize; return std::unique_ptr{construct_strided_dataset(res, src, required_stride).release()}; } +template +auto upcast_dataset_ptr(std::unique_ptr&& src) + -> std::unique_ptr> +{ + using out_type = dataset; + static_assert(std::is_base_of_v, "The source must be a child of `dataset`"); + return std::unique_ptr{src.release()}; +} + /** Parameters for VPQ compression. */ struct vpq_params { /** diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 08cc2beaeb..743188dae3 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -16,6 +16,7 @@ #pragma once #include "../../cagra_types.hpp" +#include "../../vpq_dataset.cuh" #include "graph_core.cuh" #include @@ -344,6 +345,16 @@ index build( RAFT_LOG_INFO("Graph optimized, creating index"); // Construct an index from dataset and optimized knn graph. if (construct_index_with_dataset) { + if (params.compression.has_value()) { + index idx(res, params.metric); + idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); + idx.update_dataset( + res, + // TODO: ATM, only float math type is supported in kmeans training. + // Later, we can do runtime dispatching of the math type. + neighbors::vpq_build(res, *params.compression, dataset)); + return idx; + } return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); } else { // We just add the graph. User is expected to update dataset separately. This branch is used