From 60fd404c65282a7b56a4c3c31643fc91a5b94de3 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Wed, 7 Feb 2024 15:43:33 -0800 Subject: [PATCH] update extract_transform_e to support edge masking --- cpp/src/prims/extract_transform_e.cuh | 2 -- cpp/tests/prims/mg_extract_transform_e.cu | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/cpp/src/prims/extract_transform_e.cuh b/cpp/src/prims/extract_transform_e.cuh index f135b76d6e3..fcd5e4c1483 100644 --- a/cpp/src/prims/extract_transform_e.cuh +++ b/cpp/src/prims/extract_transform_e.cuh @@ -107,8 +107,6 @@ extract_transform_e(raft::handle_t const& handle, static_assert(!std::is_same_v); using payload_t = typename e_op_result_t::value_type; - CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); - // FIXME: Consider updating detail::extract_transform_v_forntier_e to take std::nullopt to as a // frontier or create a new key bucket type that just stores [vertex_first, vertex_last) for // further optimization. Better revisit this once this becomes a performance bottleneck and after diff --git a/cpp/tests/prims/mg_extract_transform_e.cu b/cpp/tests/prims/mg_extract_transform_e.cu index bca6471a5bb..29ff25ea8bd 100644 --- a/cpp/tests/prims/mg_extract_transform_e.cu +++ b/cpp/tests/prims/mg_extract_transform_e.cu @@ -116,6 +116,7 @@ struct e_op_t { }; struct Prims_Usecase { + bool edge_masking{false}; bool check_correctness{true}; }; @@ -180,6 +181,13 @@ class Tests_MGExtractTransformE auto mg_graph_view = mg_graph.view(); + std::optional> edge_mask{std::nullopt}; + if (prims_usecase.edge_masking) { + edge_mask = + cugraph::test::generate::edge_property(*handle_, mg_graph_view, 2); + mg_graph_view.attach_edge_mask((*edge_mask).view()); + } + // 2. run MG extract_transform_e const int hash_bin_count = 5; @@ -400,7 +408,7 @@ INSTANTIATE_TEST_SUITE_P( file_test, Tests_MGExtractTransformE_File, ::testing::Combine( - ::testing::Values(Prims_Usecase{true}), + ::testing::Values(Prims_Usecase{false, true}, Prims_Usecase{true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), @@ -408,7 +416,8 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(rmat_small_test, Tests_MGExtractTransformE_Rmat, - ::testing::Combine(::testing::Values(Prims_Usecase{true}), + ::testing::Combine(::testing::Values(Prims_Usecase{false, true}, + Prims_Usecase{true, true}), ::testing::Values(cugraph::test::Rmat_Usecase( 10, 16, 0.57, 0.19, 0.19, 0, false, false)))); @@ -420,7 +429,7 @@ INSTANTIATE_TEST_SUITE_P( factor (to avoid running same benchmarks more than once) */ Tests_MGExtractTransformE_Rmat, ::testing::Combine( - ::testing::Values(Prims_Usecase{false}), + ::testing::Values(Prims_Usecase{false, false}, Prims_Usecase{true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_MG_TEST_PROGRAM_MAIN()