From 0b88ca4c7fc5671a4614113740f8d4cad97ba25a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 9 Nov 2023 11:41:06 -0800 Subject: [PATCH] Update docs --- .../raft/neighbors/detail/div_utils.hpp | 2 +- cpp/include/raft/neighbors/ivf_pq_helpers.cuh | 85 ++++++++++++++++--- 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/div_utils.hpp b/cpp/include/raft/neighbors/detail/div_utils.hpp index 023ac1020f..0455d0ec9b 100644 --- a/cpp/include/raft/neighbors/detail/div_utils.hpp +++ b/cpp/include/raft/neighbors/detail/div_utils.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh index 4495c71b8a..fd2d61a0d0 100644 --- a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh @@ -576,7 +576,15 @@ void erase_list(raft::resources const& res, index* index, uint32_t label) } /** - * @brief Public helper API for computing the index's rotation matrix. + * @brief Public helper API exposing the computation of the index's rotation matrix. + * NB: This is to be used only when the rotation matrix is not already computed through + * raft::neighbors::ivf_pq::build. + * + * Usage example: + * @code{.cpp} + * // compute the rotation matrix with random_rotation + * raft::neighbors::ivf_pq::helpers::make_rotation_matrix(res, &index, true); + * @endcode * * @tparam IdxT * @param[in] res @@ -598,6 +606,25 @@ void make_rotation_matrix(raft::resources const& res, /** * @brief Public helper API for externally modifying the index's IVF centroids. + * NB: The index must be reset before this. Use raft::neighbors::ivf_pq::extend to construct IVF + lists according to new centroids. + * + * Usage example: + * @code{.cpp} + * // allocate the buffer for the input centers + * auto cluster_centers = raft::make_device_matrix(res, index.n_lists(), + index.dim()); + * ... prepare ivf centroids in cluster_centers ... + * // reset the index + * reset_index(res, index); + * // recompute the state of the index + * raft::neighbors::ivf_pq::helpers::recompute_internal_state(res, index); + * // Write the IVF centroids + * raft::neighbors::ivf_pq::helpers::set_centers( + res, + &index, + cluster_centers); + * @endcode * * @tparam IdxT * @param[in] res @@ -613,17 +640,25 @@ void set_centers(raft::resources const& res, "Number of rows in cluster centers and IVF centers are different"); RAFT_EXPECTS(cluster_centers.extent(1) == index->dim(), "Number of columns in cluster centers and index dim are different"); + RAFT_EXPECTS(index->size() == 0, "Index must be empty"); ivf_pq::detail::set_centers(res, index, cluster_centers.data_handle()); } +/** + * @brief Helper to fetch size of a particular IVF list in bytes using the list extents. + * + * Usage example: + * @code{.cpp} + * // Fetch the size of the fourth list + * raft::neighbors::ivf_pq::helpers::get_list_size_in_bytes(index, 3); + * @endcode + * + * @tparam IdxT + * @param[in] index + * @param[in] label list ID + */ template -void set_pq_centers(raft::resources const& res, index* index, float* pq_centers) -{ - ivf_pq::detail::transpose_pq_centers(res, *index, pq_centers); -} - -template -auto get_list_size_in_bytes(const index* index, uint32_t label) -> uint32_t +auto get_list_size_in_bytes(const index& index, uint32_t label) -> uint32_t { RAFT_EXPECTS(label < index->n_lists(), "Expected label to be less than number of lists in the index"); @@ -631,28 +666,52 @@ auto get_list_size_in_bytes(const index* index, uint32_t label) -> uint32_ return list_data.size(); } +/** + * @brief Helper exposing the re-computation of list sizes and related arrays if IVF lists have been + * modified. + * + * @tparam IdxT + * @param[in] res + * @param[inout] index + */ template void recompute_internal_state(const raft::resources& res, index* index) { ivf_pq::detail::recompute_internal_state(res, *index); } +/** + * @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for + * externally modifying the index without going through the build stage. + * + * @tparam IdxT + * @param[in] res + * @param[inout] index + */ template -void reset_index(const raft::resources& res, index& index) +void reset_index(const raft::resources& res, index* index) { auto stream = resource::get_cuda_stream(res); utils::memzero( - index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); - utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); - utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); + index->accum_sorted_sizes().data_handle(), index->accum_sorted_sizes().size(), stream); + utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); + utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); + utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); } /** * @brief Public helper API for fetching a trained index's IVF centroids into a buffer that may be * allocated on either host or device. * + * Usage example: + * @code{.cpp} + * // allocate the buffer for the output centers + * auto cluster_centers = raft::make_device_matrix(res, index.n_lists(), + * index.dim()); + * // Extract the IVF centroids into the buffer + * raft::neighbors::ivf_pq::helpers::extract_centers(res, index, cluster_centers.data_handle()); + * @endcode * @tparam IdxT * @param[in] res * @param[in] index