From 50719712889137fb451e5bdb8eab4f8f6fb80408 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 31 Jan 2024 08:06:14 -0800 Subject: [PATCH 01/12] [FEA] Add support for `select_k` on CSR matrix - This PR is one part of the feature of #1969 - Add the API of 'select_k' accepting CSR as input - Add the API of 'segmented_copy' Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) --- cpp/bench/prims/CMakeLists.txt | 1 + cpp/bench/prims/matrix/select_k_csr.cu | 257 +++++++++++++ cpp/include/raft/matrix/copy.cuh | 37 +- cpp/include/raft/matrix/detail/matrix.cuh | 67 +++- .../raft/matrix/detail/select_k-ext.cuh | 30 ++ .../raft/matrix/detail/select_k-inl.cuh | 99 +++++ cpp/include/raft/matrix/select_k.cuh | 39 ++ .../matrix/detail/select_k_double_int64_t.cu | 14 + .../matrix/detail/select_k_double_uint32_t.cu | 14 + cpp/src/matrix/detail/select_k_float_int32.cu | 14 + .../matrix/detail/select_k_float_int64_t.cu | 14 + .../matrix/detail/select_k_float_uint32_t.cu | 14 + .../matrix/detail/select_k_half_int64_t.cu | 14 + .../matrix/detail/select_k_half_uint32_t.cu | 14 + cpp/test/CMakeLists.txt | 9 +- cpp/test/matrix/copy.cu | 253 +++++++++++++ cpp/test/matrix/select_k_csr.cu | 350 ++++++++++++++++++ 17 files changed, 1237 insertions(+), 3 deletions(-) create mode 100644 cpp/bench/prims/matrix/select_k_csr.cu create mode 100644 cpp/test/matrix/copy.cu create mode 100644 cpp/test/matrix/select_k_csr.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 3a2431cd34..253bc6c2e0 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -128,6 +128,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu bench/prims/matrix/select_k.cu + bench/prims/matrix/select_k_csr.cu bench/prims/matrix/main.cpp OPTIONAL LIB diff --git a/cpp/bench/prims/matrix/select_k_csr.cu b/cpp/bench/prims/matrix/select_k_csr.cu new file mode 100644 index 0000000000..99c59f4fde --- /dev/null +++ b/cpp/bench/prims/matrix/select_k_csr.cu @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_rows; + index_t n_cols; + index_t top_k; + float sparsity; + bool select_min = true; + bool customized_indices = false; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << " rows*cols=" << params.n_rows << "*" << params.n_cols << "\ttop_k=" << params.top_k + << "\tsparsity=" << params.sparsity; + return os; +} + +template +struct SelectKCsrTest : public fixture { + SelectKCsrTest(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + values_d(0, stream), + indptr_d(0, stream), + indices_d(0, stream), + customized_indices_d(0, stream), + dst_values_d(0, stream), + dst_indices_d(0, stream) + { + std::vector dense_values_h(params.n_rows * params.n_cols, false); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, dense_values_h); + + std::vector indices_h(nnz); + std::vector customized_indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + convert_to_csr(dense_values_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + std::vector dst_values_h(params.n_rows * params.top_k, static_cast(2.0f)); + std::vector dst_indices_h(params.n_rows * params.top_k, + static_cast(params.n_rows * params.n_cols * 100)); + + dst_values_d.resize(params.n_rows * params.top_k, stream); + dst_indices_d.resize(params.n_rows * params.top_k, stream); + values_d.resize(nnz, stream); + + if (nnz) { + auto blobs_values = raft::make_device_matrix(handle, 1, nnz); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_values.data_handle(), + labels.data_handle(), + 1, + nnz, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-10.0f), + value_t(10.0f), + uint64_t(2024)); + raft::copy(values_d.data(), blobs_values.data_handle(), nnz, stream); + resource::sync_stream(handle); + } + + indices_d.resize(nnz, stream); + indptr_d.resize(params.n_rows + 1, stream); + + update_device(indices_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_d.data(), indptr_h.data(), indptr_h.size(), stream); + + if (params.customized_indices) { + customized_indices_d.resize(nnz, stream); + update_device(customized_indices_d.data(), + customized_indices_h.data(), + customized_indices_h.size(), + stream); + } + } + + index_t create_sparse_matrix(index_t m, index_t n, value_t sparsity, std::vector& matrix) + { + index_t total_elements = static_cast(m * n); + index_t num_ones = static_cast((total_elements * 1.0f) * sparsity); + index_t res = num_ones; + + for (index_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_idx(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis_idx(gen); + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + template + std::optional get_opt_var(data_t x) + { + if (params.customized_indices) { + return x; + } else { + return std::nullopt; + } + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto in_val_structure = raft::make_device_compressed_structure_view( + indptr_d.data(), + indices_d.data(), + params.n_rows, + params.n_cols, + static_cast(indices_d.size())); + + auto in_val = + raft::make_device_csr_matrix_view(values_d.data(), in_val_structure); + + std::optional> in_idx; + + in_idx = get_opt_var( + raft::make_device_vector_view(customized_indices_d.data(), nnz)); + + auto out_val = raft::make_device_matrix_view( + dst_values_d.data(), params.n_rows, params.top_k); + auto out_idx = raft::make_device_matrix_view( + dst_indices_d.data(), params.n_rows, params.top_k); + + raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + resource::sync_stream(handle); + loop_on_state(state, [this, &in_val, &in_idx, &out_val, &out_idx]() { + raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + resource::sync_stream(handle); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + index_t nnz; + + rmm::device_uvector values_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector customized_indices_d; + + rmm::device_uvector dst_values_d; + rmm::device_uvector dst_indices_d; +}; // struct SelectKCsrTest + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + index_t k; + float sparsity; + }; + + const std::vector params_group = + raft::util::itertools::product({index_t(10), index_t(1024)}, + {index_t(1024 * 10), index_t(1024 * 1024)}, + {index_t(128), index_t(100), index_t(2048)}, + {0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, params.sparsity})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SelectKCsrTest), "", getInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index be83a4a19e..785ff84b56 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -122,6 +122,41 @@ void trunc_zero_origin(raft::resources const& handle, resource::get_cuda_stream(handle)); } +/** + * @brief Copy a specific number of elements row by row from the source vector to the target matrix + * according to the segment indicated by offsets + * + * @tparam m_t the type of the copied items. + * @tparam idx_t the index type of vectors and matrix. + * @param[in] handle raft handle + * @param[in] max_len_per_row Maximum number of copies per row + * @param[in] src Source vector + * @param[in] offsets Indicates the starting and ending index of each row in the vector + * @param[out] dst Destination matrix in row major order + * + * @note When the length of one segment is less than max_len_per_row, the remaining position values + * of dst will remain unchanged. + */ +template +void segmented_copy(raft::resources const& handle, + idx_t max_len_per_row, + raft::device_vector_view src, + raft::device_vector_view offsets, + raft::device_matrix_view dst) +{ + RAFT_EXPECTS(static_cast(offsets.size()) == (dst.extent(0) + 1), + "Number of offsets must be larger than number of output rows by 1"); + RAFT_EXPECTS(dst.extent(1) >= max_len_per_row, + "Number of rows in the out must be equal or larger than max_len_per_row"); + detail::segmented_copy(handle, + src.data_handle(), + dst.extent(0), + dst.extent(1), + max_len_per_row, + offsets.data_handle(), + dst.data_handle()); +} + /** @} */ // end of group matrix_copy } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 2fa741fd96..415ef31965 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -316,6 +316,71 @@ m_t getL2Norm(raft::resources const& handle, const m_t* in, idx_t size, cudaStre return normval; } +// Threads per block in segmented_copy_kernel. +static const constexpr int SEGMENTED_COPY_TPB_256 = 256; +static const constexpr int SEGMENTED_COPY_TPB_32 = 32; + +template +RAFT_KERNEL __launch_bounds__(TPB) segmented_copy_kernel( + const m_t* src, idx_t n_rows, idx_t n_cols, idx_t max_len_per_row, idx_t* offsets, m_t* dst) +{ +#pragma unroll + for (idx_t row_id = blockIdx.y; row_id < n_rows; row_id += gridDim.y) { + idx_t segment_start = offsets[row_id]; + idx_t len = min(offsets[row_id + 1] - segment_start, max_len_per_row); + for (idx_t col_id = threadIdx.x + blockIdx.x * blockDim.x; col_id < len; + col_id += blockDim.x * gridDim.x) { + dst[row_id * n_cols + col_id] = src[segment_start + col_id]; + } + } +} + +template +void segmented_copy(raft::resources const& handle, + const m_t* src, + idx_t n_rows, + idx_t n_cols, + idx_t max_len_per_row, + idx_t* offsets, + m_t* dst) +{ + auto stream = resource::get_cuda_stream(handle); + + idx_t tpb = max_len_per_row >= 256 ? SEGMENTED_COPY_TPB_256 : SEGMENTED_COPY_TPB_32; + + int dev_id, sm_count, blocks_per_sm; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + if (tpb == SEGMENTED_COPY_TPB_32) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, segmented_copy_kernel, tpb, 0); + } else if (tpb == SEGMENTED_COPY_TPB_256) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, segmented_copy_kernel, tpb, 0); + } + + idx_t max_active_blocks = sm_count * blocks_per_sm; + // `max threads number = sm_count * blocks_per_sm * tpb` + // `problem size = n_rows * max_len_per_row` + idx_t required_active_blocks = + raft::min(max_active_blocks, raft::ceildiv(n_rows * max_len_per_row, tpb)); + + idx_t blocks_per_row = raft::ceildiv(required_active_blocks, n_rows); + idx_t grid_rows = raft::ceildiv(required_active_blocks, blocks_per_row); + dim3 block(tpb, 1); + dim3 grid(blocks_per_row, grid_rows); + + if (tpb == SEGMENTED_COPY_TPB_32) { + segmented_copy_kernel + <<>>(src, n_rows, n_cols, max_len_per_row, offsets, dst); + } else if (tpb == SEGMENTED_COPY_TPB_256) { + segmented_copy_kernel + <<>>(src, n_rows, n_cols, max_len_per_row, offsets, dst); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + } // end namespace detail } // end namespace matrix } // end namespace raft diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index dfdbfa2d07..af47d45685 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -18,6 +18,7 @@ #include // uint32_t #include // __half +#include #include #include #include // RAFT_EXPLICIT @@ -41,6 +42,15 @@ void select_k(raft::resources const& handle, rmm::mr::device_memory_resource* mr = nullptr, bool sorted = false, SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + +template +void select_k(raft::resources const& handle, + raft::device_csr_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -70,3 +80,23 @@ instantiate_raft_matrix_detail_select_k(double, int64_t); instantiate_raft_matrix_detail_select_k(double, uint32_t); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); +instantiate_raft_matrix_detail_select_k(__half, int64_t); +instantiate_raft_matrix_detail_select_k(float, int64_t); +instantiate_raft_matrix_detail_select_k(float, uint32_t); +instantiate_raft_matrix_detail_select_k(float, int); +instantiate_raft_matrix_detail_select_k(double, int64_t); +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 0a6f292e68..a9d1456e29 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -17,12 +17,17 @@ #pragma once +#include + #include "select_radix.cuh" #include "select_warpsort.cuh" +#include #include #include #include +#include +#include #include #include @@ -320,4 +325,98 @@ void select_k(raft::resources const& handle, default: RAFT_FAIL("K-selection Algorithm not supported."); } } + +/** + * Selects the k smallest or largest keys/values from each row of the input matrix. + * + * This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`, + * selecting the k smallest or largest elements from each row. The selected elements are then stored + * in a row-major output matrix `out_val` with dimensions `batch_size` x k. + * + * @tparam T + * Type of the elements being compared (keys). + * @tparam IdxT + * Type of the indices associated with the keys. + * @tparam NZType + * Type representing non-zero elements of `in_val`. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] in_val + * Input matrix in CSR format with a logical dense shape of [batch_size, len], + * containing the elements to be compared and selected. + * @param[in] in_idx + * Optional input indices [in_val.nnz] associated with `in_val.values`. + * If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1. + * @param[out] out_val + * Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements + * from each row of `in_val`. + * @param[out] out_idx + * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. + * @param[in] select_min + * Flag indicating whether to select the k smallest (true) or largest (false) elements. + * @param[in] mr + * An optional memory resource to use across the calls (you can provide a large enough + * memory pool here to avoid memory allocations within the call). + */ +template +void select_k(raft::resources const& handle, + raft::device_csr_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min, + rmm::mr::device_memory_resource* mr = nullptr) +{ + auto csr_view = in_val.structure_view(); + auto nnz = csr_view.get_nnz(); + + if (nnz == 0) return; + + auto batch_size = csr_view.get_n_rows(); + auto len = csr_view.get_n_cols(); + auto k = IdxT(out_val.extent(1)); + + if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } + RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), + "output k must fit the int type."); + + RAFT_EXPECTS(batch_size == out_val.extent(0), "batch sizes must be equal"); + RAFT_EXPECTS(batch_size == out_idx.extent(0), "batch sizes must be equal"); + + if (in_idx.has_value()) { + RAFT_EXPECTS(size_t(nnz) == in_idx->size(), + "nnz of in_val must be equal to the length of in_idx"); + } + RAFT_EXPECTS(IdxT(k) == out_idx.extent(1), "value and index output lengths must be equal"); + + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector offsets(batch_size + 1, stream); + rmm::device_uvector keys(nnz, stream); + rmm::device_uvector values(nnz, stream); + + raft::copy(offsets.data(), csr_view.get_indptr().data(), batch_size + 1, stream); + raft::copy(keys.data(), in_val.get_elements().data(), nnz, stream); + raft::copy(values.data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + nnz, + stream); + + segmented_sort_by_key(handle, + keys.data(), + values.data(), + size_t(batch_size), + size_t(nnz), + offsets.data(), + select_min); + + auto src_val = raft::make_device_vector_view(keys.data(), nnz); + auto offsets_view = raft::make_device_vector_view(offsets.data(), batch_size + 1); + raft::matrix::segmented_copy(handle, k, src_val, offsets_view, out_val); + + auto src_idx = raft::make_device_vector_view(values.data(), nnz); + raft::matrix::segmented_copy(handle, k, src_idx, offsets_view, out_idx); +} + } // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 92d7db006d..1f8136290b 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -19,6 +19,7 @@ #include "detail/select_k.cuh" #include +#include #include #include #include @@ -117,6 +118,44 @@ void select_k(raft::resources const& handle, algo); } +/** + * Selects the k smallest or largest keys/values from each row of the input matrix. + * + * This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`, + * selecting the k smallest or largest elements from each row. The selected elements are then stored + * in a row-major output matrix `out_val` with dimensions `batch_size` x k. + * + * @tparam T + * Type of the elements being compared (keys). + * @tparam IdxT + * Type of the indices associated with the keys. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] in_val + * Input matrix in CSR format with a logical dense shape of [batch_size, len], + * containing the elements to be compared and selected. + * @param[in] in_idx + * Optional input indices [in_val.nnz] associated with `in_val.values`. + * If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1. + * @param[out] out_val + * Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements + * from each row of `in_val`. + * @param[out] out_idx + * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. + * @param[in] select_min + * Flag indicating whether to select the k smallest (true) or largest (false) elements. + */ +template +void select_k(raft::resources const& handle, + raft::device_csr_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min) +{ + return detail::select_k(handle, in_val, in_idx, out_val, out_idx, select_min); +} /** @} */ // end of group select_k } // namespace raft::matrix diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index 87e5d49d29..7f8aed2506 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -33,3 +33,17 @@ instantiate_raft_matrix_detail_select_k(double, int64_t); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, int64_t); + +#undef instantiate_raft_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 67dce0e166..73338e7578 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -34,3 +34,17 @@ instantiate_raft_matrix_detail_select_k(double, uint32_t); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k \ No newline at end of file diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 4be7c54839..a2d796a7c5 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -33,3 +33,17 @@ instantiate_raft_matrix_detail_select_k(float, int); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, int); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 6337994e86..c7d93ab463 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -33,3 +33,17 @@ instantiate_raft_matrix_detail_select_k(float, int64_t); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, uint64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index ad26547812..dbf7afa06e 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -33,3 +33,17 @@ instantiate_raft_matrix_detail_select_k(float, uint32_t); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index e3c29a2033..9923088a84 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -33,3 +33,17 @@ instantiate_raft_matrix_detail_select_k(__half, int64_t); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index 3e3a738915..e90fe42c3e 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -33,3 +33,17 @@ instantiate_raft_matrix_detail_select_k(__half, uint32_t); #undef instantiate_raft_matrix_detail_select_k + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k( \ + raft::resources const& handle, \ + raft::device_csr_matrix_view in_val, \ + std::optional> in_idx, \ + raft::device_matrix_view out_val, \ + raft::device_matrix_view out_idx, \ + bool select_min, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index fe29409d9b..af283cf60c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -256,6 +256,7 @@ if(BUILD_TESTS) test/matrix/argmax.cu test/matrix/argmin.cu test/matrix/columnSort.cu + test/matrix/copy.cu test/matrix/diagonal.cu test/matrix/gather.cu test/matrix/scatter.cu @@ -272,7 +273,13 @@ if(BUILD_TESTS) EXPLICIT_INSTANTIATE_ONLY ) - ConfigureTest(NAME MATRIX_SELECT_TEST PATH test/matrix/select_k.cu LIB EXPLICIT_INSTANTIATE_ONLY) + ConfigureTest( + NAME + MATRIX_SELECT_TEST + PATH test/matrix/select_k.cu + PATH test/matrix/select_k_csr.cu + LIB + EXPLICIT_INSTANTIATE_ONLY) ConfigureTest( NAME MATRIX_SELECT_LARGE_TEST PATH test/matrix/select_large_k.cu LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/matrix/copy.cu b/cpp/test/matrix/copy.cu new file mode 100644 index 0000000000..58bc8970a4 --- /dev/null +++ b/cpp/test/matrix/copy.cu @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace sparse { + +template +struct SegmentedCopyInputs { + index_t n_rows; + index_t n_cols; + index_t top_k; + float sparsity; +}; + +template +class SegmentedCopyTest : public ::testing::TestWithParam> { + public: + SegmentedCopyTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + indices_d(0, stream), + indptr_d(0, stream), + values_d(0, stream), + dst_values_d(0, stream), + dst_values_expected_d(0, stream), + dst_indices_d(0, stream), + dst_indices_expected_d(0, stream) + { + } + + protected: + index_t create_sparse_matrix(index_t m, index_t n, value_t sparsity, std::vector& matrix) + { + index_t total_elements = static_cast(m * n); + index_t num_ones = static_cast((total_elements * 1.0f) * sparsity); + index_t res = num_ones; + + for (index_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_idx(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis_idx(gen); + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + void convert_to_csr(std::vector& matrix, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + template + void cpu_segmented_copy(index_t rows, + index_t max_len_per_row, + const std::vector& src, + const std::vector& offsets, + std::vector& dst) + { + for (index_t row = 0; row < rows; ++row) { + index_t start = offsets[row]; + index_t end = offsets[row + 1]; //(row < rows - 1) ? offsets[row + 1] : src.size(); + index_t length = std::min(end - start, max_len_per_row); + if (length == 0) continue; + std::copy( + src.begin() + start, src.begin() + start + length, dst.begin() + row * max_len_per_row); + } + } + + void SetUp() override + { + std::vector dense_values_h(params.n_rows * params.n_cols); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, dense_values_h); + + std::vector values_h(nnz); + std::vector indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + std::vector dst_values_h(params.n_rows * params.top_k, static_cast(2.0f)); + + std::vector dst_indices_h(params.n_rows * params.top_k, + static_cast(params.n_rows * params.n_cols + 1)); + + // sync up the initial values in advance to 2.0 which is out of random range [-1.0, 1.0]. + dst_values_d.resize(params.n_rows * params.top_k, stream); + dst_indices_d.resize(params.n_rows * params.top_k, stream); + + update_device(dst_values_d.data(), dst_values_h.data(), dst_values_h.size(), stream); + update_device(dst_indices_d.data(), dst_indices_h.data(), dst_indices_h.size(), stream); + resource::sync_stream(handle); + + auto blobs_values = raft::make_device_matrix(handle, 1, dst_values_h.size()); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_values.data_handle(), + labels.data_handle(), + 1, + dst_values_h.size(), + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-1.0f), + value_t(1.0f), + uint64_t(2024)); + raft::copy(dst_values_h.data(), blobs_values.data_handle(), dst_values_h.size(), stream); + raft::copy(dst_values_d.data(), blobs_values.data_handle(), dst_values_h.size(), stream); + resource::sync_stream(handle); + + convert_to_csr(dense_values_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + cpu_segmented_copy(params.n_rows, params.top_k, values_h, indptr_h, dst_values_h); + cpu_segmented_copy(params.n_rows, params.top_k, indices_h, indptr_h, dst_indices_h); + + values_d.resize(nnz, stream); + indices_d.resize(nnz, stream); + indptr_d.resize(params.n_rows + 1, stream); + dst_values_expected_d.resize(params.n_rows * params.top_k, stream); + dst_indices_expected_d.resize(params.n_rows * params.top_k, stream); + + update_device(values_d.data(), values_h.data(), values_h.size(), stream); + update_device(indices_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_d.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(dst_values_expected_d.data(), dst_values_h.data(), dst_values_h.size(), stream); + update_device( + dst_indices_expected_d.data(), dst_indices_h.data(), dst_indices_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto src_values = raft::make_device_vector_view(values_d.data(), nnz); + auto src_indices = raft::make_device_vector_view(indices_d.data(), nnz); + auto offsets = + raft::make_device_vector_view(indptr_d.data(), params.n_rows + 1); + auto dst_values = raft::make_device_matrix_view( + dst_values_d.data(), params.n_rows, params.top_k); + auto dst_indices = raft::make_device_matrix_view( + dst_indices_d.data(), params.n_rows, params.top_k); + + raft::matrix::segmented_copy(handle, params.top_k, src_values, offsets, dst_values); + raft::matrix::segmented_copy(handle, params.top_k, src_indices, offsets, dst_indices); + + resource::sync_stream(handle); + + ASSERT_TRUE(raft::devArrMatch(dst_values_expected_d.data(), + dst_values_d.data(), + params.n_rows * params.top_k, + raft::CompareApprox(1e-6f), + stream)); + + ASSERT_TRUE(raft::devArrMatch(dst_indices_expected_d.data(), + dst_indices_d.data(), + params.n_rows * params.top_k, + raft::Compare(), + stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + SegmentedCopyInputs params; + + index_t nnz; + + rmm::device_uvector values_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + + rmm::device_uvector dst_values_d; + rmm::device_uvector dst_values_expected_d; + + rmm::device_uvector dst_indices_d; + rmm::device_uvector dst_indices_expected_d; +}; + +using SegmentedCopyTest_float_int = SegmentedCopyTest; +TEST_P(SegmentedCopyTest_float_int, Result) { Run(); } + +using SegmentedCopyTest_double_int64 = SegmentedCopyTest; +TEST_P(SegmentedCopyTest_double_int64, Result) { Run(); } + +template +const std::vector> segmentedcopy_inputs = { + {10, 32, 10, 0.0}, + {10, 32, 10, 0.3}, + {32, 1024, 63, 0.3}, + {1024, 1024, 128, 0.2}, + {1024, 1024 * 2000, 251, 0.2}, + {2048, 1024 * 100, 1000, 0.3}, + {2048, 1024 * 100, 2100, 0.5}}; + +INSTANTIATE_TEST_CASE_P(SegmentedCopyTest, + SegmentedCopyTest_float_int, + ::testing::ValuesIn(segmentedcopy_inputs)); +INSTANTIATE_TEST_CASE_P(SegmentedCopyTest, + SegmentedCopyTest_double_int64, + ::testing::ValuesIn(segmentedcopy_inputs)); + +} // namespace sparse +} // namespace raft diff --git a/cpp/test/matrix/select_k_csr.cu b/cpp/test/matrix/select_k_csr.cu new file mode 100644 index 0000000000..b0b24fae08 --- /dev/null +++ b/cpp/test/matrix/select_k_csr.cu @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { + +template +struct SelectKCsrInputs { + index_t n_rows; + index_t n_cols; + index_t top_k; + float sparsity; + bool select_min; + bool customized_indices; +}; + +template +class SelectKCsrTest : public ::testing::TestWithParam> { + public: + SelectKCsrTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + indices_d(0, stream), + customized_indices_d(0, stream), + indptr_d(0, stream), + values_d(0, stream), + dst_values_d(0, stream), + dst_values_expected_d(0, stream), + dst_indices_d(0, stream), + dst_indices_expected_d(0, stream) + { + } + + protected: + index_t create_sparse_matrix(index_t m, index_t n, value_t sparsity, std::vector& matrix) + { + index_t total_elements = static_cast(m * n); + index_t num_ones = static_cast((total_elements * 1.0f) * sparsity); + index_t res = num_ones; + + for (index_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_idx(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis_idx(gen); + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + void cpu_select_k(const std::vector& indptr_h, + const std::vector& indices_h, + const std::vector& values_h, + std::optional>& in_idx_h, + index_t n_rows, + index_t n_cols, + index_t top_k, + std::vector& out_values_h, + std::vector& out_indices_h, + bool select_min = true) + { + auto comp = [select_min](const std::pair& a, + const std::pair& b) { + return select_min ? a.first < b.first : a.first >= b.first; + }; + + for (index_t row = 0; row < n_rows; ++row) { + std::priority_queue, + std::vector>, + decltype(comp)> + pq(comp); + + for (index_t idx = indptr_h[row]; idx < indptr_h[row + 1]; ++idx) { + pq.push({values_h[idx], (in_idx_h.has_value()) ? (*in_idx_h)[idx] : indices_h[idx]}); + if (pq.size() > size_t(top_k)) { pq.pop(); } + } + + std::vector> row_pairs; + while (!pq.empty()) { + row_pairs.push_back(pq.top()); + pq.pop(); + } + + if (select_min) { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first <= b.first; + }); + } else { + std::sort(row_pairs.begin(), row_pairs.end(), [](const auto& a, const auto& b) { + return a.first >= b.first; + }); + } + for (index_t col = 0; col < top_k; col++) { + if (col < index_t(row_pairs.size())) { + out_values_h[row * top_k + col] = row_pairs[col].first; + out_indices_h[row * top_k + col] = row_pairs[col].second; + } + } + } + } + + void random_array(value_t* array, size_t size) + { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(-10.0, 10.0); + std::unordered_set uset; + + while (uset.size() < size) { + uset.insert(dis(gen)); + } + typename std::unordered_set::iterator it = uset.begin(); + for (size_t i = 0; i < size; ++i) { + array[i] = *(it++); + } + } + + template + std::optional get_opt_var(data_t x) + { + if (params.customized_indices) { + return x; + } else { + return std::nullopt; + } + } + + void SetUp() override + { + std::vector dense_values_h(params.n_rows * params.n_cols, false); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, dense_values_h); + + std::vector values_h(nnz); + std::vector indices_h(nnz); + std::vector customized_indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + convert_to_csr(dense_values_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + std::vector dst_values_h(params.n_rows * params.top_k, static_cast(2.0f)); + std::vector dst_indices_h(params.n_rows * params.top_k, + static_cast(params.n_rows * params.n_cols * 100)); + + dst_values_d.resize(params.n_rows * params.top_k, stream); + dst_indices_d.resize(params.n_rows * params.top_k, stream); + values_d.resize(nnz, stream); + + update_device(dst_values_d.data(), dst_values_h.data(), dst_values_h.size(), stream); + update_device(dst_indices_d.data(), dst_indices_h.data(), dst_indices_h.size(), stream); + + if (params.customized_indices) { + customized_indices_d.resize(nnz, stream); + update_device(customized_indices_d.data(), + customized_indices_h.data(), + customized_indices_h.size(), + stream); + } + + resource::sync_stream(handle); + + if (values_h.size()) { + random_array(values_h.data(), values_h.size()); + raft::copy(values_d.data(), values_h.data(), values_h.size(), stream); + resource::sync_stream(handle); + } + + auto optional_indices_h = get_opt_var(customized_indices_h); + + cpu_select_k(indptr_h, + indices_h, + values_h, + optional_indices_h, + params.n_rows, + params.n_cols, + params.top_k, + dst_values_h, + dst_indices_h, + params.select_min); + + indices_d.resize(nnz, stream); + indptr_d.resize(params.n_rows + 1, stream); + + dst_values_expected_d.resize(params.n_rows * params.top_k, stream); + dst_indices_expected_d.resize(params.n_rows * params.top_k, stream); + + update_device(values_d.data(), values_h.data(), values_h.size(), stream); + update_device(indices_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_d.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(dst_values_expected_d.data(), dst_values_h.data(), dst_values_h.size(), stream); + update_device( + dst_indices_expected_d.data(), dst_indices_h.data(), dst_indices_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto in_val_structure = raft::make_device_compressed_structure_view( + indptr_d.data(), + indices_d.data(), + params.n_rows, + params.n_cols, + static_cast(indices_d.size())); + + auto in_val = + raft::make_device_csr_matrix_view(values_d.data(), in_val_structure); + + std::optional> in_idx; + + in_idx = get_opt_var( + raft::make_device_vector_view(customized_indices_d.data(), nnz)); + + auto out_val = raft::make_device_matrix_view( + dst_values_d.data(), params.n_rows, params.top_k); + auto out_idx = raft::make_device_matrix_view( + dst_indices_d.data(), params.n_rows, params.top_k); + + raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + + ASSERT_TRUE(raft::devArrMatch(dst_values_expected_d.data(), + out_val.data_handle(), + params.n_rows * params.top_k, + raft::CompareApprox(1e-6f), + stream)); + + ASSERT_TRUE(raft::devArrMatch(dst_indices_expected_d.data(), + out_idx.data_handle(), + params.n_rows * params.top_k, + raft::Compare(), + stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + SelectKCsrInputs params; + + index_t nnz; + + rmm::device_uvector values_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector customized_indices_d; + + rmm::device_uvector dst_values_d; + rmm::device_uvector dst_values_expected_d; + + rmm::device_uvector dst_indices_d; + rmm::device_uvector dst_indices_expected_d; +}; + +using SelectKCsrTest_float_int = SelectKCsrTest; +TEST_P(SelectKCsrTest_float_int, Result) { Run(); } + +using SelectKCsrTest_double_int64 = SelectKCsrTest; +TEST_P(SelectKCsrTest_double_int64, Result) { Run(); } + +template +const std::vector> selectk_inputs = { + {10, 32, 10, 0.0, true, false}, + {10, 32, 10, 0.0, true, true}, + {10, 32, 10, 0.01, true, false}, + {10, 32, 10, 0.1, true, true}, + {10, 32, 251, 0.1, true, false}, + {10, 32, 251, 0.6, true, true}, + {1024, 1024, 258, 0.3, true, false}, + {1024, 1024, 600, 0.2, true, true}, + {100, 1024 * 1000, 251, 0.1, true, false}, + {100, 1024 * 1000, 251, 0.2, true, true}, + {1024, 1024 * 10, 251, 0.3, true, false}, + {1024, 1024 * 10, 251, 0.2, true, true}, + {2048, 1024 * 10, 1000, 0.2, true, false}, + {2048, 1024 * 10, 1000, 0.3, true, true}, + {2048, 1024 * 10, 2100, 0.1, true, false}, + {2048, 1024 * 10, 2100, 0.2, true, true}}; + +INSTANTIATE_TEST_CASE_P(SelectKCsrTest, + SelectKCsrTest_float_int, + ::testing::ValuesIn(selectk_inputs)); +INSTANTIATE_TEST_CASE_P(SelectKCsrTest, + SelectKCsrTest_double_int64, + ::testing::ValuesIn(selectk_inputs)); + +} // namespace sparse +} // namespace raft From 435286a6880a3b2f6c0d763f3d19e687d2f9f7b8 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 31 Jan 2024 09:24:38 -0800 Subject: [PATCH 02/12] add more comments on the select_k API --- cpp/include/raft/matrix/select_k.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 5c20227d23..7df1430455 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -123,6 +123,8 @@ void select_k(raft::resources const& handle, * This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`, * selecting the k smallest or largest elements from each row. The selected elements are then stored * in a row-major output matrix `out_val` with dimensions `batch_size` x k. + * If the total number of values in a row is less than K, then the extra position in the + * corresponding row of out_val will maintain the original value. This applies to out_idx * * @tparam T * Type of the elements being compared (keys). From 977ccee26674a72145c1879bffecb3a991e92d9a Mon Sep 17 00:00:00 2001 From: hrong Date: Mon, 4 Mar 2024 13:04:42 -0800 Subject: [PATCH 03/12] remove mr argument --- cpp/include/raft/matrix/detail/matrix.cuh | 6 +++--- cpp/include/raft/matrix/detail/select_k-ext.cuh | 6 ++---- cpp/include/raft/matrix/detail/select_k-inl.cuh | 14 +++++--------- cpp/src/matrix/detail/select_k_double_int64_t.cu | 3 +-- cpp/src/matrix/detail/select_k_double_uint32_t.cu | 3 +-- cpp/src/matrix/detail/select_k_float_int32.cu | 3 +-- cpp/src/matrix/detail/select_k_float_int64_t.cu | 3 +-- cpp/src/matrix/detail/select_k_float_uint32_t.cu | 3 +-- cpp/src/matrix/detail/select_k_half_int64_t.cu | 3 +-- cpp/src/matrix/detail/select_k_half_uint32_t.cu | 3 +-- 10 files changed, 17 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index f0cb4aaef8..69d38a68da 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -33,6 +33,7 @@ #include #include +#include namespace raft { namespace matrix { @@ -349,9 +350,8 @@ void segmented_copy(raft::resources const& handle, idx_t tpb = max_len_per_row >= 256 ? SEGMENTED_COPY_TPB_256 : SEGMENTED_COPY_TPB_32; - int dev_id, sm_count, blocks_per_sm; - cudaGetDevice(&dev_id); - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + int blocks_per_sm; + int sm_count = resource::get_device_properties(handle).multiProcessorCount; if (tpb == SEGMENTED_COPY_TPB_32) { cudaOccupancyMaxActiveBlocksPerMultiprocessor( diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 572abf6564..23ab3113e5 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -51,8 +51,7 @@ void select_k(raft::resources const& handle, std::optional> in_idx, raft::device_matrix_view out_val, raft::device_matrix_view out_idx, - bool select_min, - rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + bool select_min) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -89,8 +88,7 @@ instantiate_raft_matrix_detail_select_k(double, uint32_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index d1b50dd2db..86d835a0f0 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -348,9 +348,6 @@ void select_k(raft::resources const& handle, * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. * @param[in] select_min * Flag indicating whether to select the k smallest (true) or largest (false) elements. - * @param[in] mr - * An optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). */ template void select_k(raft::resources const& handle, @@ -358,8 +355,7 @@ void select_k(raft::resources const& handle, std::optional> in_idx, raft::device_matrix_view out_val, raft::device_matrix_view out_idx, - bool select_min, - rmm::mr::device_memory_resource* mr = nullptr) + bool select_min) { auto csr_view = in_val.structure_view(); auto nnz = csr_view.get_nnz(); @@ -370,7 +366,7 @@ void select_k(raft::resources const& handle, auto len = csr_view.get_n_cols(); auto k = IdxT(out_val.extent(1)); - if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } + auto mr = resource::get_workspace_resource(handle); RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), "output k must fit the int type."); @@ -385,9 +381,9 @@ void select_k(raft::resources const& handle, auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector offsets(batch_size + 1, stream); - rmm::device_uvector keys(nnz, stream); - rmm::device_uvector values(nnz, stream); + rmm::device_uvector offsets(batch_size + 1, stream, mr); + rmm::device_uvector keys(nnz, stream, mr); + rmm::device_uvector values(nnz, stream, mr); raft::copy(offsets.data(), csr_view.get_indptr().data(), batch_size + 1, stream); raft::copy(keys.data(), in_val.get_elements().data(), nnz, stream); diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index 5ed77fed3c..f90d518f71 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -40,8 +40,7 @@ instantiate_raft_matrix_detail_select_k(double, int64_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 5cf075f0e7..b88e81e2e7 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -42,8 +42,7 @@ instantiate_raft_matrix_detail_select_k(double, uint32_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 90613b696f..2ba7d41146 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -40,8 +40,7 @@ instantiate_raft_matrix_detail_select_k(float, int); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 771a851e82..c62121d70e 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -40,8 +40,7 @@ instantiate_raft_matrix_detail_select_k(float, int64_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(float, uint64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index f337574b8c..6b5cb6927d 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -40,8 +40,7 @@ instantiate_raft_matrix_detail_select_k(float, uint32_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index 0a4dd0f668..78a7cb7a7e 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -40,8 +40,7 @@ instantiate_raft_matrix_detail_select_k(__half, int64_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index f546984690..58c1668bf1 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -40,8 +40,7 @@ instantiate_raft_matrix_detail_select_k(__half, uint32_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr) + bool select_min) instantiate_raft_matrix_detail_select_k(__half, uint32_t); From bc544e3d82bf99ae5fd7b8a1629ba79f427c9c97 Mon Sep 17 00:00:00 2001 From: rhdong Date: Fri, 15 Mar 2024 15:24:04 -0700 Subject: [PATCH 04/12] fix format issue --- cpp/bench/prims/matrix/select_k_csr.cu | 25 +++++++++---------- cpp/include/raft/matrix/detail/matrix.cuh | 2 +- .../raft/matrix/detail/select_k-inl.cuh | 4 +-- cpp/test/matrix/copy.cu | 4 ++- cpp/test/matrix/select_k_csr.cu | 3 ++- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/cpp/bench/prims/matrix/select_k_csr.cu b/cpp/bench/prims/matrix/select_k_csr.cu index 99c59f4fde..0282f873c2 100644 --- a/cpp/bench/prims/matrix/select_k_csr.cu +++ b/cpp/bench/prims/matrix/select_k_csr.cu @@ -14,28 +14,27 @@ * limitations under the License. */ #include -#include -#include - -#include -#include - -#include -#include - -#include -#include -#include -#include #include #include #include +#include +#include +#include #include #include #include #include +#include #include +#include + +#include + +#include +#include +#include +#include namespace raft::bench::sparse { diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 69d38a68da..a9109d37ba 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include -#include namespace raft { namespace matrix { diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 86d835a0f0..e2490cca0b 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -17,8 +17,6 @@ #pragma once -#include - #include "select_radix.cuh" #include "select_warpsort.cuh" @@ -35,6 +33,8 @@ #include +#include + namespace raft::matrix::detail { /** diff --git a/cpp/test/matrix/copy.cu b/cpp/test/matrix/copy.cu index 58bc8970a4..adeeae73f5 100644 --- a/cpp/test/matrix/copy.cu +++ b/cpp/test/matrix/copy.cu @@ -15,7 +15,7 @@ */ #include "../test_utils.cuh" -#include + #include #include #include @@ -24,6 +24,8 @@ #include #include +#include + #include namespace raft { diff --git a/cpp/test/matrix/select_k_csr.cu b/cpp/test/matrix/select_k_csr.cu index b0b24fae08..ece773ee97 100644 --- a/cpp/test/matrix/select_k_csr.cu +++ b/cpp/test/matrix/select_k_csr.cu @@ -15,7 +15,6 @@ */ #include "../test_utils.cuh" -#include #include #include @@ -28,6 +27,8 @@ #include #include +#include + #include #include #include From 558b69e7e8bbe103f2e1e8387d827338596d1a91 Mon Sep 17 00:00:00 2001 From: rhdong Date: Fri, 29 Mar 2024 23:59:33 -0700 Subject: [PATCH 05/12] Optimizing the performance by reusing the dense `select_k` --- cpp/bench/prims/matrix/select_k_csr.cu | 54 ++- .../raft/matrix/detail/select_k-ext.cuh | 8 +- .../raft/matrix/detail/select_k-inl.cuh | 160 +++++-- .../raft/matrix/detail/select_radix.cuh | 427 ++++++++++-------- .../raft/matrix/detail/select_warpsort.cuh | 55 ++- cpp/include/raft/matrix/select_k.cuh | 7 +- .../matrix/detail/select_k_double_int64_t.cu | 4 +- .../matrix/detail/select_k_double_uint32_t.cu | 4 +- cpp/src/matrix/detail/select_k_float_int32.cu | 4 +- .../matrix/detail/select_k_float_int64_t.cu | 4 +- .../matrix/detail/select_k_float_uint32_t.cu | 4 +- .../matrix/detail/select_k_half_int64_t.cu | 4 +- .../matrix/detail/select_k_half_uint32_t.cu | 4 +- .../matrix/select_k_float_int64_t.cu | 5 +- cpp/test/matrix/select_k_csr.cu | 40 +- 15 files changed, 521 insertions(+), 263 deletions(-) diff --git a/cpp/bench/prims/matrix/select_k_csr.cu b/cpp/bench/prims/matrix/select_k_csr.cu index 0282f873c2..4ab706f471 100644 --- a/cpp/bench/prims/matrix/select_k_csr.cu +++ b/cpp/bench/prims/matrix/select_k_csr.cu @@ -51,8 +51,7 @@ struct bench_param { template inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& { - os << " rows*cols=" << params.n_rows << "*" << params.n_cols << "\ttop_k=" << params.top_k - << "\tsparsity=" << params.sparsity; + os << params.n_rows << "#" << params.n_cols << "#" << params.top_k << "#" << params.sparsity; return os; } @@ -69,7 +68,7 @@ struct SelectKCsrTest : public fixture { dst_values_d(0, stream), dst_indices_d(0, stream) { - std::vector dense_values_h(params.n_rows * params.n_cols, false); + std::vector dense_values_h(params.n_rows * params.n_cols); nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, dense_values_h); std::vector indices_h(nnz); @@ -207,7 +206,7 @@ struct SelectKCsrTest : public fixture { raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); resource::sync_stream(handle); loop_on_state(state, [this, &in_val, &in_idx, &out_val, &out_idx]() { - raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min, false); resource::sync_stream(handle); }); } @@ -235,22 +234,53 @@ const std::vector> getInputs() index_t m; index_t n; index_t k; - float sparsity; }; - const std::vector params_group = - raft::util::itertools::product({index_t(10), index_t(1024)}, - {index_t(1024 * 10), index_t(1024 * 1024)}, - {index_t(128), index_t(100), index_t(2048)}, - {0.1f, 0.2f, 0.5f}); + const std::vector params_group{ + {20000, 500, 1}, {20000, 500, 2}, {20000, 500, 4}, {20000, 500, 8}, + {20000, 500, 16}, {20000, 500, 32}, {20000, 500, 64}, {20000, 500, 128}, + {20000, 500, 256}, + + {1000, 10000, 1}, {1000, 10000, 2}, {1000, 10000, 4}, {1000, 10000, 8}, + {1000, 10000, 16}, {1000, 10000, 32}, {1000, 10000, 64}, {1000, 10000, 128}, + {1000, 10000, 256}, + + {100, 100000, 1}, {100, 100000, 2}, {100, 100000, 4}, {100, 100000, 8}, + {100, 100000, 16}, {100, 100000, 32}, {100, 100000, 64}, {100, 100000, 128}, + {100, 100000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}}; param_vec.reserve(params_group.size()); for (TestParams params : params_group) { - param_vec.push_back(bench_param({params.m, params.n, params.k, params.sparsity})); + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.1})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.2})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.5})); } return param_vec; } -RAFT_BENCH_REGISTER((SelectKCsrTest), "", getInputs()); +RAFT_BENCH_REGISTER((SelectKCsrTest), "", getInputs()); } // namespace raft::bench::sparse diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 1097935e3b..95d806dd43 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -52,7 +52,9 @@ void select_k(raft::resources const& handle, std::optional> in_idx, raft::device_matrix_view out_val, raft::device_matrix_view out_idx, - bool select_min) RAFT_EXPLICIT; + bool select_min, + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -90,7 +92,9 @@ instantiate_raft_matrix_detail_select_k(double, uint32_t); std::optional> in_idx, \ raft::device_matrix_view out_val, \ raft::device_matrix_view out_idx, \ - bool select_min) + bool select_min, \ + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 7683c03283..bcf00db709 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -324,9 +324,9 @@ void select_k(raft::resources const& handle, } /** - * Selects the k smallest or largest keys/values from each row of the input matrix. + * Selects the k smallest or largest keys/values from each row of the input CSR matrix. * - * This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`, + * This function operates on a CSR matrix `in_val` with a logical dense shape of [batch_size, len], * selecting the k smallest or largest elements from each row. The selected elements are then stored * in a row-major output matrix `out_val` with dimensions `batch_size` x k. * @@ -352,6 +352,10 @@ void select_k(raft::resources const& handle, * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. * @param[in] select_min * Flag indicating whether to select the k smallest (true) or largest (false) elements. + * @param[in] sorted + * whether to make sure selected pairs are sorted by value + * @param[in] algo + * the selection algorithm to use */ template void select_k(raft::resources const& handle, @@ -359,7 +363,9 @@ void select_k(raft::resources const& handle, std::optional> in_idx, raft::device_matrix_view out_val, raft::device_matrix_view out_idx, - bool select_min) + bool select_min, + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) { auto csr_view = in_val.structure_view(); auto nnz = csr_view.get_nnz(); @@ -383,33 +389,127 @@ void select_k(raft::resources const& handle, } RAFT_EXPECTS(IdxT(k) == out_idx.extent(1), "value and index output lengths must be equal"); - auto stream = raft::resource::get_cuda_stream(handle); - - rmm::device_uvector offsets(batch_size + 1, stream, mr); - rmm::device_uvector keys(nnz, stream, mr); - rmm::device_uvector values(nnz, stream, mr); - - raft::copy(offsets.data(), csr_view.get_indptr().data(), batch_size + 1, stream); - raft::copy(keys.data(), in_val.get_elements().data(), nnz, stream); - raft::copy(values.data(), - (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), - nnz, - stream); - - segmented_sort_by_key(handle, - keys.data(), - values.data(), - size_t(batch_size), - size_t(nnz), - offsets.data(), - select_min); - - auto src_val = raft::make_device_vector_view(keys.data(), nnz); - auto offsets_view = raft::make_device_vector_view(offsets.data(), batch_size + 1); - raft::matrix::segmented_copy(handle, k, src_val, offsets_view, out_val); - - auto src_idx = raft::make_device_vector_view(values.data(), nnz); - raft::matrix::segmented_copy(handle, k, src_idx, offsets_view, out_idx); + if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); } + + auto indptr = csr_view.get_indptr().data(); + + switch (algo) { + case SelectAlgo::kRadix8bits: + case SelectAlgo::kRadix11bits: + case SelectAlgo::kRadix11bitsExtraPass: { + if (algo == SelectAlgo::kRadix8bits) { + detail::select::radix::select_k( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + true, + indptr); + } else { + bool fused_last_filter = algo == SelectAlgo::kRadix11bits; + detail::select::radix::select_k( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + fused_last_filter, + indptr); + } + + if (sorted) { + auto offsets = make_device_mdarray( + handle, resource::get_workspace_resource(handle), make_extents(batch_size + 1)); + raft::linalg::map_offset(handle, offsets.view(), mul_const_op(k)); + + auto keys = + raft::make_device_vector_view(out_val.data_handle(), (IdxT)(batch_size * k)); + auto vals = + raft::make_device_vector_view(out_idx.data_handle(), (IdxT)(batch_size * k)); + + segmented_sort_by_key( + handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min); + } + + return; + } + case SelectAlgo::kWarpDistributed: + return detail::select::warpsort:: + select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpDistributedShm: + return detail::select::warpsort:: + select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpAuto: + return detail::select::warpsort::select_k( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpImmediate: + return detail::select::warpsort:: + select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + case SelectAlgo::kWarpFiltered: + return detail::select::warpsort:: + select_k_impl( + handle, + in_val.get_elements().data(), + (in_idx.has_value() ? in_idx->data_handle() : csr_view.get_indices().data()), + batch_size, + len, + k, + out_val.data_handle(), + out_idx.data_handle(), + select_min, + indptr); + default: RAFT_FAIL("K-selection Algorithm not supported."); + } + + return; } } // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 36a346fda3..83d4845c31 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -442,14 +442,76 @@ _RAFT_DEVICE void last_filter(const T* in_buf, } } -template +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + in_buf = reinterpret_cast(bufs); + in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + out_buf = const_cast(in_buf + buf_len); + out_idx_buf = const_cast(in_idx_buf + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + in_buf = out_buf + buf_len; + in_idx_buf = out_idx_buf + buf_len; + } +} + +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + const int pass, + const T*& out_buf, + const IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + out_buf = const_cast(reinterpret_cast(bufs) + buf_len); + out_idx_buf = + const_cast(reinterpret_cast(bufs + sizeof(T) * 2 * buf_len) + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } +} + +template RAFT_KERNEL last_filter_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, const IdxT len, + const IdxT* len_i, const IdxT k, Counter* counters, const bool select_min) @@ -458,22 +520,31 @@ RAFT_KERNEL last_filter_kernel(const T* in, Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + const IdxT buf_len = calc_buf_len(len); - if (previous_len > buf_len || in_buf == in) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; - } - out += batch_id * k; - out_idx += batch_id * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); constexpr int pass = calc_num_passes() - 1; constexpr int start_bit = calc_start_bit(pass); + set_buf_pointers(in + l_offset, in_idx + l_offset, bufs, buf_len, pass, in_buf, in_idx_buf); + + if (previous_len > buf_len || in_buf == in + l_offset) { + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; + } + out += batch_id * k; + out_idx += batch_id * k; + const auto kth_value_bits = counter->kth_value_bits; const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; @@ -510,6 +581,29 @@ RAFT_KERNEL last_filter_kernel(const T* in, f); } +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_val( + T* dest, const T* src, S len, IdxT k, const bool select_min) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + const T default_val = select_min ? upper_bound() : lower_bound(); + for (S i = idx; i < k; i += stride) { + dest[i] = i < len ? src[i] : default_val; + } +} + +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_idx(T* dest, const T* src, S len) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + + for (S i = idx; i < len; i += stride) { + dest[i] = src ? src[i] : i; + } +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -545,13 +639,16 @@ RAFT_KERNEL last_filter_kernel(const T* in, * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and * their indices. */ -template +template RAFT_KERNEL radix_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, Counter* counters, @@ -567,21 +664,38 @@ RAFT_KERNEL radix_kernel(const T* in, IdxT current_k; IdxT previous_len; IdxT current_len; + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + if (pass == 0) { current_k = k; - previous_len = len; + previous_len = l_len; // Need to do this so setting counter->previous_len for the next pass is correct. // This value is meaningless for pass 0, but it's fine because pass 0 won't be the // last pass in this implementation so pass 0 won't hit the "if (pass == // num_passes - 1)" branch. // Maybe it's better to reload counter->previous_len and use it rather than // current_len in last_filter() - current_len = len; + current_len = l_len; } else { current_k = counter->k; current_len = counter->len; previous_len = counter->previous_len; } + if constexpr (!len_or_indptr) { + if (pass == 0 && l_len <= k) { + copy_in_val(out + batch_id * k, in + l_offset, l_len, k, select_min); + copy_in_idx(out_idx + batch_id * k, (in_idx ? (in_idx + l_offset) : nullptr), l_len); + if (threadIdx.x == 0) { + counter->previous_len = 0; + counter->len = 0; + } + __syncthreads(); + return; + } + } + if (current_len == 0) { return; } // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle @@ -590,20 +704,33 @@ RAFT_KERNEL radix_kernel(const T* in, const bool early_stop = (current_len == current_k); const IdxT buf_len = calc_buf_len(len); + const T* in_buf; + const IdxT* in_idx_buf; + T* out_buf; + IdxT* out_idx_buf; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + set_buf_pointers(in + l_offset, + (in_idx ? (in_idx + l_offset) : nullptr), + bufs, + buf_len, + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + // "previous_len > buf_len" means previous pass skips writing buffer if (pass == 0 || pass == 1 || previous_len > buf_len) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; } // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -611,9 +738,6 @@ RAFT_KERNEL radix_kernel(const T* in, if (pass == 0 || current_len > buf_len) { out_buf = nullptr; out_idx_buf = nullptr; - } else { - out_buf += batch_id * buf_len; - out_idx_buf += batch_id * buf_len; } out += batch_id * k; out_idx += batch_id * k; @@ -640,7 +764,6 @@ RAFT_KERNEL radix_kernel(const T* in, unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); isLastBlock = (finished == (gridDim.x - 1)); } - if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { @@ -676,7 +799,7 @@ RAFT_KERNEL radix_kernel(const T* in, out_idx_buf ? out_idx_buf : in_idx_buf, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, counter, select_min, @@ -726,7 +849,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &active_blocks, radix_kernel, BlockSize, 0)); + &active_blocks, radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; IdxT best_num_blocks = 0; @@ -757,78 +880,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) return best_num_blocks; } -template -_RAFT_HOST void set_buf_pointers(const T* in, - const IdxT* in_idx, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = buf1; - out_idx_buf = idx_buf1; - } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; - } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; - } -} - -template -_RAFT_DEVICE void set_buf_pointers(const T* in, - const IdxT* in_idx, - char* bufs, - IdxT buf_len, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - } else if (pass % 2 == 0) { - in_buf = reinterpret_cast(bufs); - in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - out_buf = const_cast(in_buf + buf_len); - out_idx_buf = const_cast(in_idx_buf + buf_len); - } else { - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - in_buf = out_buf + buf_len; - in_idx_buf = out_idx_buf + buf_len; - } -} - -template +template void radix_topk(const T* in, const IdxT* in_idx, int batch_size, @@ -850,7 +902,7 @@ void radix_topk(const T* in, if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - auto kernel = radix_kernel; + auto kernel = radix_kernel; const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, false); if (max_chunk_size != static_cast(batch_size)) { @@ -862,55 +914,33 @@ void radix_topk(const T* in, rmm::device_uvector> counters(max_chunk_size, stream, mr); rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + rmm::device_uvector bufs( + max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); RAFT_CUDA_TRY( cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - auto kernel = radix_kernel; + auto kernel = radix_kernel; - const T* chunk_in = in + offset * len; - const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; - T* chunk_out = out + offset * k; - IdxT* chunk_out_idx = out_idx + offset * k; - const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; dim3 blocks(grid_dim, chunk_size); constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers(chunk_in, - chunk_in_idx, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data(), - pass, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf); - if (fused_last_filter && pass == num_passes - 1) { - kernel = radix_kernel; + kernel = radix_kernel; } - kernel<<>>(chunk_in, - chunk_in_idx, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, + kernel<<>>(in, + in_idx, + bufs.data(), + offset, chunk_out, chunk_out_idx, counters.data(), @@ -924,16 +954,18 @@ void radix_topk(const T* in, } if (!fused_last_filter) { - last_filter_kernel<<>>(chunk_in, - chunk_in_idx, - out_buf, - out_idx_buf, - chunk_out, - chunk_out_idx, - len, - k, - counters.data(), - select_min); + last_filter_kernel + <<>>(in, + in_idx, + bufs.data(), + offset, + chunk_out, + chunk_out_idx, + len, + chunk_len_i, + k, + counters.data(), + select_min); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } @@ -1015,7 +1047,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, } } -template +template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, @@ -1024,30 +1056,48 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, T* out, IdxT* out_idx, const bool select_min, - char* bufs) + char* bufs, + size_t offset) { constexpr int num_buckets = calc_num_buckets(); __shared__ Counter counter; __shared__ IdxT histogram[num_buckets]; + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + + IdxT l_len = len; + IdxT l_offset = (offset + batch_id) * len; + if constexpr (!len_or_indptr) { + l_offset = len_i[batch_id]; + l_len = len_i[batch_id + 1] - l_offset; + } + if (threadIdx.x == 0) { counter.k = k; - counter.len = len; - counter.previous_len = len; + counter.len = l_len; + counter.previous_len = l_len; counter.kth_value_bits = 0; counter.out_cnt = 0; counter.out_back_cnt = 0; } __syncthreads(); - const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow - in += batch_id * len; - if (in_idx) { in_idx += batch_id * len; } + in += l_offset; + if (in_idx) { in_idx += l_offset; } out += batch_id * k; out_idx += batch_id * k; const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + if constexpr (!len_or_indptr) { + if (l_len <= k) { + copy_in_val(out, in, l_len, k, select_min); + copy_in_idx(out_idx, in_idx, l_len); + __syncthreads(); + return; + } + } + constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { const T* in_buf; @@ -1073,7 +1123,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -1102,7 +1152,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_buf ? out_idx_buf : in_idx, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, &counter, select_min, @@ -1117,7 +1167,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // counters and global histograms, can be kept in shared memory and cheap sync operations can be // used. It's used when len is relatively small or when the number of blocks per row calculated by // `calc_grid_dim()` is 1. -template +template void radix_topk_one_block(const T* in, const IdxT* in_idx, int batch_size, @@ -1133,7 +1183,7 @@ void radix_topk_one_block(const T* in, { static_assert(calc_num_passes() > 1); - auto kernel = radix_topk_one_block_kernel; + auto kernel = radix_topk_one_block_kernel; const IdxT buf_len = calc_buf_len(len); const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, true); @@ -1144,15 +1194,16 @@ void radix_topk_one_block(const T* in, for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - kernel<<>>(in + offset * len, - in_idx ? (in_idx + offset * len) : nullptr, + kernel<<>>(in, + in_idx, len, chunk_len_i, k, out + offset * k, out_idx + offset * k, select_min, - bufs.data()); + bufs.data(), + offset); } } @@ -1182,6 +1233,10 @@ void radix_topk_one_block(const T* in, * it affects the number of passes and number of buckets. * @tparam BlockSize * Number of threads in a kernel thread block. + * @tparam len_or_indptr + * Flag to interpret `len_i` as either direct row lengths (true) or CSR format + * index pointers (false). When true, each `len_i` element denotes the length of a row. When + * false, `len_i` represents the index pointers for a CSR matrix with shape of `batch_size + 1`. * * @param[in] res container of reusable resources * @param[in] in @@ -1212,9 +1267,12 @@ void radix_topk_one_block(const T* in, * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param len_i - * optional array of size (batch_size) providing lengths for each individual row + * Optional array used differently based on `len_or_indptr`: + * When `len_or_indptr` is true, `len_i` presents the lengths of each row, which is `batch_size`. + * When `len_or_indptr` is false, `len_i` works like a indptr for a CSR matrix. The length of each + * row would be (`len_i[row_id + 1] - len_i[row_id]`). `len_i` size is `batch_size + 1`. */ -template +template void select_k(raft::resources const& res, const T* in, const IdxT* in_idx, @@ -1227,9 +1285,12 @@ void select_k(raft::resources const& res, bool fused_last_filter, const IdxT* len_i) { + RAFT_EXPECTS(!(!len_or_indptr && (len_i == nullptr)), + "When `len_or_indptr` is false, `len_i` must not be nullptr!"); + auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); - if (k == len) { + if (k == len && len_or_indptr) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { @@ -1248,29 +1309,29 @@ void select_k(raft::resources const& res, constexpr int items_per_thread = 32; if (len <= BlockSize * items_per_thread) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { - impl::radix_topk(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - fused_last_filter, - len_i, - grid_dim, - sm_cnt, - stream, - mr); + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + len_i, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 572558153d..2cb32585d5 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -754,22 +754,32 @@ template