Skip to content

Commit

Permalink
reduce comm. sync
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Oct 30, 2024
1 parent aa13925 commit 2db13e9
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 238 deletions.
144 changes: 72 additions & 72 deletions cpp/src/prims/detail/extract_transform_v_frontier_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,10 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
// 3. communication over minor_comm

std::vector<size_t> local_frontier_sizes{};
std::conditional_t<GraphViewType::is_multi_gpu, std::vector<size_t>, std::byte /* dummy */>
max_tmp_buffer_sizes{};
std::conditional_t<GraphViewType::is_multi_gpu, std::vector<size_t>, std::byte /* dummy */>
tmp_buffer_size_per_loop_approximations{};
std::conditional_t<try_bitmap, std::vector<vertex_t>, std::byte /* dummy */>
local_frontier_range_firsts{};
std::conditional_t<try_bitmap, std::vector<vertex_t>, std::byte /* dummy */>
Expand All @@ -801,7 +805,39 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
auto const minor_comm_rank = minor_comm.get_rank();
auto const minor_comm_size = minor_comm.get_size();

size_t num_scalars = 1; // local_frontier_size
auto max_tmp_buffer_size =
static_cast<size_t>(static_cast<double>(handle.get_device_properties().totalGlobalMem) * 0.2);
size_t approx_tmp_buffer_size_per_loop{};
{
size_t key_size{0};
if constexpr (std::is_arithmetic_v<key_t>) {
key_size = sizeof(key_t);
} else {
key_size = cugraph::sum_thrust_tuple_element_sizes<key_t>();
}
size_t output_key_size{0};
if constexpr (!std::is_same_v<output_key_t, void>) {
if constexpr (std::is_arithmetic_v<output_key_t>) {
output_key_size = sizeof(output_key_t);
} else {
output_key_size = cugraph::sum_thrust_tuple_element_sizes<output_key_t>();
}
}
size_t output_value_size{0};
if constexpr (!std::is_same_v<output_value_t, void>) {
if constexpr (std::is_arithmetic_v<output_value_t>) {
output_value_size = sizeof(output_value_t);
} else {
output_value_size = cugraph::sum_thrust_tuple_element_sizes<output_value_t>();
}
}
approx_tmp_buffer_size_per_loop =
static_cast<size_t>(thrust::distance(frontier_key_first, frontier_key_last)) * key_size +
local_max_pushes * (output_key_size + output_value_size);
}

size_t num_scalars =
3; // local_frontier_size, max_tmp_buffer_size, approx_tmp_buffer_size_per_loop
if constexpr (try_bitmap) {
num_scalars += 2; // local_frontier_range_first, local_frontier_range_last
}
Expand All @@ -810,16 +846,23 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
handle.get_stream());
thrust::tabulate(
handle.get_thrust_policy(),
d_aggregate_tmps.begin() + minor_comm_rank * num_scalars,
d_aggregate_tmps.begin() + minor_comm_rank * num_scalars + (try_bitmap ? 3 : 1),
d_aggregate_tmps.begin() + num_scalars * minor_comm_rank,
d_aggregate_tmps.begin() + (num_scalars * minor_comm_rank + (try_bitmap ? 5 : 3)),
[frontier_key_first,
max_tmp_buffer_size,
approx_tmp_buffer_size_per_loop,
v_list_size = static_cast<size_t>(thrust::distance(frontier_key_first, frontier_key_last)),
vertex_partition_range_first =
graph_view.local_vertex_partition_range_first()] __device__(size_t i) {
if (i == 0) {
return v_list_size;
} else if (i == 1) {
return max_tmp_buffer_size;
} else if (i == 2) {
return approx_tmp_buffer_size_per_loop;
}
if constexpr (try_bitmap) {
if (i == 0) {
return v_list_size;
} else if (i == 1) {
if (i == 3) {
vertex_t first{};
if (v_list_size > 0) {
first = *frontier_key_first;
Expand All @@ -828,8 +871,8 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
}
assert(static_cast<vertex_t>(static_cast<size_t>(first)) == first);
return static_cast<size_t>(first);
} else {
assert(i == 2);
} else if (i == 4) {
assert(i == 4);
vertex_t last{};
if (v_list_size > 0) {
last = *(frontier_key_first + (v_list_size - 1)) + 1;
Expand All @@ -839,14 +882,13 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
assert(static_cast<vertex_t>(static_cast<size_t>(last)) == last);
return static_cast<size_t>(last);
}
} else {
assert(i == 0);
return v_list_size;
}
assert(false);
return size_t{0};
});
if (key_segment_offsets) {
raft::update_device(
d_aggregate_tmps.data() + (minor_comm_rank * num_scalars + (try_bitmap ? 3 : 1)),
d_aggregate_tmps.data() + (minor_comm_rank * num_scalars + (try_bitmap ? 5 : 3)),
(*key_segment_offsets).data(),
(*key_segment_offsets).size(),
handle.get_stream());
Expand All @@ -866,7 +908,9 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
d_aggregate_tmps.size(),
handle.get_stream());
handle.sync_stream();
local_frontier_sizes = std::vector<size_t>(minor_comm_size);
local_frontier_sizes = std::vector<size_t>(minor_comm_size);
max_tmp_buffer_sizes = std::vector<size_t>(minor_comm_size);
tmp_buffer_size_per_loop_approximations = std::vector<size_t>(minor_comm_size);
if constexpr (try_bitmap) {
local_frontier_range_firsts = std::vector<vertex_t>(minor_comm_size);
local_frontier_range_lasts = std::vector<vertex_t>(minor_comm_size);
Expand All @@ -876,18 +920,20 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
(*key_segment_offset_vectors).reserve(minor_comm_size);
}
for (int i = 0; i < minor_comm_size; ++i) {
local_frontier_sizes[i] = h_aggregate_tmps[i * num_scalars];
local_frontier_sizes[i] = h_aggregate_tmps[i * num_scalars];
max_tmp_buffer_sizes[i] = h_aggregate_tmps[i * num_scalars + 1];
tmp_buffer_size_per_loop_approximations[i] = h_aggregate_tmps[i * num_scalars + 2];
if constexpr (try_bitmap) {
local_frontier_range_firsts[i] =
static_cast<vertex_t>(h_aggregate_tmps[i * num_scalars + 1]);
static_cast<vertex_t>(h_aggregate_tmps[i * num_scalars + 3]);
local_frontier_range_lasts[i] =
static_cast<vertex_t>(h_aggregate_tmps[i * num_scalars + 2]);
static_cast<vertex_t>(h_aggregate_tmps[i * num_scalars + 4]);
}
if (key_segment_offsets) {
(*key_segment_offset_vectors)
.emplace_back(h_aggregate_tmps.begin() + (i * num_scalars + (try_bitmap ? 3 : 1)),
.emplace_back(h_aggregate_tmps.begin() + (i * num_scalars + (try_bitmap ? 5 : 3)),
h_aggregate_tmps.begin() +
(i * num_scalars + (try_bitmap ? 3 : 1) + (*key_segment_offsets).size()));
(i * num_scalars + (try_bitmap ? 5 : 3) + (*key_segment_offsets).size()));
}
}
} else {
Expand Down Expand Up @@ -971,63 +1017,17 @@ extract_transform_v_frontier_e(raft::handle_t const& handle,
std::optional<std::vector<size_t>> stream_pool_indices{std::nullopt};
if constexpr (GraphViewType::is_multi_gpu) {
auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name());
auto const minor_comm_rank = minor_comm.get_rank();
auto partition_idx = static_cast<size_t>(minor_comm_rank);
auto const minor_comm_size = minor_comm.get_size();

if (graph_view.local_edge_partition_segment_offsets(partition_idx) &&
if (graph_view.local_vertex_partition_segment_offsets() &&
(handle.get_stream_pool_size() >= max_segments)) {
auto& comm = handle.get_comms();
auto const comm_size = comm.get_size();

auto max_tmp_buffer_size = static_cast<size_t>(
static_cast<double>(handle.get_device_properties().totalGlobalMem) * 0.2);

size_t aggregate_major_range_size{};
size_t aggregate_max_pushes{}; // this is approximate as we only consider local edges for
// [frontier_key_first, frontier_key_last), note that neighbor
// lists are partitioned if minor_comm_size > 1
{
auto tmp = host_scalar_allreduce(
comm,
thrust::make_tuple(
static_cast<size_t>(thrust::distance(frontier_key_first, frontier_key_last)),
local_max_pushes),
raft::comms::op_t::SUM,
handle.get_stream());
aggregate_major_range_size = thrust::get<0>(tmp);
aggregate_max_pushes = thrust::get<1>(tmp);
}

size_t key_size{0};
if constexpr (std::is_arithmetic_v<key_t>) {
if (v_compressible) {
key_size = sizeof(uint32_t);
} else {
key_size = sizeof(key_t);
}
} else {
key_size = cugraph::sum_thrust_tuple_element_sizes<key_t>();
}
size_t output_key_size{0};
if constexpr (!std::is_same_v<output_key_t, void>) {
if constexpr (std::is_arithmetic_v<output_key_t>) {
output_key_size = sizeof(output_key_t);
} else {
output_key_size = cugraph::sum_thrust_tuple_element_sizes<output_key_t>();
}
}
size_t output_value_size{0};
if constexpr (!std::is_same_v<output_value_t, void>) {
if constexpr (std::is_arithmetic_v<output_value_t>) {
output_value_size = sizeof(output_value_t);
} else {
output_value_size = cugraph::sum_thrust_tuple_element_sizes<output_value_t>();
}
}
auto max_tmp_buffer_size =
std::reduce(max_tmp_buffer_sizes.begin(), max_tmp_buffer_sizes.end()) /
static_cast<size_t>(minor_comm_size);
auto approx_tmp_buffer_size_per_loop =
(aggregate_major_range_size / comm_size) * key_size +
(aggregate_max_pushes / comm_size) * (output_key_size + output_value_size);

std::reduce(tmp_buffer_size_per_loop_approximations.begin(),
tmp_buffer_size_per_loop_approximations.end()) /
static_cast<size_t>(minor_comm_size);
stream_pool_indices = init_stream_pool_indices(max_tmp_buffer_size,
approx_tmp_buffer_size_per_loop,
graph_view.number_of_local_edge_partitions(),
Expand Down
Loading

0 comments on commit 2db13e9

Please sign in to comment.