Skip to content

Commit

Permalink
adjust hop offsets when there is a jump in major vertex IDs between hops
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Sep 11, 2023
1 parent 6eaf67e commit db35940
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions cpp/src/sampling/sampling_post_processing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1276,17 +1276,22 @@ renumber_and_compress_sampled_edgelist(
: thrust::nullopt,
edgelist_majors =
raft::device_span<vertex_t const>(edgelist_majors.data(), edgelist_majors.size()),
num_hops] __device__(size_t i) {
num_hops,
compress_per_hop] __device__(size_t i) {
size_t start_offset{0};
auto end_offset = edgelist_majors.size();
auto end_offset = edgelist_majors.size();
auto label_start_offset = start_offset;
auto label_end_offset = end_offset;

if (edgelist_label_offsets) {
auto l_idx = static_cast<label_index_t>(i / num_hops);
start_offset = (*edgelist_label_offsets)[l_idx];
end_offset = (*edgelist_label_offsets)[l_idx + 1];
auto l_idx = static_cast<label_index_t>(i / num_hops);
start_offset = (*edgelist_label_offsets)[l_idx];
end_offset = (*edgelist_label_offsets)[l_idx + 1];
label_start_offset = start_offset;
label_end_offset = end_offset;
}

if (edgelist_hops) {
if (num_hops > 1) {
auto h = static_cast<int32_t>(i % num_hops);
auto lower_it = thrust::lower_bound(thrust::seq,
(*edgelist_hops).begin() + start_offset,
Expand All @@ -1299,7 +1304,17 @@ renumber_and_compress_sampled_edgelist(
start_offset = static_cast<size_t>(thrust::distance((*edgelist_hops).begin(), lower_it));
end_offset = static_cast<size_t>(thrust::distance((*edgelist_hops).begin(), upper_it));
}
return (start_offset < end_offset) ? (edgelist_majors[end_offset - 1] + 1) : vertex_t{0};
if (compress_per_hop) {
return (start_offset < end_offset) ? (edgelist_majors[end_offset - 1] + 1) : vertex_t{0};
} else {
if (end_offset != label_end_offset) {
return edgelist_majors[end_offset];
} else if (label_start_offset < label_end_offset) {
return edgelist_majors[end_offset - 1] + 1;
} else {
return vertex_t{0};
}
}
});

std::optional<rmm::device_uvector<vertex_t>> minor_vertex_counts{std::nullopt};
Expand Down

0 comments on commit db35940

Please sign in to comment.