diff --git a/cpp/tests/prims/mg_transform_e.cu b/cpp/tests/prims/mg_transform_e.cu index 127eddd43c7..175113a3858 100644 --- a/cpp/tests/prims/mg_transform_e.cu +++ b/cpp/tests/prims/mg_transform_e.cu @@ -51,6 +51,7 @@ #include struct Prims_Usecase { + bool use_edgelist{false}; bool check_correctness{true}; }; @@ -113,8 +114,9 @@ class Tests_MGTransformE auto mg_dst_prop = cugraph::test::generate::dst_property( *handle_, mg_graph_view, mg_vertex_prop); - cugraph::edge_bucket_t edge_list(*handle_); - { + cugraph::edge_bucket_t edge_list( + *handle_); + if (prims_usecase.use_edgelist) { rmm::device_uvector srcs(0, handle_->get_stream()); rmm::device_uvector dsts(0, handle_->get_stream()); std::tie(srcs, dsts, std::ignore, std::ignore) = cugraph::decompress_to_edgelist( @@ -154,24 +156,41 @@ class Tests_MGTransformE if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement handle_->get_comms().barrier(); - hr_timer.start("MG transform_reduce_e"); + hr_timer.start("MG transform_e"); } - cugraph::transform_e( - *handle_, - mg_graph_view, - edge_list, - mg_src_prop.view(), - mg_dst_prop.view(), - cugraph::edge_dummy_property_t{}.view(), - [] __device__(auto src, auto dst, auto src_property, auto dst_property, thrust::nullopt_t) { - if (src_property < dst_property) { - return src_property; - } else { - return dst_property; - } - }, - edge_value_output.mutable_view()); + if (prims_usecase.use_edgelist) { + cugraph::transform_e( + *handle_, + mg_graph_view, + edge_list, + mg_src_prop.view(), + mg_dst_prop.view(), + cugraph::edge_dummy_property_t{}.view(), + [] __device__(auto src, auto dst, auto src_property, auto dst_property, thrust::nullopt_t) { + if (src_property < dst_property) { + return src_property; + } else { + return dst_property; + } + }, + edge_value_output.mutable_view()); + } else { + cugraph::transform_e( + *handle_, + mg_graph_view, + mg_src_prop.view(), + mg_dst_prop.view(), + cugraph::edge_dummy_property_t{}.view(), + [] __device__(auto src, auto dst, auto src_property, auto dst_property, thrust::nullopt_t) { + if (src_property < dst_property) { + return src_property; + } else { + return dst_property; + } + }, + edge_value_output.mutable_view()); + } if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement @@ -183,24 +202,42 @@ class Tests_MGTransformE // 3. validate MG results if (prims_usecase.check_correctness) { - auto num_invalids = cugraph::count_if_e( - *handle_, - mg_graph_view, - mg_src_prop.view(), - mg_dst_prop.view(), - edge_value_output.view(), - [property_initial_value] __device__( - auto src, auto dst, auto src_property, auto dst_property, auto edge_property) { - if (((src + dst) % 2) == 0) { + size_t num_invalids{}; + if (prims_usecase.use_edgelist) { + num_invalids = cugraph::count_if_e( + *handle_, + mg_graph_view, + mg_src_prop.view(), + mg_dst_prop.view(), + edge_value_output.view(), + [property_initial_value] __device__( + auto src, auto dst, auto src_property, auto dst_property, auto edge_property) { + if (((src + dst) % 2) == 0) { + if (src_property < dst_property) { + return edge_property != src_property; + } else { + return edge_property != dst_property; + } + } else { + return edge_property != property_initial_value; + } + }); + } else { + num_invalids = cugraph::count_if_e( + *handle_, + mg_graph_view, + mg_src_prop.view(), + mg_dst_prop.view(), + edge_value_output.view(), + [property_initial_value] __device__( + auto src, auto dst, auto src_property, auto dst_property, auto edge_property) { if (src_property < dst_property) { return edge_property != src_property; } else { return edge_property != dst_property; } - } else { - return edge_property != property_initial_value; - } - }); + }); + } ASSERT_TRUE(num_invalids == 0); } @@ -342,7 +379,7 @@ INSTANTIATE_TEST_SUITE_P( file_test, Tests_MGTransformE_File, ::testing::Combine( - ::testing::Values(Prims_Usecase{true}), + ::testing::Values(Prims_Usecase{false, true}, Prims_Usecase{true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), @@ -350,7 +387,8 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(rmat_small_test, Tests_MGTransformE_Rmat, - ::testing::Combine(::testing::Values(Prims_Usecase{true}), + ::testing::Combine(::testing::Values(Prims_Usecase{false, true}, + Prims_Usecase{true, true}), ::testing::Values(cugraph::test::Rmat_Usecase( 10, 16, 0.57, 0.19, 0.19, 0, false, false)))); @@ -362,7 +400,7 @@ INSTANTIATE_TEST_SUITE_P( factor (to avoid running same benchmarks more than once) */ Tests_MGTransformE_Rmat, ::testing::Combine( - ::testing::Values(Prims_Usecase{false}), + ::testing::Values(Prims_Usecase{false, false}, Prims_Usecase{true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_MG_TEST_PROGRAM_MAIN()