Skip to content

Commit

Permalink
extract recompute_internal_state
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Feb 22, 2024
1 parent c42167c commit 5c49e9d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 123 deletions.
55 changes: 54 additions & 1 deletion cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once

#include <cub/block/block_scan.cuh>
#include <cub/cub.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/detail/select_warpsort.cuh> // matrix::detail::select::warpsort::warp_sort_distributed

Expand Down Expand Up @@ -268,4 +268,57 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
}
}

/** Update the state of the dependent index members. */
template <typename Index>
void recompute_internal_state(const raft::resources& res, Index& index)
{
auto stream = resource::get_cuda_stream(res);
auto tmp_res = resource::get_workspace_resource(res);
rmm::device_uvector<uint32_t> sorted_sizes(index.n_lists(), stream, tmp_res);

// Actualize the list pointers
auto data_ptrs = index.data_ptrs();
auto inds_ptrs = index.inds_ptrs();
for (uint32_t label = 0; label < index.n_lists(); label++) {
auto& list = index.lists()[label];
const auto data_ptr = list ? list->data.data_handle() : nullptr;
const auto inds_ptr = list ? list->indices.data_handle() : nullptr;
copy(&data_ptrs(label), &data_ptr, 1, stream);
copy(&inds_ptrs(label), &inds_ptr, 1, stream);
}

// Sort the cluster sizes in the descending order.
int begin_bit = 0;
int end_bit = sizeof(uint32_t) * 8;
size_t cub_workspace_size = 0;
cub::DeviceRadixSort::SortKeysDescending(nullptr,
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res);
cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(),
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
// copy the results to CPU
std::vector<uint32_t> sorted_sizes_host(index.n_lists());
copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream);
resource::sync_stream(res);

// accumulate the sorted cluster sizes
auto accum_sorted_sizes = index.accum_sorted_sizes();
accum_sorted_sizes(0) = 0;
for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) {
accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label];
}
}

} // namespace raft::neighbors::ivf::detail
65 changes: 4 additions & 61 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/linalg/add.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/ivf_flat_codepacker.hpp>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft/neighbors/ivf_list.hpp>
Expand All @@ -42,64 +43,6 @@ namespace raft::neighbors::ivf_flat::detail {

using namespace raft::spatial::knn::detail; // NOLINT

/**
* Update the state of the dependent index members.
*/
template <typename T, typename IdxT>
void recompute_internal_state(const raft::resources& res, index<T, IdxT>& index)
{
auto stream = resource::get_cuda_stream(res);
auto tmp_res = resource::get_workspace_resource(res);
rmm::device_uvector<uint32_t> sorted_sizes(index.n_lists(), stream, tmp_res);

// Actualize the list pointers
auto this_lists = index.lists();
auto this_data_ptrs = index.data_ptrs();
auto this_inds_ptrs = index.inds_ptrs();
for (uint32_t label = 0; label < this_lists.size(); label++) {
auto& list = this_lists[label];
const auto data_ptr = list ? list->data.data_handle() : nullptr;
const auto inds_ptr = list ? list->indices.data_handle() : nullptr;
copy(&this_data_ptrs(label), &data_ptr, 1, stream);
copy(&this_inds_ptrs(label), &inds_ptr, 1, stream);
}

// Sort the cluster sizes in the descending order.
int begin_bit = 0;
int end_bit = sizeof(uint32_t) * 8;
size_t cub_workspace_size = 0;
cub::DeviceRadixSort::SortKeysDescending(nullptr,
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res);
cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(),
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
// copy the results to CPU
std::vector<uint32_t> sorted_sizes_host(index.n_lists());
copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream);
resource::sync_stream(res);

// accumulate the sorted cluster sizes
auto accum_sorted_sizes = index.accum_sorted_sizes();
accum_sorted_sizes(0) = 0;
for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) {
accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label];
}

index.check_consistency();
}

template <typename T, typename IdxT>
auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T, IdxT>
{
Expand Down Expand Up @@ -133,7 +76,7 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,
target.lists() = source.lists();

// Make sure the device pointers point to the new lists
recompute_internal_state(res, target);
ivf::detail::recompute_internal_state(res, target);

return target;
}
Expand Down Expand Up @@ -320,7 +263,7 @@ void extend(raft::resources const& handle,
}
}
// Update the pointers and the sizes
recompute_internal_state(handle, *index);
ivf::detail::recompute_internal_state(handle, *index);
// Copy the old sizes, so we can start from the current state of the index;
// we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter.
raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream);
Expand Down Expand Up @@ -502,7 +445,7 @@ inline void fill_refinement_index(raft::resources const& handle,
ivf::resize_list(handle, lists[label], list_device_spec, n_candidates, uint32_t(0));
}
// Update the pointers and the sizes
recompute_internal_state(handle, *refinement_index);
ivf::detail::recompute_internal_state(handle, *refinement_index);

RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream));

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index<T, Id
}
resource::sync_stream(handle);

recompute_internal_state(handle, index_);
ivf::detail::recompute_internal_state(handle, index_);

return index_;
}
Expand Down
64 changes: 6 additions & 58 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/detail/ivf_pq_codepacking.cuh>
#include <raft/neighbors/ivf_list.hpp>
#include <raft/neighbors/ivf_pq_types.hpp>
Expand Down Expand Up @@ -1363,59 +1364,6 @@ void process_and_fill_codes(raft::resources const& handle,
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

/** Update the state of the dependent index members. */
template <typename IdxT>
void recompute_internal_state(const raft::resources& res, index<IdxT>& index)
{
auto stream = resource::get_cuda_stream(res);
auto tmp_res = resource::get_workspace_resource(res);
rmm::device_uvector<uint32_t> sorted_sizes(index.n_lists(), stream, tmp_res);

// Actualize the list pointers
auto data_ptrs = index.data_ptrs();
auto inds_ptrs = index.inds_ptrs();
for (uint32_t label = 0; label < index.n_lists(); label++) {
auto& list = index.lists()[label];
const auto data_ptr = list ? list->data.data_handle() : nullptr;
const auto inds_ptr = list ? list->indices.data_handle() : nullptr;
copy(&data_ptrs(label), &data_ptr, 1, stream);
copy(&inds_ptrs(label), &inds_ptr, 1, stream);
}

// Sort the cluster sizes in the descending order.
int begin_bit = 0;
int end_bit = sizeof(uint32_t) * 8;
size_t cub_workspace_size = 0;
cub::DeviceRadixSort::SortKeysDescending(nullptr,
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res);
cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(),
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
// copy the results to CPU
std::vector<uint32_t> sorted_sizes_host(index.n_lists());
copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream);
resource::sync_stream(res);

// accumulate the sorted cluster sizes
auto accum_sorted_sizes = index.accum_sorted_sizes();
accum_sorted_sizes(0) = 0;
for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) {
accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label];
}
}

/**
* Helper function: allocate enough space in the list, compute the offset, at which to start
* writing, and fill-in indices.
Expand Down Expand Up @@ -1463,7 +1411,7 @@ void extend_list_with_codes(raft::resources const& res,
// Pack the data
pack_list_data<IdxT>(res, index, new_codes, label, offset);
// Update the pointers and the sizes
recompute_internal_state(res, *index);
ivf::detail::recompute_internal_state(res, *index);
}

/**
Expand All @@ -1482,7 +1430,7 @@ void extend_list(raft::resources const& res,
// Encode the data
encode_list_data<T, IdxT>(res, index, new_vectors, label, offset);
// Update the pointers and the sizes
recompute_internal_state(res, *index);
ivf::detail::recompute_internal_state(res, *index);
}

/**
Expand All @@ -1495,7 +1443,7 @@ void erase_list(raft::resources const& res, index<IdxT>* index, uint32_t label)
uint32_t zero = 0;
copy(index->list_sizes().data_handle() + label, &zero, 1, resource::get_cuda_stream(res));
index->lists()[label].reset();
recompute_internal_state(res, *index);
ivf::detail::recompute_internal_state(res, *index);
}

/** Copy the state of an index into a new index, but share the list data among the two. */
Expand Down Expand Up @@ -1539,7 +1487,7 @@ auto clone(const raft::resources& res, const index<IdxT>& source) -> index<IdxT>
target.lists() = source.lists();

// Make sure the device pointers point to the new lists
recompute_internal_state(res, target);
ivf::detail::recompute_internal_state(res, target);

return target;
}
Expand Down Expand Up @@ -1688,7 +1636,7 @@ void extend(raft::resources const& handle,
}

// Update the pointers and the sizes
recompute_internal_state(handle, *index);
ivf::detail::recompute_internal_state(handle, *index);

// Recover old cluster sizes: they are used as counters in the fill-codes kernel
copy(list_sizes, orig_list_sizes.data(), n_clusters, stream);
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -160,7 +160,7 @@ auto deserialize(raft::resources const& handle_, std::istream& is) -> index<IdxT

resource::sync_stream(handle_);

recompute_internal_state(handle_, index);
ivf::detail::recompute_internal_state(handle_, index);

return index;
}
Expand Down

0 comments on commit 5c49e9d

Please sign in to comment.