Skip to content

Commit

Permalink
Switch to half as the vpq codebook type
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 11, 2024
1 parent 4498a22 commit dd1cc99
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
5 changes: 2 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,8 @@ index<T, IdxT> build(
idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view()));
idx.update_dataset(
res,
// TODO: ATM, only float math type is supported in kmeans training.
// Later, we can do runtime dispatching of the math type.
neighbors::vpq_build<decltype(dataset), float, int64_t>(res, *params.compression, dataset));
// TODO: hardcoding codebook math to `half`, we can do runtime dispatching later
neighbors::vpq_build<decltype(dataset), half, int64_t>(res, *params.compression, dataset));
return idx;
}
return index<T, IdxT>(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view()));
Expand Down
19 changes: 19 additions & 0 deletions cpp/include/raft/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,25 @@ auto process_and_fill_codes(const raft::resources& res,
return codes;
}

template <typename NewMathT, typename OldMathT, typename IdxT>
auto vpq_convert_math_type(const raft::resources& res, vpq_dataset<OldMathT, IdxT>&& src)
-> vpq_dataset<NewMathT, IdxT>
{
auto vq_code_book = make_device_mdarray<NewMathT>(res, src.vq_code_book.extents());
auto pq_code_book = make_device_mdarray<NewMathT>(res, src.pq_code_book.extents());

linalg::map(res,
vq_code_book.view(),
spatial::knn::detail::utils::mapping<NewMathT>{},
raft::make_const_mdspan(src.vq_code_book.view()));
linalg::map(res,
pq_code_book.view(),
spatial::knn::detail::utils::mapping<NewMathT>{},
raft::make_const_mdspan(src.pq_code_book.view()));
return vpq_dataset<NewMathT, IdxT>{
std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)};
}

template <typename DatasetT, typename MathT, typename IdxT>
auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset)
-> vpq_dataset<MathT, IdxT>
Expand Down
7 changes: 6 additions & 1 deletion cpp/include/raft/neighbors/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ template <typename DatasetT,
auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset)
-> vpq_dataset<MathT, IdxT>
{
return detail::vpq_build<DatasetT, MathT, IdxT>(res, params, dataset);
if constexpr (std::is_same_v<MathT, half>) {
return detail::vpq_convert_math_type<half, float, IdxT>(
res, detail::vpq_build<DatasetT, float, IdxT>(res, params, dataset));
} else {
return detail::vpq_build<DatasetT, MathT, IdxT>(res, params, dataset);
}
}

} // namespace raft::neighbors

0 comments on commit dd1cc99

Please sign in to comment.