diff --git a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh index 4c5c43c7d1e..62cdbec3c13 100644 --- a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh +++ b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -53,10 +54,27 @@ namespace cugraph { namespace detail { +int32_t constexpr per_v_random_select_transform_outgoing_e_block_size = 256; + +size_t constexpr compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold = + packed_bools_per_word() * + size_t{4} /* tuning parameter */; // minimum local degree to compute inclusive sums of valid + // local neighbors per word to accelerate finding n'th local + // neighbor vertex +size_t constexpr compute_valid_local_nbr_count_inclusive_sum_mid_local_degree_threshold = + packed_bools_per_word() * static_cast(raft::warp_size()) * + size_t{ + 4} /* tuning parameter */; // minimum local degree to use a CUDA warp to compute inclusive sums +size_t constexpr compute_valid_local_nbr_count_inclusive_sum_high_local_degree_threshold = + packed_bools_per_word() * per_v_random_select_transform_outgoing_e_block_size * + size_t{4} /* tuning parameter */; // minimum local degree to use a CUDA block to compute + // inclusive sums + template struct compute_local_degree_displacements_and_global_degree_t { raft::device_span gathered_local_degrees{}; - raft::device_span segmented_local_degree_displacements{}; + raft::device_span + partitioned_local_degree_displacements{}; // one partition per gpu in the same minor_comm raft::device_span global_degrees{}; int minor_comm_size{}; @@ -75,7 +93,7 @@ struct compute_local_degree_displacements_and_global_degree_t { thrust::seq, displacements, displacements + loop_count, - segmented_local_degree_displacements.begin() + i * minor_comm_size + round * buffer_size); + partitioned_local_degree_displacements.begin() + i * minor_comm_size + round * buffer_size); } global_degrees[i] = sum; } @@ -86,7 +104,8 @@ struct compute_local_degree_displacements_and_global_degree_t { // invalid template struct convert_pair_to_quadruplet_t { - raft::device_span segmented_local_degree_displacements{}; + raft::device_span + partitioned_local_degree_displacements{}; // one partition per gpu in the same minor_comm raft::device_span tx_counts{}; size_t stride{}; int minor_comm_size{}; @@ -95,14 +114,14 @@ struct convert_pair_to_quadruplet_t { __device__ thrust::tuple operator()( thrust::tuple index_pair) const { - auto nbr_idx = thrust::get<0>(index_pair); - auto key_idx = thrust::get<1>(index_pair); - auto local_nbr_idx = nbr_idx; + auto nbr_idx = thrust::get<0>(index_pair); + auto key_idx = thrust::get<1>(index_pair); + edge_t local_nbr_idx{0}; int minor_comm_rank{-1}; size_t intra_partition_offset{}; if (nbr_idx != invalid_idx) { auto displacement_first = - segmented_local_degree_displacements.begin() + key_idx * minor_comm_size; + partitioned_local_degree_displacements.begin() + key_idx * minor_comm_size; minor_comm_rank = static_cast(thrust::distance( displacement_first, @@ -156,6 +175,7 @@ template struct transform_local_nbr_indices_t { @@ -171,6 +191,9 @@ struct transform_local_nbr_indices_t { EdgePartitionSrcValueInputWrapper edge_partition_src_value_input; EdgePartitionDstValueInputWrapper edge_partition_dst_value_input; EdgePartitionEdgeValueInputWrapper edge_partition_e_value_input; + EdgePartitionEdgeMaskWrapper edge_partition_e_mask; + thrust::optional, raft::device_span>> + key_valid_local_nbr_count_inclusive_sums{}; EdgeOp e_op{}; edge_t invalid_idx{}; thrust::optional invalid_value{thrust::nullopt}; @@ -207,7 +230,33 @@ struct transform_local_nbr_indices_t { } auto local_nbr_idx = *(local_nbr_idx_first + i); if (local_nbr_idx != invalid_idx) { - auto minor = indices[local_nbr_idx]; + vertex_t minor{}; + if (edge_partition_e_mask) { + if (local_degree < compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold) { + local_nbr_idx = find_nth_set_bits( + (*edge_partition_e_mask).value_first(), edge_offset, local_degree, local_nbr_idx); + } else { + auto inclusive_sum_first = + thrust::get<1>(*key_valid_local_nbr_count_inclusive_sums).begin(); + auto start_offset = thrust::get<0>(*key_valid_local_nbr_count_inclusive_sums)[key_idx]; + auto end_offset = thrust::get<0>(*key_valid_local_nbr_count_inclusive_sums)[key_idx + 1]; + auto word_idx = static_cast( + thrust::distance(inclusive_sum_first + start_offset, + thrust::upper_bound(thrust::seq, + inclusive_sum_first + start_offset, + inclusive_sum_first + end_offset, + local_nbr_idx))); + local_nbr_idx = + word_idx * packed_bools_per_word() + + find_nth_set_bits( + (*edge_partition_e_mask).value_first(), + edge_offset + word_idx * packed_bools_per_word(), + local_degree - word_idx * packed_bools_per_word(), + local_nbr_idx - ((word_idx > 0) ? *(inclusive_sum_first + start_offset + word_idx - 1) + : edge_t{0})); + } + } + minor = indices[local_nbr_idx]; auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); std::conditional_t @@ -279,6 +328,247 @@ struct return_value_compute_offset_t { } }; +template +__global__ void compute_valid_local_nbr_inclusive_sums_mid_local_degree( + edge_partition_device_view_t edge_partition, + edge_partition_edge_property_device_view_t edge_partition_e_mask, + raft::device_span edge_partition_frontier_majors, + raft::device_span inclusive_sum_offsets, + raft::device_span frontier_indices, + raft::device_span inclusive_sums) +{ + static_assert(per_v_random_select_transform_outgoing_e_block_size % raft::warp_size() == 0); + + auto const tid = threadIdx.x + blockIdx.x * blockDim.x; + auto const lane_id = tid % raft::warp_size(); + + auto idx = static_cast(tid / raft::warp_size()); + + using WarpScan = cub::WarpScan; + __shared__ typename WarpScan::TempStorage temp_storage; + + while (idx < frontier_indices.size()) { + auto frontier_idx = frontier_indices[idx]; + auto major = edge_partition_frontier_majors[frontier_idx]; + vertex_t major_idx{}; + if constexpr (multi_gpu) { + major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); + } else { + major_idx = edge_partition.major_offset_from_major_nocheck(major); + } + auto edge_offset = edge_partition.local_offset(major_idx); + auto local_degree = edge_partition.local_degree(major_idx); + + auto num_inclusive_sums = + inclusive_sum_offsets[frontier_idx + 1] - inclusive_sum_offsets[frontier_idx]; + auto rounded_up_num_inclusive_sums = + ((num_inclusive_sums + raft::warp_size() - 1) / raft::warp_size()) * raft::warp_size(); + edge_t sum{0}; + for (size_t j = lane_id; j <= rounded_up_num_inclusive_sums; j += raft::warp_size()) { + auto inc = + (j < num_inclusive_sums) + ? static_cast(count_set_bits( + edge_partition_e_mask.value_first(), + edge_offset + packed_bools_per_word() * j, + cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j))) + : edge_t{0}; + WarpScan(temp_storage).InclusiveSum(inc, inc); + inclusive_sums[j] = sum + inc; + sum += __shfl_sync(raft::warp_full_mask(), inc, raft::warp_size() - 1); + } + + idx += gridDim.x * (blockDim.x / raft::warp_size()); + } +} + +template +__global__ void compute_valid_local_nbr_inclusive_sums_high_local_degree( + edge_partition_device_view_t edge_partition, + edge_partition_edge_property_device_view_t edge_partition_e_mask, + raft::device_span edge_partition_frontier_majors, + raft::device_span inclusive_sum_offsets, + raft::device_span frontier_indices, + raft::device_span inclusive_sums) +{ + static_assert(per_v_random_select_transform_outgoing_e_block_size % raft::warp_size() == 0); + + auto idx = static_cast(blockIdx.x); + + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + __shared__ edge_t sum; + + while (idx < frontier_indices.size()) { + auto frontier_idx = frontier_indices[idx]; + auto major = edge_partition_frontier_majors[frontier_idx]; + vertex_t major_idx{}; + if constexpr (multi_gpu) { + major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); + } else { + major_idx = edge_partition.major_offset_from_major_nocheck(major); + } + auto edge_offset = edge_partition.local_offset(major_idx); + auto local_degree = edge_partition.local_degree(major_idx); + + auto num_inclusive_sums = + inclusive_sum_offsets[frontier_idx + 1] - inclusive_sum_offsets[frontier_idx]; + auto rounded_up_num_inclusive_sums = + ((num_inclusive_sums + per_v_random_select_transform_outgoing_e_block_size - 1) / + per_v_random_select_transform_outgoing_e_block_size) * + per_v_random_select_transform_outgoing_e_block_size; + if (threadIdx.x == per_v_random_select_transform_outgoing_e_block_size - 1) { sum = 0; } + for (size_t j = threadIdx.x; j <= rounded_up_num_inclusive_sums; j += blockDim.x) { + auto inc = + (j < num_inclusive_sums) + ? static_cast(count_set_bits( + edge_partition_e_mask.value_first(), + edge_offset + packed_bools_per_word() * j, + cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j))) + : edge_t{0}; + BlockScan(temp_storage).InclusiveSum(inc, inc); + inclusive_sums[j] = sum + inc; + __syncthreads(); + if (threadIdx.x == per_v_random_select_transform_outgoing_e_block_size - 1) { sum += inc; } + } + + idx += gridDim.x; + } +} + +template +std::tuple, rmm::device_uvector> +compute_valid_local_nbr_count_inclusive_sums( + raft::handle_t const& handle, + edge_partition_device_view_t const& edge_partition, + edge_partition_edge_property_device_view_t const& + edge_partition_e_mask, + raft::device_span edge_partition_frontier_majors) +{ + auto edge_partition_local_degrees = + edge_partition.compute_local_degrees(edge_partition_frontier_majors.begin(), + edge_partition_frontier_majors.end(), + handle.get_stream()); + auto offsets = + rmm::device_uvector(edge_partition_frontier_majors.size() + 1, handle.get_stream()); + offsets.set_element_to_zero_async(0, handle.get_stream()); + auto size_first = thrust::make_transform_iterator( + edge_partition_local_degrees.begin(), + cuda::proclaim_return_type([] __device__(edge_t local_degree) { + return static_cast((local_degree + packed_bools_per_word() - 1) / + packed_bools_per_word()); + })); + thrust::inclusive_scan(handle.get_thrust_policy(), + size_first, + size_first + edge_partition_local_degrees.size(), + offsets.begin() + 1); + + rmm::device_uvector frontier_indices(edge_partition_frontier_majors.size(), + handle.get_stream()); + frontier_indices.resize( + thrust::distance( + frontier_indices.begin(), + thrust::copy_if( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(edge_partition_frontier_majors.size()), + frontier_indices.begin(), + [threshold = compute_valid_local_nbr_count_inclusive_sum_local_degree_threshold, + local_degrees = raft::device_span( + edge_partition_local_degrees.data(), + edge_partition_local_degrees.size())] __device__(size_t i) { + return local_degrees[i] >= threshold; + })), + handle.get_stream()); + + auto low_last = thrust::partition( + handle.get_thrust_policy(), + frontier_indices.begin(), + frontier_indices.end(), + [threshold = compute_valid_local_nbr_count_inclusive_sum_mid_local_degree_threshold, + local_degrees = + raft::device_span(edge_partition_local_degrees.data(), + edge_partition_local_degrees.size())] __device__(size_t i) { + return local_degrees[i] < threshold; + }); + auto mid_last = thrust::partition( + handle.get_thrust_policy(), + low_last, + frontier_indices.end(), + [threshold = compute_valid_local_nbr_count_inclusive_sum_high_local_degree_threshold, + local_degrees = + raft::device_span(edge_partition_local_degrees.data(), + edge_partition_local_degrees.size())] __device__(size_t i) { + return local_degrees[i] < threshold; + }); + + rmm::device_uvector inclusive_sums(offsets.back_element(handle.get_stream()), + handle.get_stream()); + + thrust::for_each( + handle.get_thrust_policy(), + frontier_indices.begin(), + low_last, + [edge_partition, + edge_partition_e_mask, + edge_partition_frontier_majors, + offsets = raft::device_span(offsets.data(), offsets.size()), + inclusive_sums = raft::device_span(inclusive_sums.data(), + inclusive_sums.size())] __device__(size_t i) { + auto major = edge_partition_frontier_majors[i]; + vertex_t major_idx{}; + if constexpr (multi_gpu) { + major_idx = *(edge_partition.major_idx_from_major_nocheck(major)); + } else { + major_idx = edge_partition.major_offset_from_major_nocheck(major); + } + auto edge_offset = edge_partition.local_offset(major_idx); + auto local_degree = edge_partition.local_degree(major_idx); + edge_t sum{0}; + for (size_t j = offsets[i]; j <= offsets[i + 1]; ++j) { + sum += count_set_bits( + edge_partition_e_mask.value_first(), + edge_offset + packed_bools_per_word() * j, + cuda::std::min(packed_bools_per_word(), local_degree - packed_bools_per_word() * j)); + inclusive_sums[j] = sum; + } + }); + + if (thrust::distance(low_last, mid_last) > 0) { + raft::grid_1d_warp_t update_grid(thrust::distance(low_last, mid_last), + per_v_random_select_transform_outgoing_e_block_size, + handle.get_device_properties().maxGridSize[0]); + compute_valid_local_nbr_inclusive_sums_mid_local_degree<<>>( + edge_partition, + edge_partition_e_mask, + edge_partition_frontier_majors, + raft::device_span(offsets.data(), offsets.size()), + raft::device_span(low_last, thrust::distance(low_last, mid_last)), + raft::device_span(inclusive_sums.data(), inclusive_sums.size())); + } + + if (thrust::distance(mid_last, frontier_indices.end()) > 0) { + raft::grid_1d_block_t update_grid(thrust::distance(mid_last, frontier_indices.end()), + per_v_random_select_transform_outgoing_e_block_size, + handle.get_device_properties().maxGridSize[0]); + compute_valid_local_nbr_inclusive_sums_high_local_degree<<>>( + edge_partition, + edge_partition_e_mask, + edge_partition_frontier_majors, + raft::device_span(offsets.data(), offsets.size()), + raft::device_span(mid_last, thrust::distance(mid_last, frontier_indices.end())), + raft::device_span(inclusive_sums.data(), inclusive_sums.size())); + } + + return std::make_tuple(std::move(offsets), std::move(inclusive_sums)); +} + template rmm::device_uvector get_sampling_index_without_replacement( raft::handle_t const& handle, @@ -369,7 +659,7 @@ rmm::device_uvector get_sampling_index_without_replacement( static_cast(std::numeric_limits::max())); rmm::device_uvector tmp_sample_indices( seeds_to_sort_per_iteration * high_partition_over_sampling_K, - handle.get_stream()); // sample indices within a segment (one segment per seed) + handle.get_stream()); // sample indices within a segment (one partition per seed) rmm::device_uvector segment_sorted_tmp_sample_nbr_indices( seeds_to_sort_per_iteration * high_partition_over_sampling_K, handle.get_stream()); @@ -790,16 +1080,36 @@ per_v_random_select_transform_e(raft::handle_t const& handle, // 2. compute degrees + auto edge_mask_view = graph_view.edge_mask_view(); + auto aggregate_local_frontier_local_degrees = (minor_comm_size > 1) ? std::make_optional>( local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()) : std::nullopt; rmm::device_uvector frontier_degrees(frontier.size(), handle.get_stream()); + + std::optional, rmm::device_uvector>>> + local_frontier_valid_local_nbr_count_inclusive_sums{}; // to avoid searching the entire + // neighbor list K times for high degree + // vertices with edge masking + if (edge_mask_view) { + local_frontier_valid_local_nbr_count_inclusive_sums = + std::vector, rmm::device_uvector>>{}; + (*local_frontier_valid_local_nbr_count_inclusive_sums) + .reserve(graph_view.number_of_local_edge_partitions()); + } + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { auto edge_partition = edge_partition_device_view_t( graph_view.local_edge_partition_view(i)); + auto edge_partition_e_mask = + edge_mask_view + ? thrust::make_optional< + detail::edge_partition_edge_property_device_view_t>( + *edge_mask_view, i) + : thrust::nullopt; vertex_t const* edge_partition_frontier_major_first{nullptr}; @@ -813,10 +1123,16 @@ per_v_random_select_transform_e(raft::handle_t const& handle, edge_partition_frontier_major_first = thrust::get<0>(edge_partition_frontier_key_first); } - auto edge_partition_frontier_local_degrees = edge_partition.compute_local_degrees( - edge_partition_frontier_major_first, - edge_partition_frontier_major_first + local_frontier_sizes[i], - handle.get_stream()); + auto edge_partition_frontier_local_degrees = + edge_partition_e_mask ? edge_partition.compute_local_degrees_with_mask( + (*edge_partition_e_mask).value_first(), + edge_partition_frontier_major_first, + edge_partition_frontier_major_first + local_frontier_sizes[i], + handle.get_stream()) + : edge_partition.compute_local_degrees( + edge_partition_frontier_major_first, + edge_partition_frontier_major_first + local_frontier_sizes[i], + handle.get_stream()); if (minor_comm_size > 1) { // FIXME: this copy is unnecessary if edge_partition.compute_local_degrees() takes a pointer @@ -829,12 +1145,22 @@ per_v_random_select_transform_e(raft::handle_t const& handle, } else { frontier_degrees = std::move(edge_partition_frontier_local_degrees); } + + if (edge_partition_e_mask) { + (*local_frontier_valid_local_nbr_count_inclusive_sums) + .push_back(compute_valid_local_nbr_count_inclusive_sums( + handle, + edge_partition, + *edge_partition_e_mask, + raft::device_span(edge_partition_frontier_major_first, + local_frontier_sizes[i]))); + } } - auto frontier_segmented_local_degree_displacements = + auto frontier_partitioned_local_degree_displacements = (minor_comm_size > 1) ? std::make_optional>(size_t{0}, handle.get_stream()) - : std::nullopt; + : std::nullopt; // one partition per gpu in the same minor_comm if (minor_comm_size > 1) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); @@ -845,7 +1171,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, local_frontier_sizes, handle.get_stream()); aggregate_local_frontier_local_degrees = std::nullopt; - frontier_segmented_local_degree_displacements = + frontier_partitioned_local_degree_displacements = rmm::device_uvector(frontier_degrees.size() * minor_comm_size, handle.get_stream()); thrust::for_each( handle.get_thrust_policy(), @@ -854,8 +1180,8 @@ per_v_random_select_transform_e(raft::handle_t const& handle, compute_local_degree_displacements_and_global_degree_t{ raft::device_span(frontier_gathered_local_degrees.data(), frontier_gathered_local_degrees.size()), - raft::device_span((*frontier_segmented_local_degree_displacements).data(), - (*frontier_segmented_local_degree_displacements).size()), + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), raft::device_span(frontier_degrees.data(), frontier_degrees.size()), minor_comm_size}); } @@ -914,8 +1240,8 @@ per_v_random_select_transform_e(raft::handle_t const& handle, sample_local_nbr_indices.begin(), (*sample_key_indices).begin())), convert_pair_to_quadruplet_t{ - raft::device_span((*frontier_segmented_local_degree_displacements).data(), - (*frontier_segmented_local_degree_displacements).size()), + raft::device_span((*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()), raft::device_span(d_tx_counts.data(), d_tx_counts.size()), frontier.size(), minor_comm_size, @@ -980,6 +1306,12 @@ per_v_random_select_transform_e(raft::handle_t const& handle, auto edge_partition = edge_partition_device_view_t( graph_view.local_edge_partition_view(i)); + auto edge_partition_e_mask = + edge_mask_view + ? thrust::make_optional< + detail::edge_partition_edge_property_device_view_t>( + *edge_mask_view, i) + : thrust::nullopt; auto edge_partition_frontier_key_first = ((minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier_keys) @@ -1018,6 +1350,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, edge_partition_src_input_device_view_t, edge_partition_dst_input_device_view_t, edge_partition_e_input_device_view_t, + decltype(edge_partition_e_mask), EdgeOp, T>{ edge_partition, @@ -1028,6 +1361,16 @@ per_v_random_select_transform_e(raft::handle_t const& handle, edge_partition_src_value_input, edge_partition_dst_value_input, edge_partition_e_value_input, + edge_partition_e_mask, + local_frontier_valid_local_nbr_count_inclusive_sums + ? thrust::make_optional(thrust::make_tuple( + raft::device_span( + std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), + std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()), + raft::device_span( + std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), + std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()))) + : thrust::nullopt, e_op, cugraph::ops::graph::INVALID_ID, to_thrust_optional(invalid_value), @@ -1044,19 +1387,31 @@ per_v_random_select_transform_e(raft::handle_t const& handle, edge_partition_src_input_device_view_t, edge_partition_dst_input_device_view_t, edge_partition_e_input_device_view_t, + decltype(edge_partition_e_mask), EdgeOp, - T>{edge_partition, - thrust::nullopt, - edge_partition_frontier_key_first, - edge_partition_sample_local_nbr_index_first, - edge_partition_sample_e_op_result_first, - edge_partition_src_value_input, - edge_partition_dst_value_input, - edge_partition_e_value_input, - e_op, - cugraph::ops::graph::INVALID_ID, - to_thrust_optional(invalid_value), - K}); + T>{ + edge_partition, + thrust::nullopt, + edge_partition_frontier_key_first, + edge_partition_sample_local_nbr_index_first, + edge_partition_sample_e_op_result_first, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + edge_partition_e_mask, + local_frontier_valid_local_nbr_count_inclusive_sums + ? thrust::make_optional(thrust::make_tuple( + raft::device_span( + std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), + std::get<0>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()), + raft::device_span( + std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).data(), + std::get<1>((*local_frontier_valid_local_nbr_count_inclusive_sums)[i]).size()))) + : thrust::nullopt, + e_op, + cugraph::ops::graph::INVALID_ID, + to_thrust_optional(invalid_value), + K}); } } aggregate_local_frontier_keys = std::nullopt;