diff --git a/cpp/src/traversal/od_shortest_distances_impl.cuh b/cpp/src/traversal/od_shortest_distances_impl.cuh index 09e41466393..6a0c5a4a675 100644 --- a/cpp/src/traversal/od_shortest_distances_impl.cuh +++ b/cpp/src/traversal/od_shortest_distances_impl.cuh @@ -210,12 +210,17 @@ size_t compute_kv_store_capacity(size_t new_min_size, int32_t constexpr multi_partition_copy_block_size = 512; // tuning parameter -template +template __global__ void multi_partition_copy( InputIterator input_first, InputIterator input_last, raft::device_span output_buffer_ptrs, PartitionOp partition_op, // returns max_num_partitions to discard + KeyOp key_op, raft::device_span partition_counters) { static_assert(max_num_partitions <= static_cast(std::numeric_limits::max())); @@ -283,7 +288,7 @@ __global__ void multi_partition_copy( if (partition != static_cast(max_num_partitions)) { auto offset = block_start_offsets[partition] + static_cast(tmp_intra_block_offsets[partition] + tmp_offsets[i]); - *(output_buffer_ptrs[partition] + offset) = thrust::get<0>(*(input_first + tmp_idx)); + *(output_buffer_ptrs[partition] + offset) = key_op(*(input_first + tmp_idx)); } } tmp_idx += gridDim.x * blockDim.x; @@ -794,6 +799,7 @@ rmm::device_uvector od_shortest_distances( split_thresholds.end(), thrust::get<1>(pair)))); }, + [] __device__(auto pair) { return thrust::get<0>(pair); }, raft::device_span(d_counters.data(), d_counters.size())); std::vector h_counters(d_counters.size()); @@ -912,13 +918,6 @@ rmm::device_uvector od_shortest_distances( thrust::fill( handle.get_thrust_policy(), d_counters.begin(), d_counters.end(), size_t{0}); if (tmp_buffer.size() > 0) { - auto distance_first = thrust::make_transform_iterator( - tmp_buffer.begin(), - [key_to_dist_map = detail::kv_cuco_store_find_device_view_t( - key_to_dist_map.view())] __device__(auto key) { - return key_to_dist_map.find(key); - }); - auto input_first = thrust::make_zip_iterator(tmp_buffer.begin(), distance_first); raft::grid_1d_thread_t update_grid(tmp_buffer.size(), multi_partition_copy_block_size, handle.get_device_properties().maxGridSize[0]); @@ -926,13 +925,15 @@ rmm::device_uvector od_shortest_distances( static_cast(1 /* near queue */ + num_far_buffers); multi_partition_copy <<>>( - input_first, - input_first + tmp_buffer.size(), + tmp_buffer.begin(), + tmp_buffer.end(), raft::device_span(d_buffer_ptrs.data(), d_buffer_ptrs.size()), - [split_thresholds = raft::device_span( + [key_to_dist_map = + detail::kv_cuco_store_find_device_view_t(key_to_dist_map.view()), + split_thresholds = raft::device_span( d_split_thresholds.data(), d_split_thresholds.size()), - invalid_threshold] __device__(auto pair) { - auto dist = thrust::get<1>(pair); + invalid_threshold] __device__(auto key) { + auto dist = key_to_dist_map.find(key); return static_cast( (dist < invalid_threshold) ? max_num_partitions /* discard */ @@ -942,6 +943,7 @@ rmm::device_uvector od_shortest_distances( split_thresholds.end(), dist))); }, + thrust::identity{}, raft::device_span(d_counters.data(), d_counters.size())); } std::vector h_counters(d_counters.size()); diff --git a/cpp/tests/traversal/od_shortest_distances_test.cpp b/cpp/tests/traversal/od_shortest_distances_test.cpp index e4fbbdf9275..cc283f24dfd 100644 --- a/cpp/tests/traversal/od_shortest_distances_test.cpp +++ b/cpp/tests/traversal/od_shortest_distances_test.cpp @@ -225,27 +225,27 @@ class Tests_ODShortestDistances using Tests_ODShortestDistances_File = Tests_ODShortestDistances; using Tests_ODShortestDistances_Rmat = Tests_ODShortestDistances; -TEST_P(Tests_ODShortestDistances_File, DISABLED_CheckInt32Int32Float) +TEST_P(Tests_ODShortestDistances_File, CheckInt32Int32Float) { auto param = GetParam(); run_current_test(std::get<0>(param), std::get<1>(param)); } -TEST_P(Tests_ODShortestDistances_Rmat, DISABLED_CheckInt32Int32Float) +TEST_P(Tests_ODShortestDistances_Rmat, CheckInt32Int32Float) { auto param = GetParam(); run_current_test( std::get<0>(param), override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_ODShortestDistances_Rmat, DISABLED_CheckInt32Int64Float) +TEST_P(Tests_ODShortestDistances_Rmat, CheckInt32Int64Float) { auto param = GetParam(); run_current_test( std::get<0>(param), override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_ODShortestDistances_Rmat, DISABLED_CheckInt64Int64Float) +TEST_P(Tests_ODShortestDistances_Rmat, CheckInt64Int64Float) { auto param = GetParam(); run_current_test(