Skip to content

Commit

Permalink
Add mdspan input API, fix cmakelists
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Mar 13, 2024
1 parent cc2cf24 commit eb7e6d1
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 392 deletions.
9 changes: 6 additions & 3 deletions cpp/include/raft/matrix/detail/sample_rows.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/gather.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft::matrix {
namespace raft::matrix::detail {

/** Select rows randomly from input and copy to output. */
template <typename T, typename IdxT = int64_t>
void sample_rows(raft::resources const& res,
const T* input,
IdxT n_rows_input,
raft::device_matrix_view<T, IdxT> output,
RngState random_state)
random::RngState random_state)
{
IdxT n_dim = output.extent(1);
IdxT n_samples = output.extent(0);
Expand All @@ -51,4 +54,4 @@ void sample_rows(raft::resources const& res,
raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output);
}
}
} // namespace raft::matrix
} // namespace raft::matrix::detail
36 changes: 19 additions & 17 deletions cpp/include/raft/matrix/sample_rows.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,38 @@

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/matrix/gather.cuh>
#include <raft/core/resources.hpp>
#include <raft/matrix/detail/sample_rows.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft::matrix {

/** Select rows randomly from input and copy to output. */
template <typename T, typename IdxT = int64_t>
template <typename T, typename IdxT = int64_t, typename accessor>
void sample_rows(raft::resources const& res,
const T* input,
IdxT n_rows_input,
raft::device_matrix_view<T, IdxT> output,
RngState random_state)
random::RngState random_state,
mdspan<const T, matrix_extent<int64_t>, row_major, accessor> dataset,
raft::device_matrix_view<T, IdxT> output)
{
detail::sample_rows(res, input, n_rows_input, output, random_state);

detail::sample_rows(res, dataset.data_handle(), dataset.extent(0), output, random_state);
}

/** Subsample the dataset to create a training set*/
template <typename T, typename IdxT = int64_t>
raft::device_matrix<T, IdxT> sample_rows(raft::resources const& res,
const T* input,
IdxT n_rows_input,
IdxT n_train,
IdxT n_dim,
RngState random_state)
template <typename T, typename IdxT = int64_t, typename accessor>
raft::device_matrix<T, IdxT> sample_rows(
raft::resources const& res,
random::RngState random_state,
mdspan<const T, matrix_extent<int64_t>, row_major, accessor> dataset,
IdxT n_samples)
{
auto output = raft::make_device_matrix<T, IdxT>(res, n_train, n_dim);
detail::sample_rows(res, input, n_rows_input, output, random_state);
auto output = raft::make_device_matrix<T, IdxT>(res, n_samples, dataset.extent(1));
detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output);
return output;
}

} // namespace raft::matrix
Loading

0 comments on commit eb7e6d1

Please sign in to comment.