diff --git a/cpp/test/neighbors/dynamic_batching.cuh b/cpp/test/neighbors/dynamic_batching.cuh index 23c11eee4..b64c5b01e 100644 --- a/cpp/test/neighbors/dynamic_batching.cuh +++ b/cpp/test/neighbors/dynamic_batching.cuh @@ -138,32 +138,38 @@ struct dynamic_batching_test : public ::testing::TestWithParam minibatch_sizes{1, 3, 7, 10}; + auto get_bs = [&minibatch_sizes](auto i) { + return minibatch_sizes[i % minibatch_sizes.size()]; + }; + int64_t i = 0; + for (int64_t offset = 0; offset < ps.n_queries; offset += get_bs(i++)) { + auto bs = std::min(get_bs(i), ps.n_queries - offset); + auto j = i % ps.max_concurrent_threads; // wait for previous job in the same slot to finish if (i >= ps.max_concurrent_threads) { futures[j].wait(); } // submit a new job - if (i < ps.n_queries) { - futures[j] = - std::async(std::launch::async, - [&res = resource_pool[j], - ¶ms = search_params_dynb, - index = index_dynb.value(), - query_view = raft::make_device_matrix_view( - queries->data_handle() + i * ps.dim, 1, ps.dim), - neighbors_view = raft::make_device_matrix_view( - neighbors_dynb->data_handle() + i * ps.k, 1, ps.k), - distances_view = raft::make_device_matrix_view( - distances_dynb->data_handle() + i * ps.k, 1, ps.k)]() { - dynamic_batching::search( - res, params, index, query_view, neighbors_view, distances_view); - }); - } else { - // finalize all resources - raft::resource::sync_stream(resource_pool[j]); - } + futures[j] = std::async( + std::launch::async, + [&res = resource_pool[j], + ¶ms = search_params_dynb, + index = index_dynb.value(), + query_view = raft::make_device_matrix_view( + queries->data_handle() + offset * ps.dim, bs, ps.dim), + neighbors_view = raft::make_device_matrix_view( + neighbors_dynb->data_handle() + offset * ps.k, bs, ps.k), + distances_view = raft::make_device_matrix_view( + distances_dynb->data_handle() + offset * ps.k, bs, ps.k)]() { + dynamic_batching::search(res, params, index, query_view, neighbors_view, distances_view); + }); } + // finalize all resources + for (int64_t j = 0; j < ps.max_concurrent_threads && j < i; j++) { + futures[j].wait(); + raft::resource::sync_stream(resource_pool[j]); + } raft::resource::sync_stream(res); }