From b85c6b0d79b8a425497d3d229b029b462159a47b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 8 Nov 2024 17:01:10 -0500 Subject: [PATCH] Moving select_k back --- .../raft/matrix/detail/select_k-ext.cuh | 72 + .../raft/matrix/detail/select_k-inl.cuh | 320 ++++ cpp/include/raft/matrix/detail/select_k.cuh | 24 + .../raft/matrix/detail/select_radix.cuh | 1337 +++++++++++++++++ .../raft/matrix/detail/select_warpsort.cuh | 1210 +++++++++++++++ cpp/include/raft/matrix/select_k.cuh | 121 ++ cpp/include/raft/matrix/select_k_types.hpp | 101 ++ 7 files changed, 3185 insertions(+) create mode 100644 cpp/include/raft/matrix/detail/select_k-ext.cuh create mode 100644 cpp/include/raft/matrix/detail/select_k-inl.cuh create mode 100644 cpp/include/raft/matrix/detail/select_k.cuh create mode 100644 cpp/include/raft/matrix/detail/select_radix.cuh create mode 100644 cpp/include/raft/matrix/detail/select_warpsort.cuh create mode 100644 cpp/include/raft/matrix/select_k.cuh create mode 100644 cpp/include/raft/matrix/select_k_types.hpp diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh new file mode 100644 index 0000000000..6db1a5acac --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include // RAFT_EXPLICIT + +#include // __half + +#include // uint32_t + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::matrix::detail { + +template +void select_k(raft::resources const& handle, + const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) RAFT_EXPLICIT; +} // namespace raft::matrix::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) +instantiate_raft_matrix_detail_select_k(__half, uint32_t); +instantiate_raft_matrix_detail_select_k(__half, int64_t); +instantiate_raft_matrix_detail_select_k(float, int64_t); +instantiate_raft_matrix_detail_select_k(float, uint32_t); +// needed for brute force knn +instantiate_raft_matrix_detail_select_k(float, int); +// We did not have these two for double before, but there are tests for them. We +// therefore include them here. +instantiate_raft_matrix_detail_select_k(double, int64_t); +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh new file mode 100644 index 0000000000..93d233152b --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -0,0 +1,320 @@ +/* + + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "select_radix.cuh" +#include "select_warpsort.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::matrix::detail { + +/** + * Predict the fastest select_k algorithm based on the number of rows/cols/k + * + * The body of this method is automatically generated, using a DecisionTree + * to predict the fastest algorithm based off of thousands of trial runs + * on different values of rows/cols/k. The decision tree is converted to c++ + * code, which is cut and paste below. + * + * NOTE: The code to generate is in cpp/scripts/heuristics/select_k, running the + * 'generate_heuristic' notebook there will replace the body of this function + * with the latest learned heuristic + */ +inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k) +{ + if (k > 256) { + if (cols > 16862) { + if (rows > 1020) { + return SelectAlgo::kRadix11bitsExtraPass; + } else { + return SelectAlgo::kRadix11bits; + } + } else { + return SelectAlgo::kRadix11bitsExtraPass; + } + } else { + if (k > 2) { + if (cols > 22061) { + return SelectAlgo::kWarpDistributedShm; + } else { + if (rows > 198) { + return SelectAlgo::kWarpDistributedShm; + } else { + return SelectAlgo::kWarpImmediate; + } + } + } else { + return SelectAlgo::kWarpImmediate; + } + } +} + +/** + * Performs a segmented sorting of a keys array with respect to + * the segments of a values array. + * @tparam KeyT + * @tparam ValT + * @param handle + * @param values + * @param keys + * @param n_segments + * @param k + * @param select_min + */ +template +void segmented_sort_by_key(raft::resources const& handle, + KeyT* keys, + ValT* values, + size_t n_segments, + size_t n_elements, + const ValT* offsets, + bool asc) +{ + auto stream = resource::get_cuda_stream(handle); + auto mr = resource::get_workspace_resource(handle); + auto out_inds = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_elements)); + auto out_dists = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_elements)); + + // Determine temporary device storage requirements + size_t temp_storage_bytes = 0; + if (asc) { + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + } else { + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + } + + auto d_temp_storage = raft::make_device_mdarray( + handle, mr, raft::make_extents(temp_storage_bytes)); + + if (asc) { + // Run sorting operation + cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(), + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + + } else { + // Run sorting operation + cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(), + temp_storage_bytes, + keys, + out_dists.data_handle(), + values, + out_inds.data_handle(), + n_elements, + n_segments, + offsets, + offsets + 1, + 0, + sizeof(ValT) * 8, + stream); + } + + raft::copy(values, out_inds.data_handle(), out_inds.size(), stream); + raft::copy(keys, out_dists.data_handle(), out_dists.size(), stream); +} + +template +void segmented_sort_by_key(raft::resources const& handle, + raft::device_vector_view offsets, + raft::device_vector_view keys, + raft::device_vector_view values, + bool asc) +{ + RAFT_EXPECTS(keys.size() == values.size(), + "Keys and values must contain the same number of elements."); + segmented_sort_by_key(handle, + keys.data_handle(), + values.data_handle(), + offsets.size() - 1, + keys.size(), + offsets.data_handle(), + asc); +} + +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * + * @param[in] handle container of reusable resources + * @param[in] in_val + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_val. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out_val + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_val`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out_val`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param[in] sorted + * whether to make sure selected pairs are sorted by value + * @param[in] algo + * the selection algorithm to use + * @param[in] len_i + * array of size (batch_size) providing lengths for each individual row + * only radix select-k supported + */ +template +void select_k(raft::resources const& handle, + const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) +{ + common::nvtx::range fun_scope( + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + + if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); } + + switch (algo) { + case SelectAlgo::kRadix8bits: + case SelectAlgo::kRadix11bits: + case SelectAlgo::kRadix11bitsExtraPass: { + if (algo == SelectAlgo::kRadix8bits) { + detail::select::radix::select_k(handle, + in_val, + in_idx, + batch_size, + len, + k, + out_val, + out_idx, + select_min, + true, // fused_last_filter + len_i); + } else { + bool fused_last_filter = algo == SelectAlgo::kRadix11bits; + detail::select::radix::select_k(handle, + in_val, + in_idx, + batch_size, + len, + k, + out_val, + out_idx, + select_min, + fused_last_filter, + len_i); + } + if (sorted) { + auto offsets = make_device_mdarray( + handle, resource::get_workspace_resource(handle), make_extents(batch_size + 1)); + raft::linalg::map_offset(handle, offsets.view(), mul_const_op(k)); + + auto keys = raft::make_device_vector_view(out_val, (IdxT)(batch_size * k)); + auto vals = raft::make_device_vector_view(out_idx, (IdxT)(batch_size * k)); + + segmented_sort_by_key( + handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min); + } + return; + } + case SelectAlgo::kWarpDistributed: + return detail::select::warpsort:: + select_k_impl( + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); + case SelectAlgo::kWarpDistributedShm: + return detail::select::warpsort:: + select_k_impl( + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); + case SelectAlgo::kWarpAuto: + return detail::select::warpsort::select_k( + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); + case SelectAlgo::kWarpImmediate: + return detail::select::warpsort:: + select_k_impl( + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); + case SelectAlgo::kWarpFiltered: + return detail::select::warpsort:: + select_k_impl( + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); + default: RAFT_FAIL("K-selection Algorithm not supported."); + } +} +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh new file mode 100644 index 0000000000..711169984b --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "select_k-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "select_k-ext.cuh" +#endif diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh new file mode 100644 index 0000000000..2207b0216e --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -0,0 +1,1337 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace raft::matrix::detail::select::radix { +namespace impl { + +constexpr int VECTORIZED_READ_SIZE = 16; + +template +_RAFT_HOST_DEVICE constexpr int calc_num_buckets() +{ + return 1 << BitsPerPass; +} + +template +_RAFT_HOST_DEVICE constexpr int calc_num_passes() +{ + return ceildiv(sizeof(T) * 8, BitsPerPass); +} + +/** + * Bit 0 is the least significant (rightmost); + * this implementation processes input from the most to the least significant bit. + * This way, we can skip some passes in the end at the cost of having an unsorted output. + * + * NB: Use pass=-1 for calc_mask(). + */ +template +_RAFT_DEVICE constexpr int calc_start_bit(int pass) +{ + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; + if (start_bit < 0) { start_bit = 0; } + return start_bit; +} + +template +_RAFT_DEVICE constexpr unsigned calc_mask(int pass) +{ + static_assert(BitsPerPass <= 31); + int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); + return (1 << num_bits) - 1; +} + +/** + * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point values as well + * as of integers. + */ +template +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) +{ + auto bits = reinterpret_cast::UnsignedBits&>(key); + bits = cub::Traits::TwiddleIn(bits); + if (!select_min) { bits = ~bits; } + return bits; +} + +template +_RAFT_DEVICE T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) +{ + if (!select_min) { bits = ~bits; } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + +template +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) +{ + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +// Strangely, RATIO_T has a strong impact on register usage and occupancy for sm80, e.g. +// using RATIO_T=unsigned for radix_kernel decreases occupancy (with CUDA 12). +// In the meanwhile, RATIO_T has no impact for sm90. +template +_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) +{ + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) + // and `out_idx_buf`(IdxT). + // The ratio between these cases determines whether to skip writing and hence the buffer size. + constexpr RATIO_T ratio = 2 + sizeof(IdxT) * 2 / sizeof(T); + // Even such estimation is too conservative, so further decrease buf_len by 1/8 + IdxT buf_len = len / (ratio * 8); + + // one-block kernel splits one large buffer into smaller ones, so round buf size to 256 bytes to + // avoid alignment issues + static_assert(is_a_power_of_two(sizeof(T))); + static_assert(is_a_power_of_two(sizeof(IdxT))); + constexpr IdxT aligned = 256 / std::min(sizeof(T), sizeof(IdxT)); + buf_len = Pow2::roundDown(buf_len); + return buf_len; +} + +/** + * Map a Func over the input data, using vectorized load instructions if possible. + * + * NB: in future, we should move this to cpp/include/raft/linalg/detail/unary_op.cuh, which + * currently does not support the second lambda argument (index of an element) + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +_RAFT_DEVICE void vectorized_process( + size_t thread_rank, size_t num_threads, const T* in, IdxT len, Func f) +{ + if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { + for (IdxT i = thread_rank; i < len; i += num_threads) { + f(in[i], i); + } + } else { + using wide_t = TxN_t; + using align_bytes = Pow2<(size_t)VECTORIZED_READ_SIZE>; + using align_elems = Pow2; + wide_t wide; + + // how many elements to skip in order to do aligned vectorized load + const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); + + // The main loop: process all aligned data + for (IdxT i = thread_rank * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += num_threads * wide_t::Ratio) { + wide.load(in, i); +#pragma unroll + for (int j = 0; j < wide_t::Ratio; ++j) { + f(wide.val.data[j], i + j); + } + } + + static_assert(WarpSize >= wide_t::Ratio); + // Processes the skipped elements on the left + if (thread_rank < skip_cnt_left) { f(in[thread_rank], thread_rank); } + // Processes the skipped elements on the right + const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); + const IdxT remain_i = len - skip_cnt_right + thread_rank; + if (remain_i < len) { f(in[remain_i], remain_i); } + } +} + +template +struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to least significant. In + // each pass, we keep the length of input (`len`) and the `k` of current pass, and update them at + // the end of the pass. + IdxT k; + IdxT len; + + // `previous_len` is the length of input in previous pass. Note that `previous_len` rather + // than `len` is used for the filtering step because filtering is indeed for previous pass (see + // comments before `radix_kernel`). + IdxT previous_len; + + // We determine the bits of the k_th value inside the mask processed by the pass. The + // already known bits are stored in `kth_value_bits`. It's used to discriminate a element is a + // result (written to `out`), a candidate for next pass (written to `out_buf`), or not useful + // (discarded). The bits that are not yet processed do not matter for this purpose. + typename cub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the position in the + // `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This counter is used to + // determine if the current block is the last running block. If so, this block will execute scan() + // and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements less (if + // select_min==true) than the k-th value are written from front to back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements equal to the k-th + // value are written from back to front. We need to keep count of them separately because the + // number of elements that <= the k-th value might exceed k. + alignas(128) IdxT out_back_cnt; +}; + +/** + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). + * + * This function is more complicated than the one-block counterpart because this function handles + * the case of early stopping. When early stopping is triggered, it's desirable to do the final + * filtering in this function rather than in last_filter(), because this function is run by multiple + * blocks while last_filter is run by a single block. + */ +template +_RAFT_DEVICE void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT previous_len, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass, + bool early_stop) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + // Passed to vectorized_process, this function executes in all blocks in parallel, + // i.e. the work is split along the input (both, in batches and chunks of a single row). + // Later, the histograms are merged using atomicAdd. + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); + } else { + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + // See the remark above on the distributed execution of `f` using vectorized_process. + auto f = [in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + select_min, + start_bit, + mask, + previous_start_bit, + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + } + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to + // `out` too. So we won't write the same value to `out` multiple times in different passes. + // And if we keep skipping the writing, values will be written in `last_filter_kernel()` at + // last. But when `early_stop` is true, we need to write to `out` since it's the last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); + } + if (early_stop) { return; } + __syncthreads(); + + // merge histograms produced by individual blocks + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } + } +} + +/** + * Replace histogram with its own prefix sum + * (step 2 in `radix_kernel` description) + */ +template +_RAFT_DEVICE void scan(volatile IdxT* histogram) +{ + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; + + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + IdxT thread_data[items_per_thread]; + + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } + } +} + +/** + * Calculate in which bucket the k-th value will fall + * (steps 3 in `radix_kernel` description) + */ +template +_RAFT_DEVICE void choose_bucket(Counter* counter, + const IdxT* histogram, + const IdxT k, + const int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is written by only one thread + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } + } +} + +// For one-block version, last_filter() could be called when pass < num_passes - 1. +// So `pass` could not be constexpr +template +_RAFT_DEVICE void last_filter(const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, + const bool select_min, + const int pass) +{ + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); + + // changed in choose_bucket(); need to reload + const IdxT num_of_kth_needed = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` could be nullptr if + // `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < num_of_kth_needed) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + in_buf = reinterpret_cast(bufs); + in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + out_buf = const_cast(in_buf + buf_len); + out_idx_buf = const_cast(in_idx_buf + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + in_buf = out_buf + buf_len; + in_idx_buf = out_idx_buf + buf_len; + } +} + +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + const int pass, + const T*& out_buf, + const IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + out_buf = const_cast(reinterpret_cast(bufs) + buf_len); + out_idx_buf = + const_cast(reinterpret_cast(bufs + sizeof(T) * 2 * buf_len) + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } +} + +template +RAFT_KERNEL last_filter_kernel(const T* in, + const IdxT* in_idx, + char* bufs, + size_t offset, + T* out, + IdxT* out_idx, + const IdxT len, + const IdxT* len_i, + const IdxT k, + Counter* counters, + const bool select_min) +{ + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter* counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + + if (previous_len == 0) { return; } + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + + const IdxT buf_len = calc_buf_len(len); + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + set_buf_pointers(in + l_offset, in_idx + l_offset, bufs, buf_len, pass, in_buf, in_idx_buf); + + if (previous_len > buf_len || in_buf == in + l_offset) { + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; + } + out += batch_id * k; + out_idx += batch_id * k; + + const auto kth_value_bits = counter->kth_value_bits; + const IdxT num_of_kth_needed = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, + select_min, + kth_value_bits, + num_of_kth_needed, + p_out_cnt, + p_out_back_cnt, + in_idx_buf, + out, + out_idx](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < num_of_kth_needed) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); +} + +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_val( + T* dest, const T* src, S len, IdxT k, const bool select_min) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + const T default_val = select_min ? upper_bound() : lower_bound(); + for (S i = idx; i < k; i += stride) { + dest[i] = i < len ? src[i] : default_val; + } +} + +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_idx(T* dest, const T* src, S len) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + + for (S i = idx; i < len; i += stride) { + dest[i] = src ? src[i] : i; + } +} + +/** + * + * It is expected to call this kernel multiple times (passes), in each pass we process a radix, + * going from the most significant towards the least significant bits (MSD). + * + * Conceptually, each pass consists of 4 steps: + * + * 1. Calculate histogram + * First, transform bits into a digit, the value of which is in the range + * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value and the result is a + * histogram. That is, histogram[i] contains the count of inputs having value i. + * + * 2. Scan the histogram + * Inclusive prefix sum is computed for the histogram. After this step, histogram[i] contains + * the count of inputs having value <= i. + * + * 3. Find the bucket j of the histogram that the k-th value falls into + * + * 4. Filtering + * Input elements whose digit value +RAFT_KERNEL radix_kernel(const T* in, + const IdxT* in_idx, + char* bufs, + size_t offset, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT* len_i, + const IdxT k, + const bool select_min, + const int pass) +{ + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; + IdxT current_k; + IdxT previous_len; + IdxT current_len; + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + + if (pass == 0) { + current_k = k; + previous_len = l_len; + // Need to do this so setting counter->previous_len for the next pass is correct. + // This value is meaningless for pass 0, but it's fine because pass 0 won't be the + // last pass in this implementation so pass 0 won't hit the "if (pass == + // num_passes - 1)" branch. + // Maybe it's better to reload counter->previous_len and use it rather than + // current_len in last_filter() + current_len = l_len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if constexpr (!len_or_indptr) { + if (pass == 0 && l_len <= k) { + copy_in_val(out + batch_id * k, in + l_offset, l_len, k, select_min); + copy_in_idx(out_idx + batch_id * k, (in_idx ? (in_idx + l_offset) : nullptr), l_len); + if (threadIdx.x == 0) { + counter->previous_len = 0; + counter->len = 0; + } + __syncthreads(); + return; + } + } + + if (current_len == 0) { return; } + + // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle + // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is + // handled in other way in select_k() so such case is not possible here. + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); + + const T* in_buf; + const IdxT* in_idx_buf; + T* out_buf; + IdxT* out_idx_buf; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + set_buf_pointers(in + l_offset, + (in_idx ? (in_idx + l_offset) : nullptr), + bufs, + buf_len, + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; + } + + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(l_len, k); + if (max_len < previous_len) previous_len = max_len; + } + + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; + + filter_and_histogram(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + previous_len, + counter, + histogram, + select_min, + pass, + early_stop); + __threadfence(); + + bool isLastBlock = false; + if (threadIdx.x == 0) { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + isLastBlock = (finished == (gridDim.x - 1)); + } + if (__syncthreads_or(isLastBlock)) { + if (early_stop) { + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; + } + + scan(histogram); + __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + constexpr int num_passes = calc_num_passes(); + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len even in the last pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (fused_last_filter) { + if (pass == num_passes - 1) { + last_filter(out_buf ? out_buf : in_buf, + out_idx_buf ? out_idx_buf : in_idx_buf, + out, + out_idx, + out_buf ? current_len : l_len, + k, + counter, + select_min, + pass); + } + } + } +} + +template +int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel, bool one_block) +{ + int active_blocks; + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); + + // The chunk size is chosen so that there is enough workload to fully utilize GPU. + // One full wave contains (sm_cnt * active_blocks) blocks, and 10 waves is an empirically safe + // estimation of enough workload. It also counteracts imbalance if some blocks run slower + // than others. + constexpr int num_waves = 10; + int chunk_size; + if (one_block) { + // For one-block version, one block processes one instance in the chunk. Just ensure that there + // are enough blocks. + chunk_size = num_waves * sm_cnt * active_blocks; + } else { + // One instance in the chunk contains len items and is processed by multiple blocks. + // The total number of items in a chunk (chunk_size * len) should be large enough that every + // thread has enough items to processes. So set it to num_waves * "max num of active threads" + // (sm_cnt * active_blocks * BlockSize) * items_per_thread. + // + // Also, the upper bound of the total number of items in a chunk is: + // 10 (num_waves) * ~100 (sm_cnt) * 2048 (active_blocks*BlockSize) * 32 (items_per_thread) =64M. + // So temporary buffer size required for one chunk won't be too large. + constexpr int items_per_thread = 32; + chunk_size = + std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + } + return std::min(chunk_size, batch_size); +} + +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) +{ + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, radix_kernel, BlockSize, 0)); + active_blocks *= sm_cnt; + + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for (int num_waves = 1;; ++num_waves) { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop early, + // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; + } + + if (num_blocks == max_num_blocks) { break; } + } + return best_num_blocks; +} + +template +void radix_topk(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool fused_last_filter, + const IdxT* len_i, + unsigned grid_dim, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + auto kernel = radix_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel, false); + if (max_chunk_size != static_cast(batch_size)) { + grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); + } + const IdxT buf_len = calc_buf_len(len); + + size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + rmm::device_uvector> counters(max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + + rmm::device_uvector bufs( + max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + auto kernel = radix_kernel; + + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; + + dim3 blocks(grid_dim, chunk_size); + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + if (fused_last_filter && pass == num_passes - 1) { + kernel = radix_kernel; + } + + kernel<<>>(in, + in_idx, + bufs.data(), + offset, + chunk_out, + chunk_out_idx, + counters.data(), + histograms.data(), + len, + chunk_len_i, + k, + select_min, + pass); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + if (!fused_last_filter) { + last_filter_kernel + <<>>(in, + in_idx, + bufs.data(), + offset, + chunk_out, + chunk_out_idx, + len, + chunk_len_i, + k, + counters.data(), + select_min); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } +} + +// The following a few functions are for the one-block version, which uses single thread block for +// each row of a batch. +template +_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + const IdxT previous_len, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + IdxT* p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { *p_filter_cnt = 0; } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else if (!out_buf) { + // not use vectorized_process here because it increases #registers a lot + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + } + } else { + // not use vectorized_process here because it increases #registers a lot + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { +#if CUDART_VERSION < 12000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } else if (previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +RAFT_KERNEL radix_topk_one_block_kernel(const T* in, + const IdxT* in_idx, + const IdxT len, + const IdxT* len_i, + const IdxT k, + T* out, + IdxT* out_idx, + const bool select_min, + char* bufs, + size_t offset) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + + IdxT l_len = len; + IdxT l_offset = (offset + batch_id) * len; + if constexpr (!len_or_indptr) { + l_offset = len_i[batch_id]; + l_len = len_i[batch_id + 1] - l_offset; + } + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = l_len; + counter.previous_len = l_len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; + } + __syncthreads(); + + in += l_offset; + if (in_idx) { in_idx += l_offset; } + out += batch_id * k; + out_idx += batch_id * k; + const IdxT buf_len = calc_buf_len(len); + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + if constexpr (!len_or_indptr) { + if (l_len <= k) { + copy_in_val(out, in, l_len, k, select_min); + copy_in_idx(out_idx, in_idx, l_len); + __syncthreads(); + return; + } + } + + constexpr int num_passes = calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + const T* in_buf; + const IdxT* in_idx_buf; + T* out_buf; + IdxT* out_idx_buf; + set_buf_pointers(in, in_idx, bufs, buf_len, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + + const IdxT current_len = counter.len; + const IdxT current_k = counter.k; + IdxT previous_len = counter.previous_len; + if (previous_len > buf_len) { + in_buf = in; + in_idx_buf = in_idx; + previous_len = len; + } + if (current_len > buf_len) { + // so "out_buf==nullptr" denotes skipping writing buffer in current pass + out_buf = nullptr; + out_idx_buf = nullptr; + } + + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(l_len, k); + if (max_len < previous_len) previous_len = max_len; + } + + filter_and_histogram_for_one_block(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + previous_len, + &counter, + histogram, + select_min, + pass); + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + if (threadIdx.x == 0) { counter.previous_len = current_len; } + __syncthreads(); + + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(out_buf ? out_buf : in, + out_buf ? out_idx_buf : in_idx, + out, + out_idx, + out_buf ? current_len : l_len, + k, + &counter, + select_min, + pass); + break; + } + } +} + +// radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following +// one-block version uses single thread block for one row of a batch, so intermediate data, like +// counters and global histograms, can be kept in shared memory and cheap sync operations can be +// used. It's used when len is relatively small or when the number of blocks per row calculated by +// `calc_grid_dim()` is 1. +template +void radix_topk_one_block(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + const IdxT* len_i, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + static_assert(calc_num_passes() > 1); + + auto kernel = radix_topk_one_block_kernel; + const IdxT buf_len = calc_buf_len(len); + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel, true); + + rmm::device_uvector bufs( + max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; + kernel<<>>(in, + in_idx, + len, + chunk_len_i, + k, + out + offset * k, + out_idx + offset * k, + select_min, + bufs.data(), + offset); + } +} + +} // namespace impl + +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_keys` as a row-major matrix with len columns and + * batch_size rows, then this function selects k smallest/largest values in each row and fills + * in the row-major matrix `out` of size (batch_size, k). + * + * Note, the output is NOT sorted within the groups of `k` selected elements. + * + * Reference: + * Jingrong Zhang, Akira Naruse, Xipeng Li, and Yong Wang. 2023. Parallel Top-K Algorithms on GPU: + * A Comprehensive Study and New Methods. In The International Conference for High Performance + * Computing, Networking, Storage and Analysis (SC ’23), November 12–17, 2023, Denver, CO, USA. + * ACM, New York, NY, USA. https://doi.org/10.1145/3581784.3607062 + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * @tparam BitsPerPass + * The size of the radix; + * it affects the number of passes and number of buckets. + * @tparam BlockSize + * Number of threads in a kernel thread block. + * @tparam len_or_indptr + * Flag to interpret `len_i` as either direct row lengths (true) or CSR format + * index pointers (false). When true, each `len_i` element denotes the length of a row. When + * false, `len_i` represents the index pointers for a CSR matrix with shape of `batch_size + 1`. + * + * @param[in] res container of reusable resources + * @param[in] in + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_keys. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_keys`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param fused_last_filter + * when it's true, the last filter is fused into the kernel in the last pass and only one thread + * block will do the filtering; when false, a standalone filter kernel with multiple thread + * blocks is called. The later case is preferable when leading bits of input data are almost the + * same. That is, when the value range of input data is narrow. In such case, there could be a + * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. + * @param len_i + * Optional array used differently based on `len_or_indptr`: + * When `len_or_indptr` is true, `len_i` presents the lengths of each row, which is `batch_size`. + * When `len_or_indptr` is false, `len_i` works like a indptr for a CSR matrix. The length of each + * row would be (`len_i[row_id + 1] - len_i[row_id]`). `len_i` size is `batch_size + 1`. + */ +template +void select_k(raft::resources const& res, + const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool fused_last_filter, + const IdxT* len_i) +{ + RAFT_EXPECTS(!(!len_or_indptr && (len_i == nullptr)), + "When `len_or_indptr` is false, `len_i` must not be nullptr!"); + + auto stream = resource::get_cuda_stream(res); + auto mr = resource::get_workspace_resource(res); + if (k == len && len_or_indptr) { + RAFT_CUDA_TRY( + cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + if (in_idx) { + RAFT_CUDA_TRY(cudaMemcpyAsync( + out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + } else { + auto out_idx_view = + raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); + raft::linalg::map_offset(res, out_idx_view, raft::mod_const_op(len)); + } + return; + } + + int sm_cnt = resource::get_device_properties(res).multiProcessorCount; + + constexpr int items_per_thread = 32; + + if (len <= BlockSize * items_per_thread) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); + } else { + unsigned grid_dim = + impl::calc_grid_dim(batch_size, len, sm_cnt); + if (grid_dim == 1) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); + } else { + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + len_i, + grid_dim, + sm_cnt, + stream, + mr); + } + } +} + +} // namespace raft::matrix::detail::select::radix diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh new file mode 100644 index 0000000000..7da659291c --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -0,0 +1,1210 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +/* + Three APIs of different scopes are provided: + 1. host function: select_k() + 2. block-wide API: class block_sort + 3. warp-wide API: several implementations of warp_sort_* + + + 1. select_k() + (see the docstring) + + 2. class block_sort + It can be regarded as a fixed size priority queue for a thread block, + although the API is not typical. + one of the classes `warp_sort_*` can be used to instantiate block_sort. + + It uses dynamic shared memory as an intermediate buffer. + So the required shared memory size should be calculated using + calc_smem_size_for_block_wide() and passed as the 3rd kernel launch parameter. + + To add elements to the queue, use add(T val, IdxT idx) with unique values per-thread. + Use WarpSortClass<...>::kDummy constant for the threads outside of input bounds. + + After adding is finished, function done() should be called. And finally, store() is used to get + the top-k result. + + Example: + RAFT_KERNEL kernel() { + block_sort queue(...); + + for (IdxT i = threadIdx.x; i < len, i += blockDim.x) { + queue.add(in[i], in_idx[i]); + } + + queue.done(); + queue.store(out, out_idx); + } + + int smem_size = calc_smem_size_for_block_wide(...); + kernel<<>>(); + + + 3. class warp_sort_* + These two classes can be regarded as fixed size priority queue for a warp. + Usage is similar to class block_sort. No shared memory is needed. + + The host function (select_k) uses a heuristic to choose between these two classes for + sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small + (see the usage of LaunchThreshold::len_factor_for_choosing). + + Example: + RAFT_KERNEL kernel() { + warp_sort_immediate<...> queue(...); + int warp_id = threadIdx.x / WarpSize; + int lane_id = threadIdx.x % WarpSize; + + for (IdxT i = lane_id; i < len, i += WarpSize) { + queue.add(in[i], idx[i]); + } + + queue.done(); + // each warp outputs to a different offset + queue.store(out + warp_id * k, out_idx + warp_id * k); + } + */ + +namespace raft::matrix::detail::select::warpsort { + +static constexpr int kMaxCapacity = 256; + +namespace { + +/** Whether 'left` should indeed be on the left w.r.t. `right`. */ +template +_RAFT_DEVICE _RAFT_FORCEINLINE auto is_ordered(T left, T right) -> bool +{ + if constexpr (Ascending) { return left < right; } + if constexpr (!Ascending) { return left > right; } +} + +} // namespace + +/** + * A fixed-size warp-level priority queue. + * By feeding the data through this queue, you get the `k <= Capacity` + * smallest/greatest values in the data. + * + * @tparam Capacity + * maximum number of elements in the queue. + * @tparam Ascending + * which comparison to use: `true` means `<`, collect the smallest elements, + * `false` means `>`, collect the greatest elements. + * @tparam T + * the type of keys (what is being compared) + * @tparam IdxT + * the type of payload (normally, indices of elements), i.e. + * the content sorted alongside the keys. + */ +template +class warp_sort { + static_assert(is_a_power_of_two(Capacity)); + static_assert(std::is_default_constructible_v); + + public: + /** + * The `empty` value for the chosen binary operation, + * i.e. `Ascending ? upper_bound() : lower_bound()`. + */ + static constexpr T kDummy = Ascending ? upper_bound() : lower_bound(); + /** Width of the subwarp. */ + static constexpr int kWarpWidth = std::min(Capacity, WarpSize); + /** The number of elements to select. */ + const int k; + + /** Extra memory required per-block for keeping the state (shared or global). */ + constexpr static auto mem_required(uint32_t block_size) -> size_t { return 0; } + + /** + * Construct the warp_sort empty queue. + * + * @param k + * number of elements to select. + */ + _RAFT_DEVICE warp_sort(int k) : k(k) + { +#pragma unroll + for (int i = 0; i < kMaxArrLen; i++) { + val_arr_[i] = kDummy; + idx_arr_[i] = IdxT{}; + } + } + + /** + * Load k values from the pointers at the given position, and merge them in the storage. + * + * When it actually loads the values, it always performs some collective warp operations in the + * end, thus enforcing warp sync. This means, it's safe to call `store` with the same arguments + * after `load_sorted` without extra sync. Note, however, that this is not necessarily true for + * the reverse order, because the access patterns of `store` and `load_sorted` are different. + * + * @param[in] in + * a device pointer to a contiguous array, unique per-subwarp + * (length: k <= kWarpWidth * kMaxArrLen). + * @param[in] in_idx + * a device pointer to a contiguous array, unique per-subwarp + * (length: k <= kWarpWidth * kMaxArrLen). + * @param[in] do_merge + * must be the same for all threads within a subwarp of size `kWarpWidth`. + * It serves as a conditional; when `false` the function does nothing. + * We need it to ensure threads within a full warp don't diverge calling `bitonic::merge()`. + */ + _RAFT_DEVICE void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) + { + if (do_merge) { + int idx = Pow2::mod(laneId()) ^ Pow2::Mask; +#pragma unroll + for (int i = kMaxArrLen - 1; i >= 0; --i, idx += kWarpWidth) { + if (idx < k) { + T t = in[idx]; + if (is_ordered(t, val_arr_[i])) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + } + if (kWarpWidth < WarpSize || do_merge) { + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + } + } + + /** + * Save the content by the pointer location. + * + * @param[out] out + * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` + * (length: k <= kWarpWidth * kMaxArrLen). + * @param[out] out_idx + * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` + * (length: k <= kWarpWidth * kMaxArrLen). + * @param valF (optional) postprocess values (T -> OutT) + * @param idxF (optional) postprocess indices (IdxT -> OutIdxT) + */ + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const + { + int idx = Pow2::mod(laneId()); +#pragma unroll kMaxArrLen + for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { + out[idx] = valF(val_arr_[i]); + out_idx[idx] = idxF(idx_arr_[i]); + } + } + + protected: + static constexpr int kMaxArrLen = Capacity / kWarpWidth; + + T val_arr_[kMaxArrLen]; + IdxT idx_arr_[kMaxArrLen]; + + /** + * Merge another array (sorted in the opposite direction) in the queue. + * Thanks to the other array being sorted in the opposite direction, + * it's enough to call bitonic.merge once to maintain the valid state + * of the queue. + * + * @tparam PerThreadSizeIn + * the size of the other array per-thread (compared to `kMaxArrLen`). + * + * @param keys_in + * the values to be merged in. Pointers are unique per-thread. The values + * must already be sorted in the opposite direction. + * The layout of `keys_in` must be the same as the layout of `val_arr_`. + * @param ids_in + * the associated indices of the elements in the same format as `keys_in`. + */ + template + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_in(const T* __restrict__ keys_in, + const IdxT* __restrict__ ids_in) + { +#pragma unroll + for (int i = std::min(kMaxArrLen, PerThreadSizeIn); i > 0; i--) { + T& key = val_arr_[kMaxArrLen - i]; + T other = keys_in[PerThreadSizeIn - i]; + if (is_ordered(other, key)) { + key = other; + idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; + } + } + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + } +}; + +/** + * This version of warp_sort compares each input element against the current + * estimate of k-th value before adding it to the intermediate sorting buffer. + * This makes the algorithm do less sorting steps for long input sequences + * at the cost of extra checks on each step. + * + * This implementation is preferred for large len values. + */ +template +class warp_sort_filtered : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + using warp_sort::mem_required; + + explicit _RAFT_DEVICE warp_sort_filtered(int k, T limit = kDummy) + : warp_sort(k), buf_len_(0), k_th_(limit) + { +#pragma unroll + for (int i = 0; i < kMaxBufLen; i++) { + val_buf_[i] = kDummy; + idx_buf_[i] = IdxT{}; + } + } + + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) + { + return warp_sort_filtered{k, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + // comparing for k_th should reduce the total amount of updates: + // `false` means the input value is surely not in the top-k values. + bool do_add = is_ordered(val, k_th_); + // merge the buf if it's full and we cannot add an element anymore. + if (any(buf_len_ + do_add > kMaxBufLen)) { + // still, add an element before merging if possible for this thread + if (do_add && buf_len_ < kMaxBufLen) { + add_to_buf_(val, idx); + do_add = false; + } + merge_buf_(); + } + // add an element if necessary and haven't already. + if (do_add) { add_to_buf_(val, idx); } + } + + _RAFT_DEVICE void done() + { + if (any(buf_len_ != 0)) { merge_buf_(); } + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + this->merge_in(val_buf_, idx_buf_); + buf_len_ = 0; + set_k_th_(); // contains warp sync +#pragma unroll + for (int i = 0; i < kMaxBufLen; i++) { + val_buf_[i] = kDummy; + } + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void add_to_buf_(T val, IdxT idx) + { + // NB: the loop is used here to ensure the constant indexing, + // to not force the buffers spill into the local memory. +#pragma unroll + for (int i = 0; i < kMaxBufLen; i++) { + if (i == buf_len_) { + val_buf_[i] = val; + idx_buf_[i] = idx; + } + } + buf_len_++; + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + static constexpr int kMaxBufLen = (Capacity <= 64) ? 2 : 4; + + T val_buf_[kMaxBufLen]; + IdxT idx_buf_[kMaxBufLen]; + int buf_len_; + + T k_th_; +}; + +/** + * This version of warp_sort compares each input element against the current + * estimate of k-th value before adding it to the intermediate sorting buffer. + * In contrast to `warp_sort_filtered`, it keeps one distributed buffer for + * all threads in a warp (independently of the subwarp size), which makes its flushing less often. + */ +template +class warp_sort_distributed : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + using warp_sort::mem_required; + + explicit _RAFT_DEVICE warp_sort_distributed(int k, T limit = kDummy) + : warp_sort(k), + buf_val_(kDummy), + buf_idx_(IdxT{}), + buf_len_(0), + k_th_(limit) + { + } + + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) + { + return warp_sort_distributed{k, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + // mask tells which lanes in the warp have valid items to be added + uint32_t mask = ballot(is_ordered(val, k_th_)); + if (mask == 0) { return; } + // how many elements to be added + uint32_t n_valid = __popc(mask); + // index of the source lane containing the value to put into the current lane. + uint32_t src_ix = 0; + // remove a few smallest set bits from the mask. + for (uint32_t i = std::min(n_valid, Pow2::mod(uint32_t(laneId()) - buf_len_)); i > 0; + i--) { + src_ix = __ffs(mask) - 1; + mask ^= (0x1u << src_ix); + } + // now the least significant bit of the mask corresponds to the lane id we want to get. + // for not-added (invalid) indices, the mask is zeroed by now. + src_ix = __ffs(mask) - 1; + // rearrange the inputs to be ready to put them into the tmp buffer + val = shfl(val, src_ix); + idx = shfl(idx, src_ix); + // for non-valid lanes, src_ix should be uint(-1) + if (mask == 0) { val = kDummy; } + // save the values into the free slots of the warp tmp buffer + if (laneId() >= buf_len_) { + buf_val_ = val; + buf_idx_ = idx; + } + buf_len_ += n_valid; + if (buf_len_ < WarpSize) { return; } + // merge the warp tmp buffer into the queue + merge_buf_(); + buf_len_ -= WarpSize; + // save the inputs that couldn't fit before the merge + if (laneId() < buf_len_) { + buf_val_ = val; + buf_idx_ = idx; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + merge_buf_(); + buf_len_ = 0; + } + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); + this->merge_in<1>(&buf_val_, &buf_idx_); + set_k_th_(); // contains warp sync + buf_val_ = kDummy; + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T buf_val_; + IdxT buf_idx_; + uint32_t buf_len_; // 0 <= buf_len_ <= WarpSize + + T k_th_; +}; + +/** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ +template +class warp_sort_distributed_ext : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + constexpr static auto mem_required(uint32_t block_size) -> size_t + { + return (sizeof(T) + sizeof(IdxT)) * block_size; + } + + _RAFT_DEVICE warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit = kDummy) + : warp_sort(k), + val_buf_(val_buf), + idx_buf_(idx_buf), + buf_len_(0), + k_th_(limit) + { + val_buf_[laneId()] = kDummy; + } + + _RAFT_DEVICE static auto init_blockwide(int k, uint8_t* shmem, T limit = kDummy) + { + T* val_buf = nullptr; + IdxT* idx_buf = nullptr; + if constexpr (alignof(T) >= alignof(IdxT)) { + val_buf = reinterpret_cast(shmem); + idx_buf = reinterpret_cast(val_buf + blockDim.x); + } else { + idx_buf = reinterpret_cast(shmem); + val_buf = reinterpret_cast(idx_buf + blockDim.x); + } + auto warp_offset = Pow2::roundDown(threadIdx.x); + val_buf += warp_offset; + idx_buf += warp_offset; + return warp_sort_distributed_ext{k, val_buf, idx_buf, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + bool do_add = is_ordered(val, k_th_); + // mask tells which lanes in the warp have valid items to be added + uint32_t mask = ballot(do_add); + if (mask == 0) { return; } + // where to put the element in the tmp buffer + int dst_ix = buf_len_ + __popc(mask & ((1u << laneId()) - 1u)); + // put all elements, which fit into the current tmp buffer + if (do_add && dst_ix < WarpSize) { + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + do_add = false; + } + // Total number of elements to be added + buf_len_ += __popc(mask); + // If the buffer is still not full, we can return + if (buf_len_ < WarpSize) { return; } + // Otherwise, merge the warp tmp buffer into the queue + merge_buf_(); // implies warp sync + buf_len_ -= WarpSize; + // save the inputs that couldn't fit before the merge + if (do_add) { + dst_ix -= WarpSize; + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + merge_buf_(); + buf_len_ = 0; + } + __syncthreads(); + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + __syncwarp(); // make sure the threads are aware of the data written by others + T buf_val = val_buf_[laneId()]; + IdxT buf_idx = idx_buf_[laneId()]; + val_buf_[laneId()] = kDummy; + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val, buf_idx); + this->merge_in<1>(&buf_val, &buf_idx); + set_k_th_(); // contains warp sync + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T* val_buf_; + IdxT* idx_buf_; + uint32_t buf_len_; // 0 <= buf_len_ < WarpSize + + T k_th_; +}; + +/** + * This version of warp_sort adds every input element into the intermediate sorting + * buffer, and thus does the sorting step every `Capacity` input elements. + * + * This implementation is preferred for very small len values. + */ +template +class warp_sort_immediate : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + using warp_sort::mem_required; + + explicit _RAFT_DEVICE warp_sort_immediate(int k) + : warp_sort(k), buf_len_(0) + { +#pragma unroll + for (int i = 0; i < kMaxArrLen; i++) { + val_buf_[i] = kDummy; + idx_buf_[i] = IdxT{}; + } + } + + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, uint8_t* = nullptr) + { + return warp_sort_immediate{k}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + // NB: the loop is used here to ensure the constant indexing, + // to not force the buffers spill into the local memory. +#pragma unroll + for (int i = 0; i < kMaxArrLen; ++i) { + if (i == buf_len_) { + val_buf_[i] = val; + idx_buf_[i] = idx; + } + } + + ++buf_len_; + if (buf_len_ == kMaxArrLen) { + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + this->merge_in(val_buf_, idx_buf_); +#pragma unroll + for (int i = 0; i < kMaxArrLen; i++) { + val_buf_[i] = kDummy; + } + buf_len_ = 0; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + this->merge_in(val_buf_, idx_buf_); + } + } + + private: + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T val_buf_[kMaxArrLen]; + IdxT idx_buf_[kMaxArrLen]; + int buf_len_; +}; + +template +auto calc_smem_size_for_block_wide(int num_of_warp, int k) -> int +{ + return Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k) + + ceildiv(num_of_warp, 2) * sizeof(IdxT) * k; +} + +template