From 4857b36635cfcb9fce9465c0cd301c1a9637e00f Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Mon, 23 Sep 2024 11:48:41 -0700 Subject: [PATCH] call scatter instead of gather and fix type bug --- cpp/src/sampling/detail/conversion_utilities.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/sampling/detail/conversion_utilities.cu b/cpp/src/sampling/detail/conversion_utilities.cu index 5e6f4b00fd2..c6e396b7024 100644 --- a/cpp/src/sampling/detail/conversion_utilities.cu +++ b/cpp/src/sampling/detail/conversion_utilities.cu @@ -35,7 +35,7 @@ rmm::device_uvector flatten_label_map( { rmm::device_uvector label_map(0, handle.get_stream()); - label_t max_label = thrust::reduce(handle.get_thrust_policy(), + label_t max_label = thrust::scatter(handle.get_thrust_policy(), std::get<0>(label_to_output_comm_rank).begin(), std::get<0>(label_to_output_comm_rank).end(), label_t{0}, @@ -43,7 +43,7 @@ rmm::device_uvector flatten_label_map( label_map.resize(max_label, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), label_t{0}); + thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), int32_t); thrust::gather(handle.get_thrust_policy(), std::get<0>(label_to_output_comm_rank).begin(), std::get<0>(label_to_output_comm_rank).end(),