diff --git a/cpp/src/sampling/sampling_post_processing_impl.cuh b/cpp/src/sampling/sampling_post_processing_impl.cuh index ff8da72ff35..0c397d91b20 100644 --- a/cpp/src/sampling/sampling_post_processing_impl.cuh +++ b/cpp/src/sampling/sampling_post_processing_impl.cuh @@ -1619,10 +1619,13 @@ renumber_and_sort_sampled_edgelist( (*edgelist_label_hop_offsets).begin(), (*edgelist_label_hop_offsets).end(), size_t{0}); - thrust::for_each( + // FIXME: the device lambda should be placed in cuda::proclaim_return_type() + // once we update CCCL version to 2.x + thrust::transform( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(num_labels * num_hops), + (*edgelist_label_hop_offsets).begin(), [edgelist_label_offsets = edgelist_label_offsets ? thrust::make_optional(std::get<0>(*edgelist_label_offsets)) : thrust::nullopt, @@ -1743,10 +1746,13 @@ sort_sampled_edgelist( (*edgelist_label_hop_offsets).begin(), (*edgelist_label_hop_offsets).end(), size_t{0}); - thrust::for_each( + // FIXME: the device lambda should be placed in cuda::proclaim_return_type() + // once we update CCCL version to 2.x + thrust::transform( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(num_labels * num_hops), + (*edgelist_label_hop_offsets).begin(), [edgelist_label_offsets = edgelist_label_offsets ? thrust::make_optional(std::get<0>(*edgelist_label_offsets)) : thrust::nullopt, diff --git a/cpp/tests/sampling/sampling_post_processing_test.cu b/cpp/tests/sampling/sampling_post_processing_test.cu index 422fe953b20..e5267d75ac2 100644 --- a/cpp/tests/sampling/sampling_post_processing_test.cu +++ b/cpp/tests/sampling/sampling_post_processing_test.cu @@ -635,6 +635,12 @@ class Tests_SamplingPostProcessing (*renumbered_and_sorted_edgelist_label_hop_offsets).end())) << "Renumbered and sorted edge list (label,hop) offset array values should be " "non-decreasing."; + + ASSERT_TRUE( + (*renumbered_and_sorted_edgelist_label_hop_offsets).back_element(handle.get_stream()) == + renumbered_and_sorted_edgelist_srcs.size()) + << "Renumbered and sorted edge list (label,hop) offset array's last element should " + "coincide with the number of edges."; } if (renumbered_and_sorted_renumber_map_label_offsets) { @@ -1189,6 +1195,11 @@ class Tests_SamplingPostProcessing (*sorted_edgelist_label_hop_offsets).end())) << "Sorted edge list (label,hop) offset array values should be " "non-decreasing."; + + ASSERT_TRUE((*sorted_edgelist_label_hop_offsets).back_element(handle.get_stream()) == + sorted_edgelist_srcs.size()) + << "Sorted edge list (label,hop) offset array's last element should coincide with the " + "number of edges."; } for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) {