Skip to content

Commit

Permalink
Fix and improve one-block radix select (#1878)
Browse files Browse the repository at this point in the history
- fix matrix::detail::select::radix::calc_chunk_size() for one-block kernel
- use `calc_buf_len()` rather than `len` as the buffer size of one-block kernel
- reduce register footprint of one-block kernel by reducing the number of buffer pointers
- reduce the buffer size by 1/8 for all radix select functions


Resolve #1823

Authors:
  - Yong Wang (https://github.com/yong-wang)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Ben Frederickson (https://github.com/benfred)
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Artem M. Chirkin (https://github.com/achirkin)
  - Ben Frederickson (https://github.com/benfred)

URL: #1878
  • Loading branch information
yong-wang authored Nov 9, 2023
1 parent 9c38633 commit 93e393d
Showing 1 changed file with 150 additions and 64 deletions.
214 changes: 150 additions & 64 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/linalg/map.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/device_atomics.cuh>
#include <raft/util/integer_utils.hpp>
#include <raft/util/pow2_utils.cuh>
#include <raft/util/vectorized.cuh>

Expand Down Expand Up @@ -103,15 +104,27 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min)
return (twiddle_in(x, select_min) >> start_bit) & mask;
}

template <typename T, typename IdxT>
// Strangely, RATIO_T has a strong impact on register usage and occupancy for sm80, e.g.
// using RATIO_T=unsigned for radix_kernel decreases occupancy (with CUDA 12).
// In the meanwhile, RATIO_T has no impact for sm90.
template <typename T, typename IdxT, typename RATIO_T = float>
_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len)
{
// When writing is skipped, only read `in`(type T).
// When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T)
// and `out_idx_buf`(IdxT).
// The ratio between these cases determines whether to skip writing and hence the buffer size.
constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T);
return len / ratio;
constexpr RATIO_T ratio = 2 + sizeof(IdxT) * 2 / sizeof(T);
// Even such estimation is too conservative, so further decrease buf_len by 1/8
IdxT buf_len = len / (ratio * 8);

// one-block kernel splits one large buffer into smaller ones, so round buf size to 256 bytes to
// avoid alignment issues
static_assert(is_a_power_of_two(sizeof(T)));
static_assert(is_a_power_of_two(sizeof(IdxT)));
constexpr IdxT aligned = 256 / std::min(sizeof(T), sizeof(IdxT));
buf_len = Pow2<aligned>::roundDown(buf_len);
return buf_len;
}

/**
Expand Down Expand Up @@ -208,6 +221,11 @@ struct alignas(128) Counter {
/**
* Fused filtering of the current pass and building histogram for the next pass
* (see steps 4 & 1 in `radix_kernel` description).
*
* This function is more complicated than the one-block counterpart because this function handles
* the case of early stopping. When early stopping is triggered, it's desirable to do the final
* filtering in this function rather than in last_filter(), because this function is run by multiple
* blocks while last_filter is run by a single block.
*/
template <typename T, typename IdxT, int BitsPerPass>
_RAFT_DEVICE void filter_and_histogram(const T* in_buf,
Expand Down Expand Up @@ -397,7 +415,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
const int start_bit = calc_start_bit<T, BitsPerPass>(pass);

// changed in choose_bucket(); need to reload
const IdxT needed_num_of_kth = counter->k;
const IdxT num_of_kth_needed = counter->k;
IdxT* p_out_cnt = &counter->out_cnt;
IdxT* p_out_back_cnt = &counter->out_back_cnt;
for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) {
Expand All @@ -412,7 +430,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf,
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
} else if (bits == kth_value_bits) {
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if (back_pos < needed_num_of_kth) {
if (back_pos < num_of_kth_needed) {
IdxT pos = k - 1 - back_pos;
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
Expand All @@ -428,8 +446,8 @@ RAFT_KERNEL last_filter_kernel(const T* in,
const IdxT* in_idx_buf,
T* out,
IdxT* out_idx,
IdxT len,
IdxT k,
const IdxT len,
const IdxT k,
Counter<T, IdxT>* counters,
const bool select_min)
{
Expand All @@ -454,14 +472,14 @@ RAFT_KERNEL last_filter_kernel(const T* in,
constexpr int start_bit = calc_start_bit<T, BitsPerPass>(pass);

const auto kth_value_bits = counter->kth_value_bits;
const IdxT needed_num_of_kth = counter->k;
const IdxT num_of_kth_needed = counter->k;
IdxT* p_out_cnt = &counter->out_cnt;
IdxT* p_out_back_cnt = &counter->out_back_cnt;

auto f = [k,
select_min,
kth_value_bits,
needed_num_of_kth,
num_of_kth_needed,
p_out_cnt,
p_out_back_cnt,
in_idx_buf,
Expand All @@ -474,7 +492,7 @@ RAFT_KERNEL last_filter_kernel(const T* in,
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
} else if (bits == kth_value_bits) {
IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast<IdxT>(1));
if (back_pos < needed_num_of_kth) {
if (back_pos < num_of_kth_needed) {
IdxT pos = k - 1 - back_pos;
out[pos] = value;
out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i;
Expand Down Expand Up @@ -657,16 +675,35 @@ RAFT_KERNEL radix_kernel(const T* in,
}

template <typename T, typename IdxT, int BlockSize, typename Kernel>
int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel)
int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel, bool one_block)
{
int active_blocks;
RAFT_CUDA_TRY(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0));

constexpr int items_per_thread = 32;
constexpr int num_waves = 10;
int chunk_size =
std::max<int>(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len);
// The chunk size is chosen so that there is enough workload to fully utilize GPU.
// One full wave contains (sm_cnt * active_blocks) blocks, and 10 waves is an empirically safe
// estimation of enough workload. It also counteracts imbalance if some blocks run slower
// than others.
constexpr int num_waves = 10;
int chunk_size;
if (one_block) {
// For one-block version, one block processes one instance in the chunk. Just ensure that there
// are enough blocks.
chunk_size = num_waves * sm_cnt * active_blocks;
} else {
// One instance in the chunk contains len items and is processed by multiple blocks.
// The total number of items in a chunk (chunk_size * len) should be large enough that every
// thread has enough items to processes. So set it to num_waves * "max num of active threads"
// (sm_cnt * active_blocks * BlockSize) * items_per_thread.
//
// Also, the upper bound of the total number of items in a chunk is:
// 10 (num_waves) * ~100 (sm_cnt) * 2048 (active_blocks*BlockSize) * 32 (items_per_thread) =64M.
// So temporary buffer size required for one chunk won't be too large.
constexpr int items_per_thread = 32;
chunk_size =
std::max<int>(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len);
}
return std::min(chunk_size, batch_size);
}

Expand Down Expand Up @@ -709,17 +746,17 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt)
}

template <typename T, typename IdxT>
_RAFT_HOST_DEVICE void set_buf_pointers(const T* in,
const IdxT* in_idx,
T* buf1,
IdxT* idx_buf1,
T* buf2,
IdxT* idx_buf2,
int pass,
const T*& in_buf,
const IdxT*& in_idx_buf,
T*& out_buf,
IdxT*& out_idx_buf)
_RAFT_HOST void set_buf_pointers(const T* in,
const IdxT* in_idx,
T* buf1,
IdxT* idx_buf1,
T* buf2,
IdxT* idx_buf2,
int pass,
const T*& in_buf,
const IdxT*& in_idx_buf,
T*& out_buf,
IdxT*& out_idx_buf)
{
if (pass == 0) {
in_buf = in;
Expand All @@ -744,6 +781,41 @@ _RAFT_HOST_DEVICE void set_buf_pointers(const T* in,
}
}

template <typename T, typename IdxT>
_RAFT_DEVICE void set_buf_pointers(const T* in,
const IdxT* in_idx,
char* bufs,
IdxT buf_len,
int pass,
const T*& in_buf,
const IdxT*& in_idx_buf,
T*& out_buf,
IdxT*& out_idx_buf)
{
// bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2
if (pass == 0) {
in_buf = in;
in_idx_buf = nullptr;
out_buf = nullptr;
out_idx_buf = nullptr;
} else if (pass == 1) {
in_buf = in;
in_idx_buf = in_idx;
out_buf = reinterpret_cast<T*>(bufs);
out_idx_buf = reinterpret_cast<IdxT*>(bufs + sizeof(T) * 2 * buf_len);
} else if (pass % 2 == 0) {
in_buf = reinterpret_cast<T*>(bufs);
in_idx_buf = reinterpret_cast<IdxT*>(bufs + sizeof(T) * 2 * buf_len);
out_buf = const_cast<T*>(in_buf + buf_len);
out_idx_buf = const_cast<IdxT*>(in_idx_buf + buf_len);
} else {
out_buf = reinterpret_cast<T*>(bufs);
out_idx_buf = reinterpret_cast<IdxT*>(bufs + sizeof(T) * 2 * buf_len);
in_buf = out_buf + buf_len;
in_idx_buf = out_idx_buf + buf_len;
}
}

template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void radix_topk(const T* in,
const IdxT* in_idx,
Expand All @@ -765,7 +837,7 @@ void radix_topk(const T* in,

auto kernel = radix_kernel<T, IdxT, BitsPerPass, BlockSize, false>;
const size_t max_chunk_size =
calc_chunk_size<T, IdxT, BlockSize>(batch_size, len, sm_cnt, kernel);
calc_chunk_size<T, IdxT, BlockSize>(batch_size, len, sm_cnt, kernel, false);
if (max_chunk_size != static_cast<size_t>(batch_size)) {
grid_dim = calc_grid_dim<T, IdxT, BitsPerPass, BlockSize>(max_chunk_size, len, sm_cnt);
}
Expand Down Expand Up @@ -793,6 +865,7 @@ void radix_topk(const T* in,
RAFT_CUDA_TRY(
cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter<T, IdxT>), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream));
auto kernel = radix_kernel<T, IdxT, BitsPerPass, BlockSize, false>;

const T* chunk_in = in + offset * len;
const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr;
Expand Down Expand Up @@ -866,6 +939,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf,
IdxT* out_idx_buf,
T* out,
IdxT* out_idx,
const IdxT previous_len,
Counter<T, IdxT>* counter,
IdxT* histogram,
bool select_min,
Expand All @@ -879,16 +953,29 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf,
if (threadIdx.x == 0) { *p_filter_cnt = 0; }
__syncthreads();

const int start_bit = calc_start_bit<T, BitsPerPass>(pass);
const unsigned mask = calc_mask<T, BitsPerPass>(pass);
const IdxT previous_len = counter->previous_len;
const int start_bit = calc_start_bit<T, BitsPerPass>(pass);
const unsigned mask = calc_mask<T, BitsPerPass>(pass);

if (pass == 0) {
auto f = [histogram, select_min, start_bit, mask](T value, IdxT) {
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
};
vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f);
} else if (!out_buf) {
// not use vectorized_process here because it increases #registers a lot
const auto kth_value_bits = counter->kth_value_bits;
const int previous_start_bit = calc_start_bit<T, BitsPerPass>(pass - 1);

for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) {
const T value = in_buf[i];
const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if (previous_bits == kth_value_bits) {
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
}
}
} else {
// not use vectorized_process here because it increases #registers a lot
IdxT* p_out_cnt = &counter->out_cnt;
Expand Down Expand Up @@ -927,10 +1014,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
T* out,
IdxT* out_idx,
const bool select_min,
T* buf1,
IdxT* idx_buf1,
T* buf2,
IdxT* idx_buf2)
char* bufs)
{
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
__shared__ Counter<T, IdxT> counter;
Expand All @@ -951,29 +1035,38 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
if (in_idx) { in_idx += batch_id * len; }
out += batch_id * k;
out_idx += batch_id * k;
buf1 += batch_id * len;
idx_buf1 += batch_id * len;
buf2 += batch_id * len;
idx_buf2 += batch_id * len;
const T* in_buf = nullptr;
const IdxT* in_idx_buf = nullptr;
T* out_buf = nullptr;
IdxT* out_idx_buf = nullptr;
const IdxT buf_len = calc_buf_len<T, IdxT, unsigned>(len);
bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT));

constexpr int num_passes = calc_num_passes<T, BitsPerPass>();
for (int pass = 0; pass < num_passes; ++pass) {
set_buf_pointers(
in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf);

IdxT current_len = counter.len;
IdxT current_k = counter.k;
const T* in_buf;
const IdxT* in_idx_buf;
T* out_buf;
IdxT* out_idx_buf;
set_buf_pointers(in, in_idx, bufs, buf_len, pass, in_buf, in_idx_buf, out_buf, out_idx_buf);

const IdxT current_len = counter.len;
const IdxT current_k = counter.k;
IdxT previous_len = counter.previous_len;
if (previous_len > buf_len) {
in_buf = in;
in_idx_buf = in_idx;
previous_len = len;
}
if (current_len > buf_len) {
// so "out_buf==nullptr" denotes skipping writing buffer in current pass
out_buf = nullptr;
out_idx_buf = nullptr;
}

filter_and_histogram_for_one_block<T, IdxT, BitsPerPass>(in_buf,
in_idx_buf,
out_buf,
out_idx_buf,
out,
out_idx,
previous_len,
&counter,
histogram,
select_min,
Expand All @@ -988,11 +1081,11 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in,
__syncthreads();

if (counter.len == counter.k || pass == num_passes - 1) {
last_filter<T, IdxT, BitsPerPass>(pass == 0 ? in : out_buf,
pass == 0 ? in_idx : out_idx_buf,
last_filter<T, IdxT, BitsPerPass>(out_buf ? out_buf : in,
out_buf ? out_idx_buf : in_idx,
out,
out_idx,
current_len,
out_buf ? current_len : len,
k,
&counter,
select_min,
Expand Down Expand Up @@ -1022,21 +1115,17 @@ void radix_topk_one_block(const T* in,
{
static_assert(calc_num_passes<T, BitsPerPass>() > 1);

auto kernel = radix_topk_one_block_kernel<T, IdxT, BitsPerPass, BlockSize>;
auto kernel = radix_topk_one_block_kernel<T, IdxT, BitsPerPass, BlockSize>;
const IdxT buf_len = calc_buf_len<T, IdxT, unsigned>(len);
const size_t max_chunk_size =
calc_chunk_size<T, IdxT, BlockSize>(batch_size, len, sm_cnt, kernel);
calc_chunk_size<T, IdxT, BlockSize>(batch_size, len, sm_cnt, kernel, true);

auto pool_guard =
raft::get_pool_memory_resource(mr,
max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) +
256 * 4 // might need extra memory for alignment
);
raft::get_pool_memory_resource(mr, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)));
if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource"); }

rmm::device_uvector<T> buf1(len * max_chunk_size, stream, mr);
rmm::device_uvector<IdxT> idx_buf1(len * max_chunk_size, stream, mr);
rmm::device_uvector<T> buf2(len * max_chunk_size, stream, mr);
rmm::device_uvector<IdxT> idx_buf2(len * max_chunk_size, stream, mr);
rmm::device_uvector<char> bufs(
max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr);

for (size_t offset = 0; offset < static_cast<size_t>(batch_size); offset += max_chunk_size) {
int chunk_size = std::min(max_chunk_size, batch_size - offset);
Expand All @@ -1047,10 +1136,7 @@ void radix_topk_one_block(const T* in,
out + offset * k,
out_idx + offset * k,
select_min,
buf1.data(),
idx_buf1.data(),
buf2.data(),
idx_buf2.data());
bufs.data());
}
}

Expand Down

0 comments on commit 93e393d

Please sign in to comment.