diff --git a/cpp/tests/centrality/mg_betweenness_centrality_test.cpp b/cpp/tests/centrality/mg_betweenness_centrality_test.cpp index 8c28d493c65..7924d449897 100644 --- a/cpp/tests/centrality/mg_betweenness_centrality_test.cpp +++ b/cpp/tests/centrality/mg_betweenness_centrality_test.cpp @@ -17,6 +17,7 @@ #include "utilities/base_fixture.hpp" #include "utilities/conversion_utilities.hpp" #include "utilities/device_comm_wrapper.hpp" +#include "utilities/property_generator_utilities.hpp" #include "utilities/test_graphs.hpp" #include "utilities/thrust_wrapper.hpp" @@ -39,6 +40,8 @@ struct BetweennessCentrality_Usecase { bool normalized{false}; bool include_endpoints{false}; bool test_weighted{false}; + + bool edge_masking{false}; bool check_correctness{true}; }; @@ -82,6 +85,13 @@ class Tests_MGBetweennessCentrality auto mg_edge_weight_view = mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + std::optional> edge_mask{std::nullopt}; + if (betweenness_usecase.edge_masking) { + edge_mask = cugraph::test::generate::edge_property( + *handle_, mg_graph_view, 2); + mg_graph_view.attach_edge_mask((*edge_mask).view()); + } + raft::random::RngState rng_state(handle_->get_comms().get_rank()); auto d_mg_seeds = cugraph::select_random_vertices( *handle_, @@ -210,9 +220,13 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGBetweennessCentrality_File, ::testing::Combine( // enable correctness checks - ::testing::Values(BetweennessCentrality_Usecase{20, false, false, false, true}, + ::testing::Values(BetweennessCentrality_Usecase{20, false, false, false, false}, + BetweennessCentrality_Usecase{20, false, false, false, true}, + BetweennessCentrality_Usecase{20, false, false, true, false}, BetweennessCentrality_Usecase{20, false, false, true, true}, - BetweennessCentrality_Usecase{20, false, true, true, true}, + BetweennessCentrality_Usecase{20, false, true, false, false}, + BetweennessCentrality_Usecase{20, false, true, false, true}, + BetweennessCentrality_Usecase{20, false, true, true, false}, BetweennessCentrality_Usecase{20, false, true, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), @@ -225,7 +239,21 @@ INSTANTIATE_TEST_SUITE_P( // disable correctness checks, running out of memory ::testing::Combine( ::testing::Values(BetweennessCentrality_Usecase{50, false, false, false, false}, - BetweennessCentrality_Usecase{50, false, false, true, false}), + BetweennessCentrality_Usecase{50, false, false, false, true}, + BetweennessCentrality_Usecase{50, false, false, true, false}, + BetweennessCentrality_Usecase{50, false, false, true, true}, + BetweennessCentrality_Usecase{50, false, true, false, false}, + BetweennessCentrality_Usecase{50, false, true, false, true}, + BetweennessCentrality_Usecase{50, false, true, true, false}, + BetweennessCentrality_Usecase{50, false, true, true, true}, + BetweennessCentrality_Usecase{50, true, false, false, false}, + BetweennessCentrality_Usecase{50, true, false, false, true}, + BetweennessCentrality_Usecase{50, true, false, true, false}, + BetweennessCentrality_Usecase{50, true, false, true, true}, + BetweennessCentrality_Usecase{50, true, true, false, false}, + BetweennessCentrality_Usecase{50, true, true, false, true}, + BetweennessCentrality_Usecase{50, true, true, true, false}, + BetweennessCentrality_Usecase{50, true, true, true, true}), ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, true, false)))); INSTANTIATE_TEST_SUITE_P( @@ -237,8 +265,10 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGBetweennessCentrality_Rmat, // disable correctness checks for large graphs ::testing::Combine( - ::testing::Values(BetweennessCentrality_Usecase{500, false, false, false, false}, - BetweennessCentrality_Usecase{500, false, false, true, false}), + ::testing::Values(BetweennessCentrality_Usecase{500, false, false, false, false, false}, + BetweennessCentrality_Usecase{500, false, false, false, true, false}, + BetweennessCentrality_Usecase{500, false, false, true, false, false}, + BetweennessCentrality_Usecase{500, false, false, true, true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_MG_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp b/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp index 7b3613f8210..c3417e96c03 100644 --- a/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp +++ b/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp @@ -17,6 +17,7 @@ #include "utilities/base_fixture.hpp" #include "utilities/conversion_utilities.hpp" #include "utilities/device_comm_wrapper.hpp" +#include "utilities/property_generator_utilities.hpp" #include "utilities/test_graphs.hpp" #include "utilities/thrust_wrapper.hpp" @@ -38,6 +39,8 @@ struct EdgeBetweennessCentrality_Usecase { size_t num_seeds{std::numeric_limits::max()}; bool normalized{false}; bool test_weighted{false}; + + bool edge_masking{false}; bool check_correctness{true}; }; @@ -84,6 +87,13 @@ class Tests_MGEdgeBetweennessCentrality auto mg_edge_weight_view = mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + std::optional> edge_mask{std::nullopt}; + if (betweenness_usecase.edge_masking) { + edge_mask = cugraph::test::generate::edge_property( + *handle_, mg_graph_view, 2); + mg_graph_view.attach_edge_mask((*edge_mask).view()); + } + raft::random::RngState rng_state(handle_->get_comms().get_rank()); auto d_mg_seeds = cugraph::select_random_vertices( *handle_, @@ -116,6 +126,9 @@ class Tests_MGEdgeBetweennessCentrality } if (betweenness_usecase.check_correctness) { + auto d_mg_aggregate_seeds = cugraph::test::device_gatherv( + *handle_, raft::device_span{d_mg_seeds.data(), d_mg_seeds.size()}); + // Extract MG results auto [d_cugraph_src_vertex_ids, d_cugraph_dst_vertex_ids, d_cugraph_results] = cugraph::test::graph_to_device_coo( @@ -136,9 +149,6 @@ class Tests_MGEdgeBetweennessCentrality std::optional>{std::nullopt}, false); - auto d_mg_aggregate_seeds = cugraph::test::device_gatherv( - *handle_, raft::device_span{d_mg_seeds.data(), d_mg_seeds.size()}); - if (handle_->get_comms().get_rank() == 0) { auto sg_edge_weights_view = sg_edge_weights ? std::make_optional(sg_edge_weights->view()) : std::nullopt; @@ -214,7 +224,9 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGEdgeBetweennessCentrality_File, ::testing::Combine( // enable correctness checks - ::testing::Values(EdgeBetweennessCentrality_Usecase{20, false, false, true}, + ::testing::Values(EdgeBetweennessCentrality_Usecase{20, false, false, false}, + EdgeBetweennessCentrality_Usecase{20, false, false, true}, + EdgeBetweennessCentrality_Usecase{20, false, true, false}, EdgeBetweennessCentrality_Usecase{20, false, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), @@ -226,8 +238,14 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGEdgeBetweennessCentrality_Rmat, // enable correctness checks ::testing::Combine( - ::testing::Values(EdgeBetweennessCentrality_Usecase{50, false, false, true}, - EdgeBetweennessCentrality_Usecase{50, false, true, true}), + ::testing::Values(EdgeBetweennessCentrality_Usecase{50, false, false, false}, + EdgeBetweennessCentrality_Usecase{50, false, false, true}, + EdgeBetweennessCentrality_Usecase{50, false, true, false}, + EdgeBetweennessCentrality_Usecase{50, false, true, true}, + EdgeBetweennessCentrality_Usecase{50, true, false, false}, + EdgeBetweennessCentrality_Usecase{50, true, false, true}, + EdgeBetweennessCentrality_Usecase{50, true, true, false}, + EdgeBetweennessCentrality_Usecase{50, true, true, true}), ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, true, false)))); INSTANTIATE_TEST_SUITE_P( @@ -239,8 +257,10 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGEdgeBetweennessCentrality_Rmat, // disable correctness checks for large graphs ::testing::Combine( - ::testing::Values(EdgeBetweennessCentrality_Usecase{500, false, false, false}, - EdgeBetweennessCentrality_Usecase{500, false, true, false}), + ::testing::Values(EdgeBetweennessCentrality_Usecase{500, false, false, false, false}, + EdgeBetweennessCentrality_Usecase{500, false, false, true, false}, + EdgeBetweennessCentrality_Usecase{500, false, true, false, false}, + EdgeBetweennessCentrality_Usecase{500, false, true, true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_MG_TEST_PROGRAM_MAIN()