diff --git a/cpp/src/sampling/neighbor_sampling_impl.hpp b/cpp/src/sampling/neighbor_sampling_impl.hpp index 11a792fe03..d666d2dd28 100644 --- a/cpp/src/sampling/neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/neighbor_sampling_impl.hpp @@ -490,13 +490,10 @@ neighbor_sample_impl(raft::handle_t const& handle, // If there are missing labels, still inlude it in the result_labels result_labels = std::move(*cp_starting_vertex_labels); auto unique_labels_end = - thrust::unique(handle.get_thrust_policy(), - result_labels->begin(), - result_labels->end()); - - auto num_unique_labels = thrust::distance( - result_labels->begin(), unique_labels_end); - + thrust::unique(handle.get_thrust_policy(), result_labels->begin(), result_labels->end()); + + auto num_unique_labels = thrust::distance(result_labels->begin(), unique_labels_end); + result_labels->resize(num_unique_labels, handle.get_stream()); result_offsets->resize(num_unique_labels + 1, handle.get_stream()); @@ -519,7 +516,6 @@ neighbor_sample_impl(raft::handle_t const& handle, return sampled_label_size; }); - // Run inclusive scan thrust::inclusive_scan(handle.get_thrust_policy(),