diff --git a/cpp/src/prims/fill_edge_src_dst_property.cuh b/cpp/src/prims/fill_edge_src_dst_property.cuh index 7155ce23dbd..9f561b57ff7 100644 --- a/cpp/src/prims/fill_edge_src_dst_property.cuh +++ b/cpp/src/prims/fill_edge_src_dst_property.cuh @@ -15,6 +15,8 @@ */ #pragma once +#include "prims/vertex_frontier.cuh" + #include #include #include @@ -153,12 +155,12 @@ void fill_edge_major_property(raft::handle_t const& handle, auto const minor_comm_rank = minor_comm.get_rank(); auto const minor_comm_size = minor_comm.get_size(); - auto rx_counts = host_scalar_allgather( + auto local_v_list_sizes = host_scalar_allgather( minor_comm, static_cast(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)), handle.get_stream()); - auto max_rx_size = - std::reduce(rx_counts.begin(), rx_counts.end(), size_t{0}, [](auto lhs, auto rhs) { + auto max_rx_size = std::reduce( + local_v_list_sizes.begin(), local_v_list_sizes.end(), size_t{0}, [](auto lhs, auto rhs) { return std::max(lhs, rhs); }); rmm::device_uvector rx_vertices(max_rx_size, handle.get_stream()); @@ -172,7 +174,7 @@ void fill_edge_major_property(raft::handle_t const& handle, device_bcast(minor_comm, sorted_unique_vertex_first, rx_vertices.begin(), - rx_counts[i], + local_v_list_sizes[i], i, handle.get_stream()); @@ -180,7 +182,7 @@ void fill_edge_major_property(raft::handle_t const& handle, thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(rx_counts[i]), + thrust::make_counting_iterator(local_v_list_sizes[i]), [rx_vertex_first = rx_vertices.begin(), input, edge_partition_key_first = ((*edge_partition_keys)[i]).begin(), @@ -203,7 +205,7 @@ void fill_edge_major_property(raft::handle_t const& handle, thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(vertex_t{0}), - thrust::make_counting_iterator(static_cast(rx_counts[i])), + thrust::make_counting_iterator(static_cast(local_v_list_sizes[i])), [edge_partition, rx_vertex_first = rx_vertices.begin(), input, @@ -223,7 +225,7 @@ void fill_edge_major_property(raft::handle_t const& handle, // directly scatters from the internal buffer) thrust::scatter(handle.get_thrust_policy(), val_first, - val_first + rx_counts[i], + val_first + local_v_list_sizes[i], map_first, edge_partition_value_firsts[i]); } @@ -312,15 +314,41 @@ void fill_edge_minor_property(raft::handle_t const& handle, auto const major_comm_rank = major_comm.get_rank(); auto const major_comm_size = major_comm.get_size(); - auto rx_counts = host_scalar_allgather( - major_comm, - static_cast(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)), - handle.get_stream()); - auto max_rx_size = - std::reduce(rx_counts.begin(), rx_counts.end(), size_t{0}, [](auto lhs, auto rhs) { - return std::max(lhs, rhs); + auto v_list_size = + static_cast(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)); + std::array v_list_range = {vertex_t{0}, vertex_t{0}}; + if (v_list_size > 0) { + rmm::device_uvector tmps(2, handle.get_stream()); + thrust::tabulate(handle.get_thrust_policy(), + tmps.begin(), + tmps.end(), + [sorted_unique_vertex_first, v_list_size] __device__(size_t i) { + return (i == 0) ? *sorted_unique_vertex_first + : (*(sorted_unique_vertex_first + (v_list_size - 1)) + 1); + }); + raft::update_host(v_list_range.data(), tmps.data(), 2, handle.get_stream()); + handle.sync_stream(); + } + + auto v_list_bitmap = compute_vertex_list_bitmap_info(sorted_unique_vertex_first, + sorted_unique_vertex_last, + v_list_range[0], + v_list_range[1], + handle.get_stream()); + + std::vector use_bitmap_flags(major_comm_size, false); + { + auto tmp_flags = host_scalar_allgather( + major_comm, v_list_bitmap ? uint8_t{1} : uint8_t{0}, handle.get_stream()); + std::transform(tmp_flags.begin(), tmp_flags.end(), use_bitmap_flags.begin(), [](auto flag) { + return flag == uint8_t{1}; }); - rmm::device_uvector rx_vertices(max_rx_size, handle.get_stream()); + } + auto local_v_list_sizes = host_scalar_allgather(major_comm, v_list_size, handle.get_stream()); + auto local_v_list_range_firsts = + host_scalar_allgather(major_comm, v_list_range[0], handle.get_stream()); + auto local_v_list_range_lasts = + host_scalar_allgather(major_comm, v_list_range[1], handle.get_stream()); std::optional> key_offsets{}; if constexpr (GraphViewType::is_storage_transposed) { @@ -334,21 +362,33 @@ void fill_edge_minor_property(raft::handle_t const& handle, graph_view.local_edge_partition_view(size_t{0})); auto edge_partition_keys = edge_minor_property_output.keys(); for (int i = 0; i < major_comm_size; ++i) { - // FIXME: we can optionally use bitmap for this broadcast + rmm::device_uvector rx_vertices(local_v_list_sizes[i], handle.get_stream()); // FIXME: these broadcast operations can be placed between ncclGroupStart() and // ncclGroupEnd() - device_bcast(major_comm, - sorted_unique_vertex_first, - rx_vertices.begin(), - rx_counts[i], - i, - handle.get_stream()); + std::variant, decltype(sorted_unique_vertex_first)> + v_list{}; + if (use_bitmap_flags[i]) { + v_list = + (i == major_comm_rank) + ? raft::device_span((*v_list_bitmap).data(), (*v_list_bitmap).size()) + : raft::device_span(static_cast(nullptr), size_t{0}); + } else { + v_list = sorted_unique_vertex_first; + } + device_bcast_vertex_list(major_comm, + v_list, + rx_vertices.begin(), + local_v_list_range_firsts[i], + local_v_list_range_lasts[i], + local_v_list_sizes[i], + i, + handle.get_stream()); if (edge_partition_keys) { thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(rx_counts[i]), + thrust::make_counting_iterator(local_v_list_sizes[i]), [rx_vertex_first = rx_vertices.begin(), input, subrange_key_first = (*edge_partition_keys).begin() + (*key_offsets)[i], @@ -370,18 +410,18 @@ void fill_edge_minor_property(raft::handle_t const& handle, }); } else { if constexpr (contains_packed_bool_element) { - thrust::for_each(handle.get_thrust_policy(), - thrust::make_counting_iterator(vertex_t{0}), - thrust::make_counting_iterator(static_cast(rx_counts[i])), - [edge_partition, - rx_vertex_first = rx_vertices.begin(), - input, - output_value_first = edge_partition_value_first] __device__(auto i) { - auto rx_vertex = *(rx_vertex_first + i); - auto minor_offset = - edge_partition.minor_offset_from_minor_nocheck(rx_vertex); - fill_scalar_or_thrust_tuple(output_value_first, minor_offset, input); - }); + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(vertex_t{0}), + thrust::make_counting_iterator(static_cast(local_v_list_sizes[i])), + [edge_partition, + rx_vertex_first = rx_vertices.begin(), + input, + output_value_first = edge_partition_value_first] __device__(auto i) { + auto rx_vertex = *(rx_vertex_first + i); + auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(rx_vertex); + fill_scalar_or_thrust_tuple(output_value_first, minor_offset, input); + }); } else { auto map_first = thrust::make_transform_iterator( rx_vertices.begin(), @@ -393,7 +433,7 @@ void fill_edge_minor_property(raft::handle_t const& handle, // directly scatters from the internal buffer) thrust::scatter(handle.get_thrust_policy(), val_first, - val_first + rx_counts[i], + val_first + local_v_list_sizes[i], map_first, edge_partition_value_first); } diff --git a/cpp/src/prims/update_edge_src_dst_property.cuh b/cpp/src/prims/update_edge_src_dst_property.cuh index 392e12420ad..2408dcb3d68 100644 --- a/cpp/src/prims/update_edge_src_dst_property.cuh +++ b/cpp/src/prims/update_edge_src_dst_property.cuh @@ -16,6 +16,7 @@ #pragma once #include "detail/graph_partition_utils.cuh" +#include "prims/vertex_frontier.cuh" #include #include @@ -288,12 +289,12 @@ void update_edge_major_property(raft::handle_t const& handle, auto const minor_comm_rank = minor_comm.get_rank(); auto const minor_comm_size = minor_comm.get_size(); - auto rx_counts = host_scalar_allgather( + auto local_v_list_sizes = host_scalar_allgather( minor_comm, static_cast(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)), handle.get_stream()); - auto max_rx_size = - std::reduce(rx_counts.begin(), rx_counts.end(), size_t{0}, [](auto lhs, auto rhs) { + auto max_rx_size = std::reduce( + local_v_list_sizes.begin(), local_v_list_sizes.end(), size_t{0}, [](auto lhs, auto rhs) { return std::max(lhs, rhs); }); rmm::device_uvector rx_vertices(max_rx_size, handle.get_stream()); @@ -352,13 +353,14 @@ void update_edge_major_property(raft::handle_t const& handle, device_bcast(minor_comm, sorted_unique_vertex_first, rx_vertices.begin(), - rx_counts[i], + local_v_list_sizes[i], i, handle.get_stream()); device_bcast(minor_comm, rx_value_first, rx_value_first, - contains_packed_bool_element ? packed_bool_size(rx_counts[i]) : rx_counts[i], + contains_packed_bool_element ? packed_bool_size(local_v_list_sizes[i]) + : local_v_list_sizes[i], i, handle.get_stream()); @@ -366,7 +368,7 @@ void update_edge_major_property(raft::handle_t const& handle, thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(rx_counts[i]), + thrust::make_counting_iterator(local_v_list_sizes[i]), [rx_vertex_first = rx_vertices.begin(), rx_value_first, edge_partition_key_first = ((*edge_partition_keys)[i]).begin(), @@ -392,7 +394,7 @@ void update_edge_major_property(raft::handle_t const& handle, thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(vertex_t{0}), - thrust::make_counting_iterator(static_cast(rx_counts[i])), + thrust::make_counting_iterator(static_cast(local_v_list_sizes[i])), [edge_partition, rx_vertex_first = rx_vertices.begin(), rx_value_first, @@ -413,7 +415,7 @@ void update_edge_major_property(raft::handle_t const& handle, // directly scatters from the internal buffer) thrust::scatter(handle.get_thrust_policy(), rx_value_first, - rx_value_first + rx_counts[i], + rx_value_first + local_v_list_sizes[i], map_first, edge_partition_value_firsts[i]); } @@ -463,13 +465,11 @@ void update_edge_minor_property(raft::handle_t const& handle, auto edge_partition_value_first = edge_minor_property_output.value_first(); if constexpr (GraphViewType::is_multi_gpu) { - using vertex_t = typename GraphViewType::vertex_type; - using bcast_buffer_type = - decltype(allocate_dataframe_buffer< - std::conditional_t>( - size_t{0}, handle.get_stream())); + using vertex_t = typename GraphViewType::vertex_type; + using bcast_buffer_type = dataframe_buffer_type_t< + std::conditional_t>; auto& comm = handle.get_comms(); auto const comm_rank = comm.get_rank(); @@ -540,15 +540,17 @@ void update_edge_minor_property(raft::handle_t const& handle, *(graph_view.local_sorted_unique_edge_dst_vertex_partition_offsets()); } } else { - std::vector rx_counts(major_comm_size, size_t{0}); + std::vector local_v_list_sizes(major_comm_size, size_t{0}); for (int i = 0; i < major_comm_size; ++i) { auto minor_range_vertex_partition_id = compute_local_edge_partition_minor_range_vertex_partition_id_t{ major_comm_size, minor_comm_size, major_comm_rank, minor_comm_rank}(i); - rx_counts[i] = graph_view.vertex_partition_range_size(minor_range_vertex_partition_id); + local_v_list_sizes[i] = + graph_view.vertex_partition_range_size(minor_range_vertex_partition_id); } std::vector rx_displacements(major_comm_size, size_t{0}); - std::exclusive_scan(rx_counts.begin(), rx_counts.end(), rx_displacements.begin(), size_t{0}); + std::exclusive_scan( + local_v_list_sizes.begin(), local_v_list_sizes.end(), rx_displacements.begin(), size_t{0}); key_offsets_or_rx_displacements = std::move(rx_displacements); } @@ -714,22 +716,42 @@ void update_edge_minor_property(raft::handle_t const& handle, auto const major_comm_rank = major_comm.get_rank(); auto const major_comm_size = major_comm.get_size(); - auto rx_counts = host_scalar_allgather( - major_comm, - static_cast(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)), - handle.get_stream()); - auto max_rx_size = - std::reduce(rx_counts.begin(), rx_counts.end(), size_t{0}, [](auto lhs, auto rhs) { - return std::max(lhs, rhs); + auto v_list_size = + static_cast(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)); + std::array v_list_range = {vertex_t{0}, vertex_t{0}}; + if (v_list_size > 0) { + rmm::device_uvector tmps(2, handle.get_stream()); + thrust::tabulate(handle.get_thrust_policy(), + tmps.begin(), + tmps.end(), + [sorted_unique_vertex_first, v_list_size] __device__(size_t i) { + return (i == 0) ? *sorted_unique_vertex_first + : (*(sorted_unique_vertex_first + (v_list_size - 1)) + 1); + }); + raft::update_host(v_list_range.data(), tmps.data(), 2, handle.get_stream()); + handle.sync_stream(); + } + + auto v_list_bitmap = compute_vertex_list_bitmap_info(sorted_unique_vertex_first, + sorted_unique_vertex_last, + v_list_range[0], + v_list_range[1], + handle.get_stream()); + + std::vector use_bitmap_flags(major_comm_size, false); + { + auto tmp_flags = host_scalar_allgather( + major_comm, v_list_bitmap ? uint8_t{1} : uint8_t{0}, handle.get_stream()); + std::transform(tmp_flags.begin(), tmp_flags.end(), use_bitmap_flags.begin(), [](auto flag) { + return flag == uint8_t{1}; }); - rmm::device_uvector rx_vertices(max_rx_size, handle.get_stream()); - auto rx_tmp_buffer = allocate_dataframe_buffer< - std::conditional_t>( - contains_packed_bool_element ? packed_bool_size(max_rx_size) : max_rx_size, - handle.get_stream()); - auto rx_value_first = get_dataframe_buffer_begin(rx_tmp_buffer); + } + + auto local_v_list_sizes = host_scalar_allgather(major_comm, v_list_size, handle.get_stream()); + auto local_v_list_range_firsts = + host_scalar_allgather(major_comm, v_list_range[0], handle.get_stream()); + auto local_v_list_range_lasts = + host_scalar_allgather(major_comm, v_list_range[1], handle.get_stream()); std::optional> key_offsets{}; if constexpr (GraphViewType::is_storage_transposed) { @@ -743,6 +765,16 @@ void update_edge_minor_property(raft::handle_t const& handle, graph_view.local_edge_partition_view(size_t{0})); auto edge_partition_keys = edge_minor_property_output.keys(); for (int i = 0; i < major_comm_size; ++i) { + rmm::device_uvector rx_vertices(local_v_list_sizes[i], handle.get_stream()); + auto rx_tmp_buffer = allocate_dataframe_buffer< + std::conditional_t>( + contains_packed_bool_element ? packed_bool_size(local_v_list_sizes[i]) + : local_v_list_sizes[i], + handle.get_stream()); + auto rx_value_first = get_dataframe_buffer_begin(rx_tmp_buffer); + if (i == major_comm_rank) { auto vertex_partition = vertex_partition_device_view_t( @@ -781,16 +813,29 @@ void update_edge_minor_property(raft::handle_t const& handle, // FIXME: these broadcast operations can be placed between ncclGroupStart() and // ncclGroupEnd() - device_bcast(major_comm, - sorted_unique_vertex_first, - rx_vertices.begin(), - rx_counts[i], - i, - handle.get_stream()); + std::variant, decltype(sorted_unique_vertex_first)> + v_list{}; + if (use_bitmap_flags[i]) { + v_list = + (i == major_comm_rank) + ? raft::device_span((*v_list_bitmap).data(), (*v_list_bitmap).size()) + : raft::device_span(static_cast(nullptr), size_t{0}); + } else { + v_list = sorted_unique_vertex_first; + } + device_bcast_vertex_list(major_comm, + v_list, + rx_vertices.begin(), + local_v_list_range_firsts[i], + local_v_list_range_lasts[i], + local_v_list_sizes[i], + i, + handle.get_stream()); device_bcast(major_comm, rx_value_first, rx_value_first, - contains_packed_bool_element ? packed_bool_size(rx_counts[i]) : rx_counts[i], + contains_packed_bool_element ? packed_bool_size(local_v_list_sizes[i]) + : local_v_list_sizes[i], i, handle.get_stream()); @@ -798,7 +843,7 @@ void update_edge_minor_property(raft::handle_t const& handle, thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(rx_counts[i]), + thrust::make_counting_iterator(local_v_list_sizes[i]), [rx_vertex_first = rx_vertices.begin(), rx_value_first, subrange_key_first = (*edge_partition_keys).begin() + (*key_offsets)[i], @@ -826,7 +871,7 @@ void update_edge_minor_property(raft::handle_t const& handle, thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(vertex_t{0}), - thrust::make_counting_iterator(static_cast(rx_counts[i])), + thrust::make_counting_iterator(static_cast(local_v_list_sizes[i])), [edge_partition, rx_vertex_first = rx_vertices.begin(), rx_value_first, @@ -847,7 +892,7 @@ void update_edge_minor_property(raft::handle_t const& handle, // directly scatters from the internal buffer) thrust::scatter(handle.get_thrust_policy(), rx_value_first, - rx_value_first + rx_counts[i], + rx_value_first + local_v_list_sizes[i], map_first, edge_partition_value_first); }