diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index eb44e58cb5..5919de07e7 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -90,7 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_CAGRA - OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB + OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -263,7 +263,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib LINKS raft::compiled - CXXFLAGS "${HNSW_CXX_FLAGS}" + CXXFLAGS + "${HNSW_CXX_FLAGS}" ) endif() diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 5b4048c1c3..a218c85a0a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -44,7 +44,7 @@ struct l2_exp_cutlass_op { __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index b8c00616da..4c1f7ea21e 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -50,7 +50,7 @@ void search(raft::resources const& res, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; + raft::device_matrix_view distances) RAFT_EXPLICIT; template & idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances) { - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), - "Number of columns in queries must match brute force index"); - - auto k = neighbors.extent(1); - auto d = idx.dataset().extent(1); - - std::vector dataset = {const_cast(idx.dataset().data_handle())}; - std::vector sizes = {idx.dataset().extent(0)}; - std::vector norms; - if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } - - detail::brute_force_knn_impl(res, - dataset, - sizes, - d, - const_cast(queries.data_handle()), - queries.extent(0), - neighbors.data_handle(), - distances.data_handle(), - k, - true, - true, - nullptr, - idx.metric(), - idx.metric_arg(), - raft::identity_op(), - norms.size() ? &norms : nullptr); + raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); } /** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 6cebf4b52a..4ba9159556 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -14,6 +14,7 @@ * limitations under the License. */ #pragma once +#include #ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "brute_force-inl.cuh" @@ -22,3 +23,70 @@ #ifdef RAFT_COMPILED #include "brute_force-ext.cuh" #endif + +#include + +namespace raft::neighbors::brute_force { +/** + * @brief Make a brute force query over batches of k + * + * This lets you query for batches of k. For example, you can get + * the first 100 neighbors, then the next 100 neighbors etc. + * + * Example usage: + * @code{.cpp} + * #include + * #include + * #include + + * // create a random dataset + * int n_rows = 10000; + * int n_cols = 10000; + + * raft::device_resources res; + * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); + * auto labels = raft::make_device_vector(res, n_rows); + + * raft::make_blobs(res, dataset.view(), labels.view()); + * + * // create a brute_force knn index from the dataset + * auto index = raft::neighbors::brute_force::build(res, + * raft::make_const_mdspan(dataset.view())); + * + * // search the index in batches of 128 nearest neighbors + * auto search = raft::make_const_mdspan(dataset.view()); + * auto query = make_batch_k_query(res, index, search, 128); + * for (auto & batch: *query) { + * // batch.indices() and batch.distances() contain the information on the current batch + * } + * + * // we can also support variable sized batches - loaded up a different number + * // of neighbors at each iteration through the ::advance method + * int64_t batch_size = 128; + * query = make_batch_k_query(res, index, search, batch_size); + * for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { + * // batch.indices() and batch.distances() contain the information on the current batch + * + * batch_size += 16; // load up an extra 16 items in the next batch + * } + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * @param[in] res + * @param[in] index The index to query + * @param[in] query A device matrix view to query for [n_queries, index->dim()] + * @param[in] batch_size The size of each batch + */ + +template +std::shared_ptr> make_batch_k_query( + const raft::resources& res, + const raft::neighbors::brute_force::index& index, + raft::device_matrix_view query, + int64_t batch_size) +{ + return std::shared_ptr>( + new detail::gpu_batch_k_query(res, index, query, batch_size)); +} +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index f7030503f1..039599845e 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -69,7 +70,7 @@ struct index : ann::index { return norms_view_.value(); } - /** Whether ot not this index has dataset norms */ + /** Whether or not this index has dataset norms */ [[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); } [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } @@ -160,6 +161,122 @@ struct index : ann::index { T metric_arg_; }; +/** + * @brief Interface for performing queries over values of k + * + * This interface lets you iterate over batches of k from a brute_force::index. + * This lets you do things like retrieve the first 100 neighbors for a query, + * apply post processing to remove any unwanted items and then if needed get the + * next 100 closest neighbors for the query. + * + * This query interface exposes C++ iterators through the ::begin and ::end, and + * is compatible with range based for loops. + * + * Note that this class is an abstract class without any cuda dependencies, meaning + * that it doesn't require a cuda compiler to use - but also means it can't be directly + * instantiated. See the raft::neighbors::brute_force::make_batch_k_query + * function for usage examples. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + */ +template +class batch_k_query { + public: + batch_k_query(const raft::resources& res, + int64_t index_size, + int64_t query_size, + int64_t batch_size) + : res(res), index_size(index_size), query_size(query_size), batch_size(batch_size) + { + } + virtual ~batch_k_query() {} + + using value_type = raft::neighbors::batch; + + class iterator { + public: + using value_type = raft::neighbors::batch; + using reference = const value_type&; + using pointer = const value_type*; + + iterator(const batch_k_query* query, int64_t offset = 0) + : current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset) + { + query->load_batch(offset, query->batch_size, &batches); + query->slice_batch(batches, offset, query->batch_size, ¤t); + } + + reference operator*() const { return current; } + + pointer operator->() const { return ¤t; } + + iterator& operator++() + { + advance(query->batch_size); + return *this; + } + + iterator operator++(int) + { + iterator previous(*this); + operator++(); + return previous; + } + + /** + * @brief Advance the iterator, using a custom size for the next batch + * + * Using operator++ means that we will load up the same batch_size for each + * batch. This method allows us to get around this restriction, and load up + * arbitrary batch sizes on each iteration. + * See raft::neighbors::brute_force::make_batch_k_query for a usage example. + * + * @param[in] next_batch_size: size of the next batch to load up + */ + void advance(int64_t next_batch_size) + { + offset = std::min(offset + current.batch_size(), query->index_size); + if (offset + next_batch_size > batches.batch_size()) { + query->load_batch(offset, next_batch_size, &batches); + } + query->slice_batch(batches, offset, next_batch_size, ¤t); + } + + friend bool operator==(const iterator& lhs, const iterator& rhs) + { + return (lhs.query == rhs.query) && (lhs.offset == rhs.offset); + }; + friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); }; + + protected: + // the current batch of data + value_type current; + + // the currently loaded group of data (containing multiple batches of data that we can iterate + // through) + value_type batches; + + const batch_k_query* query; + int64_t offset, current_batch_size; + }; + + iterator begin() const { return iterator(this); } + iterator end() const { return iterator(this, index_size); } + + protected: + // these two methods need cuda code, and are implemented in the subclass + virtual void load_batch(int64_t offset, + int64_t next_batch_size, + batch* output) const = 0; + virtual void slice_batch(const value_type& input, + int64_t offset, + int64_t batch_size, + value_type* output) const = 0; + + const raft::resources& res; + int64_t index_size, query_size, batch_size; +}; /** @} */ } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 5da4e77874..27ef00e385 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -65,11 +66,12 @@ void tiled_brute_force_knn(const raft::resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - DistanceEpilogue distance_epilogue = raft::identity_op(), - const ElementType* precomputed_index_norms = nullptr) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op(), + const ElementType* precomputed_index_norms = nullptr, + const ElementType* precomputed_search_norms = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -98,18 +100,20 @@ void tiled_brute_force_knn(const raft::resources& handle, if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::CosineExpanded) { - search_norms.resize(m, stream); + if (!precomputed_search_norms) { search_norms.resize(m, stream); } if (!precomputed_index_norms) { index_norms.resize(n, stream); } // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == raft::distance::DistanceType::CosineExpanded) { - raft::linalg::rowNorm(search_norms.data(), - search, - d, - m, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); + if (!precomputed_search_norms) { + raft::linalg::rowNorm(search_norms.data(), + search, + d, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } if (!precomputed_index_norms) { raft::linalg::rowNorm(index_norms.data(), index, @@ -121,9 +125,10 @@ void tiled_brute_force_knn(const raft::resources& handle, raft::sqrt_op{}); } } else { - raft::linalg::rowNorm( - search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - + if (!precomputed_search_norms) { + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + } if (!precomputed_index_norms) { raft::linalg::rowNorm( index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); @@ -184,7 +189,7 @@ void tiled_brute_force_knn(const raft::resources& handle, metric_arg); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = search_norms.data(); + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); bool sqrt = metric == raft::distance::DistanceType::L2SqrtExpanded; @@ -201,7 +206,7 @@ void tiled_brute_force_knn(const raft::resources& handle, return distance_epilogue(val, row, col); }); } else if (metric == raft::distance::DistanceType::CosineExpanded) { - auto row_norms = search_norms.data(); + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); @@ -333,7 +338,8 @@ void brute_force_knn_impl( raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metricArg = 0, DistanceEpilogue distance_epilogue = raft::identity_op(), - std::vector* input_norms = nullptr) + std::vector* input_norms = nullptr, + const value_t* search_norms = nullptr) { auto userStream = resource::get_cuda_stream(handle); @@ -376,7 +382,7 @@ void brute_force_knn_impl( } // currently we don't support col_major inside tiled_brute_force_knn, because - // of limitattions of the pairwise_distance API: + // of limitations of the pairwise_distance API: // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have // multiple options here (like rowMajorQuery/rowMajorIndex) // 2) because of tiling, we need to be able to set a custom stride in the PW @@ -428,7 +434,8 @@ void brute_force_knn_impl( rowMajorQuery, stream, metric, - input_norms ? (*input_norms)[i] : nullptr); + input_norms ? (*input_norms)[i] : nullptr, + search_norms); // Perform necessary post-processing if (metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -478,7 +485,8 @@ void brute_force_knn_impl( 0, 0, distance_epilogue, - input_norms ? (*input_norms)[i] : nullptr); + input_norms ? (*input_norms)[i] : nullptr, + search_norms); break; } } @@ -500,4 +508,43 @@ void brute_force_knn_impl( if (translations == nullptr) delete id_ranges; }; +template +void brute_force_search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) +{ + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + + auto k = neighbors.extent(1); + auto d = idx.dataset().extent(1); + + std::vector dataset = {const_cast(idx.dataset().data_handle())}; + std::vector sizes = {idx.dataset().extent(0)}; + std::vector norms; + if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } + + brute_force_knn_impl(res, + dataset, + sizes, + d, + const_cast(queries.data_handle()), + queries.extent(0), + neighbors.data_handle(), + distances.data_handle(), + k, + true, + true, + nullptr, + idx.metric(), + idx.metric_arg(), + raft::identity_op(), + norms.size() ? &norms : nullptr, + query_norms ? query_norms->data_handle() : nullptr); +} } // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh new file mode 100644 index 0000000000..384eacae79 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::neighbors::brute_force::detail { +template +class gpu_batch_k_query : public batch_k_query { + public: + gpu_batch_k_query(const raft::resources& res, + const raft::neighbors::brute_force::index& index, + raft::device_matrix_view query, + int64_t batch_size) + : batch_k_query(res, index.size(), query.extent(0), batch_size), + index(index), + query(query) + { + auto metric = index.metric(); + + // precompute query norms, and re-use across batches + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded) { + query_norms = make_device_vector(res, query.extent(0)); + + if (metric == raft::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(res, + query, + query_norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op{}); + } else { + raft::linalg::norm(res, + query, + query_norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + } + } + } + + protected: + void load_batch(int64_t offset, int64_t next_batch_size, batch* output) const override + { + if (offset >= index.size()) { return; } + + // we're aiming to load multiple batches here - since we don't know the max iteration + // grow the size we're loading exponentially + int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), this->index_size); + output->resize(this->res, this->query_size, batch_size); + + std::optional> query_norms_view; + if (query_norms) { query_norms_view = query_norms->view(); } + + raft::neighbors::detail::brute_force_search( + this->res, index, query, output->indices(), output->distances(), query_norms_view); + }; + + void slice_batch(const batch& input, + int64_t offset, + int64_t batch_size, + batch* output) const override + { + auto num_queries = input.indices().extent(0); + batch_size = std::min(batch_size, index.size() - offset); + + output->resize(this->res, num_queries, batch_size); + + if (!num_queries || !batch_size) { return; } + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + matrix::slice(this->res, input.indices(), output->indices(), coords); + matrix::slice(this->res, input.distances(), output->distances(), coords); + } + + const raft::neighbors::brute_force::index& index; + raft::device_matrix_view query; + std::optional> query_norms; +}; +} // namespace raft::neighbors::brute_force::detail diff --git a/cpp/include/raft/neighbors/neighbors_types.hpp b/cpp/include/raft/neighbors/neighbors_types.hpp new file mode 100644 index 0000000000..d503779741 --- /dev/null +++ b/cpp/include/raft/neighbors/neighbors_types.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::neighbors { + +/** A single batch of nearest neighbors in device memory */ +template +class batch { + public: + /** Create a new empty batch of data */ + batch(raft::resources const& res, int64_t rows, int64_t cols) + : indices_(make_device_matrix(res, rows, cols)), + distances_(make_device_matrix(res, rows, cols)) + { + } + + void resize(raft::resources const& res, int64_t rows, int64_t cols) + { + indices_ = make_device_matrix(res, rows, cols); + distances_ = make_device_matrix(res, rows, cols); + } + + /** Returns the indices for the batch */ + device_matrix_view indices() const + { + return raft::make_const_mdspan(indices_.view()); + } + device_matrix_view indices() { return indices_.view(); } + + /** Returns the distances for the batch */ + device_matrix_view distances() const + { + return raft::make_const_mdspan(distances_.view()); + } + device_matrix_view distances() { return distances_.view(); } + + /** Returns the size of the batch */ + int64_t batch_size() const { return indices().extent(1); } + + protected: + raft::device_matrix indices_; + raft::device_matrix distances_; +}; +} // namespace raft::neighbors diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index f2fda93a97..d4f902c087 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -36,4 +36,4 @@ template raft::neighbors::brute_force::index raft::neighbors::brute_force raft::resources const& res, raft::device_matrix_view dataset, raft::distance::DistanceType metric, - float metric_arg); \ No newline at end of file + float metric_arg); diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ebde8e6d35..a84c9749d7 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -38,6 +38,7 @@ #include namespace raft::neighbors::brute_force { + struct TiledKNNInputs { int num_queries; int num_db_vecs; @@ -190,11 +191,13 @@ class TiledKNNTest : public ::testing::TestWithParam { metric, metric_arg); + auto query_view = raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim); + raft::neighbors::brute_force::search( handle_, idx, - raft::make_device_matrix_view( - search_queries.data(), params_.num_queries, params_.dim), + query_view, raft::make_device_matrix_view( raft_indices_.data(), params_.num_queries, params_.k), raft::make_device_matrix_view( @@ -209,6 +212,73 @@ class TiledKNNTest : public ::testing::TestWithParam { float(0.001), stream_, true)); + // also test out the batch api. First get new reference results (all k, up to a certain + // max size) + auto all_size = std::min(params_.num_db_vecs, 1024); + auto all_indices = raft::make_device_matrix(handle_, num_queries, all_size); + auto all_distances = raft::make_device_matrix(handle_, num_queries, all_size); + raft::neighbors::brute_force::search( + handle_, idx, query_view, all_indices.view(), all_distances.view()); + + int64_t offset = 0; + auto query = make_batch_k_query(handle_, idx, query_view, k_); + for (auto batch : *query) { + auto batch_size = batch.batch_size(); + auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); + auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + + matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); + matrix::slice( + handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), + batch.indices().data_handle(), + distances.data_handle(), + batch.distances().data_handle(), + num_queries, + batch_size, + float(0.001), + stream_, + true)); + + offset += batch_size; + if (offset + batch_size > all_size) break; + } + + // also test out with variable batch sizes + offset = 0; + int64_t batch_size = k_; + query = make_batch_k_query(handle_, idx, query_view, batch_size); + for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { + // batch_size could be less than requested (in the case of final batch). handle. + ASSERT_TRUE(it->indices().extent(1) <= batch_size); + batch_size = it->indices().extent(1); + + auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); + auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); + matrix::slice( + handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), + it->indices().data_handle(), + distances.data_handle(), + it->distances().data_handle(), + num_queries, + batch_size, + float(0.001), + stream_, + true)); + + offset += batch_size; + if (offset + batch_size > all_size) break; + + batch_size += 2; + } } }