From 4a20d03af7f6181e3083bc3b65522d7f2c3b6218 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 8 Apr 2024 09:36:00 -0700 Subject: [PATCH] [FEA] Add support for `select_k` on CSR matrix (#2140) - This PR is one part of the feature of #1969 - Add the API of 'select_k' accepting CSR as input Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Authors: - rhdong (https://github.com/rhdong) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2140 --- cpp/bench/prims/CMakeLists.txt | 1 + cpp/bench/prims/sparse/select_k_csr.cu | 287 ++++++++++++ .../raft/matrix/detail/select_radix.cuh | 427 ++++++++++-------- .../raft/matrix/detail/select_warpsort.cuh | 55 ++- .../sparse/matrix/detail/select_k-ext.cuh | 67 +++ .../sparse/matrix/detail/select_k-inl.cuh | 225 +++++++++ .../raft/sparse/matrix/detail/select_k.cuh | 24 + cpp/include/raft/sparse/matrix/select_k.cuh | 87 ++++ .../matrix/detail/select_k_double_int64_t.cu | 32 ++ .../matrix/detail/select_k_double_uint32_t.cu | 34 ++ .../matrix/detail/select_k_float_int32.cu | 32 ++ .../matrix/detail/select_k_float_int64_t.cu | 32 ++ .../matrix/detail/select_k_float_uint32_t.cu | 32 ++ .../matrix/detail/select_k_half_int64_t.cu | 32 ++ .../matrix/detail/select_k_half_uint32_t.cu | 32 ++ cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/select_k_csr.cu | 398 ++++++++++++++++ 17 files changed, 1600 insertions(+), 198 deletions(-) create mode 100644 cpp/bench/prims/sparse/select_k_csr.cu create mode 100644 cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh create mode 100644 cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh create mode 100644 cpp/include/raft/sparse/matrix/detail/select_k.cuh create mode 100644 cpp/include/raft/sparse/matrix/select_k.cuh create mode 100644 cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int32.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu create mode 100644 cpp/test/sparse/select_k_csr.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 9f23c44a5c..0c5521d447 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -137,6 +137,7 @@ if(BUILD_PRIMS_BENCH) PATH bench/prims/sparse/bitmap_to_csr.cu bench/prims/sparse/convert_csr.cu + bench/prims/sparse/select_k_csr.cu bench/prims/main.cpp ) diff --git a/cpp/bench/prims/sparse/select_k_csr.cu b/cpp/bench/prims/sparse/select_k_csr.cu new file mode 100644 index 0000000000..a91e6c8514 --- /dev/null +++ b/cpp/bench/prims/sparse/select_k_csr.cu @@ -0,0 +1,287 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_rows; + index_t n_cols; + index_t top_k; + float sparsity; + bool select_min = true; + bool customized_indices = false; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << params.n_rows << "#" << params.n_cols << "#" << params.top_k << "#" << params.sparsity; + return os; +} + +template +struct SelectKCsrTest : public fixture { + SelectKCsrTest(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + values_d(0, stream), + indptr_d(0, stream), + indices_d(0, stream), + customized_indices_d(0, stream), + dst_values_d(0, stream), + dst_indices_d(0, stream) + { + std::vector dense_values_h(params.n_rows * params.n_cols); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, dense_values_h); + + std::vector indices_h(nnz); + std::vector customized_indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + convert_to_csr(dense_values_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + std::vector dst_values_h(params.n_rows * params.top_k, static_cast(2.0f)); + std::vector dst_indices_h(params.n_rows * params.top_k, + static_cast(params.n_rows * params.n_cols * 100)); + + dst_values_d.resize(params.n_rows * params.top_k, stream); + dst_indices_d.resize(params.n_rows * params.top_k, stream); + values_d.resize(nnz, stream); + + if (nnz) { + auto blobs_values = raft::make_device_matrix(handle, 1, nnz); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_values.data_handle(), + labels.data_handle(), + 1, + nnz, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-10.0f), + value_t(10.0f), + uint64_t(2024)); + raft::copy(values_d.data(), blobs_values.data_handle(), nnz, stream); + resource::sync_stream(handle); + } + + indices_d.resize(nnz, stream); + indptr_d.resize(params.n_rows + 1, stream); + + update_device(indices_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_d.data(), indptr_h.data(), indptr_h.size(), stream); + + if (params.customized_indices) { + customized_indices_d.resize(nnz, stream); + update_device(customized_indices_d.data(), + customized_indices_h.data(), + customized_indices_h.size(), + stream); + } + } + + index_t create_sparse_matrix(index_t m, index_t n, value_t sparsity, std::vector& matrix) + { + index_t total_elements = static_cast(m * n); + index_t num_ones = static_cast((total_elements * 1.0f) * sparsity); + index_t res = num_ones; + + for (index_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_idx(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis_idx(gen); + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + template + std::optional get_opt_var(data_t x) + { + if (params.customized_indices) { + return x; + } else { + return std::nullopt; + } + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto in_val_structure = raft::make_device_compressed_structure_view( + indptr_d.data(), + indices_d.data(), + params.n_rows, + params.n_cols, + static_cast(indices_d.size())); + + auto in_val = + raft::make_device_csr_matrix_view(values_d.data(), in_val_structure); + + std::optional> in_idx; + + in_idx = get_opt_var( + raft::make_device_vector_view(customized_indices_d.data(), nnz)); + + auto out_val = raft::make_device_matrix_view( + dst_values_d.data(), params.n_rows, params.top_k); + auto out_idx = raft::make_device_matrix_view( + dst_indices_d.data(), params.n_rows, params.top_k); + + raft::sparse::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + resource::sync_stream(handle); + loop_on_state(state, [this, &in_val, &in_idx, &out_val, &out_idx]() { + raft::sparse::matrix::select_k( + handle, in_val, in_idx, out_val, out_idx, params.select_min, false); + resource::sync_stream(handle); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + index_t nnz; + + rmm::device_uvector values_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector customized_indices_d; + + rmm::device_uvector dst_values_d; + rmm::device_uvector dst_indices_d; +}; // struct SelectKCsrTest + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + index_t k; + }; + + const std::vector params_group{ + {20000, 500, 1}, {20000, 500, 2}, {20000, 500, 4}, {20000, 500, 8}, + {20000, 500, 16}, {20000, 500, 32}, {20000, 500, 64}, {20000, 500, 128}, + {20000, 500, 256}, + + {1000, 10000, 1}, {1000, 10000, 2}, {1000, 10000, 4}, {1000, 10000, 8}, + {1000, 10000, 16}, {1000, 10000, 32}, {1000, 10000, 64}, {1000, 10000, 128}, + {1000, 10000, 256}, + + {100, 100000, 1}, {100, 100000, 2}, {100, 100000, 4}, {100, 100000, 8}, + {100, 100000, 16}, {100, 100000, 32}, {100, 100000, 64}, {100, 100000, 128}, + {100, 100000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}}; + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.1})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.2})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.5})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SelectKCsrTest), "", getInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 36a346fda3..83d4845c31 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -442,14 +442,76 @@ _RAFT_DEVICE void last_filter(const T* in_buf, } } -template +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + in_buf = reinterpret_cast(bufs); + in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + out_buf = const_cast(in_buf + buf_len); + out_idx_buf = const_cast(in_idx_buf + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + in_buf = out_buf + buf_len; + in_idx_buf = out_idx_buf + buf_len; + } +} + +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + const int pass, + const T*& out_buf, + const IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + out_buf = const_cast(reinterpret_cast(bufs) + buf_len); + out_idx_buf = + const_cast(reinterpret_cast(bufs + sizeof(T) * 2 * buf_len) + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } +} + +template RAFT_KERNEL last_filter_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, const IdxT len, + const IdxT* len_i, const IdxT k, Counter* counters, const bool select_min) @@ -458,22 +520,31 @@ RAFT_KERNEL last_filter_kernel(const T* in, Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + const IdxT buf_len = calc_buf_len(len); - if (previous_len > buf_len || in_buf == in) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; - } - out += batch_id * k; - out_idx += batch_id * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); constexpr int pass = calc_num_passes() - 1; constexpr int start_bit = calc_start_bit(pass); + set_buf_pointers(in + l_offset, in_idx + l_offset, bufs, buf_len, pass, in_buf, in_idx_buf); + + if (previous_len > buf_len || in_buf == in + l_offset) { + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; + } + out += batch_id * k; + out_idx += batch_id * k; + const auto kth_value_bits = counter->kth_value_bits; const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; @@ -510,6 +581,29 @@ RAFT_KERNEL last_filter_kernel(const T* in, f); } +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_val( + T* dest, const T* src, S len, IdxT k, const bool select_min) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + const T default_val = select_min ? upper_bound() : lower_bound(); + for (S i = idx; i < k; i += stride) { + dest[i] = i < len ? src[i] : default_val; + } +} + +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_idx(T* dest, const T* src, S len) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + + for (S i = idx; i < len; i += stride) { + dest[i] = src ? src[i] : i; + } +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -545,13 +639,16 @@ RAFT_KERNEL last_filter_kernel(const T* in, * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and * their indices. */ -template +template RAFT_KERNEL radix_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, Counter* counters, @@ -567,21 +664,38 @@ RAFT_KERNEL radix_kernel(const T* in, IdxT current_k; IdxT previous_len; IdxT current_len; + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + if (pass == 0) { current_k = k; - previous_len = len; + previous_len = l_len; // Need to do this so setting counter->previous_len for the next pass is correct. // This value is meaningless for pass 0, but it's fine because pass 0 won't be the // last pass in this implementation so pass 0 won't hit the "if (pass == // num_passes - 1)" branch. // Maybe it's better to reload counter->previous_len and use it rather than // current_len in last_filter() - current_len = len; + current_len = l_len; } else { current_k = counter->k; current_len = counter->len; previous_len = counter->previous_len; } + if constexpr (!len_or_indptr) { + if (pass == 0 && l_len <= k) { + copy_in_val(out + batch_id * k, in + l_offset, l_len, k, select_min); + copy_in_idx(out_idx + batch_id * k, (in_idx ? (in_idx + l_offset) : nullptr), l_len); + if (threadIdx.x == 0) { + counter->previous_len = 0; + counter->len = 0; + } + __syncthreads(); + return; + } + } + if (current_len == 0) { return; } // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle @@ -590,20 +704,33 @@ RAFT_KERNEL radix_kernel(const T* in, const bool early_stop = (current_len == current_k); const IdxT buf_len = calc_buf_len(len); + const T* in_buf; + const IdxT* in_idx_buf; + T* out_buf; + IdxT* out_idx_buf; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + set_buf_pointers(in + l_offset, + (in_idx ? (in_idx + l_offset) : nullptr), + bufs, + buf_len, + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + // "previous_len > buf_len" means previous pass skips writing buffer if (pass == 0 || pass == 1 || previous_len > buf_len) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; } // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -611,9 +738,6 @@ RAFT_KERNEL radix_kernel(const T* in, if (pass == 0 || current_len > buf_len) { out_buf = nullptr; out_idx_buf = nullptr; - } else { - out_buf += batch_id * buf_len; - out_idx_buf += batch_id * buf_len; } out += batch_id * k; out_idx += batch_id * k; @@ -640,7 +764,6 @@ RAFT_KERNEL radix_kernel(const T* in, unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); isLastBlock = (finished == (gridDim.x - 1)); } - if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { @@ -676,7 +799,7 @@ RAFT_KERNEL radix_kernel(const T* in, out_idx_buf ? out_idx_buf : in_idx_buf, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, counter, select_min, @@ -726,7 +849,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &active_blocks, radix_kernel, BlockSize, 0)); + &active_blocks, radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; IdxT best_num_blocks = 0; @@ -757,78 +880,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) return best_num_blocks; } -template -_RAFT_HOST void set_buf_pointers(const T* in, - const IdxT* in_idx, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = buf1; - out_idx_buf = idx_buf1; - } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; - } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; - } -} - -template -_RAFT_DEVICE void set_buf_pointers(const T* in, - const IdxT* in_idx, - char* bufs, - IdxT buf_len, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - } else if (pass % 2 == 0) { - in_buf = reinterpret_cast(bufs); - in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - out_buf = const_cast(in_buf + buf_len); - out_idx_buf = const_cast(in_idx_buf + buf_len); - } else { - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - in_buf = out_buf + buf_len; - in_idx_buf = out_idx_buf + buf_len; - } -} - -template +template void radix_topk(const T* in, const IdxT* in_idx, int batch_size, @@ -850,7 +902,7 @@ void radix_topk(const T* in, if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - auto kernel = radix_kernel; + auto kernel = radix_kernel; const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, false); if (max_chunk_size != static_cast(batch_size)) { @@ -862,55 +914,33 @@ void radix_topk(const T* in, rmm::device_uvector> counters(max_chunk_size, stream, mr); rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + rmm::device_uvector bufs( + max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); RAFT_CUDA_TRY( cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - auto kernel = radix_kernel; + auto kernel = radix_kernel; - const T* chunk_in = in + offset * len; - const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; - T* chunk_out = out + offset * k; - IdxT* chunk_out_idx = out_idx + offset * k; - const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; dim3 blocks(grid_dim, chunk_size); constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers(chunk_in, - chunk_in_idx, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data(), - pass, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf); - if (fused_last_filter && pass == num_passes - 1) { - kernel = radix_kernel; + kernel = radix_kernel; } - kernel<<>>(chunk_in, - chunk_in_idx, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, + kernel<<>>(in, + in_idx, + bufs.data(), + offset, chunk_out, chunk_out_idx, counters.data(), @@ -924,16 +954,18 @@ void radix_topk(const T* in, } if (!fused_last_filter) { - last_filter_kernel<<>>(chunk_in, - chunk_in_idx, - out_buf, - out_idx_buf, - chunk_out, - chunk_out_idx, - len, - k, - counters.data(), - select_min); + last_filter_kernel + <<>>(in, + in_idx, + bufs.data(), + offset, + chunk_out, + chunk_out_idx, + len, + chunk_len_i, + k, + counters.data(), + select_min); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } @@ -1015,7 +1047,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, } } -template +template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, @@ -1024,30 +1056,48 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, T* out, IdxT* out_idx, const bool select_min, - char* bufs) + char* bufs, + size_t offset) { constexpr int num_buckets = calc_num_buckets(); __shared__ Counter counter; __shared__ IdxT histogram[num_buckets]; + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + + IdxT l_len = len; + IdxT l_offset = (offset + batch_id) * len; + if constexpr (!len_or_indptr) { + l_offset = len_i[batch_id]; + l_len = len_i[batch_id + 1] - l_offset; + } + if (threadIdx.x == 0) { counter.k = k; - counter.len = len; - counter.previous_len = len; + counter.len = l_len; + counter.previous_len = l_len; counter.kth_value_bits = 0; counter.out_cnt = 0; counter.out_back_cnt = 0; } __syncthreads(); - const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow - in += batch_id * len; - if (in_idx) { in_idx += batch_id * len; } + in += l_offset; + if (in_idx) { in_idx += l_offset; } out += batch_id * k; out_idx += batch_id * k; const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + if constexpr (!len_or_indptr) { + if (l_len <= k) { + copy_in_val(out, in, l_len, k, select_min); + copy_in_idx(out_idx, in_idx, l_len); + __syncthreads(); + return; + } + } + constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { const T* in_buf; @@ -1073,7 +1123,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -1102,7 +1152,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_buf ? out_idx_buf : in_idx, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, &counter, select_min, @@ -1117,7 +1167,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // counters and global histograms, can be kept in shared memory and cheap sync operations can be // used. It's used when len is relatively small or when the number of blocks per row calculated by // `calc_grid_dim()` is 1. -template +template void radix_topk_one_block(const T* in, const IdxT* in_idx, int batch_size, @@ -1133,7 +1183,7 @@ void radix_topk_one_block(const T* in, { static_assert(calc_num_passes() > 1); - auto kernel = radix_topk_one_block_kernel; + auto kernel = radix_topk_one_block_kernel; const IdxT buf_len = calc_buf_len(len); const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, true); @@ -1144,15 +1194,16 @@ void radix_topk_one_block(const T* in, for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - kernel<<>>(in + offset * len, - in_idx ? (in_idx + offset * len) : nullptr, + kernel<<>>(in, + in_idx, len, chunk_len_i, k, out + offset * k, out_idx + offset * k, select_min, - bufs.data()); + bufs.data(), + offset); } } @@ -1182,6 +1233,10 @@ void radix_topk_one_block(const T* in, * it affects the number of passes and number of buckets. * @tparam BlockSize * Number of threads in a kernel thread block. + * @tparam len_or_indptr + * Flag to interpret `len_i` as either direct row lengths (true) or CSR format + * index pointers (false). When true, each `len_i` element denotes the length of a row. When + * false, `len_i` represents the index pointers for a CSR matrix with shape of `batch_size + 1`. * * @param[in] res container of reusable resources * @param[in] in @@ -1212,9 +1267,12 @@ void radix_topk_one_block(const T* in, * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param len_i - * optional array of size (batch_size) providing lengths for each individual row + * Optional array used differently based on `len_or_indptr`: + * When `len_or_indptr` is true, `len_i` presents the lengths of each row, which is `batch_size`. + * When `len_or_indptr` is false, `len_i` works like a indptr for a CSR matrix. The length of each + * row would be (`len_i[row_id + 1] - len_i[row_id]`). `len_i` size is `batch_size + 1`. */ -template +template void select_k(raft::resources const& res, const T* in, const IdxT* in_idx, @@ -1227,9 +1285,12 @@ void select_k(raft::resources const& res, bool fused_last_filter, const IdxT* len_i) { + RAFT_EXPECTS(!(!len_or_indptr && (len_i == nullptr)), + "When `len_or_indptr` is false, `len_i` must not be nullptr!"); + auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); - if (k == len) { + if (k == len && len_or_indptr) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { @@ -1248,29 +1309,29 @@ void select_k(raft::resources const& res, constexpr int items_per_thread = 32; if (len <= BlockSize * items_per_thread) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { - impl::radix_topk(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - fused_last_filter, - len_i, - grid_dim, - sm_cnt, - stream, - mr); + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + len_i, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 572558153d..2cb32585d5 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -754,22 +754,32 @@ template