Skip to content

Commit

Permalink
move public API to naming scope of sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Apr 4, 2024
1 parent 96ded4d commit 9e8ae31
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 55 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ if(BUILD_PRIMS_BENCH)
bench/prims/matrix/argmin.cu
bench/prims/matrix/gather.cu
bench/prims/matrix/select_k.cu
bench/prims/matrix/select_k_csr.cu
bench/prims/main.cpp
OPTIONAL
LIB
Expand All @@ -146,6 +145,7 @@ if(BUILD_PRIMS_BENCH)
PATH
bench/prims/sparse/bitmap_to_csr.cu
bench/prims/sparse/convert_csr.cu
bench/prims/sparse/select_k_csr.cu
bench/prims/main.cpp
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/copy.cuh>
#include <raft/matrix/select_k.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/matrix/select_k.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/itertools.hpp>

Expand Down Expand Up @@ -203,10 +203,11 @@ struct SelectKCsrTest : public fixture {
auto out_idx = raft::make_device_matrix_view<index_t, index_t, raft::row_major>(
dst_indices_d.data(), params.n_rows, params.top_k);

raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min);
raft::sparse::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min);
resource::sync_stream(handle);
loop_on_state(state, [this, &in_val, &in_idx, &out_val, &out_idx]() {
raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min, false);
raft::sparse::matrix::select_k(
handle, in_val, in_idx, out_val, out_idx, params.select_min, false);
resource::sync_stream(handle);
});
}
Expand Down
48 changes: 0 additions & 48 deletions cpp/include/raft/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include "detail/select_k.cuh"

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
Expand Down Expand Up @@ -117,53 +116,6 @@ void select_k(raft::resources const& handle,
algo);
}

/**
* Selects the k smallest or largest keys/values from each row of the input matrix.
*
* This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`,
* selecting the k smallest or largest elements from each row. The selected elements are then stored
* in a row-major output matrix `out_val` with dimensions `batch_size` x k.
* If the total number of values in a row is less than K, then the extra position in the
* corresponding row of out_val will maintain the original value. This applies to out_idx
*
* @tparam T
* Type of the elements being compared (keys).
* @tparam IdxT
* Type of the indices associated with the keys.
*
* @param[in] handle
* Container for managing reusable resources.
* @param[in] in_val
* Input matrix in CSR format with a logical dense shape of [batch_size, len],
* containing the elements to be compared and selected.
* @param[in] in_idx
* Optional input indices [in_val.nnz] associated with `in_val.values`.
* If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1.
* @param[out] out_val
* Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements
* from each row of `in_val`.
* @param[out] out_idx
* Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`.
* @param[in] select_min
* Flag indicating whether to select the k smallest (true) or largest (false) elements.
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> in_val,
std::optional<raft::device_vector_view<const IdxT, IdxT>> in_idx,
raft::device_matrix_view<T, IdxT, raft::row_major> out_val,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
return detail::select_k<T, IdxT>(
handle, in_val, in_idx, out_val, out_idx, select_min, sorted, algo);
}
/** @} */ // end of group select_k

} // namespace raft::matrix
81 changes: 81 additions & 0 deletions cpp/include/raft/sparse/matrix/select_k.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/detail/select_k.cuh>
#include <raft/matrix/select_k_types.hpp>

#include <optional>

namespace raft::sparse::matrix {

using SelectAlgo = raft::matrix::SelectAlgo;
/**
* Selects the k smallest or largest keys/values from each row of the input matrix.
*
* This function operates on a row-major matrix `in_val` with dimensions `batch_size` x `len`,
* selecting the k smallest or largest elements from each row. The selected elements are then stored
* in a row-major output matrix `out_val` with dimensions `batch_size` x k.
* If the total number of values in a row is less than K, then the extra position in the
* corresponding row of out_val will maintain the original value. This applies to out_idx
*
* @tparam T
* Type of the elements being compared (keys).
* @tparam IdxT
* Type of the indices associated with the keys.
*
* @param[in] handle
* Container for managing reusable resources.
* @param[in] in_val
* Input matrix in CSR format with a logical dense shape of [batch_size, len],
* containing the elements to be compared and selected.
* @param[in] in_idx
* Optional input indices [in_val.nnz] associated with `in_val.values`.
* If `in_idx` is `std::nullopt`, it defaults to a contiguous array from 0 to len-1.
* @param[out] out_val
* Output matrix [in_val.get_n_row(), k] storing the selected k smallest/largest elements
* from each row of `in_val`.
* @param[out] out_idx
* Output indices [in_val.get_n_row(), k] corresponding to the selected elements in `out_val`.
* @param[in] select_min
* Flag indicating whether to select the k smallest (true) or largest (false) elements.
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
raft::device_csr_matrix_view<const T, IdxT, IdxT, IdxT> in_val,
std::optional<raft::device_vector_view<const IdxT, IdxT>> in_idx,
raft::device_matrix_view<T, IdxT, raft::row_major> out_val,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
return raft::matrix::detail::select_k<T, IdxT>(
handle, in_val, in_idx, out_val, out_idx, select_min, sorted, algo);
}
/** @} */ // end of group select_k

} // namespace raft::sparse::matrix
2 changes: 1 addition & 1 deletion cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ if(BUILD_TESTS)
NAME
MATRIX_SELECT_TEST
PATH test/matrix/select_k.cu
PATH test/matrix/select_k_csr.cu
LIB
EXPLICIT_INSTANTIATE_ONLY)

Expand Down Expand Up @@ -326,6 +325,7 @@ if(BUILD_TESTS)
test/sparse/reduce.cu
test/sparse/row_op.cu
test/sparse/sddmm.cu
test/sparse/select_k_csr.cu
test/sparse/sort.cu
test/sparse/spgemmi.cu
test/sparse/spmm.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/copy.cuh>
#include <raft/matrix/select_k.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/sparse/matrix/select_k.cuh>
#include <raft/util/cuda_utils.cuh>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -298,7 +298,8 @@ class SelectKCsrTest : public ::testing::TestWithParam<SelectKCsrInputs<index_t>
auto out_idx = raft::make_device_matrix_view<index_t, index_t, raft::row_major>(
dst_indices_d.data(), params.n_rows, params.top_k);

raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min, true);
raft::sparse::matrix::select_k(
handle, in_val, in_idx, out_val, out_idx, params.select_min, true);

ASSERT_TRUE(raft::devArrMatch<index_t>(dst_indices_expected_d.data(),
out_idx.data_handle(),
Expand Down

0 comments on commit 9e8ae31

Please sign in to comment.