diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 7dc0c8b711..4ba9159556 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -52,12 +52,23 @@ namespace raft::neighbors::brute_force { * // create a brute_force knn index from the dataset * auto index = raft::neighbors::brute_force::build(res, * raft::make_const_mdspan(dataset.view())); + * * // search the index in batches of 128 nearest neighbors * auto search = raft::make_const_mdspan(dataset.view()); * auto query = make_batch_k_query(res, index, search, 128); * for (auto & batch: *query) { * // batch.indices() and batch.distances() contain the information on the current batch * } + * + * // we can also support variable sized batches - loaded up a different number + * // of neighbors at each iteration through the ::advance method + * int64_t batch_size = 128; + * query = make_batch_k_query(res, index, search, batch_size); + * for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { + * // batch.indices() and batch.distances() contain the information on the current batch + * + * batch_size += 16; // load up an extra 16 items in the next batch + * } * @endcode * * @tparam T data element type diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index cc095e664e..039599845e 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -164,6 +164,19 @@ struct index : ann::index { /** * @brief Interface for performing queries over values of k * + * This interface lets you iterate over batches of k from a brute_force::index. + * This lets you do things like retrieve the first 100 neighbors for a query, + * apply post processing to remove any unwanted items and then if needed get the + * next 100 closest neighbors for the query. + * + * This query interface exposes C++ iterators through the ::begin and ::end, and + * is compatible with range based for loops. + * + * Note that this class is an abstract class without any cuda dependencies, meaning + * that it doesn't require a cuda compiler to use - but also means it can't be directly + * instantiated. See the raft::neighbors::brute_force::make_batch_k_query + * function for usage examples. + * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset */ @@ -211,6 +224,16 @@ class batch_k_query { return previous; } + /** + * @brief Advance the iterator, using a custom size for the next batch + * + * Using operator++ means that we will load up the same batch_size for each + * batch. This method allows us to get around this restriction, and load up + * arbitrary batch sizes on each iteration. + * See raft::neighbors::brute_force::make_batch_k_query for a usage example. + * + * @param[in] next_batch_size: size of the next batch to load up + */ void advance(int64_t next_batch_size) { offset = std::min(offset + current.batch_size(), query->index_size);