diff --git a/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh b/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh index c63557f157..8ef7da3f02 100644 --- a/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh +++ b/cpp/src/prims/detail/extract_transform_v_frontier_e.cuh @@ -791,6 +791,10 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, // 3. communication over minor_comm std::vector local_frontier_sizes{}; + std::conditional_t, std::byte /* dummy */> + max_tmp_buffer_sizes{}; + std::conditional_t, std::byte /* dummy */> + tmp_buffer_size_per_loop_approximations{}; std::conditional_t, std::byte /* dummy */> local_frontier_range_firsts{}; std::conditional_t, std::byte /* dummy */> @@ -801,7 +805,39 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, auto const minor_comm_rank = minor_comm.get_rank(); auto const minor_comm_size = minor_comm.get_size(); - size_t num_scalars = 1; // local_frontier_size + auto max_tmp_buffer_size = + static_cast(static_cast(handle.get_device_properties().totalGlobalMem) * 0.2); + size_t approx_tmp_buffer_size_per_loop{}; + { + size_t key_size{0}; + if constexpr (std::is_arithmetic_v) { + key_size = sizeof(key_t); + } else { + key_size = cugraph::sum_thrust_tuple_element_sizes(); + } + size_t output_key_size{0}; + if constexpr (!std::is_same_v) { + if constexpr (std::is_arithmetic_v) { + output_key_size = sizeof(output_key_t); + } else { + output_key_size = cugraph::sum_thrust_tuple_element_sizes(); + } + } + size_t output_value_size{0}; + if constexpr (!std::is_same_v) { + if constexpr (std::is_arithmetic_v) { + output_value_size = sizeof(output_value_t); + } else { + output_value_size = cugraph::sum_thrust_tuple_element_sizes(); + } + } + approx_tmp_buffer_size_per_loop = + static_cast(thrust::distance(frontier_key_first, frontier_key_last)) * key_size + + local_max_pushes * (output_key_size + output_value_size); + } + + size_t num_scalars = + 3; // local_frontier_size, max_tmp_buffer_size, approx_tmp_buffer_size_per_loop if constexpr (try_bitmap) { num_scalars += 2; // local_frontier_range_first, local_frontier_range_last } @@ -810,16 +846,23 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, handle.get_stream()); thrust::tabulate( handle.get_thrust_policy(), - d_aggregate_tmps.begin() + minor_comm_rank * num_scalars, - d_aggregate_tmps.begin() + minor_comm_rank * num_scalars + (try_bitmap ? 3 : 1), + d_aggregate_tmps.begin() + num_scalars * minor_comm_rank, + d_aggregate_tmps.begin() + (num_scalars * minor_comm_rank + (try_bitmap ? 5 : 3)), [frontier_key_first, + max_tmp_buffer_size, + approx_tmp_buffer_size_per_loop, v_list_size = static_cast(thrust::distance(frontier_key_first, frontier_key_last)), vertex_partition_range_first = graph_view.local_vertex_partition_range_first()] __device__(size_t i) { + if (i == 0) { + return v_list_size; + } else if (i == 1) { + return max_tmp_buffer_size; + } else if (i == 2) { + return approx_tmp_buffer_size_per_loop; + } if constexpr (try_bitmap) { - if (i == 0) { - return v_list_size; - } else if (i == 1) { + if (i == 3) { vertex_t first{}; if (v_list_size > 0) { first = *frontier_key_first; @@ -828,8 +871,8 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, } assert(static_cast(static_cast(first)) == first); return static_cast(first); - } else { - assert(i == 2); + } else if (i == 4) { + assert(i == 4); vertex_t last{}; if (v_list_size > 0) { last = *(frontier_key_first + (v_list_size - 1)) + 1; @@ -839,14 +882,13 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, assert(static_cast(static_cast(last)) == last); return static_cast(last); } - } else { - assert(i == 0); - return v_list_size; } + assert(false); + return size_t{0}; }); if (key_segment_offsets) { raft::update_device( - d_aggregate_tmps.data() + (minor_comm_rank * num_scalars + (try_bitmap ? 3 : 1)), + d_aggregate_tmps.data() + (minor_comm_rank * num_scalars + (try_bitmap ? 5 : 3)), (*key_segment_offsets).data(), (*key_segment_offsets).size(), handle.get_stream()); @@ -866,7 +908,9 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, d_aggregate_tmps.size(), handle.get_stream()); handle.sync_stream(); - local_frontier_sizes = std::vector(minor_comm_size); + local_frontier_sizes = std::vector(minor_comm_size); + max_tmp_buffer_sizes = std::vector(minor_comm_size); + tmp_buffer_size_per_loop_approximations = std::vector(minor_comm_size); if constexpr (try_bitmap) { local_frontier_range_firsts = std::vector(minor_comm_size); local_frontier_range_lasts = std::vector(minor_comm_size); @@ -876,18 +920,20 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, (*key_segment_offset_vectors).reserve(minor_comm_size); } for (int i = 0; i < minor_comm_size; ++i) { - local_frontier_sizes[i] = h_aggregate_tmps[i * num_scalars]; + local_frontier_sizes[i] = h_aggregate_tmps[i * num_scalars]; + max_tmp_buffer_sizes[i] = h_aggregate_tmps[i * num_scalars + 1]; + tmp_buffer_size_per_loop_approximations[i] = h_aggregate_tmps[i * num_scalars + 2]; if constexpr (try_bitmap) { local_frontier_range_firsts[i] = - static_cast(h_aggregate_tmps[i * num_scalars + 1]); + static_cast(h_aggregate_tmps[i * num_scalars + 3]); local_frontier_range_lasts[i] = - static_cast(h_aggregate_tmps[i * num_scalars + 2]); + static_cast(h_aggregate_tmps[i * num_scalars + 4]); } if (key_segment_offsets) { (*key_segment_offset_vectors) - .emplace_back(h_aggregate_tmps.begin() + (i * num_scalars + (try_bitmap ? 3 : 1)), + .emplace_back(h_aggregate_tmps.begin() + (i * num_scalars + (try_bitmap ? 5 : 3)), h_aggregate_tmps.begin() + - (i * num_scalars + (try_bitmap ? 3 : 1) + (*key_segment_offsets).size())); + (i * num_scalars + (try_bitmap ? 5 : 3) + (*key_segment_offsets).size())); } } } else { @@ -971,63 +1017,17 @@ extract_transform_v_frontier_e(raft::handle_t const& handle, std::optional> stream_pool_indices{std::nullopt}; if constexpr (GraphViewType::is_multi_gpu) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - auto const minor_comm_rank = minor_comm.get_rank(); - auto partition_idx = static_cast(minor_comm_rank); + auto const minor_comm_size = minor_comm.get_size(); - if (graph_view.local_edge_partition_segment_offsets(partition_idx) && + if (graph_view.local_vertex_partition_segment_offsets() && (handle.get_stream_pool_size() >= max_segments)) { - auto& comm = handle.get_comms(); - auto const comm_size = comm.get_size(); - - auto max_tmp_buffer_size = static_cast( - static_cast(handle.get_device_properties().totalGlobalMem) * 0.2); - - size_t aggregate_major_range_size{}; - size_t aggregate_max_pushes{}; // this is approximate as we only consider local edges for - // [frontier_key_first, frontier_key_last), note that neighbor - // lists are partitioned if minor_comm_size > 1 - { - auto tmp = host_scalar_allreduce( - comm, - thrust::make_tuple( - static_cast(thrust::distance(frontier_key_first, frontier_key_last)), - local_max_pushes), - raft::comms::op_t::SUM, - handle.get_stream()); - aggregate_major_range_size = thrust::get<0>(tmp); - aggregate_max_pushes = thrust::get<1>(tmp); - } - - size_t key_size{0}; - if constexpr (std::is_arithmetic_v) { - if (v_compressible) { - key_size = sizeof(uint32_t); - } else { - key_size = sizeof(key_t); - } - } else { - key_size = cugraph::sum_thrust_tuple_element_sizes(); - } - size_t output_key_size{0}; - if constexpr (!std::is_same_v) { - if constexpr (std::is_arithmetic_v) { - output_key_size = sizeof(output_key_t); - } else { - output_key_size = cugraph::sum_thrust_tuple_element_sizes(); - } - } - size_t output_value_size{0}; - if constexpr (!std::is_same_v) { - if constexpr (std::is_arithmetic_v) { - output_value_size = sizeof(output_value_t); - } else { - output_value_size = cugraph::sum_thrust_tuple_element_sizes(); - } - } + auto max_tmp_buffer_size = + std::reduce(max_tmp_buffer_sizes.begin(), max_tmp_buffer_sizes.end()) / + static_cast(minor_comm_size); auto approx_tmp_buffer_size_per_loop = - (aggregate_major_range_size / comm_size) * key_size + - (aggregate_max_pushes / comm_size) * (output_key_size + output_value_size); - + std::reduce(tmp_buffer_size_per_loop_approximations.begin(), + tmp_buffer_size_per_loop_approximations.end()) / + static_cast(minor_comm_size); stream_pool_indices = init_stream_pool_indices(max_tmp_buffer_size, approx_tmp_buffer_size_per_loop, graph_view.number_of_local_edge_partitions(), diff --git a/cpp/src/prims/detail/per_v_transform_reduce_e.cuh b/cpp/src/prims/detail/per_v_transform_reduce_e.cuh index 3ab4ba39e3..650a307d54 100644 --- a/cpp/src/prims/detail/per_v_transform_reduce_e.cuh +++ b/cpp/src/prims/detail/per_v_transform_reduce_e.cuh @@ -1724,61 +1724,117 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, } } - // 5. collect local_key_list_sizes & local_v_list_range_firsts & local_v_list_range_lasts & + // 5. collect max_tmp_buffer_size, approx_tmp_buffer_size_per_loop, local_key_list_sizes, + // local_v_list_range_firsts, local_v_list_range_lasts, local_key_list_deg1_sizes, // key_segment_offset_vectors + std::conditional_t, std::byte /* dummy */> + max_tmp_buffer_sizes{}; + std::conditional_t, std::byte /* dummy */> + tmp_buffer_size_per_loop_approximations{}; std::conditional_t, std::byte /* dummy */> local_key_list_sizes{}; std::conditional_t, std::byte /* dummy */> local_v_list_range_firsts{}; std::conditional_t, std::byte /* dummy */> local_v_list_range_lasts{}; - std::conditional_t>, - std::optional>, - std::byte /* dummy */> + std::conditional_t>, std::byte /* dummy */> local_key_list_deg1_sizes{}; // if global degree is 1, any valid local value should be selected std::conditional_t>>, std::byte /* dummy */> key_segment_offset_vectors{}; - if constexpr (use_input_key) { - if constexpr (GraphViewType::is_multi_gpu) { - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - auto const minor_comm_rank = minor_comm.get_rank(); - auto const minor_comm_size = minor_comm.get_size(); + if constexpr (GraphViewType::is_multi_gpu) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_rank = minor_comm.get_rank(); + auto const minor_comm_size = minor_comm.get_size(); + + auto max_tmp_buffer_size = + static_cast(static_cast(handle.get_device_properties().totalGlobalMem) * 0.2); + size_t approx_tmp_buffer_size_per_loop{0}; + if constexpr (update_major) { + size_t key_size{0}; + if constexpr (use_input_key) { + if constexpr (std::is_arithmetic_v) { + key_size = sizeof(key_t); + } else { + key_size = sum_thrust_tuple_element_sizes(); + } + } + size_t value_size{0}; + if constexpr (std::is_arithmetic_v) { + value_size = sizeof(T); + } else { + value_size = sum_thrust_tuple_element_sizes(); + } - size_t num_scalars = 1; // local_key_list_size + size_t major_range_size{}; + if constexpr (use_input_key) { + major_range_size = static_cast( + thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last)); + ; + } else { + major_range_size = graph_view.local_vertex_partition_range_size(); + } + size_t size_per_key{}; + if constexpr (filter_input_key) { + size_per_key = + key_size + + value_size / 2; // to reflect that many keys will be filtered out, note that this is a + // simple approximation, memory requirement in this case is much more + // complex as we store additional temporary variables + + } else { + size_per_key = key_size + value_size; + } + approx_tmp_buffer_size_per_loop = major_range_size * size_per_key; + } + + size_t num_scalars = 2; // max_tmp_buffer_size, approx_tmp_buffer_size_per_loop + size_t num_scalars_less_key_segment_offsets = num_scalars; + if constexpr (use_input_key) { + num_scalars += 1; // local_key_list_size if constexpr (try_bitmap) { num_scalars += 2; // local_key_list_range_first, local_key_list_range_last } if (filter_input_key && graph_view.use_dcs()) { num_scalars += 1; // local_key_list_degree_1_size } + num_scalars_less_key_segment_offsets = num_scalars; if (key_segment_offsets) { num_scalars += (*key_segment_offsets).size(); } + } - rmm::device_uvector d_aggregate_tmps(minor_comm_size * num_scalars, - handle.get_stream()); - auto hypersparse_degree_offsets = - graph_view.local_vertex_partition_hypersparse_degree_offsets(); - thrust::tabulate( - handle.get_thrust_policy(), - d_aggregate_tmps.begin() + minor_comm_rank * num_scalars, - d_aggregate_tmps.begin() + minor_comm_rank * num_scalars + (try_bitmap ? 3 : 1) + - (filter_input_key && graph_view.use_dcs() ? 1 : 0), - [sorted_unique_key_first, - v_list_size = static_cast( - thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last)), - deg1_v_first = (filter_input_key && graph_view.use_dcs()) - ? thrust::make_optional(graph_view.local_vertex_partition_range_first() + - (*local_vertex_partition_segment_offsets)[3] + - *((*hypersparse_degree_offsets).rbegin() + 1)) - : thrust::nullopt, - vertex_partition_range_first = - graph_view.local_vertex_partition_range_first()] __device__(size_t i) { + rmm::device_uvector d_aggregate_tmps(minor_comm_size * num_scalars, + handle.get_stream()); + auto hypersparse_degree_offsets = + graph_view.local_vertex_partition_hypersparse_degree_offsets(); + thrust::tabulate( + handle.get_thrust_policy(), + d_aggregate_tmps.begin() + num_scalars * minor_comm_rank, + d_aggregate_tmps.begin() + num_scalars * minor_comm_rank + + num_scalars_less_key_segment_offsets, + [max_tmp_buffer_size, + approx_tmp_buffer_size_per_loop, + sorted_unique_key_first, + sorted_unique_nzd_key_last, + deg1_v_first = (filter_input_key && graph_view.use_dcs()) + ? thrust::make_optional(graph_view.local_vertex_partition_range_first() + + (*local_vertex_partition_segment_offsets)[3] + + *((*hypersparse_degree_offsets).rbegin() + 1)) + : thrust::nullopt, + vertex_partition_range_first = + graph_view.local_vertex_partition_range_first()] __device__(size_t i) { + if (i == 0) { + return max_tmp_buffer_size; + } else if (i == 1) { + return approx_tmp_buffer_size_per_loop; + } + if constexpr (use_input_key) { + auto v_list_size = static_cast( + thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last)); + if (i == 2) { return v_list_size; } if constexpr (try_bitmap) { - if (i == 0) { - return v_list_size; - } else if (i == 1) { + if (i == 3) { vertex_t first{}; if (v_list_size > 0) { first = *sorted_unique_key_first; @@ -1787,8 +1843,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, } assert(static_cast(static_cast(first)) == first); return static_cast(first); - } - if (i == 2) { + } else if (i == 4) { vertex_t last{}; if (v_list_size > 0) { last = *(sorted_unique_key_first + (v_list_size - 1)) + 1; @@ -1797,7 +1852,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, } assert(static_cast(static_cast(last)) == last); return static_cast(last); - } else { + } else if (i == 5) { if (deg1_v_first) { auto sorted_unique_v_first = thrust::make_transform_iterator( sorted_unique_key_first, @@ -1810,15 +1865,10 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, sorted_unique_v_first, sorted_unique_v_first + v_list_size, deg1_v_first))); - } else { - assert(false); - return size_t{0}; } } } else { - if (i == 0) { - return v_list_size; - } else { + if (i == 3) { if (deg1_v_first) { auto sorted_unique_v_first = thrust::make_transform_iterator( sorted_unique_key_first, @@ -1831,36 +1881,40 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, sorted_unique_v_first, sorted_unique_v_first + v_list_size, deg1_v_first))); - } else { - assert(false); - return size_t{0}; } } } - }); + } + assert(false); + return size_t{0}; + }); + if constexpr (use_input_key) { if (key_segment_offsets) { - raft::update_device( - d_aggregate_tmps.data() + (minor_comm_rank * num_scalars + (try_bitmap ? 3 : 1) + - (filter_input_key && graph_view.use_dcs() ? 1 : 0)), - (*key_segment_offsets).data(), - (*key_segment_offsets).size(), - handle.get_stream()); + raft::update_device(d_aggregate_tmps.data() + (num_scalars * minor_comm_rank + + num_scalars_less_key_segment_offsets), + (*key_segment_offsets).data(), + (*key_segment_offsets).size(), + handle.get_stream()); } + } - if (minor_comm_size > 1) { - device_allgather(minor_comm, - d_aggregate_tmps.data() + minor_comm_rank * num_scalars, - d_aggregate_tmps.data(), - num_scalars, - handle.get_stream()); - } + if (minor_comm_size > 1) { + device_allgather(minor_comm, + d_aggregate_tmps.data() + minor_comm_rank * num_scalars, + d_aggregate_tmps.data(), + num_scalars, + handle.get_stream()); + } - std::vector h_aggregate_tmps(d_aggregate_tmps.size()); - raft::update_host(h_aggregate_tmps.data(), - d_aggregate_tmps.data(), - d_aggregate_tmps.size(), - handle.get_stream()); - handle.sync_stream(); + std::vector h_aggregate_tmps(d_aggregate_tmps.size()); + raft::update_host(h_aggregate_tmps.data(), + d_aggregate_tmps.data(), + d_aggregate_tmps.size(), + handle.get_stream()); + handle.sync_stream(); + max_tmp_buffer_sizes = std::vector(minor_comm_size); + tmp_buffer_size_per_loop_approximations = std::vector(minor_comm_size); + if constexpr (use_input_key) { local_key_list_sizes = std::vector(minor_comm_size); if constexpr (try_bitmap) { local_v_list_range_firsts = std::vector(minor_comm_size); @@ -1875,30 +1929,35 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, key_segment_offset_vectors = std::vector>{}; (*key_segment_offset_vectors).reserve(minor_comm_size); } - for (int i = 0; i < minor_comm_size; ++i) { - local_key_list_sizes[i] = h_aggregate_tmps[i * num_scalars]; + } + for (int i = 0; i < minor_comm_size; ++i) { + max_tmp_buffer_sizes[i] = h_aggregate_tmps[i * num_scalars]; + tmp_buffer_size_per_loop_approximations[i] = h_aggregate_tmps[i * num_scalars + 1]; + if constexpr (use_input_key) { + local_key_list_sizes[i] = h_aggregate_tmps[i * num_scalars + 2]; if constexpr (try_bitmap) { local_v_list_range_firsts[i] = - static_cast(h_aggregate_tmps[i * num_scalars + 1]); + static_cast(h_aggregate_tmps[i * num_scalars + 3]); local_v_list_range_lasts[i] = - static_cast(h_aggregate_tmps[i * num_scalars + 2]); + static_cast(h_aggregate_tmps[i * num_scalars + 4]); } if constexpr (filter_input_key) { if (graph_view.use_dcs()) { (*local_key_list_deg1_sizes)[i] = - static_cast(h_aggregate_tmps[i * num_scalars + (try_bitmap ? 3 : 1)]); + static_cast(h_aggregate_tmps[i * num_scalars + (try_bitmap ? 5 : 3)]); } } if (key_segment_offsets) { (*key_segment_offset_vectors) - .emplace_back(h_aggregate_tmps.begin() + i * num_scalars + (try_bitmap ? 3 : 1) + - ((filter_input_key && graph_view.use_dcs()) ? 1 : 0), - h_aggregate_tmps.begin() + i * num_scalars + (try_bitmap ? 3 : 1) + - ((filter_input_key && graph_view.use_dcs()) ? 1 : 0) + - (*key_segment_offsets).size()); + .emplace_back( + h_aggregate_tmps.begin() + i * num_scalars + num_scalars_less_key_segment_offsets, + h_aggregate_tmps.begin() + i * num_scalars + num_scalars_less_key_segment_offsets + + (*key_segment_offsets).size()); } } - } else { + } + } else { + if constexpr (use_input_key) { local_key_list_sizes = std::vector{ static_cast(thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last))}; if (key_segment_offsets) { @@ -2008,63 +2067,17 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, std::optional> stream_pool_indices{std::nullopt}; if constexpr (GraphViewType::is_multi_gpu) { if (local_vertex_partition_segment_offsets && (handle.get_stream_pool_size() >= max_segments)) { - auto max_tmp_buffer_size = static_cast( - static_cast(handle.get_device_properties().totalGlobalMem) * 0.2); - size_t tmp_buffer_size_per_loop{0}; - if constexpr (update_major) { - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - auto const minor_comm_size = minor_comm.get_size(); - - size_t key_size{0}; - if constexpr (use_input_key) { - if constexpr (std::is_arithmetic_v) { - if (v_compressible) { - key_size = sizeof(uint32_t); - } else { - key_size = sizeof(key_t); - } - } else { - key_size = sum_thrust_tuple_element_sizes(); - } - } - size_t value_size{0}; - if constexpr (std::is_arithmetic_v) { - value_size = sizeof(T); - } else { - value_size = sum_thrust_tuple_element_sizes(); - } - - size_t aggregate_major_range_size{}; - if constexpr (use_input_key) { - aggregate_major_range_size = - std::reduce(local_key_list_sizes.begin(), local_key_list_sizes.end()); - } else { - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { - if constexpr (GraphViewType::is_storage_transposed) { - aggregate_major_range_size += graph_view.local_edge_partition_dst_range_size(i); - } else { - aggregate_major_range_size += graph_view.local_edge_partition_src_range_size(i); - } - } - } - size_t size_per_key{}; - if constexpr (filter_input_key) { - size_per_key = - key_size + - value_size / 2; // to reflect that many keys will be filtered out, note that this is a - // simple approximation, memory requirement in this case is much more - // complex as we store additional temporary variables - - } else { - size_per_key = key_size + value_size; - } - tmp_buffer_size_per_loop = - (aggregate_major_range_size / graph_view.number_of_local_edge_partitions()) * - size_per_key; - } - + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_size = minor_comm.get_size(); + auto max_tmp_buffer_size = + std::reduce(max_tmp_buffer_sizes.begin(), max_tmp_buffer_sizes.end()) / + static_cast(minor_comm_size); + auto approx_tmp_buffer_size_per_loop = + std::reduce(tmp_buffer_size_per_loop_approximations.begin(), + tmp_buffer_size_per_loop_approximations.end()) / + static_cast(minor_comm_size); stream_pool_indices = init_stream_pool_indices(max_tmp_buffer_size, - tmp_buffer_size_per_loop, + approx_tmp_buffer_size_per_loop, graph_view.number_of_local_edge_partitions(), max_segments, handle.get_stream_pool_size()); diff --git a/cpp/src/prims/fill_edge_src_dst_property.cuh b/cpp/src/prims/fill_edge_src_dst_property.cuh index 6efee71f5a..bef61080f4 100644 --- a/cpp/src/prims/fill_edge_src_dst_property.cuh +++ b/cpp/src/prims/fill_edge_src_dst_property.cuh @@ -335,60 +335,74 @@ void fill_edge_minor_property(raft::handle_t const& handle, sizeof( uint32_t); // 128B cache line alignment (unaligned ncclBroadcast operations are slower) + std::vector max_tmp_buffer_sizes{}; std::vector local_v_list_sizes{}; std::vector local_v_list_range_firsts{}; std::vector local_v_list_range_lasts{}; { auto v_list_size = static_cast( thrust::distance(sorted_unique_vertex_first, sorted_unique_vertex_last)); - rmm::device_uvector d_aggregate_tmps(major_comm_size * size_t{3}, - handle.get_stream()); - thrust::tabulate(handle.get_thrust_policy(), - d_aggregate_tmps.begin() + major_comm_rank * size_t{3}, - d_aggregate_tmps.begin() + (major_comm_rank + 1) * size_t{3}, - [sorted_unique_vertex_first, - v_list_size, - vertex_partition_range_first = - graph_view.local_vertex_partition_range_first()] __device__(size_t i) { - if (i == 0) { - return v_list_size; - } else if (i == 1) { - if (v_list_size > 0) { - return *sorted_unique_vertex_first; - } else { - return vertex_partition_range_first; - } - } else { - if (v_list_size > 0) { - return *(sorted_unique_vertex_first + (v_list_size - 1)) + 1; - } else { - return vertex_partition_range_first; - } - } - }); + rmm::device_uvector d_aggregate_tmps(major_comm_size * size_t{4}, + handle.get_stream()); + thrust::tabulate( + handle.get_thrust_policy(), + d_aggregate_tmps.begin() + major_comm_rank * size_t{4}, + d_aggregate_tmps.begin() + (major_comm_rank + 1) * size_t{4}, + [max_tmp_buffer_size = static_cast( + static_cast(handle.get_device_properties().totalGlobalMem) * 0.05), + sorted_unique_vertex_first, + v_list_size, + vertex_partition_range_first = + graph_view.local_vertex_partition_range_first()] __device__(size_t i) { + if (i == 0) { + return max_tmp_buffer_size; + } else if (i == 1) { + return static_cast(v_list_size); + } else if (i == 2) { + vertex_t first{}; + if (v_list_size > 0) { + first = *sorted_unique_vertex_first; + } else { + first = vertex_partition_range_first; + } + assert(static_cast(static_cast(first)) == first); + return static_cast(first); + } else { + vertex_t last{}; + if (v_list_size > 0) { + last = *(sorted_unique_vertex_first + (v_list_size - 1)) + 1; + } else { + last = vertex_partition_range_first; + } + assert(static_cast(static_cast(last)) == last); + return static_cast(last); + } + }); - if (major_comm_size > 1) { // allgather v_list_size, v_list_range_first (inclusive), - // v_list_range_last (exclusive) + if (major_comm_size > 1) { // allgather max_tmp_buffer_size, v_list_size, v_list_range_first + // (inclusive), v_list_range_last (exclusive) device_allgather(major_comm, - d_aggregate_tmps.data() + major_comm_rank * size_t{3}, + d_aggregate_tmps.data() + major_comm_rank * size_t{4}, d_aggregate_tmps.data(), - size_t{3}, + size_t{4}, handle.get_stream()); } - std::vector h_aggregate_tmps(d_aggregate_tmps.size()); + std::vector h_aggregate_tmps(d_aggregate_tmps.size()); raft::update_host(h_aggregate_tmps.data(), d_aggregate_tmps.data(), d_aggregate_tmps.size(), handle.get_stream()); handle.sync_stream(); + max_tmp_buffer_sizes = std::vector(major_comm_size); local_v_list_sizes = std::vector(major_comm_size); local_v_list_range_firsts = std::vector(major_comm_size); local_v_list_range_lasts = std::vector(major_comm_size); for (int i = 0; i < major_comm_size; ++i) { - local_v_list_sizes[i] = h_aggregate_tmps[i * size_t{3}]; - local_v_list_range_firsts[i] = h_aggregate_tmps[i * size_t{3} + 1]; - local_v_list_range_lasts[i] = h_aggregate_tmps[i * size_t{3} + 2]; + max_tmp_buffer_sizes[i] = h_aggregate_tmps[i * size_t{4}]; + local_v_list_sizes[i] = static_cast(h_aggregate_tmps[i * size_t{4} + 1]); + local_v_list_range_firsts[i] = static_cast(h_aggregate_tmps[i * size_t{4} + 2]); + local_v_list_range_lasts[i] = static_cast(h_aggregate_tmps[i * size_t{4} + 3]); } } @@ -546,8 +560,8 @@ void fill_edge_minor_property(raft::handle_t const& handle, } tmp_buffer_size_per_loop /= major_comm_size; stream_pool_indices = init_stream_pool_indices( - static_cast(static_cast(handle.get_device_properties().totalGlobalMem) * - 0.05), + std::reduce(max_tmp_buffer_sizes.begin(), max_tmp_buffer_sizes.end()) / + static_cast(major_comm_size), tmp_buffer_size_per_loop, major_comm_size, 1,