From 61a0926fd61693165228a261ddf096bb507e9721 Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Wed, 20 Nov 2024 13:02:10 -0800 Subject: [PATCH] fix bug in heterogeneous renumbering --- cpp/src/c_api/neighbor_sampling.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/cpp/src/c_api/neighbor_sampling.cpp b/cpp/src/c_api/neighbor_sampling.cpp index be3a44d813..e89630dfb8 100644 --- a/cpp/src/c_api/neighbor_sampling.cpp +++ b/cpp/src/c_api/neighbor_sampling.cpp @@ -1220,13 +1220,18 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { } else { // heterogeneous renumbering + // FIXME: If no 'vertex_type_offsets' is provided, all vertices are assumed to have + // a vertex type of value 1. Update the API once 'vertex_type_offsets' is supported rmm::device_uvector vertex_type_offsets( - graph_view.local_vertex_partition_range_size(), handle_.get_stream()); + 2, handle_.get_stream()); + + cugraph::detail::stride_fill( + handle_.get_stream(), + vertex_type_offsets.begin(), + vertex_type_offsets.size(), + vertex_t{0}, + vertex_t{graph_view.local_vertex_partition_range_size()} - cugraph::detail::sequence_fill(handle_.get_stream(), - vertex_type_offsets.begin(), - vertex_type_offsets.size(), - vertex_t{0} // FIXME: Update array ); rmm::device_uvector output_majors(0, handle_.get_stream()); @@ -1240,7 +1245,7 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { edge_id, label_type_hop_offsets, // Contains information about the type and hop offsets output_renumber_map, - (*renumber_map_offsets), + renumber_map_offsets, renumbered_and_sorted_edge_id_renumber_map, renumbered_and_sorted_edge_id_renumber_map_label_type_offsets) = cugraph::heterogeneous_renumber_and_sort_sampled_edgelist( @@ -1267,7 +1272,7 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor { edge_label ? (*offsets).size() - 1 : size_t{1}, hop ? fan_out_->size_ : size_t{1}, - size_t{1}, + vertex_type_offsets.size() - 1, // num_vertex_type is by default 1 if not provided num_edge_types_, src_is_major, do_expensive_check_);