Skip to content

Commit

Permalink
Merge branch 'branch-24.04' into cagra-q
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 authored Mar 20, 2024
2 parents 38ab2bd + 335236c commit 15afe26
Show file tree
Hide file tree
Showing 22 changed files with 221 additions and 99 deletions.
8 changes: 5 additions & 3 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand All @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle,
IdxT* out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)
raft::matrix::SelectAlgo algo, \
const IdxT* len_i)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
Expand Down
16 changes: 10 additions & 6 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle,
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
* @param[in] len_i
* array of size (batch_size) providing lengths for each individual row
* only radix select-k supported
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
const IdxT* len_i = nullptr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
Expand All @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
true // fused_last_filter
);

true, // fused_last_filter
len_i);
} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(handle,
Expand All @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
fused_last_filter);
fused_last_filter,
len_i);
}
if (sorted) {
auto offsets = make_device_mdarray<IdxT, IdxT>(
Expand Down
35 changes: 31 additions & 4 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ RAFT_KERNEL radix_kernel(const T* in,
Counter<T, IdxT>* counters,
IdxT* histograms,
const IdxT len,
const IdxT* len_i,
const IdxT k,
const bool select_min,
const int pass)
Expand Down Expand Up @@ -598,6 +599,14 @@ RAFT_KERNEL radix_kernel(const T* in,
in_buf += batch_id * buf_len;
in_idx_buf += batch_id * buf_len;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = max_len;
}

// "current_len > buf_len" means current pass will skip writing buffer
if (pass == 0 || current_len > buf_len) {
out_buf = nullptr;
Expand Down Expand Up @@ -829,6 +838,7 @@ void radix_topk(const T* in,
IdxT* out_idx,
bool select_min,
bool fused_last_filter,
const IdxT* len_i,
unsigned grid_dim,
int sm_cnt,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -868,6 +878,7 @@ void radix_topk(const T* in,
const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr;
T* chunk_out = out + offset * k;
IdxT* chunk_out_idx = out_idx + offset * k;
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;

const T* in_buf = nullptr;
const IdxT* in_idx_buf = nullptr;
Expand Down Expand Up @@ -905,6 +916,7 @@ void radix_topk(const T* in,
counters.data(),
histograms.data(),
len,
chunk_len_i,
k,
select_min,
pass);
Expand Down Expand Up @@ -1007,6 +1019,7 @@ template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
const IdxT* in_idx,
const IdxT len,
const IdxT* len_i,
const IdxT k,
T* out,
IdxT* out_idx,
Expand Down Expand Up @@ -1057,6 +1070,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
out_idx_buf = nullptr;
}

// in case we have individual len for each query defined we want to make sure
// that we only iterate valid elements.
if (len_i != nullptr) {
const IdxT max_len = max(len_i[batch_id], k);
if (max_len < previous_len) previous_len = max_len;
}

filter_and_histogram_for_one_block<T, IdxT, BitsPerPass>(in_buf,
in_idx_buf,
out_buf,
Expand Down Expand Up @@ -1106,6 +1126,7 @@ void radix_topk_one_block(const T* in,
T* out,
IdxT* out_idx,
bool select_min,
const IdxT* len_i,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
Expand All @@ -1121,10 +1142,12 @@ void radix_topk_one_block(const T* in,
max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr);

for (size_t offset = 0; offset < static_cast<size_t>(batch_size); offset += max_chunk_size) {
int chunk_size = std::min(max_chunk_size, batch_size - offset);
int chunk_size = std::min(max_chunk_size, batch_size - offset);
const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr;
kernel<<<chunk_size, BlockSize, 0, stream>>>(in + offset * len,
in_idx ? (in_idx + offset * len) : nullptr,
len,
chunk_len_i,
k,
out + offset * k,
out_idx + offset * k,
Expand Down Expand Up @@ -1188,6 +1211,8 @@ void radix_topk_one_block(const T* in,
* blocks is called. The later case is preferable when leading bits of input data are almost the
* same. That is, when the value range of input data is narrow. In such case, there could be a
* large number of inputs for the last filter, hence using multiple thread blocks is beneficial.
* @param len_i
* optional array of size (batch_size) providing lengths for each individual row
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void select_k(raft::resources const& res,
Expand All @@ -1199,7 +1224,8 @@ void select_k(raft::resources const& res,
T* out,
IdxT* out_idx,
bool select_min,
bool fused_last_filter)
bool fused_last_filter,
const IdxT* len_i)
{
auto stream = resource::get_cuda_stream(res);
auto mr = resource::get_workspace_resource(res);
Expand All @@ -1223,13 +1249,13 @@ void select_k(raft::resources const& res,

if (len <= BlockSize * items_per_thread) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
unsigned grid_dim =
impl::calc_grid_dim<T, IdxT, BitsPerPass, BlockSize>(batch_size, len, sm_cnt);
if (grid_dim == 1) {
impl::radix_topk_one_block<T, IdxT, BitsPerPass, BlockSize>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr);
in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr);
} else {
impl::radix_topk<T, IdxT, BitsPerPass, BlockSize>(in,
in_idx,
Expand All @@ -1240,6 +1266,7 @@ void select_k(raft::resources const& res,
out_idx,
select_min,
fused_last_filter,
len_i,
grid_dim,
sm_cnt,
stream,
Expand Down
20 changes: 10 additions & 10 deletions cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT
return ix_min;
}

template <int BlockDim, typename IdxT1, typename IdxT2 = uint32_t>
template <int BlockDim, typename IdxT>
__launch_bounds__(BlockDim) RAFT_KERNEL
postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand All @@ -170,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices);
const bool valid = chunk_ix < n_probes;
neighbors_out[k] =
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT1>;
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT>;
}

/**
Expand All @@ -180,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
* probed clusters / defined by the `chunk_indices`.
* We assume the searched sample sizes (for a single query) fit into `uint32_t`.
*/
template <typename IdxT1, typename IdxT2 = uint32_t>
void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk]
const IdxT2* neighbors_in, // [n_queries, topk]
const IdxT1* const* db_indices, // [n_clusters][..]
template <typename IdxT>
void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand All @@ -193,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to
{
constexpr int kPNThreads = 256;
const int pn_blocks = raft::div_rounding_up_unsafe<size_t>(n_queries * topk, kPNThreads);
postprocess_neighbors_kernel<kPNThreads, IdxT1, IdxT2>
postprocess_neighbors_kernel<kPNThreads, IdxT>
<<<pn_blocks, kPNThreads, 0, stream>>>(neighbors_out,
neighbors_in,
db_indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;
Expand All @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i
const uint32_t* chunk_indices, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)
Expand Down
25 changes: 10 additions & 15 deletions cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t query_smem_elems,
const T* query,
const uint32_t* coarse_index,
const IdxT* const* list_indices_ptrs,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
Expand All @@ -700,7 +699,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
const uint32_t* chunk_indices,
const uint32_t dim,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances)
{
extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[];
Expand All @@ -719,16 +718,16 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
distances += query_id * k * gridDim.x + blockIdx.x * k;
} else {
distances += query_id * uint64_t(max_samples);
chunk_indices += (n_probes * query_id);
}
chunk_indices += (n_probes * query_id);
coarse_index += query_id * n_probes;
}

// Copy a part of the query into shared memory for faster processing
copy_vectorized(query_shared, query, std::min(dim, query_smem_elems));
__syncthreads();

using local_topk_t = block_sort_t<Capacity, Ascending, float, IdxT>;
using local_topk_t = block_sort_t<Capacity, Ascending, float, uint32_t>;
local_topk_t queue(k);
{
using align_warp = Pow2<WarpSize>;
Expand All @@ -752,11 +751,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2

uint32_t sample_offset = 0;
if constexpr (!kManageLocalTopK) {
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);
}
if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; }
assert(list_length == chunk_indices[probe_id] - sample_offset);
assert(sample_offset + list_length <= max_samples);

constexpr int kUnroll = WarpSize / Veclen;
constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize;
Expand Down Expand Up @@ -806,8 +803,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
// Enqueue one element per thread
const float val = valid ? static_cast<float>(dist) : local_topk_t::queue_t::kDummy;
if constexpr (kManageLocalTopK) {
const size_t idx = valid ? static_cast<size_t>(list_indices_ptrs[list_id][vec_id]) : 0;
queue.add(val, idx);
queue.add(val, sample_offset + vec_id);
} else {
if (vec_id < list_length) distances[sample_offset + vec_id] = val;
}
Expand Down Expand Up @@ -873,7 +869,7 @@ void launch_kernel(Lambda lambda,
const uint32_t max_samples,
const uint32_t* chunk_indices,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down Expand Up @@ -927,7 +923,6 @@ void launch_kernel(Lambda lambda,
query_smem_elems,
queries,
coarse_index,
index.inds_ptrs().data_handle(),
index.data_ptrs().data_handle(),
index.list_sizes().data_handle(),
queries_offset + query_offset,
Expand All @@ -945,8 +940,8 @@ void launch_kernel(Lambda lambda,
distances += grid_dim_y * grid_dim_x * k;
} else {
distances += grid_dim_y * max_samples;
chunk_indices += grid_dim_y * n_probes;
}
chunk_indices += grid_dim_y * n_probes;
coarse_index += grid_dim_y * n_probes;
}
}
Expand Down Expand Up @@ -1161,7 +1156,7 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
Expand Down
Loading

0 comments on commit 15afe26

Please sign in to comment.