From df3e4ffa62dd873f5e9df23c2af7e8f7fc6b90ab Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Sat, 28 Sep 2024 08:28:26 -0700 Subject: [PATCH] add tests for heterogeneous uniform/biased neighborhood sampling --- cpp/tests/CMakeLists.txt | 13 +- ...heterogeneous_biased_neighbor_sampling.cpp | 340 ++++++++++++++++++ ...eterogeneous_uniform_neighbor_sampling.cpp | 339 +++++++++++++++++ 3 files changed, 690 insertions(+), 2 deletions(-) create mode 100644 cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp create mode 100644 cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 7007344261e..bfcab9a026b 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -481,19 +481,28 @@ ConfigureTest(RANDOM_WALKS_TEST sampling/sg_random_walks_test.cpp) # - UNIFORM NBR SAMPLING tests -------------------------------------------------------------------- ConfigureTest(UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/uniform_neighbor_sampling.cpp) -# - HOMOGENEOUS UNIFORM NBR SAMPLING tests -------------------------------------------------------------------- +# - HOMOGENEOUS UNIFORM NBR SAMPLING tests -------------------------------------------------------- ConfigureTest( HOMOGENEOUS_UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/homogeneous_uniform_neighbor_sampling.cpp) +# - HETEROGENEOUS UNIFORM NBR SAMPLING tests ----------------------------------------------------- +ConfigureTest( + HETEROGENEOUS_UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/heterogeneous_uniform_neighbor_sampling.cpp) + ################################################################################################### # - BIASED NBR SAMPLING tests --------------------------------------------------------------------- ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cpp) ################################################################################################### -# - HOMOGENEOUS BIASED NBR SAMPLING tests --------------------------------------------------------------------- +# - HOMOGENEOUS BIASED NBR SAMPLING tests --------------------------------------------------------- ConfigureTest( HOMOGENEOUS_BIASED_NEIGHBOR_SAMPLING_TEST sampling/homogeneous_biased_neighbor_sampling.cpp) +################################################################################################### +# - HETEROGENEOUS BIASED NBR SAMPLING tests ------------------------------------------------------- +ConfigureTest( + HETEROGENEOUS_BIASED_NEIGHBOR_SAMPLING_TEST sampling/heterogeneous_biased_neighbor_sampling.cpp) + ################################################################################################### # - SAMPLING_POST_PROCESSING tests ---------------------------------------------------------------- ConfigureTest(SAMPLING_POST_PROCESSING_TEST sampling/sampling_post_processing_test.cpp) diff --git a/cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp b/cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp new file mode 100644 index 00000000000..ca6f1da2ad3 --- /dev/null +++ b/cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/nbr_sampling_validate.hpp" +#include "utilities/base_fixture.hpp" +#include "utilities/property_generator_utilities.hpp" + +#include +#include + +#include + +struct Heterogeneous_Biased_Neighbor_Sampling_Usecase { + std::vector fanout{{-1}}; + int32_t batch_size{10}; + int32_t num_edge_types{1}; + bool flag_replacement{true}; + + bool check_correctness{true}; +}; + +template +class Tests_Heterogeneous_Biased_Neighbor_Sampling + : public ::testing::TestWithParam< + std::tuple> { + public: + Tests_Heterogeneous_Biased_Neighbor_Sampling() {} + + static void SetUpTestCase() {} + static void TearDownTestCase() {} + + virtual void SetUp() {} + virtual void TearDown() {} + + template + void run_current_test( + std::tuple const& param) + { + using edge_type_t = int32_t; + + auto [heterogeneous_biased_neighbor_sampling_usecase, input_usecase] = param; + + raft::handle_t handle{}; + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Construct graph"); + } + + auto [graph, edge_weights, renumber_map_labels] = + cugraph::test::construct_graph( + handle, input_usecase, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto graph_view = graph.view(); + auto edge_weight_view = + edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; + + std::optional> edge_mask{std::nullopt}; + + constexpr float select_probability{0.05}; + + // FIXME: Update the tests to initialize RngState and use it instead + // of seed... + constexpr uint64_t seed{0}; + + raft::random::RngState rng_state(seed); + + auto random_sources = cugraph::select_random_vertices( + handle, + graph_view, + std::optional>{std::nullopt}, + rng_state, + std::max(static_cast(graph_view.number_of_vertices() * select_probability), + std::min(static_cast(graph_view.number_of_vertices()), size_t{1})), + false, + false); + + // + // Now we'll assign the vertices to batches + // + + auto batch_number = std::make_optional>(0, handle.get_stream()); + + batch_number = cugraph::test::sequence( + handle, random_sources.size(), heterogeneous_biased_neighbor_sampling_usecase.batch_size, int32_t{0}); + + rmm::device_uvector random_sources_copy(random_sources.size(), handle.get_stream()); + + raft::copy(random_sources_copy.data(), + random_sources.data(), + random_sources.size(), + handle.get_stream()); + + std::optional> + label_to_output_comm_rank_mapping{std::nullopt}; + + // Generate the edge types + + std::optional> edge_types{ + std::nullopt}; + + if (heterogeneous_biased_neighbor_sampling_usecase.num_edge_types > 1) { + edge_types = cugraph::test::generate::edge_property( + handle, + graph_view, + heterogeneous_biased_neighbor_sampling_usecase.num_edge_types); + } + +#ifdef NO_CUGRAPH_OPS + EXPECT_THROW( + cugraph::heterogeneous_biased_neighbor_sample( + handle, + graph_view, + edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + raft::device_span{random_sources_copy.data(), random_sources.size()}, + batch_number ? std::make_optional(raft::device_span{batch_number->data(), + batch_number->size()}) + : std::nullopt, + label_to_output_comm_rank_mapping, + raft::host_span(heterogeneous_biased_neighbor_sampling_usecase.fanout.data(), + heterogeneous_biased_neighbor_sampling_usecase.fanout.size()), + rng_state, + true, + heterogeneous_biased_neighbor_sampling_usecase.flag_replacement), + std::exception); +#else + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Biased neighbor sampling"); + } + + auto&& [src_out, dst_out, wgt_out, edge_id, edge_type, hop, offsets] = + cugraph::heterogeneous_biased_neighbor_sample( + handle, + rng_state, + graph_view, + edge_weight_view, + std::optional>{std::nullopt}, + edge_types + ? std::optional>{(*edge_types) + .view()} + : std::nullopt, + *edge_weight_view, + raft::device_span{random_sources_copy.data(), random_sources.size()}, + batch_number ? std::make_optional(raft::device_span{batch_number->data(), + batch_number->size()}) + : std::nullopt, + label_to_output_comm_rank_mapping, + raft::host_span(heterogeneous_biased_neighbor_sampling_usecase.fanout.data(), + heterogeneous_biased_neighbor_sampling_usecase.fanout.size()), + heterogeneous_biased_neighbor_sampling_usecase.num_edge_types, + cugraph::sampling_flags_t{ + cugraph::prior_sources_behavior_t{0}, + true, // return_hops + false, // dedupe_sources + heterogeneous_biased_neighbor_sampling_usecase.flag_replacement + } + ); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (heterogeneous_biased_neighbor_sampling_usecase.check_correctness) { + // First validate that the extracted edges are actually a subset of the + // edges in the input graph + rmm::device_uvector vertices(2 * src_out.size(), handle.get_stream()); + raft::copy(vertices.data(), src_out.data(), src_out.size(), handle.get_stream()); + raft::copy( + vertices.data() + src_out.size(), dst_out.data(), dst_out.size(), handle.get_stream()); + vertices = cugraph::test::sort(handle, std::move(vertices)); + vertices = cugraph::test::unique(handle, std::move(vertices)); + + rmm::device_uvector d_subgraph_offsets(2, handle.get_stream()); + std::vector h_subgraph_offsets({0, vertices.size()}); + + raft::update_device(d_subgraph_offsets.data(), + h_subgraph_offsets.data(), + h_subgraph_offsets.size(), + handle.get_stream()); + + rmm::device_uvector src_compare(0, handle.get_stream()); + rmm::device_uvector dst_compare(0, handle.get_stream()); + std::optional> wgt_compare{std::nullopt}; + + std::tie(src_compare, dst_compare, wgt_compare, std::ignore) = extract_induced_subgraphs( + handle, + graph_view, + edge_weight_view, + raft::device_span(d_subgraph_offsets.data(), 2), + raft::device_span(vertices.data(), vertices.size()), + true); + + + + + ASSERT_TRUE(cugraph::test::validate_extracted_graph_is_subgraph( + handle, src_compare, dst_compare, wgt_compare, src_out, dst_out, wgt_out)); + + if (random_sources.size() < 100) { + // This validation is too expensive for large number of vertices + ASSERT_TRUE( + cugraph::test::validate_sampling_depth(handle, + std::move(src_out), + std::move(dst_out), + std::move(wgt_out), + std::move(random_sources), + heterogeneous_biased_neighbor_sampling_usecase.fanout.size())); + } + } +#endif + } +}; + +using Tests_Heterogeneous_Biased_Neighbor_Sampling_File = + Tests_Heterogeneous_Biased_Neighbor_Sampling; +//#if 0 +using Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat = + Tests_Heterogeneous_Biased_Neighbor_Sampling; +//#endif + +TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int32Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +#if 0 +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Heterogeneous_Biased_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); +#endif + +//#if 0 +TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int64Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt64Int64Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Heterogeneous_Biased_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_Heterogeneous_Biased_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, + ::testing::Combine( + ::testing::Values(Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, + ::testing::Combine( + ::testing::Values(Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, false, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, false, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, true, false}, + Heterogeneous_Biased_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, true, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false, 0)))); +//#endif + +CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp b/cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp new file mode 100644 index 00000000000..0b4cf128975 --- /dev/null +++ b/cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp @@ -0,0 +1,339 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/nbr_sampling_validate.hpp" +#include "utilities/base_fixture.hpp" +#include "utilities/property_generator_utilities.hpp" + +#include +#include + +#include + +struct Heterogeneous_Uniform_Neighbor_Sampling_Usecase { + std::vector fanout{{-1}}; + int32_t batch_size{10}; + int32_t num_edge_types{1}; + bool flag_replacement{true}; + + bool check_correctness{true}; +}; + +template +class Tests_Heterogeneous_Uniform_Neighbor_Sampling + : public ::testing::TestWithParam< + std::tuple> { + public: + Tests_Heterogeneous_Uniform_Neighbor_Sampling() {} + + static void SetUpTestCase() {} + static void TearDownTestCase() {} + + virtual void SetUp() {} + virtual void TearDown() {} + + template + void run_current_test( + std::tuple const& param) + { + using edge_type_t = int32_t; + + auto [heterogeneous_uniform_neighbor_sampling_usecase, input_usecase] = param; + + raft::handle_t handle{}; + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Construct graph"); + } + + auto [graph, edge_weights, renumber_map_labels] = + cugraph::test::construct_graph( + handle, input_usecase, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto graph_view = graph.view(); + auto edge_weight_view = + edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; + + std::optional> edge_mask{std::nullopt}; + + constexpr float select_probability{0.05}; + + // FIXME: Update the tests to initialize RngState and use it instead + // of seed... + constexpr uint64_t seed{0}; + + raft::random::RngState rng_state(seed); + + auto random_sources = cugraph::select_random_vertices( + handle, + graph_view, + std::optional>{std::nullopt}, + rng_state, + std::max(static_cast(graph_view.number_of_vertices() * select_probability), + std::min(static_cast(graph_view.number_of_vertices()), size_t{1})), + false, + false); + + // + // Now we'll assign the vertices to batches + // + + auto batch_number = std::make_optional>(0, handle.get_stream()); + + batch_number = cugraph::test::sequence( + handle, random_sources.size(), heterogeneous_uniform_neighbor_sampling_usecase.batch_size, int32_t{0}); + + rmm::device_uvector random_sources_copy(random_sources.size(), handle.get_stream()); + + raft::copy(random_sources_copy.data(), + random_sources.data(), + random_sources.size(), + handle.get_stream()); + + std::optional> + label_to_output_comm_rank_mapping{std::nullopt}; + + // Generate the edge types + + std::optional> edge_types{ + std::nullopt}; + + if (heterogeneous_uniform_neighbor_sampling_usecase.num_edge_types > 1) { + edge_types = cugraph::test::generate::edge_property( + handle, + graph_view, + heterogeneous_uniform_neighbor_sampling_usecase.num_edge_types); + } + +#ifdef NO_CUGRAPH_OPS + EXPECT_THROW( + cugraph::heterogeneous_uniform_neighbor_sample( + handle, + graph_view, + edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + raft::device_span{random_sources_copy.data(), random_sources.size()}, + batch_number ? std::make_optional(raft::device_span{batch_number->data(), + batch_number->size()}) + : std::nullopt, + label_to_output_comm_rank_mapping, + raft::host_span(heterogeneous_uniform_neighbor_sampling_usecase.fanout.data(), + heterogeneous_uniform_neighbor_sampling_usecase.fanout.size()), + rng_state, + true, + heterogeneous_uniform_neighbor_sampling_usecase.flag_replacement), + std::exception); +#else + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Uniform neighbor sampling"); + } + + auto&& [src_out, dst_out, wgt_out, edge_id, edge_type, hop, offsets] = + cugraph::heterogeneous_uniform_neighbor_sample( + handle, + rng_state, + graph_view, + edge_weight_view, + std::optional>{std::nullopt}, + edge_types + ? std::optional>{(*edge_types) + .view()} + : std::nullopt, + raft::device_span{random_sources_copy.data(), random_sources.size()}, + batch_number ? std::make_optional(raft::device_span{batch_number->data(), + batch_number->size()}) + : std::nullopt, + label_to_output_comm_rank_mapping, + raft::host_span(heterogeneous_uniform_neighbor_sampling_usecase.fanout.data(), + heterogeneous_uniform_neighbor_sampling_usecase.fanout.size()), + heterogeneous_uniform_neighbor_sampling_usecase.num_edge_types, + cugraph::sampling_flags_t{ + cugraph::prior_sources_behavior_t{0}, + true, // return_hops + false, // dedupe_sources + heterogeneous_uniform_neighbor_sampling_usecase.flag_replacement + } + ); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (heterogeneous_uniform_neighbor_sampling_usecase.check_correctness) { + // First validate that the extracted edges are actually a subset of the + // edges in the input graph + rmm::device_uvector vertices(2 * src_out.size(), handle.get_stream()); + raft::copy(vertices.data(), src_out.data(), src_out.size(), handle.get_stream()); + raft::copy( + vertices.data() + src_out.size(), dst_out.data(), dst_out.size(), handle.get_stream()); + vertices = cugraph::test::sort(handle, std::move(vertices)); + vertices = cugraph::test::unique(handle, std::move(vertices)); + + rmm::device_uvector d_subgraph_offsets(2, handle.get_stream()); + std::vector h_subgraph_offsets({0, vertices.size()}); + + raft::update_device(d_subgraph_offsets.data(), + h_subgraph_offsets.data(), + h_subgraph_offsets.size(), + handle.get_stream()); + + rmm::device_uvector src_compare(0, handle.get_stream()); + rmm::device_uvector dst_compare(0, handle.get_stream()); + std::optional> wgt_compare{std::nullopt}; + + std::tie(src_compare, dst_compare, wgt_compare, std::ignore) = extract_induced_subgraphs( + handle, + graph_view, + edge_weight_view, + raft::device_span(d_subgraph_offsets.data(), 2), + raft::device_span(vertices.data(), vertices.size()), + true); + + + + + ASSERT_TRUE(cugraph::test::validate_extracted_graph_is_subgraph( + handle, src_compare, dst_compare, wgt_compare, src_out, dst_out, wgt_out)); + + if (random_sources.size() < 100) { + // This validation is too expensive for large number of vertices + ASSERT_TRUE( + cugraph::test::validate_sampling_depth(handle, + std::move(src_out), + std::move(dst_out), + std::move(wgt_out), + std::move(random_sources), + heterogeneous_uniform_neighbor_sampling_usecase.fanout.size())); + } + } +#endif + } +}; + +using Tests_Heterogeneous_Uniform_Neighbor_Sampling_File = + Tests_Heterogeneous_Uniform_Neighbor_Sampling; +//#if 0 +using Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat = + Tests_Heterogeneous_Uniform_Neighbor_Sampling; +//#endif + +TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +#if 0 +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); +#endif + +//#if 0 +TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int64Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt64Int64Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, + ::testing::Combine( + ::testing::Values(Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8}, 128, 2, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, + ::testing::Combine( + ::testing::Values(Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, false, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, false, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, true, false}, + Heterogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10, 7, 8, 1, 9, 5, 12}, 1024, 4, true, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false, 0)))); +//#endif + +CUGRAPH_TEST_PROGRAM_MAIN()