Skip to content

Commit

Permalink
IVF-flat neighbor ids forced to uint32 during compute
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 12, 2024
1 parent 09014e8 commit 1f6f6e3
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;
Expand All @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t* chunk_indices,
const uint32_t dim,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances)
{
extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[];
Expand Down Expand Up @@ -752,11 +752,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2

uint32_t sample_offset = 0;
if constexpr (!kManageLocalTopK) {
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);
}
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);

constexpr int kUnroll = WarpSize / Veclen;
constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize;
Expand Down Expand Up @@ -806,8 +804,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
// Enqueue one element per thread
const float val = valid ? static_cast<float>(dist) : local_topk_t::queue_t::kDummy;
if constexpr (kManageLocalTopK) {
const size_t idx = valid ? static_cast<size_t>(list_indices_ptrs[list_id][vec_id]) : 0;
queue.add(val, idx);
queue.add(val, sample_offset + vec_id);
} else {
if (vec_id < list_length) distances[sample_offset + vec_id] = val;
}
Expand Down Expand Up @@ -873,7 +870,7 @@ void launch_kernel(Lambda lambda,
const uint32_t max_samples,
const uint32_t* chunk_indices,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down Expand Up @@ -1161,7 +1158,7 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down
97 changes: 54 additions & 43 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,16 @@ void search_impl(raft::resources const& handle,
// Optional structures if postprocessing is required
// The topk distance value of candidate vectors from each cluster(list)
rmm::device_uvector<AccT> distances_tmp_dev(0, stream, search_mr);
// The topk index of candidate vectors from each cluster(list)
rmm::device_uvector<IdxT> indices_tmp_dev(0, stream, search_mr);
// Number of samples for each query
rmm::device_uvector<IdxT> num_samples(0, stream, search_mr);
rmm::device_uvector<uint32_t> num_samples(0, stream, search_mr);
// Offsets per probe for each query
rmm::device_uvector<uint32_t> chunk_index(0, stream, search_mr);

// The topk index of candidate vectors from each cluster(list), local index offset
// also we might need additional storage for select_k
rmm::device_uvector<uint32_t> indices_tmp_dev(0, stream, search_mr);
rmm::device_uvector<uint32_t> neighbors_uint32_buf(0, stream, search_mr);

size_t float_query_size;
if constexpr (std::is_integral_v<T>) {
float_query_size = n_queries * index.dim();
Expand Down Expand Up @@ -175,30 +178,39 @@ void search_impl(raft::resources const& handle,
grid_dim_x = 1;
}

num_samples.resize(n_queries, stream);
chunk_index.resize(n_queries_probes, stream);

ivf::detail::calc_chunk_indices<uint32_t>::configure(n_probes, n_queries)(
index.list_sizes().data_handle(),
coarse_indices_dev.data(),
chunk_index.data(),
num_samples.data(),
stream);

auto distances_dev_ptr = distances;
auto indices_dev_ptr = neighbors;

uint32_t* neighbors_uint32 = nullptr;
if constexpr (sizeof(IdxT) == sizeof(uint32_t)) {
neighbors_uint32 = reinterpret_cast<uint32_t*>(neighbors);
} else {
neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), stream);
neighbors_uint32 = neighbors_uint32_buf.data();
}

uint32_t* indices_dev_ptr = nullptr;

bool manage_local_topk = is_local_topk_feasible(k);
if (!manage_local_topk || grid_dim_x > 1) {
if (!manage_local_topk) {
num_samples.resize(n_queries, stream);
chunk_index.resize(n_queries_probes, stream);

ivf::detail::calc_chunk_indices<IdxT>::configure(n_probes, n_queries)(
index.list_sizes().data_handle(),
coarse_indices_dev.data(),
chunk_index.data(),
num_samples.data(),
stream);
}

auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples);

distances_tmp_dev.resize(target_size, stream);
if (manage_local_topk) indices_tmp_dev.resize(target_size, stream);

distances_dev_ptr = distances_tmp_dev.data();
indices_dev_ptr = indices_tmp_dev.data();
} else {
indices_dev_ptr = neighbors_uint32;
}

ivfflat_interleaved_scan<T, typename utils::config<T>::value_t, IdxT, IvfSampleFilterT>(
Expand All @@ -224,34 +236,33 @@ void search_impl(raft::resources const& handle,

// Merge topk values from different blocks
if (!manage_local_topk || grid_dim_x > 1) {
matrix::detail::select_k<AccT, IdxT>(handle,
distances_tmp_dev.data(),
indices_tmp_dev.data(),
n_queries,
manage_local_topk ? (k * grid_dim_x) : max_samples,
k,
distances,
neighbors,
select_min,
false,
matrix::SelectAlgo::kAuto,
num_samples.data());

if (!manage_local_topk) {
// post process distances && neighbor IDs
ivf::detail::postprocess_distances(
distances, distances, index.metric(), n_queries, k, 1.0, false, stream);
ivf::detail::postprocess_neighbors(neighbors,
neighbors,
index.inds_ptrs().data_handle(),
coarse_indices_dev.data(),
chunk_index.data(),
n_queries,
n_probes,
k,
stream);
}
matrix::detail::select_k<AccT, uint32_t>(handle,
distances_tmp_dev.data(),
indices_tmp_dev.data(),
n_queries,
manage_local_topk ? (k * grid_dim_x) : max_samples,
k,
distances,
neighbors_uint32,
select_min,
false,
matrix::SelectAlgo::kAuto,
num_samples.data());
}
if (!manage_local_topk) {
// post process distances && neighbor IDs
ivf::detail::postprocess_distances(
distances, distances, index.metric(), n_queries, k, 1.0, false, stream);
}
ivf::detail::postprocess_neighbors(neighbors,
neighbors_uint32,
index.inds_ptrs().data_handle(),
coarse_indices_dev.data(),
chunk_index.data(),
n_queries,
n_probes,
k,
stream);
}

/** See raft::neighbors::ivf_flat::search docs */
Expand Down
39 changes: 37 additions & 2 deletions cpp/include/raft/neighbors/detail/refine_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,30 @@ void refine_device(raft::resources const& handle,
n_queries,
n_candidates);
uint32_t grid_dim_x = 1;

// the neighbor ids will be computed in uint32_t as offset
rmm::device_uvector<uint32_t> neighbors_uint32_buf(0, resource::get_cuda_stream(handle));
// Number of samples for each query
rmm::device_uvector<uint32_t> num_samples(n_queries, resource::get_cuda_stream(handle));
// Offsets per probe for each query
rmm::device_uvector<uint32_t> chunk_index(n_queries, resource::get_cuda_stream(handle));

ivf::detail::calc_chunk_indices<uint32_t>::configure(1, n_queries)(
refinement_index.list_sizes().data_handle(),
fake_coarse_idx.data(),
chunk_index.data(),
num_samples.data(),
resource::get_cuda_stream(handle));

uint32_t* neighbors_uint32 = nullptr;
if constexpr (sizeof(idx_t) == sizeof(uint32_t)) {
neighbors_uint32 = reinterpret_cast<uint32_t*>(indices.data_handle());
} else {
neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k),
resource::get_cuda_stream(handle));
neighbors_uint32 = neighbors_uint32_buf.data();
}

raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<
data_t,
typename raft::spatial::knn::detail::utils::config<data_t>::value_t,
Expand All @@ -100,13 +124,24 @@ void refine_device(raft::resources const& handle,
1,
k,
0,
nullptr,
chunk_index.data(),
raft::distance::is_min_close(metric),
raft::neighbors::filtering::none_ivf_sample_filter(),
indices.data_handle(),
neighbors_uint32,
distances.data_handle(),
grid_dim_x,
resource::get_cuda_stream(handle));

// postprocessing -- neighbors from position to actual id
ivf::detail::postprocess_neighbors(indices.data_handle(),
neighbors_uint32,
refinement_index.inds_ptrs().data_handle(),
fake_coarse_idx.data(),
chunk_index.data(),
n_queries,
1,
k,
resource::get_cuda_stream(handle));
}

} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down

0 comments on commit 1f6f6e3

Please sign in to comment.