Skip to content

Commit

Permalink
handle variable sized batches
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Nov 1, 2023
1 parent ad27879 commit d4394c1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ index<T> build(raft::resources const& res,
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] query_norms Optional device_vector_view of precomputed query norms
*/
template <typename T, typename IdxT>
Expand Down
26 changes: 15 additions & 11 deletions cpp/include/raft/neighbors/brute_force_batch_k_query.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ class batch_k_query {
iterator(const batch_k_query<T, IdxT>* query, int64_t offset = 0)
: current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset)
{
load_batches();
slice_current_batch();
load_batches(query->batch_size);
slice_current_batch(offset, query->batch_size);
}

reference operator*() const { return current; }
Expand All @@ -155,9 +155,7 @@ class batch_k_query {

iterator& operator++()
{
offset = std::min(offset + query->batch_size, query->index.size());
if (offset + query->batch_size > current_batch_size) { load_batches(); }
slice_current_batch();
advance(query->batch_size);
return *this;
}

Expand All @@ -168,30 +166,36 @@ class batch_k_query {
return previous;
}

void advance(int64_t next_batch_size)
{
offset = std::min(offset + current.indices().extent(1), query->index.size());
if (offset + next_batch_size > current_batch_size) { load_batches(next_batch_size); }
slice_current_batch(offset, next_batch_size);
}

friend bool operator==(const iterator& lhs, const iterator& rhs)
{
return (lhs.query == rhs.query) && (lhs.offset == rhs.offset);
};
friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); };

protected:
void load_batches()
void load_batches(int64_t next_batch_size)
{
if (offset >= query->index.size()) { return; }

// we're aiming to load multiple batches here - since we don't know the max iteration
// grow the size we're loading exponentially
int64_t batch_size =
std::min(std::max(offset * 2, query->batch_size * 2), query->index.size());
batches = batch(query->res, query->query.extent(0), batch_size);
int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), query->index.size());
batches = batch(query->res, query->query.extent(0), batch_size);
query->load_batch(batches);
current_batch_size = batch_size;
}

void slice_current_batch()
void slice_current_batch(int64_t offset, int64_t batch_size)
{
auto num_queries = batches.indices_.extent(0);
auto batch_size = std::min(query->batch_size, query->index.size() - offset);
batch_size = std::min(batch_size, query->index.size() - offset);
current = batch(query->res, num_queries, batch_size);

if (!num_queries || !batch_size) { return; }
Expand Down
35 changes: 34 additions & 1 deletion cpp/test/neighbors/tiled_knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,47 @@ class TiledKNNTest : public ::testing::TestWithParam<TiledKNNInputs> {
distances.data_handle(),
batch.distances().data_handle(),
num_queries,
batch.indices().extent(1),
batch_size,
float(0.001),
stream_,
true));

offset += batch_size;
if (offset + batch_size > all_size) break;
}

// also test out with variable batch sizes
offset = 0;
int64_t batch_size = k_;
batch_k_query<T, int> query(handle_, idx, query_view, batch_size);
for (auto it = query.begin(); it != query.end(); it.advance(batch_size)) {
// batch_size could be less than requested (in the case of final batch). handle.
ASSERT_TRUE(it->indices().extent(1) <= batch_size);
batch_size = it->indices().extent(1);

auto indices = raft::make_device_matrix<int, int64_t>(handle_, num_queries, batch_size);
auto distances = raft::make_device_matrix<T, int64_t>(handle_, num_queries, batch_size);

matrix::slice_coordinates<int64_t> coords{0, offset, num_queries, offset + batch_size};
matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords);
matrix::slice(
handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords);

ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(),
it->indices().data_handle(),
distances.data_handle(),
it->distances().data_handle(),
num_queries,
batch_size,
float(0.001),
stream_,
true));

offset += batch_size;
if (offset + batch_size > all_size) break;

batch_size += 2;
}
}
}

Expand Down

0 comments on commit d4394c1

Please sign in to comment.