diff --git a/cpp/src/prims/transform_e.cuh b/cpp/src/prims/transform_e.cuh index a34a5a04b49..2cb1a5358b0 100644 --- a/cpp/src/prims/transform_e.cuh +++ b/cpp/src/prims/transform_e.cuh @@ -85,7 +85,7 @@ __global__ void transform_e_packed_bool( if (local_edge_idx < num_edges) { bool compute_predicate = true; if constexpr (check_edge_mask) { - compute_predicate = (edge_mask & packed_bool_mask(lane_id) != packed_bool_empty_mask()); + compute_predicate = ((edge_mask & packed_bool_mask(lane_id)) != packed_bool_empty_mask()); } if (compute_predicate) { @@ -111,10 +111,10 @@ __global__ void transform_e_packed_bool( uint32_t new_val = __ballot_sync(raft::warp_full_mask(), predicate); if (lane_id == 0) { if constexpr (check_edge_mask) { - *(edge_partition_e_value_output.value_first() + idx) = new_val; - } else { auto old_val = *(edge_partition_e_value_output.value_first() + idx); *(edge_partition_e_value_output.value_first() + idx) = (old_val & ~edge_mask) | new_val; + } else { + *(edge_partition_e_value_output.value_first() + idx) = new_val; } } @@ -196,6 +196,9 @@ struct update_e_value_t { __device__ void operator()(typename GraphViewType::edge_type i) const { + if constexpr (check_edge_mask) { + if (!edge_partition_e_mask.get(i)) { return; } + } auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(i); auto major = edge_partition.major_from_major_idx_nocheck(major_idx); auto major_offset = edge_partition.major_offset_from_major_nocheck(major);