diff --git a/cpp/src/sampling/sampling_post_processing_impl.cuh b/cpp/src/sampling/sampling_post_processing_impl.cuh index ca37205f175..8f5e6e20da0 100644 --- a/cpp/src/sampling/sampling_post_processing_impl.cuh +++ b/cpp/src/sampling/sampling_post_processing_impl.cuh @@ -1276,17 +1276,22 @@ renumber_and_compress_sampled_edgelist( : thrust::nullopt, edgelist_majors = raft::device_span(edgelist_majors.data(), edgelist_majors.size()), - num_hops] __device__(size_t i) { + num_hops, + compress_per_hop] __device__(size_t i) { size_t start_offset{0}; - auto end_offset = edgelist_majors.size(); + auto end_offset = edgelist_majors.size(); + auto label_start_offset = start_offset; + auto label_end_offset = end_offset; if (edgelist_label_offsets) { - auto l_idx = static_cast(i / num_hops); - start_offset = (*edgelist_label_offsets)[l_idx]; - end_offset = (*edgelist_label_offsets)[l_idx + 1]; + auto l_idx = static_cast(i / num_hops); + start_offset = (*edgelist_label_offsets)[l_idx]; + end_offset = (*edgelist_label_offsets)[l_idx + 1]; + label_start_offset = start_offset; + label_end_offset = end_offset; } - if (edgelist_hops) { + if (num_hops > 1) { auto h = static_cast(i % num_hops); auto lower_it = thrust::lower_bound(thrust::seq, (*edgelist_hops).begin() + start_offset, @@ -1299,7 +1304,17 @@ renumber_and_compress_sampled_edgelist( start_offset = static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); } - return (start_offset < end_offset) ? (edgelist_majors[end_offset - 1] + 1) : vertex_t{0}; + if (compress_per_hop) { + return (start_offset < end_offset) ? (edgelist_majors[end_offset - 1] + 1) : vertex_t{0}; + } else { + if (end_offset != label_end_offset) { + return edgelist_majors[end_offset]; + } else if (label_start_offset < label_end_offset) { + return edgelist_majors[end_offset - 1] + 1; + } else { + return vertex_t{0}; + } + } }); std::optional> minor_vertex_counts{std::nullopt};