Skip to content

Commit

Permalink
update transform_reduce_v_frontier_outgoing_e_by_dst to support edge …
Browse files Browse the repository at this point in the history
…masking
  • Loading branch information
seunghwak committed Feb 7, 2024
1 parent 217ef47 commit 6a44b55
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
41 changes: 31 additions & 10 deletions cpp/src/prims/transform_reduce_v_frontier_outgoing_e_by_dst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ size_t compute_num_out_nbrs_from_frontier(raft::handle_t const& handle,
using edge_t = typename GraphViewType::edge_type;
using key_t = typename VertexFrontierBucketType::key_type;

CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented.");

size_t ret{0};

vertex_t const* local_frontier_vertex_first{nullptr};
Expand All @@ -207,10 +205,19 @@ size_t compute_num_out_nbrs_from_frontier(raft::handle_t const& handle,
} else {
local_frontier_sizes = std::vector<size_t>{static_cast<size_t>(frontier.size())};
}

auto edge_mask_view = graph_view.edge_mask_view();

for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) {
auto edge_partition =
edge_partition_device_view_t<vertex_t, edge_t, GraphViewType::is_multi_gpu>(
graph_view.local_edge_partition_view(i));
auto edge_partition_e_mask =
edge_mask_view
? thrust::make_optional<
detail::edge_partition_edge_property_device_view_t<edge_t, uint32_t const*, bool>>(
*edge_mask_view, i)
: thrust::nullopt;

if constexpr (GraphViewType::is_multi_gpu) {
auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name());
Expand All @@ -225,14 +232,30 @@ size_t compute_num_out_nbrs_from_frontier(raft::handle_t const& handle,
static_cast<int>(i),
handle.get_stream());

ret += edge_partition.compute_number_of_edges(edge_partition_frontier_vertices.begin(),
edge_partition_frontier_vertices.end(),
handle.get_stream());
if (edge_partition_e_mask) {
ret +=
edge_partition.compute_number_of_edges_with_mask((*edge_partition_e_mask).value_first(),
edge_partition_frontier_vertices.begin(),
edge_partition_frontier_vertices.end(),
handle.get_stream());
} else {
ret += edge_partition.compute_number_of_edges(edge_partition_frontier_vertices.begin(),
edge_partition_frontier_vertices.end(),
handle.get_stream());
}
} else {
assert(i == 0);
ret += edge_partition.compute_number_of_edges(local_frontier_vertex_first,
local_frontier_vertex_first + frontier.size(),
handle.get_stream());
if (edge_partition_e_mask) {
ret += edge_partition.compute_number_of_edges_with_mask(
(*edge_partition_e_mask).value_first(),
local_frontier_vertex_first,
local_frontier_vertex_first + frontier.size(),
handle.get_stream());
} else {
ret += edge_partition.compute_number_of_edges(local_frontier_vertex_first,
local_frontier_vertex_first + frontier.size(),
handle.get_stream());
}
}
}

Expand Down Expand Up @@ -323,8 +346,6 @@ transform_reduce_v_frontier_outgoing_e_by_dst(raft::handle_t const& handle,
using key_t = typename VertexFrontierBucketType::key_type;
using payload_t = typename ReduceOp::value_type;

CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented.");

if (do_expensive_check) {
// currently, nothing to do
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ struct e_op_t {
};

struct Prims_Usecase {
bool edge_masking{false};
bool check_correctness{true};
};

Expand Down Expand Up @@ -152,6 +153,13 @@ class Tests_MGTransformReduceVFrontierOutgoingEByDst

auto mg_graph_view = mg_graph.view();

std::optional<cugraph::edge_property_t<decltype(mg_graph_view), bool>> edge_mask{std::nullopt};
if (prims_usecase.edge_masking) {
edge_mask =
cugraph::test::generate<vertex_t, bool>::edge_property(*handle_, mg_graph_view, 2);
mg_graph_view.attach_edge_mask((*edge_mask).view());
}

// 2. run MG transform reduce

const int hash_bin_count = 5;
Expand Down Expand Up @@ -533,15 +541,16 @@ INSTANTIATE_TEST_SUITE_P(
file_test,
Tests_MGTransformReduceVFrontierOutgoingEByDst_File,
::testing::Combine(
::testing::Values(Prims_Usecase{true}),
::testing::Values(Prims_Usecase{false, true}, Prims_Usecase{true, true}),
::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"),
cugraph::test::File_Usecase("test/datasets/web-Google.mtx"),
cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"),
cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))));

INSTANTIATE_TEST_SUITE_P(rmat_small_test,
Tests_MGTransformReduceVFrontierOutgoingEByDst_Rmat,
::testing::Combine(::testing::Values(Prims_Usecase{true}),
::testing::Combine(::testing::Values(Prims_Usecase{false, true},
Prims_Usecase{true, true}),
::testing::Values(cugraph::test::Rmat_Usecase(
10, 16, 0.57, 0.19, 0.19, 0, false, false))));

Expand All @@ -553,7 +562,7 @@ INSTANTIATE_TEST_SUITE_P(
factor (to avoid running same benchmarks more than once) */
Tests_MGTransformReduceVFrontierOutgoingEByDst_Rmat,
::testing::Combine(
::testing::Values(Prims_Usecase{false}),
::testing::Values(Prims_Usecase{false, false}, Prims_Usecase{true, false}),
::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false))));

CUGRAPH_MG_TEST_PROGRAM_MAIN()

0 comments on commit 6a44b55

Please sign in to comment.