Skip to content

Commit

Permalink
update fill|update_edge_minor_property to optionally use bitmap to br…
Browse files Browse the repository at this point in the history
…oadcast vertex list
  • Loading branch information
seunghwak committed Aug 24, 2024
1 parent 3a950a5 commit d27a5e3
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 79 deletions.
112 changes: 76 additions & 36 deletions cpp/src/prims/fill_edge_src_dst_property.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
#pragma once

#include "prims/vertex_frontier.cuh"

#include <cugraph/edge_partition_device_view.cuh>
#include <cugraph/edge_partition_endpoint_property_device_view.cuh>
#include <cugraph/edge_src_dst_property.hpp>
Expand Down Expand Up @@ -153,12 +155,12 @@ void fill_edge_major_property(raft::handle_t const& handle,
auto const minor_comm_rank = minor_comm.get_rank();
auto const minor_comm_size = minor_comm.get_size();

auto rx_counts = host_scalar_allgather(
auto local_v_list_sizes = host_scalar_allgather(
minor_comm,
static_cast<size_t>(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)),
handle.get_stream());
auto max_rx_size =
std::reduce(rx_counts.begin(), rx_counts.end(), size_t{0}, [](auto lhs, auto rhs) {
auto max_rx_size = std::reduce(
local_v_list_sizes.begin(), local_v_list_sizes.end(), size_t{0}, [](auto lhs, auto rhs) {
return std::max(lhs, rhs);
});
rmm::device_uvector<vertex_t> rx_vertices(max_rx_size, handle.get_stream());
Expand All @@ -172,15 +174,15 @@ void fill_edge_major_property(raft::handle_t const& handle,
device_bcast(minor_comm,
sorted_unique_vertex_first,
rx_vertices.begin(),
rx_counts[i],
local_v_list_sizes[i],
i,
handle.get_stream());

if (edge_partition_keys) {
thrust::for_each(
handle.get_thrust_policy(),
thrust::make_counting_iterator(size_t{0}),
thrust::make_counting_iterator(rx_counts[i]),
thrust::make_counting_iterator(local_v_list_sizes[i]),
[rx_vertex_first = rx_vertices.begin(),
input,
edge_partition_key_first = ((*edge_partition_keys)[i]).begin(),
Expand All @@ -203,7 +205,7 @@ void fill_edge_major_property(raft::handle_t const& handle,
thrust::for_each(
handle.get_thrust_policy(),
thrust::make_counting_iterator(vertex_t{0}),
thrust::make_counting_iterator(static_cast<vertex_t>(rx_counts[i])),
thrust::make_counting_iterator(static_cast<vertex_t>(local_v_list_sizes[i])),
[edge_partition,
rx_vertex_first = rx_vertices.begin(),
input,
Expand All @@ -223,7 +225,7 @@ void fill_edge_major_property(raft::handle_t const& handle,
// directly scatters from the internal buffer)
thrust::scatter(handle.get_thrust_policy(),
val_first,
val_first + rx_counts[i],
val_first + local_v_list_sizes[i],
map_first,
edge_partition_value_firsts[i]);
}
Expand Down Expand Up @@ -312,15 +314,41 @@ void fill_edge_minor_property(raft::handle_t const& handle,
auto const major_comm_rank = major_comm.get_rank();
auto const major_comm_size = major_comm.get_size();

auto rx_counts = host_scalar_allgather(
major_comm,
static_cast<size_t>(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)),
handle.get_stream());
auto max_rx_size =
std::reduce(rx_counts.begin(), rx_counts.end(), size_t{0}, [](auto lhs, auto rhs) {
return std::max(lhs, rhs);
auto v_list_size =
static_cast<size_t>(thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last));
std::array<vertex_t, 2> v_list_range = {vertex_t{0}, vertex_t{0}};
if (v_list_size > 0) {
rmm::device_uvector<vertex_t> tmps(2, handle.get_stream());
thrust::tabulate(handle.get_thrust_policy(),
tmps.begin(),
tmps.end(),
[sorted_unique_vertex_first, v_list_size] __device__(size_t i) {
return (i == 0) ? *sorted_unique_vertex_first
: (*(sorted_unique_vertex_first + (v_list_size - 1)) + 1);
});
raft::update_host(v_list_range.data(), tmps.data(), 2, handle.get_stream());
handle.sync_stream();
}

auto v_list_bitmap = compute_vertex_list_bitmap_info(sorted_unique_vertex_first,
sorted_unique_vertex_last,
v_list_range[0],
v_list_range[1],
handle.get_stream());

std::vector<bool> use_bitmap_flags(major_comm_size, false);
{
auto tmp_flags = host_scalar_allgather(
major_comm, v_list_bitmap ? uint8_t{1} : uint8_t{0}, handle.get_stream());
std::transform(tmp_flags.begin(), tmp_flags.end(), use_bitmap_flags.begin(), [](auto flag) {
return flag == uint8_t{1};
});
rmm::device_uvector<vertex_t> rx_vertices(max_rx_size, handle.get_stream());
}
auto local_v_list_sizes = host_scalar_allgather(major_comm, v_list_size, handle.get_stream());
auto local_v_list_range_firsts =
host_scalar_allgather(major_comm, v_list_range[0], handle.get_stream());
auto local_v_list_range_lasts =
host_scalar_allgather(major_comm, v_list_range[1], handle.get_stream());

std::optional<raft::host_span<vertex_t const>> key_offsets{};
if constexpr (GraphViewType::is_storage_transposed) {
Expand All @@ -334,21 +362,33 @@ void fill_edge_minor_property(raft::handle_t const& handle,
graph_view.local_edge_partition_view(size_t{0}));
auto edge_partition_keys = edge_minor_property_output.keys();
for (int i = 0; i < major_comm_size; ++i) {
// FIXME: we can optionally use bitmap for this broadcast
rmm::device_uvector<vertex_t> rx_vertices(local_v_list_sizes[i], handle.get_stream());
// FIXME: these broadcast operations can be placed between ncclGroupStart() and
// ncclGroupEnd()
device_bcast(major_comm,
sorted_unique_vertex_first,
rx_vertices.begin(),
rx_counts[i],
i,
handle.get_stream());
std::variant<raft::device_span<uint32_t const>, decltype(sorted_unique_vertex_first)>
v_list{};
if (use_bitmap_flags[i]) {
v_list =
(i == major_comm_rank)
? raft::device_span<uint32_t const>((*v_list_bitmap).data(), (*v_list_bitmap).size())
: raft::device_span<uint32_t const>(static_cast<uint32_t const*>(nullptr), size_t{0});
} else {
v_list = sorted_unique_vertex_first;
}
device_bcast_vertex_list(major_comm,
v_list,
rx_vertices.begin(),
local_v_list_range_firsts[i],
local_v_list_range_lasts[i],
local_v_list_sizes[i],
i,
handle.get_stream());

if (edge_partition_keys) {
thrust::for_each(
handle.get_thrust_policy(),
thrust::make_counting_iterator(size_t{0}),
thrust::make_counting_iterator(rx_counts[i]),
thrust::make_counting_iterator(local_v_list_sizes[i]),
[rx_vertex_first = rx_vertices.begin(),
input,
subrange_key_first = (*edge_partition_keys).begin() + (*key_offsets)[i],
Expand All @@ -370,18 +410,18 @@ void fill_edge_minor_property(raft::handle_t const& handle,
});
} else {
if constexpr (contains_packed_bool_element) {
thrust::for_each(handle.get_thrust_policy(),
thrust::make_counting_iterator(vertex_t{0}),
thrust::make_counting_iterator(static_cast<vertex_t>(rx_counts[i])),
[edge_partition,
rx_vertex_first = rx_vertices.begin(),
input,
output_value_first = edge_partition_value_first] __device__(auto i) {
auto rx_vertex = *(rx_vertex_first + i);
auto minor_offset =
edge_partition.minor_offset_from_minor_nocheck(rx_vertex);
fill_scalar_or_thrust_tuple(output_value_first, minor_offset, input);
});
thrust::for_each(
handle.get_thrust_policy(),
thrust::make_counting_iterator(vertex_t{0}),
thrust::make_counting_iterator(static_cast<vertex_t>(local_v_list_sizes[i])),
[edge_partition,
rx_vertex_first = rx_vertices.begin(),
input,
output_value_first = edge_partition_value_first] __device__(auto i) {
auto rx_vertex = *(rx_vertex_first + i);
auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(rx_vertex);
fill_scalar_or_thrust_tuple(output_value_first, minor_offset, input);
});
} else {
auto map_first = thrust::make_transform_iterator(
rx_vertices.begin(),
Expand All @@ -393,7 +433,7 @@ void fill_edge_minor_property(raft::handle_t const& handle,
// directly scatters from the internal buffer)
thrust::scatter(handle.get_thrust_policy(),
val_first,
val_first + rx_counts[i],
val_first + local_v_list_sizes[i],
map_first,
edge_partition_value_first);
}
Expand Down
Loading

0 comments on commit d27a5e3

Please sign in to comment.