diff --git a/cpp/src/sampling/detail/conversion_utilities.cu b/cpp/src/sampling/detail/conversion_utilities.cu index 5e6f4b00fd2..c6e396b7024 100644 --- a/cpp/src/sampling/detail/conversion_utilities.cu +++ b/cpp/src/sampling/detail/conversion_utilities.cu @@ -35,7 +35,7 @@ rmm::device_uvector flatten_label_map( { rmm::device_uvector label_map(0, handle.get_stream()); - label_t max_label = thrust::reduce(handle.get_thrust_policy(), + label_t max_label = thrust::scatter(handle.get_thrust_policy(), std::get<0>(label_to_output_comm_rank).begin(), std::get<0>(label_to_output_comm_rank).end(), label_t{0}, @@ -43,7 +43,7 @@ rmm::device_uvector flatten_label_map( label_map.resize(max_label, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), label_t{0}); + thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), int32_t); thrust::gather(handle.get_thrust_policy(), std::get<0>(label_to_output_comm_rank).begin(), std::get<0>(label_to_output_comm_rank).end(),