Skip to content

Commit

Permalink
Make the request batch sizes variable to improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Dec 3, 2024
1 parent 07ed312 commit 7a1bde3
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions cpp/test/neighbors/dynamic_batching.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,32 +138,38 @@ struct dynamic_batching_test : public ::testing::TestWithParam<dynamic_batching_
raft::resource::set_cuda_stream(resource_pool[i], worker_streams.get_stream(i));
}

for (int64_t i = 0; i < ps.n_queries + ps.max_concurrent_threads; i++) {
auto j = i % ps.max_concurrent_threads;
// Try multiple batch sizes in a round-robin to improve test coverage
std::vector<int64_t> minibatch_sizes{1, 3, 7, 10};
auto get_bs = [&minibatch_sizes](auto i) {
return minibatch_sizes[i % minibatch_sizes.size()];
};
int64_t i = 0;
for (int64_t offset = 0; offset < ps.n_queries; offset += get_bs(i++)) {
auto bs = std::min<int64_t>(get_bs(i), ps.n_queries - offset);
auto j = i % ps.max_concurrent_threads;
// wait for previous job in the same slot to finish
if (i >= ps.max_concurrent_threads) { futures[j].wait(); }
// submit a new job
if (i < ps.n_queries) {
futures[j] =
std::async(std::launch::async,
[&res = resource_pool[j],
&params = search_params_dynb,
index = index_dynb.value(),
query_view = raft::make_device_matrix_view<data_type, int64_t>(
queries->data_handle() + i * ps.dim, 1, ps.dim),
neighbors_view = raft::make_device_matrix_view<index_type, int64_t>(
neighbors_dynb->data_handle() + i * ps.k, 1, ps.k),
distances_view = raft::make_device_matrix_view<distance_type, int64_t>(
distances_dynb->data_handle() + i * ps.k, 1, ps.k)]() {
dynamic_batching::search(
res, params, index, query_view, neighbors_view, distances_view);
});
} else {
// finalize all resources
raft::resource::sync_stream(resource_pool[j]);
}
futures[j] = std::async(
std::launch::async,
[&res = resource_pool[j],
&params = search_params_dynb,
index = index_dynb.value(),
query_view = raft::make_device_matrix_view<data_type, int64_t>(
queries->data_handle() + offset * ps.dim, bs, ps.dim),
neighbors_view = raft::make_device_matrix_view<index_type, int64_t>(
neighbors_dynb->data_handle() + offset * ps.k, bs, ps.k),
distances_view = raft::make_device_matrix_view<distance_type, int64_t>(
distances_dynb->data_handle() + offset * ps.k, bs, ps.k)]() {
dynamic_batching::search(res, params, index, query_view, neighbors_view, distances_view);
});
}

// finalize all resources
for (int64_t j = 0; j < ps.max_concurrent_threads && j < i; j++) {
futures[j].wait();
raft::resource::sync_stream(resource_pool[j]);
}
raft::resource::sync_stream(res);
}

Expand Down

0 comments on commit 7a1bde3

Please sign in to comment.