Skip to content

Commit

Permalink
reduce code repetition
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Jan 24, 2024
1 parent 65104df commit 55927e9
Showing 1 changed file with 41 additions and 35 deletions.
76 changes: 41 additions & 35 deletions cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,35 @@ struct per_v_transform_reduce_call_e_op_t {
}
};

template <typename vertex_t,
typename edge_t,
bool multi_gpu,
typename result_t,
typename TransformOp,
typename ReduceOp,
typename ResultValueOutputIteratorOrWrapper>
struct transform_and_atomic_reduce_t {
edge_partition_device_view_t<vertex_t, edge_t, multi_gpu> edge_partition{};
result_t identity_element{};
vertex_t const* indices{nullptr};
TransformOp transform_op{};
ResultValueOutputIteratorOrWrapper result_value_output{};

__device__ void operator()(edge_t i) const
{
auto e_op_result = transform_op(i);
if (e_op_result != identity_element) {
auto minor = indices[i];
auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor);
if constexpr (multi_gpu) {
reduce_op::atomic_reduce<ReduceOp>(result_value_output, minor_offset, e_op_result);
} else {
reduce_op::atomic_reduce<ReduceOp>(result_value_output + minor_offset, e_op_result);
}
}
}
};

template <bool update_major,
typename vertex_t,
typename edge_t,
Expand Down Expand Up @@ -126,41 +155,18 @@ __device__ void update_result_value_output(
init,
reduce_op);
} else {
if constexpr (multi_gpu) {
thrust::for_each(
thrust::seq,
thrust::make_counting_iterator(edge_t{0}),
thrust::make_counting_iterator(local_degree),
[&edge_partition,
identity_element,
indices,
&result_value_output,
&transform_op] __device__(auto i) {
auto e_op_result = transform_op(i);
if (e_op_result != identity_element) {
auto minor = indices[i];
auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor);
reduce_op::atomic_reduce<ReduceOp>(result_value_output, minor_offset, e_op_result);
}
});
} else {
thrust::for_each(
thrust::seq,
thrust::make_counting_iterator(edge_t{0}),
thrust::make_counting_iterator(local_degree),
[&edge_partition,
identity_element,
indices,
&result_value_output,
&transform_op] __device__(auto i) {
auto e_op_result = transform_op(i);
if (e_op_result != identity_element) {
auto minor = indices[i];
auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor);
reduce_op::atomic_reduce<ReduceOp>(result_value_output + minor_offset, e_op_result);
}
});
}
thrust::for_each(
thrust::seq,
thrust::make_counting_iterator(edge_t{0}),
thrust::make_counting_iterator(local_degree),
transform_and_atomic_reduce_t<vertex_t,
edge_t,
multi_gpu,
result_t,
TransformOp,
ReduceOp,
ResultValueOutputIteratorOrWrapper>{
edge_partition, identity_element, indices, transform_op, result_value_output});
}
}

Expand Down

0 comments on commit 55927e9

Please sign in to comment.