Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Nov 9, 2023
1 parent 140701e commit 0b88ca4
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/div_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
85 changes: 72 additions & 13 deletions cpp/include/raft/neighbors/ivf_pq_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,15 @@ void erase_list(raft::resources const& res, index<IdxT>* index, uint32_t label)
}

/**
* @brief Public helper API for computing the index's rotation matrix.
* @brief Public helper API exposing the computation of the index's rotation matrix.
* NB: This is to be used only when the rotation matrix is not already computed through
* raft::neighbors::ivf_pq::build.
*
* Usage example:
* @code{.cpp}
* // compute the rotation matrix with random_rotation
* raft::neighbors::ivf_pq::helpers::make_rotation_matrix(res, &index, true);
* @endcode
*
* @tparam IdxT
* @param[in] res
Expand All @@ -598,6 +606,25 @@ void make_rotation_matrix(raft::resources const& res,

/**
* @brief Public helper API for externally modifying the index's IVF centroids.
* NB: The index must be reset before this. Use raft::neighbors::ivf_pq::extend to construct IVF
lists according to new centroids.
*
* Usage example:
* @code{.cpp}
* // allocate the buffer for the input centers
* auto cluster_centers = raft::make_device_matrix<float, uint32_t>(res, index.n_lists(),
index.dim());
* ... prepare ivf centroids in cluster_centers ...
* // reset the index
* reset_index(res, index);
* // recompute the state of the index
* raft::neighbors::ivf_pq::helpers::recompute_internal_state(res, index);
* // Write the IVF centroids
* raft::neighbors::ivf_pq::helpers::set_centers(
res,
&index,
cluster_centers);
* @endcode
*
* @tparam IdxT
* @param[in] res
Expand All @@ -613,46 +640,78 @@ void set_centers(raft::resources const& res,
"Number of rows in cluster centers and IVF centers are different");
RAFT_EXPECTS(cluster_centers.extent(1) == index->dim(),
"Number of columns in cluster centers and index dim are different");
RAFT_EXPECTS(index->size() == 0, "Index must be empty");
ivf_pq::detail::set_centers(res, index, cluster_centers.data_handle());
}

/**
* @brief Helper to fetch size of a particular IVF list in bytes using the list extents.
*
* Usage example:
* @code{.cpp}
* // Fetch the size of the fourth list
* raft::neighbors::ivf_pq::helpers::get_list_size_in_bytes(index, 3);
* @endcode
*
* @tparam IdxT
* @param[in] index
* @param[in] label list ID
*/
template <typename IdxT>
void set_pq_centers(raft::resources const& res, index<IdxT>* index, float* pq_centers)
{
ivf_pq::detail::transpose_pq_centers(res, *index, pq_centers);
}

template <typename IdxT>
auto get_list_size_in_bytes(const index<IdxT>* index, uint32_t label) -> uint32_t
auto get_list_size_in_bytes(const index<IdxT>& index, uint32_t label) -> uint32_t
{
RAFT_EXPECTS(label < index->n_lists(),
"Expected label to be less than number of lists in the index");
auto list_data = index->lists()[label]->data;
return list_data.size();
}

/**
* @brief Helper exposing the re-computation of list sizes and related arrays if IVF lists have been
* modified.
*
* @tparam IdxT
* @param[in] res
* @param[inout] index
*/
template <typename IdxT>
void recompute_internal_state(const raft::resources& res, index<IdxT>* index)
{
ivf_pq::detail::recompute_internal_state(res, *index);
}

/**
* @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for
* externally modifying the index without going through the build stage.
*
* @tparam IdxT
* @param[in] res
* @param[inout] index
*/
template <typename IdxT>
void reset_index(const raft::resources& res, index<IdxT>& index)
void reset_index(const raft::resources& res, index<IdxT>* index)
{
auto stream = resource::get_cuda_stream(res);

utils::memzero(
index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream);
utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream);
utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream);
utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream);
index->accum_sorted_sizes().data_handle(), index->accum_sorted_sizes().size(), stream);
utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream);
utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream);
utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream);
}

/**
* @brief Public helper API for fetching a trained index's IVF centroids into a buffer that may be
* allocated on either host or device.
*
* Usage example:
* @code{.cpp}
* // allocate the buffer for the output centers
* auto cluster_centers = raft::make_device_matrix<float, uint32_t>(res, index.n_lists(),
* index.dim());
* // Extract the IVF centroids into the buffer
* raft::neighbors::ivf_pq::helpers::extract_centers(res, index, cluster_centers.data_handle());
* @endcode
* @tparam IdxT
* @param[in] res
* @param[in] index
Expand Down

0 comments on commit 0b88ca4

Please sign in to comment.