From fe35c805c276e4ec8ef7182f4fa7fd9dc465f718 Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Mon, 9 Sep 2024 16:42:33 -0700 Subject: [PATCH] add exit condition --- cpp/src/sampling/neighbor_sampling_impl.hpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/cpp/src/sampling/neighbor_sampling_impl.hpp b/cpp/src/sampling/neighbor_sampling_impl.hpp index feaa8824d73..bde467b9407 100644 --- a/cpp/src/sampling/neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/neighbor_sampling_impl.hpp @@ -200,8 +200,14 @@ neighbor_sample_impl( std::vector level_sizes{}; int32_t hop{0}; int32_t edge_type_id_max{1}; // A value of 1 translate to homogeneous neighbor sample + int32_t num_edge_type_per_hop{0}; auto cur_graph_view = modified_graph_view ? *modified_graph_view : graph_view; + + if (heterogeneous_fan_out) { + num_edge_type_per_hop = std::get<0>(*heterogeneous_fan_out).back() - 1; + } + while(true) { int32_t k_level{0}; if (fan_out) { @@ -210,8 +216,11 @@ neighbor_sample_impl( break; } } else if (heterogeneous_fan_out) { - // initially edge type - edge_type_id_max = std::get<0>(*heterogeneous_fan_out).back() - 1; + if (num_edge_type_per_hop == 0) { + break; + } + edge_type_id_max = std::get<0>(*heterogeneous_fan_out).back() - 1; + } for (int i = 0; i < edge_type_id_max; i++) { @@ -223,7 +232,10 @@ neighbor_sample_impl( auto k_level_size = (std::get<1>(*heterogeneous_fan_out)[i + 1] - std::get<1>(*heterogeneous_fan_out)[i]); if (k_level_size > hop) { k_level = i + hop; - } // otherwise, k_level = 0 + } else { // otherwise, k_level = 0 + --num_edge_type_per_hop ; + + } } rmm::device_uvector srcs(0, handle.get_stream()); rmm::device_uvector dsts(0, handle.get_stream());