diff --git a/cpp/src/sampling/neighbor_sampling_impl.hpp b/cpp/src/sampling/neighbor_sampling_impl.hpp index feaa8824d73..bde467b9407 100644 --- a/cpp/src/sampling/neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/neighbor_sampling_impl.hpp @@ -200,8 +200,14 @@ neighbor_sample_impl( std::vector level_sizes{}; int32_t hop{0}; int32_t edge_type_id_max{1}; // A value of 1 translate to homogeneous neighbor sample + int32_t num_edge_type_per_hop{0}; auto cur_graph_view = modified_graph_view ? *modified_graph_view : graph_view; + + if (heterogeneous_fan_out) { + num_edge_type_per_hop = std::get<0>(*heterogeneous_fan_out).back() - 1; + } + while(true) { int32_t k_level{0}; if (fan_out) { @@ -210,8 +216,11 @@ neighbor_sample_impl( break; } } else if (heterogeneous_fan_out) { - // initially edge type - edge_type_id_max = std::get<0>(*heterogeneous_fan_out).back() - 1; + if (num_edge_type_per_hop == 0) { + break; + } + edge_type_id_max = std::get<0>(*heterogeneous_fan_out).back() - 1; + } for (int i = 0; i < edge_type_id_max; i++) { @@ -223,7 +232,10 @@ neighbor_sample_impl( auto k_level_size = (std::get<1>(*heterogeneous_fan_out)[i + 1] - std::get<1>(*heterogeneous_fan_out)[i]); if (k_level_size > hop) { k_level = i + hop; - } // otherwise, k_level = 0 + } else { // otherwise, k_level = 0 + --num_edge_type_per_hop ; + + } } rmm::device_uvector srcs(0, handle.get_stream()); rmm::device_uvector dsts(0, handle.get_stream());