diff --git a/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh b/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh index 01b9ceec176..1557b378bd9 100644 --- a/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh +++ b/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh @@ -613,24 +613,11 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, } } - std::vector local_frontier_sizes{}; - if constexpr (GraphViewType::is_multi_gpu) { - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - local_frontier_sizes = host_scalar_allgather( - minor_comm, - static_cast(thrust::distance(frontier_key_first, frontier_key_last)), - handle.get_stream()); - } else { - local_frontier_sizes = std::vector{static_cast( - static_cast(thrust::distance(frontier_key_first, frontier_key_last)))}; - } - // update frontier bitmap (used to reduce broadcast bandwidth size) std:: conditional_t>, std::byte /* dummy */> frontier_bitmap{}; - std::conditional_t, std::byte /* dummy */> use_bitmap_flags{}; if constexpr (try_bitmap) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); auto const minor_comm_size = minor_comm.get_size(); @@ -648,12 +635,6 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, graph_view.local_vertex_partition_range_first() + bool_size, handle.get_stream()); } - auto tmp_flags = host_scalar_allgather( - minor_comm, frontier_bitmap ? uint8_t{1} : uint8_t{0}, handle.get_stream()); - use_bitmap_flags.resize(tmp_flags.size()); - std::transform(tmp_flags.begin(), tmp_flags.end(), use_bitmap_flags.begin(), [](auto flag) { - return flag == uint8_t{1}; - }); } // compute max_pushes @@ -679,6 +660,29 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, frontier_major_first, frontier_major_last, handle.get_stream()); } + // communication over minor_comm + + std::vector local_frontier_sizes{}; + std::conditional_t, std::byte /* dummy */> use_bitmap_flags{}; + if constexpr (GraphViewType::is_multi_gpu) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + local_frontier_sizes = host_scalar_allgather( + minor_comm, + static_cast(thrust::distance(frontier_key_first, frontier_key_last)), + handle.get_stream()); + if constexpr (try_bitmap) { + auto tmp_flags = host_scalar_allgather( + minor_comm, frontier_bitmap ? uint8_t{1} : uint8_t{0}, handle.get_stream()); + use_bitmap_flags.resize(tmp_flags.size()); + std::transform(tmp_flags.begin(), tmp_flags.end(), use_bitmap_flags.begin(), [](auto flag) { + return flag == uint8_t{1}; + }); + } + } else { + local_frontier_sizes = std::vector{static_cast( + static_cast(thrust::distance(frontier_key_first, frontier_key_last)))}; + } + // set-up stream ppol std::optional> stream_pool_indices{std::nullopt};