diff --git a/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh b/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh index 1349454f5b6..cfba8a35cc1 100644 --- a/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh +++ b/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh @@ -513,24 +513,20 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, static_assert(is_arithmetic_or_thrust_tuple_of_arithmetic::value); - [[maybe_unused]] std::conditional_t, - edge_dst_property_t> - minor_tmp_buffer(handle); // relevant only when (GraphViewType::is_multi_gpu && !update_major + using minor_tmp_buffer_type = std::conditional_t, + edge_dst_property_t>; + std::unique_ptr minor_tmp_buffer{}; if constexpr (GraphViewType::is_multi_gpu && !update_major) { - if constexpr (GraphViewType::is_storage_transposed) { - minor_tmp_buffer = edge_src_property_t(handle, graph_view); - } else { - minor_tmp_buffer = edge_dst_property_t(handle, graph_view); - } + minor_tmp_buffer = std::make_unique(handle, graph_view); } using edge_partition_minor_output_device_view_t = - std::conditional_t>; + decltype(minor_tmp_buffer->mutable_view().value_first())>, + void /* dummy */>; if constexpr (update_major) { size_t partition_idx = 0; @@ -549,7 +545,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, } else { if constexpr (GraphViewType::is_multi_gpu) { auto minor_init = init; - auto view = minor_tmp_buffer.view(); + auto view = minor_tmp_buffer->view(); if (view.keys()) { // defer applying the initial value to the end as minor_tmp_buffer may not // store values for the entire minor range minor_init = ReduceOp::identity_element; @@ -558,7 +554,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, auto const major_comm_rank = major_comm.get_rank(); minor_init = (major_comm_rank == 0) ? init : ReduceOp::identity_element; } - fill_edge_minor_property(handle, graph_view, minor_init, minor_tmp_buffer.mutable_view()); + fill_edge_minor_property(handle, graph_view, minor_init, minor_tmp_buffer->mutable_view()); } else { thrust::fill(handle.get_thrust_policy(), vertex_value_output_first, @@ -699,7 +695,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, if constexpr (update_major) { output_buffer = major_buffer_first; } else { - output_buffer = edge_partition_minor_output_device_view_t(minor_tmp_buffer.mutable_view()); + output_buffer = edge_partition_minor_output_device_view_t(minor_tmp_buffer->mutable_view()); } } else { output_buffer = vertex_value_output_first; @@ -913,7 +909,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, auto const minor_comm_rank = minor_comm.get_rank(); auto const minor_comm_size = minor_comm.get_size(); - auto view = minor_tmp_buffer.view(); + auto view = minor_tmp_buffer->view(); if (view.keys()) { // applying the initial value is deferred to here vertex_t max_vertex_partition_size{0}; for (int i = 0; i < major_comm_size; ++i) {