Skip to content

Commit

Permalink
corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Mar 13, 2024
1 parent eb7e6d1 commit 7857f2f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions cpp/include/raft/matrix/sample_rows.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace raft::matrix {
template <typename T, typename IdxT = int64_t, typename accessor>
void sample_rows(raft::resources const& res,
random::RngState random_state,
mdspan<const T, matrix_extent<int64_t>, row_major, accessor> dataset,
mdspan<const T, matrix_extent<IdxT>, row_major, accessor> dataset,
raft::device_matrix_view<T, IdxT> output)
{
detail::sample_rows(res, input, n_rows_input, output, random_state);
Expand All @@ -42,11 +42,11 @@ template <typename T, typename IdxT = int64_t, typename accessor>
raft::device_matrix<T, IdxT> sample_rows(
raft::resources const& res,
random::RngState random_state,
mdspan<const T, matrix_extent<int64_t>, row_major, accessor> dataset,
mdspan<const T, matrix_extent<IdxT>, row_major, accessor> dataset,
IdxT n_samples)
{
auto output = raft::make_device_matrix<T, IdxT>(res, n_samples, dataset.extent(1));
detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output);
detail::sample_rows(res, random_state, dataset, output.view());
return output;
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/test/matrix/sample_rows.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class SampleRowsTest : public ::testing::TestWithParam<inputs> {

void check()
{
out = raft::matrix::sample_rows<T, int64_t>(res, state, make_const_mdspan(in.view()));
out = raft::matrix::sample_rows<T, int64_t>(
res, state, make_const_mdspan(in.view()), params.n_samples);
ASSERT_TRUE(out.extent(0) == params.n_samples);
ASSERT_TRUE(out.extent(1) == params.dim);
// TODO(tfeher): check sampled values
Expand Down

0 comments on commit 7857f2f

Please sign in to comment.