Skip to content

Commit

Permalink
IVF-FLAT support k > 256 (#2169)
Browse files Browse the repository at this point in the history
Add support for topk > 256 for ivf_flat (Issue [#1555](#1555))

The PR adds a non-fused version of topk that is utilized if k > 256.

FYI, @tfeher

Authors:
  - Malte Förster (https://github.com/mfoerste4)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2169
  • Loading branch information
mfoerste4 authored Mar 4, 2024
1 parent ad27a73 commit 515ac62
Show file tree
Hide file tree
Showing 22 changed files with 694 additions and 462 deletions.
17 changes: 9 additions & 8 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/sample_filter_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

Expand Down Expand Up @@ -181,13 +181,14 @@ void search_main(raft::resources const& res,
// and divide the values by kDivisor. Here we restore the original scale.
constexpr float kScale = spatial::knn::detail::utils::config<T>::kDivisor /
spatial::knn::detail::utils::config<DistanceT>::kDivisor;
ivf_pq::detail::postprocess_distances(dist_out,
dist_in,
index.metric(),
distances.extent(0),
distances.extent(1),
kScale,
resource::get_cuda_stream(res));
ivf::detail::postprocess_distances(dist_out,
dist_in,
index.metric(),
distances.extent(0),
distances.extent(1),
kScale,
true,
resource::get_cuda_stream(res));
}
/** @} */ // end group cagra

Expand Down
325 changes: 325 additions & 0 deletions cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cub/cub.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/detail/select_warpsort.cuh> // matrix::detail::select::warpsort::warp_sort_distributed

namespace raft::neighbors::ivf::detail {

/**
* Default value returned by `search` when the `n_probes` is too small and top-k is too large.
* One may encounter it if the combined size of probed clusters is smaller than the requested
* number of results per query.
*/
template <typename IdxT>
constexpr static IdxT kOutOfBoundsRecord = std::numeric_limits<IdxT>::max();

template <typename T, typename IdxT, bool Ascending = true>
struct dummy_block_sort_t {
using queue_t =
matrix::detail::select::warpsort::warp_sort_distributed<WarpSize, Ascending, T, IdxT>;
template <typename... Args>
__device__ dummy_block_sort_t(int k, Args...){};
};

/**
* For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that
* in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total
* number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples.
*/
template <int BlockDim>
__launch_bounds__(BlockDim) RAFT_KERNEL
calc_chunk_indices_kernel(uint32_t n_probes,
const uint32_t* cluster_sizes, // [n_clusters]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t* n_samples // [n_queries]
)
{
using block_scan = cub::BlockScan<uint32_t, BlockDim>;
__shared__ typename block_scan::TempStorage shm;

// locate the query data
clusters_to_probe += n_probes * blockIdx.x;
chunk_indices += n_probes * blockIdx.x;

// block scan
const uint32_t n_probes_aligned = Pow2<BlockDim>::roundUp(n_probes);
uint32_t total = 0;
for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) {
auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u;
auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u;
if (threadIdx.x == 0) { chunk += total; }
block_scan(shm).InclusiveSum(chunk, chunk, total);
__syncthreads();
if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; }
}
// save the total size
if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; }
}

struct calc_chunk_indices {
public:
struct configured {
void* kernel;
dim3 block_dim;
dim3 grid_dim;
uint32_t n_probes;

inline void operator()(const uint32_t* cluster_sizes,
const uint32_t* clusters_to_probe,
uint32_t* chunk_indices,
uint32_t* n_samples,
rmm::cuda_stream_view stream)
{
void* args[] = // NOLINT
{&n_probes, &cluster_sizes, &clusters_to_probe, &chunk_indices, &n_samples};
RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream));
}
};

static inline auto configure(uint32_t n_probes, uint32_t n_queries) -> configured
{
return try_block_dim<1024>(n_probes, n_queries);
}

private:
template <int BlockDim>
static auto try_block_dim(uint32_t n_probes, uint32_t n_queries) -> configured
{
if constexpr (BlockDim >= WarpSize * 2) {
if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); }
}
return {reinterpret_cast<void*>(calc_chunk_indices_kernel<BlockDim>),
dim3(BlockDim, 1, 1),
dim3(n_queries, 1, 1),
n_probes};
}
};

/**
* Look up the chunk id corresponding to the sample index.
*
* Each query vector was compared to all the vectors from n_probes clusters, and sample_ix is an
* ordered number of one of such vectors. This function looks up to which chunk it belongs,
* and returns the index within the chunk (which is also an index within a cluster).
*
* @param[inout] sample_ix
* input: the offset of the sample in the batch;
* output: the offset inside the chunk (probe) / selected cluster.
* @param[in] n_probes number of probes
* @param[in] chunk_indices offsets of the chunks within the batch [n_probes]
* @return chunk index (== n_probes when the input index is not in the valid range,
* which can happen if there is not enough data to output in the selected clusters).
*/
__device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT
uint32_t n_probes,
const uint32_t* chunk_indices) -> uint32_t
{
uint32_t ix_min = 0;
uint32_t ix_max = n_probes;
do {
uint32_t i = (ix_min + ix_max) / 2;
if (chunk_indices[i] <= sample_ix) {
ix_min = i + 1;
} else {
ix_max = i;
}
} while (ix_min < ix_max);
if (ix_min > 0) { sample_ix -= chunk_indices[ix_min - 1]; }
return ix_min;
}

template <int BlockDim, typename IdxT1, typename IdxT2 = uint32_t>
__launch_bounds__(BlockDim) RAFT_KERNEL
postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
uint32_t n_probes,
uint32_t topk)
{
const uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x);
const uint32_t query_ix = i / uint64_t(topk);
if (query_ix >= n_queries) { return; }
const uint32_t k = i % uint64_t(topk);
neighbors_in += query_ix * topk;
neighbors_out += query_ix * topk;
chunk_indices += query_ix * n_probes;
clusters_to_probe += query_ix * n_probes;
uint32_t data_ix = neighbors_in[k];
const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices);
const bool valid = chunk_ix < n_probes;
neighbors_out[k] =
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT1>;
}

/**
* Transform found sample indices into the corresponding database indices
* (as stored in index.indices()).
* The sample indices are the record indices as they appear in the database view formed by the
* probed clusters / defined by the `chunk_indices`.
* We assume the searched sample sizes (for a single query) fit into `uint32_t`.
*/
template <typename IdxT1, typename IdxT2 = uint32_t>
void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
uint32_t n_probes,
uint32_t topk,
rmm::cuda_stream_view stream)
{
constexpr int kPNThreads = 256;
const int pn_blocks = raft::div_rounding_up_unsafe<size_t>(n_queries * topk, kPNThreads);
postprocess_neighbors_kernel<kPNThreads, IdxT1, IdxT2>
<<<pn_blocks, kPNThreads, 0, stream>>>(neighbors_out,
neighbors_in,
db_indices,
clusters_to_probe,
chunk_indices,
n_queries,
n_probes,
topk);
}

/**
* Post-process the scores depending on the metric type;
* translate the element type if necessary.
*/
template <typename ScoreInT, typename ScoreOutT = float>
void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
const ScoreInT* in, // [n_queries, topk]
distance::DistanceType metric,
uint32_t n_queries,
uint32_t topk,
float scaling_factor,
bool account_for_max_close,
rmm::cuda_stream_view stream)
{
constexpr bool needs_cast = !std::is_same<ScoreInT, ScoreOutT>::value;
const bool needs_copy = ((void*)in) != ((void*)out);
size_t len = size_t(n_queries) * size_t(topk);
switch (metric) {
case distance::DistanceType::L2Unexpanded:
case distance::DistanceType::L2Expanded: {
if (scaling_factor != 1.0) {
linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{scaling_factor * scaling_factor},
raft::cast_op<ScoreOutT>{}),
stream);
} else if (needs_cast || needs_copy) {
linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
case distance::DistanceType::L2SqrtUnexpanded:
case distance::DistanceType::L2SqrtExpanded: {
if (scaling_factor != 1.0) {
linalg::unaryOp(out,
in,
len,
raft::compose_op{raft::mul_const_op<ScoreOutT>{scaling_factor},
raft::sqrt_op{},
raft::cast_op<ScoreOutT>{}},
stream);
} else if (needs_cast) {
linalg::unaryOp(
out, in, len, raft::compose_op{raft::sqrt_op{}, raft::cast_op<ScoreOutT>{}}, stream);
} else {
linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream);
}
} break;
case distance::DistanceType::InnerProduct: {
float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor;
if (factor != 1.0) {
linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{factor}, raft::cast_op<ScoreOutT>{}),
stream);
} else if (needs_cast || needs_copy) {
linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
default: RAFT_FAIL("Unexpected metric.");
}
}

/** Update the state of the dependent index members. */
template <typename Index>
void recompute_internal_state(const raft::resources& res, Index& index)
{
auto stream = resource::get_cuda_stream(res);
auto tmp_res = resource::get_workspace_resource(res);
rmm::device_uvector<uint32_t> sorted_sizes(index.n_lists(), stream, tmp_res);

// Actualize the list pointers
auto data_ptrs = index.data_ptrs();
auto inds_ptrs = index.inds_ptrs();
for (uint32_t label = 0; label < index.n_lists(); label++) {
auto& list = index.lists()[label];
const auto data_ptr = list ? list->data.data_handle() : nullptr;
const auto inds_ptr = list ? list->indices.data_handle() : nullptr;
copy(&data_ptrs(label), &data_ptr, 1, stream);
copy(&inds_ptrs(label), &inds_ptr, 1, stream);
}

// Sort the cluster sizes in the descending order.
int begin_bit = 0;
int end_bit = sizeof(uint32_t) * 8;
size_t cub_workspace_size = 0;
cub::DeviceRadixSort::SortKeysDescending(nullptr,
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res);
cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(),
cub_workspace_size,
index.list_sizes().data_handle(),
sorted_sizes.data(),
index.n_lists(),
begin_bit,
end_bit,
stream);
// copy the results to CPU
std::vector<uint32_t> sorted_sizes_host(index.n_lists());
copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream);
resource::sync_stream(res);

// accumulate the sorted cluster sizes
auto accum_sorted_sizes = index.accum_sorted_sizes();
accum_sorted_sizes(0) = 0;
for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) {
accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label];
}
}

} // namespace raft::neighbors::ivf::detail
Loading

0 comments on commit 515ac62

Please sign in to comment.