Skip to content

Commit

Permalink
Add test to sample_rows
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Mar 13, 2024
1 parent 7857f2f commit 93ff94f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 18 deletions.
4 changes: 2 additions & 2 deletions cpp/include/raft/matrix/detail/sample_rows.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ namespace raft::matrix::detail {
/** Select rows randomly from input and copy to output. */
template <typename T, typename IdxT = int64_t>
void sample_rows(raft::resources const& res,
random::RngState random_state,
const T* input,
IdxT n_rows_input,
raft::device_matrix_view<T, IdxT> output,
random::RngState random_state)
raft::device_matrix_view<T, IdxT> output)
{
IdxT n_dim = output.extent(1);
IdxT n_samples = output.extent(0);
Expand Down
6 changes: 2 additions & 4 deletions cpp/include/raft/matrix/sample_rows.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ void sample_rows(raft::resources const& res,
mdspan<const T, matrix_extent<IdxT>, row_major, accessor> dataset,
raft::device_matrix_view<T, IdxT> output)
{
detail::sample_rows(res, input, n_rows_input, output, random_state);

detail::sample_rows(res, dataset.data_handle(), dataset.extent(0), output, random_state);
detail::sample_rows<T, IdxT>(res, random_state, dataset.data_handle(), dataset.extent(0), output);
}

/** Subsample the dataset to create a training set*/
Expand All @@ -46,7 +44,7 @@ raft::device_matrix<T, IdxT> sample_rows(
IdxT n_samples)
{
auto output = raft::make_device_matrix<T, IdxT>(res, n_samples, dataset.extent(1));
detail::sample_rows(res, random_state, dataset, output.view());
sample_rows(res, random_state, dataset, output.view());
return output;
}

Expand Down
78 changes: 67 additions & 11 deletions cpp/test/matrix/sample_rows.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,31 @@

#include <raft/core/device_mdarray.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/sample_rows.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/itertools.hpp>

#include <gtest/gtest.h>

#include <unordered_set>

namespace raft {
namespace matrix {

struct inputs {
int N;
int dim;
int n_samples;
bool host;
};

::std::ostream& operator<<(::std::ostream& os, const inputs p)
{
os << p.N << "#" << p.dim << "#" << p.n_samples;
os << p.N << "#" << p.dim << "#" << p.n_samples << (p.host ? "#host" : "#device");
return os;
}

Expand All @@ -46,38 +51,89 @@ class SampleRowsTest : public ::testing::TestWithParam<inputs> {
public:
SampleRowsTest()
: params(::testing::TestWithParam<inputs>::GetParam()),
stream(resource::get_cuda_stream(res)),
state{137ULL},
in(make_device_matrix<T, int64_t>(res, params.N, params.dim)),
out(make_device_matrix<T, int64_t>(res, 0, 0))

out(make_device_matrix<T, int64_t>(res, 0, 0)),
in_h(make_host_matrix<T, int64_t>(res, params.N, params.dim)),
out_h(make_host_matrix<T, int64_t>(res, params.n_samples, params.dim))
{
raft::random::uniform(res, state, in.data_handle(), in.size(), T(-1.0), T(1.0));
for (int64_t i = 0; i < params.N; i++) {
for (int64_t k = 0; k < params.dim; k++)
in_h(i, k) = i * 1000 + k;
}
raft::copy(in.data_handle(), in_h.data_handle(), in_h.size(), stream);
}

void check()
{
out = raft::matrix::sample_rows<T, int64_t>(
res, state, make_const_mdspan(in.view()), params.n_samples);
if (params.host) {
out = raft::matrix::sample_rows<T, int64_t>(
res, state, make_const_mdspan(in_h.view()), (int64_t)params.n_samples);
} else {
out = raft::matrix::sample_rows<T, int64_t>(
res, state, make_const_mdspan(in.view()), (int64_t)params.n_samples);
}

raft::copy(out_h.data_handle(), out.data_handle(), out.size(), stream);
resource::sync_stream(res, stream);

ASSERT_TRUE(out.extent(0) == params.n_samples);
ASSERT_TRUE(out.extent(1) == params.dim);
// TODO(tfeher): check sampled values
// TODO(tfeher): check host / device input

std::unordered_set<int> occurrence;

for (int64_t i = 0; i < params.n_samples; ++i) {
int val = (int)out_h(i, 0) / 1000;
ASSERT_TRUE(0 <= val && val < params.N)
<< "out-of-range index @i=" << i << " val=" << val << " params=" << params;
ASSERT_TRUE(occurrence.find(val) == occurrence.end())
<< "repeated index @i=" << i << " idx=" << val << " params=" << params;
occurrence.insert(val);
for (int64_t k = 0; k < params.dim; k++) {
ASSERT_TRUE(raft::match((int64_t)(out_h(i, k)), val * 1000 + k, raft::Compare<int64_t>()));
}
}
}

protected:
inputs params;
raft::resources res;
cudaStream_t stream;
random::RngState state;
device_matrix<T, int64_t> out, in;
device_matrix<T, int64_t> in, out;
host_matrix<T, int64_t> in_h, out_h;
};

const std::vector<inputs> input1 = {
{10, 1, 1}, {10, 4, 1}, {10, 4, 10}, {10, 10}, {137, 42, 59}, {10000, 128, 893}};
inline std::vector<inputs> generate_inputs()
{
std::vector<inputs> input1 =
raft::util::itertools::product<inputs>({10}, {1, 17, 96}, {1, 6, 9, 10}, {false});

std::vector<inputs> input2 =
raft::util::itertools::product<inputs>({137}, {1, 17, 128}, {1, 10, 100, 137}, {false});
input1.insert(input1.end(), input2.begin(), input2.end());

input2 = raft::util::itertools::product<inputs>(
{100000}, {1, 42}, {1, 137, 1000, 10000, 100000}, {false});
input1.insert(input1.end(), input2.begin(), input2.end());

int n = input1.size();
// Add same tests for host data
for (int i = 0; i < n; i++) {
inputs x = input1[i];
x.host = true;
input1.push_back(x);
}
return input1;
}

const std::vector<inputs> inputs1 = generate_inputs();

using SampleRowsTestInt64 = SampleRowsTest<float>;
TEST_P(SampleRowsTestInt64, SamplingTest) { check(); }
INSTANTIATE_TEST_SUITE_P(SampleRowsTests, SampleRowsTestInt64, ::testing::ValuesIn(input1));
INSTANTIATE_TEST_SUITE_P(SampleRowsTests, SampleRowsTestInt64, ::testing::ValuesIn(inputs1));

} // namespace matrix
} // namespace raft
3 changes: 2 additions & 1 deletion cpp/test/random/excess_sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ExcessSamplingTest : public ::testing::TestWithParam<inputs> {
public:
ExcessSamplingTest()
: params(::testing::TestWithParam<inputs>::GetParam()),
stream(resource::get_cuda_stream(res)),
state{137ULL},
in(make_device_vector<T, int64_t>(res, params.n_samples)),
out(make_device_vector<T, int64_t>(res, 0)),
Expand Down Expand Up @@ -89,7 +90,7 @@ class ExcessSamplingTest : public ::testing::TestWithParam<inputs> {
raft::resources res;
cudaStream_t stream;
RngState state;
device_vector<T, int64_t> out, in;
device_vector<T, int64_t> in, out;
host_vector<T, int64_t> h_out;
};

Expand Down

0 comments on commit 93ff94f

Please sign in to comment.