diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index b0697dd6c7..3c15cf3959 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -339,6 +339,7 @@ index build(raft::resources const& res, * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] * @param[in] query_norms Optional device_vector_view of precomputed query norms */ template diff --git a/cpp/include/raft/neighbors/brute_force_batch_k_query.cuh b/cpp/include/raft/neighbors/brute_force_batch_k_query.cuh index 5ea299c7b4..738cab9e48 100644 --- a/cpp/include/raft/neighbors/brute_force_batch_k_query.cuh +++ b/cpp/include/raft/neighbors/brute_force_batch_k_query.cuh @@ -145,8 +145,8 @@ class batch_k_query { iterator(const batch_k_query* query, int64_t offset = 0) : current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset) { - load_batches(); - slice_current_batch(); + load_batches(query->batch_size); + slice_current_batch(offset, query->batch_size); } reference operator*() const { return current; } @@ -155,9 +155,7 @@ class batch_k_query { iterator& operator++() { - offset = std::min(offset + query->batch_size, query->index.size()); - if (offset + query->batch_size > current_batch_size) { load_batches(); } - slice_current_batch(); + advance(query->batch_size); return *this; } @@ -168,6 +166,13 @@ class batch_k_query { return previous; } + void advance(int64_t next_batch_size) + { + offset = std::min(offset + current.indices().extent(1), query->index.size()); + if (offset + next_batch_size > current_batch_size) { load_batches(next_batch_size); } + slice_current_batch(offset, next_batch_size); + } + friend bool operator==(const iterator& lhs, const iterator& rhs) { return (lhs.query == rhs.query) && (lhs.offset == rhs.offset); @@ -175,23 +180,22 @@ class batch_k_query { friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); }; protected: - void load_batches() + void load_batches(int64_t next_batch_size) { if (offset >= query->index.size()) { return; } // we're aiming to load multiple batches here - since we don't know the max iteration // grow the size we're loading exponentially - int64_t batch_size = - std::min(std::max(offset * 2, query->batch_size * 2), query->index.size()); - batches = batch(query->res, query->query.extent(0), batch_size); + int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), query->index.size()); + batches = batch(query->res, query->query.extent(0), batch_size); query->load_batch(batches); current_batch_size = batch_size; } - void slice_current_batch() + void slice_current_batch(int64_t offset, int64_t batch_size) { auto num_queries = batches.indices_.extent(0); - auto batch_size = std::min(query->batch_size, query->index.size() - offset); + batch_size = std::min(batch_size, query->index.size() - offset); current = batch(query->res, num_queries, batch_size); if (!num_queries || !batch_size) { return; } diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index c70d904229..d0d56cea3b 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -238,7 +238,7 @@ class TiledKNNTest : public ::testing::TestWithParam { distances.data_handle(), batch.distances().data_handle(), num_queries, - batch.indices().extent(1), + batch_size, float(0.001), stream_, true)); @@ -246,6 +246,39 @@ class TiledKNNTest : public ::testing::TestWithParam { offset += batch_size; if (offset + batch_size > all_size) break; } + + // also test out with variable batch sizes + offset = 0; + int64_t batch_size = k_; + batch_k_query query(handle_, idx, query_view, batch_size); + for (auto it = query.begin(); it != query.end(); it.advance(batch_size)) { + // batch_size could be less than requested (in the case of final batch). handle. + ASSERT_TRUE(it->indices().extent(1) <= batch_size); + batch_size = it->indices().extent(1); + + auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); + auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); + matrix::slice( + handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), + it->indices().data_handle(), + distances.data_handle(), + it->distances().data_handle(), + num_queries, + batch_size, + float(0.001), + stream_, + true)); + + offset += batch_size; + if (offset + batch_size > all_size) break; + + batch_size += 2; + } } }