From 9e8ae31d343690c9527c7fba77e1a07ca4016f01 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 4 Apr 2024 11:25:22 -0700 Subject: [PATCH] move public API to naming scope of sparse --- cpp/bench/prims/CMakeLists.txt | 2 +- .../prims/{matrix => sparse}/select_k_csr.cu | 7 +- cpp/include/raft/matrix/select_k.cuh | 48 ----------- cpp/include/raft/sparse/matrix/select_k.cuh | 81 +++++++++++++++++++ cpp/test/CMakeLists.txt | 2 +- cpp/test/{matrix => sparse}/select_k_csr.cu | 5 +- 6 files changed, 90 insertions(+), 55 deletions(-) rename cpp/bench/prims/{matrix => sparse}/select_k_csr.cu (97%) create mode 100644 cpp/include/raft/sparse/matrix/select_k.cuh rename cpp/test/{matrix => sparse}/select_k_csr.cu (98%) diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 1b28e7d0b9..063d69a737 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -128,7 +128,6 @@ if(BUILD_PRIMS_BENCH) bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu bench/prims/matrix/select_k.cu - bench/prims/matrix/select_k_csr.cu bench/prims/main.cpp OPTIONAL LIB @@ -146,6 +145,7 @@ if(BUILD_PRIMS_BENCH) PATH bench/prims/sparse/bitmap_to_csr.cu bench/prims/sparse/convert_csr.cu + bench/prims/sparse/select_k_csr.cu bench/prims/main.cpp ) diff --git a/cpp/bench/prims/matrix/select_k_csr.cu b/cpp/bench/prims/sparse/select_k_csr.cu similarity index 97% rename from cpp/bench/prims/matrix/select_k_csr.cu rename to cpp/bench/prims/sparse/select_k_csr.cu index 4ab706f471..a91e6c8514 100644 --- a/cpp/bench/prims/matrix/select_k_csr.cu +++ b/cpp/bench/prims/sparse/select_k_csr.cu @@ -22,10 +22,10 @@ #include #include #include -#include #include #include #include +#include #include #include @@ -203,10 +203,11 @@ struct SelectKCsrTest : public fixture { auto out_idx = raft::make_device_matrix_view( dst_indices_d.data(), params.n_rows, params.top_k); - raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + raft::sparse::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); resource::sync_stream(handle); loop_on_state(state, [this, &in_val, &in_idx, &out_val, &out_idx]() { - raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min, false); + raft::sparse::matrix::select_k( + handle, in_val, in_idx, out_val, out_idx, params.select_min, false); resource::sync_stream(handle); }); } diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 6a8cab5c90..2efa146495 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -18,7 +18,6 @@ #include "detail/select_k.cuh" -#include #include #include #include @@ -117,53 +116,6 @@ void select_k(raft::resources const& handle, algo); } -/** - * Selects the k smallest or largest keys/values from each row of the input matrix. - * - * This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`, - * selecting the k smallest or largest elements from each row. The selected elements are then stored - * in a row-major output matrix `out_val` with dimensions `batch_size` x k. - * If the total number of values in a row is less than K, then the extra position in the - * corresponding row of out_val will maintain the original value. This applies to out_idx - * - * @tparam T - * Type of the elements being compared (keys). - * @tparam IdxT - * Type of the indices associated with the keys. - * - * @param[in] handle - * Container for managing reusable resources. - * @param[in] in_val - * Input matrix in CSR format with a logical dense shape of [batch_size, len], - * containing the elements to be compared and selected. - * @param[in] in_idx - * Optional input indices [in_val.nnz] associated with `in_val.values`. - * If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1. - * @param[out] out_val - * Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements - * from each row of `in_val`. - * @param[out] out_idx - * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. - * @param[in] select_min - * Flag indicating whether to select the k smallest (true) or largest (false) elements. - * @param[in] sorted - * whether to make sure selected pairs are sorted by value - * @param[in] algo - * the selection algorithm to use - */ -template -void select_k(raft::resources const& handle, - raft::device_csr_matrix_view in_val, - std::optional> in_idx, - raft::device_matrix_view out_val, - raft::device_matrix_view out_idx, - bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) -{ - return detail::select_k( - handle, in_val, in_idx, out_val, out_idx, select_min, sorted, algo); -} /** @} */ // end of group select_k } // namespace raft::matrix diff --git a/cpp/include/raft/sparse/matrix/select_k.cuh b/cpp/include/raft/sparse/matrix/select_k.cuh new file mode 100644 index 0000000000..f6c8bbe0c7 --- /dev/null +++ b/cpp/include/raft/sparse/matrix/select_k.cuh @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::sparse::matrix { + +using SelectAlgo = raft::matrix::SelectAlgo; +/** + * Selects the k smallest or largest keys/values from each row of the input matrix. + * + * This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`, + * selecting the k smallest or largest elements from each row. The selected elements are then stored + * in a row-major output matrix `out_val` with dimensions `batch_size` x k. + * If the total number of values in a row is less than K, then the extra position in the + * corresponding row of out_val will maintain the original value. This applies to out_idx + * + * @tparam T + * Type of the elements being compared (keys). + * @tparam IdxT + * Type of the indices associated with the keys. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] in_val + * Input matrix in CSR format with a logical dense shape of [batch_size, len], + * containing the elements to be compared and selected. + * @param[in] in_idx + * Optional input indices [in_val.nnz] associated with `in_val.values`. + * If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1. + * @param[out] out_val + * Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements + * from each row of `in_val`. + * @param[out] out_idx + * Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`. + * @param[in] select_min + * Flag indicating whether to select the k smallest (true) or largest (false) elements. + * @param[in] sorted + * whether to make sure selected pairs are sorted by value + * @param[in] algo + * the selection algorithm to use + */ +template +void select_k(raft::resources const& handle, + raft::device_csr_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min, + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) +{ + return raft::matrix::detail::select_k( + handle, in_val, in_idx, out_val, out_idx, select_min, sorted, algo); +} +/** @} */ // end of group select_k + +} // namespace raft::sparse::matrix diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index be7e469da6..17990700e6 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -279,7 +279,6 @@ if(BUILD_TESTS) NAME MATRIX_SELECT_TEST PATH test/matrix/select_k.cu - PATH test/matrix/select_k_csr.cu LIB EXPLICIT_INSTANTIATE_ONLY) @@ -326,6 +325,7 @@ if(BUILD_TESTS) test/sparse/reduce.cu test/sparse/row_op.cu test/sparse/sddmm.cu + test/sparse/select_k_csr.cu test/sparse/sort.cu test/sparse/spgemmi.cu test/sparse/spmm.cu diff --git a/cpp/test/matrix/select_k_csr.cu b/cpp/test/sparse/select_k_csr.cu similarity index 98% rename from cpp/test/matrix/select_k_csr.cu rename to cpp/test/sparse/select_k_csr.cu index ed58e6d80d..fc1061d7bb 100644 --- a/cpp/test/matrix/select_k_csr.cu +++ b/cpp/test/sparse/select_k_csr.cu @@ -22,9 +22,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -298,7 +298,8 @@ class SelectKCsrTest : public ::testing::TestWithParam auto out_idx = raft::make_device_matrix_view( dst_indices_d.data(), params.n_rows, params.top_k); - raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min, true); + raft::sparse::matrix::select_k( + handle, in_val, in_idx, out_val, out_idx, params.select_min, true); ASSERT_TRUE(raft::devArrMatch(dst_indices_expected_d.data(), out_idx.data_handle(),