From 93e393d9fd810d3f6f3019a02f1dd35adab7fa1e Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Fri, 10 Nov 2023 05:17:21 +0800 Subject: [PATCH] Fix and improve one-block radix select (#1878) - fix matrix::detail::select::radix::calc_chunk_size() for one-block kernel - use `calc_buf_len()` rather than `len` as the buffer size of one-block kernel - reduce register footprint of one-block kernel by reducing the number of buffer pointers - reduce the buffer size by 1/8 for all radix select functions Resolve #1823 Authors: - Yong Wang (https://github.com/yong-wang) - Corey J. Nolet (https://github.com/cjnolet) - Ben Frederickson (https://github.com/benfred) - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Artem M. Chirkin (https://github.com/achirkin) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1878 --- .../raft/matrix/detail/select_radix.cuh | 214 ++++++++++++------ 1 file changed, 150 insertions(+), 64 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index b3c07b9d3a..fa12005df2 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -103,15 +104,27 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) return (twiddle_in(x, select_min) >> start_bit) & mask; } -template +// Strangely, RATIO_T has a strong impact on register usage and occupancy for sm80, e.g. +// using RATIO_T=unsigned for radix_kernel decreases occupancy (with CUDA 12). +// In the meanwhile, RATIO_T has no impact for sm90. +template _RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) { // When writing is skipped, only read `in`(type T). // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) // and `out_idx_buf`(IdxT). // The ratio between these cases determines whether to skip writing and hence the buffer size. - constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); - return len / ratio; + constexpr RATIO_T ratio = 2 + sizeof(IdxT) * 2 / sizeof(T); + // Even such estimation is too conservative, so further decrease buf_len by 1/8 + IdxT buf_len = len / (ratio * 8); + + // one-block kernel splits one large buffer into smaller ones, so round buf size to 256 bytes to + // avoid alignment issues + static_assert(is_a_power_of_two(sizeof(T))); + static_assert(is_a_power_of_two(sizeof(IdxT))); + constexpr IdxT aligned = 256 / std::min(sizeof(T), sizeof(IdxT)); + buf_len = Pow2::roundDown(buf_len); + return buf_len; } /** @@ -208,6 +221,11 @@ struct alignas(128) Counter { /** * Fused filtering of the current pass and building histogram for the next pass * (see steps 4 & 1 in `radix_kernel` description). + * + * This function is more complicated than the one-block counterpart because this function handles + * the case of early stopping. When early stopping is triggered, it's desirable to do the final + * filtering in this function rather than in last_filter(), because this function is run by multiple + * blocks while last_filter is run by a single block. */ template _RAFT_DEVICE void filter_and_histogram(const T* in_buf, @@ -397,7 +415,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf, const int start_bit = calc_start_bit(pass); // changed in choose_bucket(); need to reload - const IdxT needed_num_of_kth = counter->k; + const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { @@ -412,7 +430,7 @@ _RAFT_DEVICE void last_filter(const T* in_buf, out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); - if (back_pos < needed_num_of_kth) { + if (back_pos < num_of_kth_needed) { IdxT pos = k - 1 - back_pos; out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; @@ -428,8 +446,8 @@ RAFT_KERNEL last_filter_kernel(const T* in, const IdxT* in_idx_buf, T* out, IdxT* out_idx, - IdxT len, - IdxT k, + const IdxT len, + const IdxT k, Counter* counters, const bool select_min) { @@ -454,14 +472,14 @@ RAFT_KERNEL last_filter_kernel(const T* in, constexpr int start_bit = calc_start_bit(pass); const auto kth_value_bits = counter->kth_value_bits; - const IdxT needed_num_of_kth = counter->k; + const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; auto f = [k, select_min, kth_value_bits, - needed_num_of_kth, + num_of_kth_needed, p_out_cnt, p_out_back_cnt, in_idx_buf, @@ -474,7 +492,7 @@ RAFT_KERNEL last_filter_kernel(const T* in, out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); - if (back_pos < needed_num_of_kth) { + if (back_pos < num_of_kth_needed) { IdxT pos = k - 1 - back_pos; out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; @@ -657,16 +675,35 @@ RAFT_KERNEL radix_kernel(const T* in, } template -int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) +int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel, bool one_block) { int active_blocks; RAFT_CUDA_TRY( cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); - constexpr int items_per_thread = 32; - constexpr int num_waves = 10; - int chunk_size = - std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + // The chunk size is chosen so that there is enough workload to fully utilize GPU. + // One full wave contains (sm_cnt * active_blocks) blocks, and 10 waves is an empirically safe + // estimation of enough workload. It also counteracts imbalance if some blocks run slower + // than others. + constexpr int num_waves = 10; + int chunk_size; + if (one_block) { + // For one-block version, one block processes one instance in the chunk. Just ensure that there + // are enough blocks. + chunk_size = num_waves * sm_cnt * active_blocks; + } else { + // One instance in the chunk contains len items and is processed by multiple blocks. + // The total number of items in a chunk (chunk_size * len) should be large enough that every + // thread has enough items to processes. So set it to num_waves * "max num of active threads" + // (sm_cnt * active_blocks * BlockSize) * items_per_thread. + // + // Also, the upper bound of the total number of items in a chunk is: + // 10 (num_waves) * ~100 (sm_cnt) * 2048 (active_blocks*BlockSize) * 32 (items_per_thread) =64M. + // So temporary buffer size required for one chunk won't be too large. + constexpr int items_per_thread = 32; + chunk_size = + std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + } return std::min(chunk_size, batch_size); } @@ -709,17 +746,17 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) } template -_RAFT_HOST_DEVICE void set_buf_pointers(const T* in, - const IdxT* in_idx, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) +_RAFT_HOST void set_buf_pointers(const T* in, + const IdxT* in_idx, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) { if (pass == 0) { in_buf = in; @@ -744,6 +781,41 @@ _RAFT_HOST_DEVICE void set_buf_pointers(const T* in, } } +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + in_buf = reinterpret_cast(bufs); + in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + out_buf = const_cast(in_buf + buf_len); + out_idx_buf = const_cast(in_idx_buf + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + in_buf = out_buf + buf_len; + in_idx_buf = out_idx_buf + buf_len; + } +} + template void radix_topk(const T* in, const IdxT* in_idx, @@ -765,7 +837,7 @@ void radix_topk(const T* in, auto kernel = radix_kernel; const size_t max_chunk_size = - calc_chunk_size(batch_size, len, sm_cnt, kernel); + calc_chunk_size(batch_size, len, sm_cnt, kernel, false); if (max_chunk_size != static_cast(batch_size)) { grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); } @@ -793,6 +865,7 @@ void radix_topk(const T* in, RAFT_CUDA_TRY( cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + auto kernel = radix_kernel; const T* chunk_in = in + offset * len; const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; @@ -866,6 +939,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, + const IdxT previous_len, Counter* counter, IdxT* histogram, bool select_min, @@ -879,9 +953,8 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, if (threadIdx.x == 0) { *p_filter_cnt = 0; } __syncthreads(); - const int start_bit = calc_start_bit(pass); - const unsigned mask = calc_mask(pass); - const IdxT previous_len = counter->previous_len; + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); if (pass == 0) { auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { @@ -889,6 +962,20 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, atomicAdd(histogram + bucket, static_cast(1)); }; vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else if (!out_buf) { + // not use vectorized_process here because it increases #registers a lot + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + } } else { // not use vectorized_process here because it increases #registers a lot IdxT* p_out_cnt = &counter->out_cnt; @@ -927,10 +1014,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, T* out, IdxT* out_idx, const bool select_min, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2) + char* bufs) { constexpr int num_buckets = calc_num_buckets(); __shared__ Counter counter; @@ -951,22 +1035,30 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, if (in_idx) { in_idx += batch_id * len; } out += batch_id * k; out_idx += batch_id * k; - buf1 += batch_id * len; - idx_buf1 += batch_id * len; - buf2 += batch_id * len; - idx_buf2 += batch_id * len; - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; + const IdxT buf_len = calc_buf_len(len); + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers( - in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); - - IdxT current_len = counter.len; - IdxT current_k = counter.k; + const T* in_buf; + const IdxT* in_idx_buf; + T* out_buf; + IdxT* out_idx_buf; + set_buf_pointers(in, in_idx, bufs, buf_len, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + + const IdxT current_len = counter.len; + const IdxT current_k = counter.k; + IdxT previous_len = counter.previous_len; + if (previous_len > buf_len) { + in_buf = in; + in_idx_buf = in_idx; + previous_len = len; + } + if (current_len > buf_len) { + // so "out_buf==nullptr" denotes skipping writing buffer in current pass + out_buf = nullptr; + out_idx_buf = nullptr; + } filter_and_histogram_for_one_block(in_buf, in_idx_buf, @@ -974,6 +1066,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_idx_buf, out, out_idx, + previous_len, &counter, histogram, select_min, @@ -988,11 +1081,11 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, __syncthreads(); if (counter.len == counter.k || pass == num_passes - 1) { - last_filter(pass == 0 ? in : out_buf, - pass == 0 ? in_idx : out_idx_buf, + last_filter(out_buf ? out_buf : in, + out_buf ? out_idx_buf : in_idx, out, out_idx, - current_len, + out_buf ? current_len : len, k, &counter, select_min, @@ -1022,21 +1115,17 @@ void radix_topk_one_block(const T* in, { static_assert(calc_num_passes() > 1); - auto kernel = radix_topk_one_block_kernel; + auto kernel = radix_topk_one_block_kernel; + const IdxT buf_len = calc_buf_len(len); const size_t max_chunk_size = - calc_chunk_size(batch_size, len, sm_cnt, kernel); + calc_chunk_size(batch_size, len, sm_cnt, kernel, true); auto pool_guard = - raft::get_pool_memory_resource(mr, - max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + - 256 * 4 // might need extra memory for alignment - ); + raft::get_pool_memory_resource(mr, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT))); if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource"); } - rmm::device_uvector buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector buf2(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector bufs( + max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); @@ -1047,10 +1136,7 @@ void radix_topk_one_block(const T* in, out + offset * k, out_idx + offset * k, select_min, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data()); + bufs.data()); } }