From 7857f2fd433c958d51a533a8ffe5b1e7881b93f0 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 13 Mar 2024 10:15:12 +0100 Subject: [PATCH] corrections --- cpp/include/raft/matrix/sample_rows.cuh | 6 +++--- cpp/test/matrix/sample_rows.cu | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/matrix/sample_rows.cuh b/cpp/include/raft/matrix/sample_rows.cuh index 55b17800c7..67281ff297 100644 --- a/cpp/include/raft/matrix/sample_rows.cuh +++ b/cpp/include/raft/matrix/sample_rows.cuh @@ -29,7 +29,7 @@ namespace raft::matrix { template void sample_rows(raft::resources const& res, random::RngState random_state, - mdspan, row_major, accessor> dataset, + mdspan, row_major, accessor> dataset, raft::device_matrix_view output) { detail::sample_rows(res, input, n_rows_input, output, random_state); @@ -42,11 +42,11 @@ template raft::device_matrix sample_rows( raft::resources const& res, random::RngState random_state, - mdspan, row_major, accessor> dataset, + mdspan, row_major, accessor> dataset, IdxT n_samples) { auto output = raft::make_device_matrix(res, n_samples, dataset.extent(1)); - detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output); + detail::sample_rows(res, random_state, dataset, output.view()); return output; } diff --git a/cpp/test/matrix/sample_rows.cu b/cpp/test/matrix/sample_rows.cu index 80abeb7397..8d9be8e1e1 100644 --- a/cpp/test/matrix/sample_rows.cu +++ b/cpp/test/matrix/sample_rows.cu @@ -56,7 +56,8 @@ class SampleRowsTest : public ::testing::TestWithParam { void check() { - out = raft::matrix::sample_rows(res, state, make_const_mdspan(in.view())); + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in.view()), params.n_samples); ASSERT_TRUE(out.extent(0) == params.n_samples); ASSERT_TRUE(out.extent(1) == params.dim); // TODO(tfeher): check sampled values