Skip to content

Commit

Permalink
update fill_edge_property to work with graphs with edge masks
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Nov 16, 2023
1 parent 0887af9 commit 69463f4
Showing 1 changed file with 63 additions and 8 deletions.
71 changes: 63 additions & 8 deletions cpp/src/prims/fill_edge_property.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,78 @@ void fill_edge_property(raft::handle_t const& handle,
{
static_assert(std::is_same_v<T, typename EdgePropertyOutputWrapper::value_type>);

using edge_t = typename GraphViewType::edge_type;

auto edge_mask_view = graph_view.edge_mask_view();

auto value_firsts = edge_property_output.value_firsts();
auto edge_counts = edge_property_output.edge_counts();
for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) {
auto edge_partition_e_mask =
edge_mask_view
? thrust::make_optional<
detail::edge_partition_edge_property_device_view_t<edge_t, uint32_t const*, bool>>(
*edge_mask_view, i)
: thrust::nullopt;

if constexpr (cugraph::has_packed_bool_element<
std::remove_reference_t<decltype(value_firsts[i])>,
T>()) {
static_assert(std::is_arithmetic_v<T>, "unimplemented for thrust::tuple types.");
auto packed_input = input ? packed_bool_full_mask() : packed_bool_empty_mask();
thrust::fill_n(handle.get_thrust_policy(),
value_firsts[i],
packed_bool_size(static_cast<size_t>(edge_counts[i])),
packed_input);
auto rem = edge_counts[i] % packed_bools_per_word();
if (edge_partition_e_mask) {
auto input_first =
thrust::make_zip_iterator(value_firsts[i], (*edge_partition_e_mask).value_first());
thrust::transform(handle.get_thrust_policy(),
input_first,
input_first + packed_bool_size(static_cast<size_t>(edge_counts[i] - rem)),
value_firsts[i],
[packed_input] __device__(thrust::tuple<T, uint32_t> pair) {
auto old_value = thrust::get<0>(pair);
auto mask = thrust::get<1>(pair);
return (old_value & ~mask) | (packed_input & mask);
});
if (rem > 0) {
thrust::transform(
handle.get_thrust_policy(),
input_first + packed_bool_size(static_cast<size_t>(edge_counts[i] - rem)),
input_first + packed_bool_size(static_cast<size_t>(edge_counts[i])),
value_firsts[i] + packed_bool_size(static_cast<size_t>(edge_counts[i] - rem)),
[packed_input, rem] __device__(thrust::tuple<T, uint32_t> pair) {
auto old_value = thrust::get<0>(pair);
auto mask = thrust::get<1>(pair);
return ((old_value & ~mask) | (packed_input & mask)) & packed_bool_partial_mask(rem);
});
}
} else {
thrust::fill_n(handle.get_thrust_policy(),
value_firsts[i],
packed_bool_size(static_cast<size_t>(edge_counts[i] - rem)),
packed_input);
if (rem > 0) {
thrust::fill_n(
handle.get_thrust_policy(),
value_firsts[i] + packed_bool_size(static_cast<size_t>(edge_counts[i] - rem)),
1,
packed_input & packed_bool_partial_mask(rem));
}
}
} else {
thrust::fill_n(
handle.get_thrust_policy(), value_firsts[i], static_cast<size_t>(edge_counts[i]), input);
if (edge_partition_e_mask) {
thrust::transform_if(handle.get_thrust_policy(),
thrust::make_constant_iterator(input),
thrust::make_constant_iterator(input) + edge_counts[i],
thrust::make_counting_iterator(edge_t{0}),
value_firsts[i],
thrust::identity<T>{},
[edge_partition_e_mask = *edge_partition_e_mask] __device__(edge_t i) {
return edge_partition_e_mask.get(i);
});
} else {
thrust::fill_n(
handle.get_thrust_policy(), value_firsts[i], static_cast<size_t>(edge_counts[i]), input);
}
}
}
}
Expand All @@ -79,8 +136,6 @@ void fill_edge_property(raft::handle_t const& handle,
edge_property_t<GraphViewType, T>& edge_property_output,
bool do_expensive_check = false)
{
CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented.");

if (do_expensive_check) {
// currently, nothing to do
}
Expand Down

0 comments on commit 69463f4

Please sign in to comment.