diff --git a/cpp/src/sampling/detail/conversion_utilities.cu b/cpp/src/sampling/detail/conversion_utilities.cu index bc6cb128561..a067005514d 100644 --- a/cpp/src/sampling/detail/conversion_utilities.cu +++ b/cpp/src/sampling/detail/conversion_utilities.cu @@ -41,7 +41,7 @@ rmm::device_uvector flatten_label_map( label_t{0}, thrust::maximum()); - label_map.resize(max_label, handle.get_stream()); + label_map.resize(max_label + 1, handle.get_stream()); thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), int32_t{0}); thrust::scatter(handle.get_thrust_policy(), diff --git a/cpp/src/sampling/detail/conversion_utilities_impl.cuh b/cpp/src/sampling/detail/conversion_utilities_impl.cuh index 3ed11d05181..0c8d8ac95ea 100644 --- a/cpp/src/sampling/detail/conversion_utilities_impl.cuh +++ b/cpp/src/sampling/detail/conversion_utilities_impl.cuh @@ -42,10 +42,10 @@ rmm::device_uvector flatten_label_map( label_t{0}, thrust::maximum()); - label_map.resize(max_label, handle.get_stream()); + label_map.resize(max_label + 1, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), label_t{0}); - thrust::gather(handle.get_thrust_policy(), + thrust::fill(handle.get_thrust_policy(), label_map.begin(), label_map.end(), int32_t{0}); + thrust::scatter(handle.get_thrust_policy(), std::get<0>(label_to_output_comm_rank).begin(), std::get<0>(label_to_output_comm_rank).end(), std::get<1>(label_to_output_comm_rank).begin(),