Skip to content

Commit

Permalink
Undo changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Naim committed Dec 12, 2023
1 parent 7692477 commit 0573351
Showing 1 changed file with 84 additions and 65 deletions.
149 changes: 84 additions & 65 deletions cpp/tests/structure/mg_select_random_vertices_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Tests_MGSelectRandomVertices
//

std::vector<bool> with_replacement_flags = {true, false};
std::vector<bool> sort_vertices_flags = {true, false};

{
// Generate distributed vertex set to sample from
Expand Down Expand Up @@ -108,75 +109,93 @@ class Tests_MGSelectRandomVertices
? select_random_vertices_usecase.select_count
: std::rand() % (num_of_elements_in_given_set + 1);

for (int idx = 0; idx < with_replacement_flags.size(); idx++) {
bool with_replacement = with_replacement_flags[idx];
auto d_sampled_vertices =
cugraph::select_random_vertices(*handle_,
mg_graph_view,
std::make_optional(raft::device_span<vertex_t const>{
d_given_set.data(), d_given_set.size()}),
rng_state,
select_count,
with_replacement,
true);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
for (int i = 0; i < with_replacement_flags.size(); i++) {
for (int j = 0; j < sort_vertices_flags.size(); j++) {
bool with_replacement = with_replacement_flags[i];
bool sort_vertices = sort_vertices_flags[j];

auto d_sampled_vertices =
cugraph::select_random_vertices(*handle_,
mg_graph_view,
std::make_optional(raft::device_span<vertex_t const>{
d_given_set.data(), d_given_set.size()}),
rng_state,
select_count,
with_replacement,
sort_vertices);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
}

std::sort(h_given_set.begin(), h_given_set.end());
if (sort_vertices) {
std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end());
} else {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());
}
std::for_each(
h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) {
ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v));
});
}

std::sort(h_given_set.begin(), h_given_set.end());
std::for_each(
h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) {
ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v));
});
}
}
}

//
// Test sampling from [0, V)
//

std::vector<bool> sort_vertices_flags = {true, false};

for (int i = 0; i < with_replacement_flags.size(); i++) {
for (int j = 0; j < sort_vertices_flags.size(); j++) {
bool with_replacement = with_replacement_flags[i];
bool sort_vertices = sort_vertices_flags[j];

auto d_sampled_vertices = cugraph::select_random_vertices(
*handle_,
mg_graph_view,
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
rng_state,
select_random_vertices_usecase.select_count,
with_replacement,
sort_vertices);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
//
// Test sampling from [0, V)
//

for (int i = 0; i < with_replacement_flags.size(); i++) {
for (int j = 0; j < sort_vertices_flags.size(); j++) {
bool with_replacement = with_replacement_flags[i];
bool sort_vertices = sort_vertices_flags[j];

auto d_sampled_vertices = cugraph::select_random_vertices(
*handle_,
mg_graph_view,
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
rng_state,
select_random_vertices_usecase.select_count,
with_replacement,
sort_vertices);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
}
if (sort_vertices) {
std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end());
}

auto vertex_first = mg_graph_view.local_vertex_partition_range_first();
auto vertex_last = mg_graph_view.local_vertex_partition_range_last();
std::for_each(h_sampled_vertices.begin(),
h_sampled_vertices.end(),
[vertex_first, vertex_last](vertex_t v) {
ASSERT_TRUE((v >= vertex_first) && (v < vertex_last));
});
}
}
}
Expand Down

0 comments on commit 0573351

Please sign in to comment.