Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test select_random_vertices for all possible values of flags #4042

Merged
merged 14 commits into from
Dec 12, 2023
159 changes: 88 additions & 71 deletions cpp/tests/structure/mg_select_random_vertices_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class Tests_MGSelectRandomVertices
//

std::vector<bool> with_replacement_flags = {true, false};
std::vector<bool> sort_vertices_flags = {true, false};

{
// Generate distributed vertex set to sample from
std::srand((unsigned)std::chrono::duration_cast<std::chrono::milliseconds>(
Expand Down Expand Up @@ -107,80 +109,95 @@ class Tests_MGSelectRandomVertices
? select_random_vertices_usecase.select_count
: std::rand() % (num_of_elements_in_given_set + 1);

for (int idx = 0; idx < with_replacement_flags.size(); idx++) {
bool with_replacement = with_replacement_flags[idx];
auto d_sampled_vertices =
cugraph::select_random_vertices(*handle_,
mg_graph_view,
std::make_optional(raft::device_span<vertex_t const>{
d_given_set.data(), d_given_set.size()}),
rng_state,
select_count,
with_replacement,
true);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
for (int i = 0; i < with_replacement_flags.size(); i++) {
for (int j = 0; j < sort_vertices_flags.size(); j++) {
bool with_replacement = with_replacement_flags[i];
bool sort_vertices = sort_vertices_flags[j];

auto d_sampled_vertices =
cugraph::select_random_vertices(*handle_,
mg_graph_view,
std::make_optional(raft::device_span<vertex_t const>{
d_given_set.data(), d_given_set.size()}),
rng_state,
select_count,
with_replacement,
sort_vertices);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
}

std::sort(h_given_set.begin(), h_given_set.end());
if (sort_vertices) {
assert(std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end()));
} else {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());
}
std::for_each(
h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) {
ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v));
});
}

std::sort(h_given_set.begin(), h_given_set.end());
std::for_each(
h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) {
ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v));
});
}
}
}

//
// Test sampling from [0, V)
//

for (int idx = 0; idx < with_replacement_flags.size(); idx++) {
bool with_replacement = false;
auto d_sampled_vertices = cugraph::select_random_vertices(
*handle_,
mg_graph_view,
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
rng_state,
select_random_vertices_usecase.select_count,
with_replacement,
true);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
//
// Test sampling from [0, V)
//

for (int i = 0; i < with_replacement_flags.size(); i++) {
for (int j = 0; j < sort_vertices_flags.size(); j++) {
bool with_replacement = with_replacement_flags[i];
bool sort_vertices = sort_vertices_flags[j];

auto d_sampled_vertices = cugraph::select_random_vertices(
*handle_,
mg_graph_view,
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
rng_state,
select_random_vertices_usecase.select_count,
with_replacement,
sort_vertices);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
}
if (sort_vertices) {
assert(std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end()));
}

auto vertex_first = mg_graph_view.local_vertex_partition_range_first();
auto vertex_last = mg_graph_view.local_vertex_partition_range_last();
std::for_each(h_sampled_vertices.begin(),
h_sampled_vertices.end(),
[vertex_first, vertex_last](vertex_t v) {
ASSERT_TRUE((v >= vertex_first) && (v < vertex_last));
});
}
}

auto vertex_first = mg_graph_view.local_vertex_partition_range_first();
auto vertex_last = mg_graph_view.local_vertex_partition_range_last();

std::for_each(h_sampled_vertices.begin(),
h_sampled_vertices.end(),
[vertex_first, vertex_last](vertex_t v) {
ASSERT_TRUE((v >= vertex_first) && (v < vertex_last));
});
}
}
}
Expand Down Expand Up @@ -242,8 +259,8 @@ INSTANTIATE_TEST_SUITE_P(
factor (to avoid running same benchmarks more than once) */
Tests_MGSelectRandomVertices_Rmat,
::testing::Combine(
::testing::Values(SelectRandomVertices_Usecase{500, false},
SelectRandomVertices_Usecase{500, false}),
::testing::Values(SelectRandomVertices_Usecase{500, true},
SelectRandomVertices_Usecase{500, true}),
naimnv marked this conversation as resolved.
Show resolved Hide resolved
::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false))));

CUGRAPH_MG_TEST_PROGRAM_MAIN()
Loading