Skip to content

Commit

Permalink
remove warp divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 8, 2024
1 parent 2315e2f commit 2d06d0a
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -556,19 +556,21 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered,
i *= (cur_R_dist - min_warp_dist <= eps);
}

while (i >= WarpSize) {
uint32_t i0 = raft::shfl(i, 0);

while (i0 >= WarpSize) {
y_ptr -= WarpSize * n_cols;
i -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid];
i0 -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0];
const value_t dist = dfunc(x_ptr, y_ptr, n_cols);
const bool in_range = (dist <= eps2);
if (in_range) {
auto index = R_1nn_cols[R_start_offset + i];
auto index = R_1nn_cols[R_start_offset + i0 + lid];
column_count++;
adj[index] = true;
}
// abort in case subsequent points cannot possibly be in reach
i *= (cur_R_dist - min_warp_dist <= eps);
i0 *= (cur_R_dist - min_warp_dist <= eps);
}
} while (lane_mask);
}
Expand Down Expand Up @@ -687,16 +689,18 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered,
i *= (cur_R_dist - min_warp_dist <= eps);
}

while (i >= WarpSize) {
uint32_t i0 = raft::shfl(i, 0);

while (i0 >= WarpSize) {
y_ptr -= WarpSize * n_cols;
i -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid];
i0 -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0];
const value_t dist = dfunc(x_ptr, y_ptr, n_cols);
const bool in_range = (dist <= eps2);
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
const uint32_t index = R_1nn_cols[R_start_offset + i];
const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid];
const uint32_t row_pos = __popc(mask & lid_mask);
adj_ja[row_pos] = index;
}
Expand All @@ -705,7 +709,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered,
column_index_offset += (in_range);
}
// abort in case subsequent points cannot possibly be in reach
i *= (cur_R_dist - min_warp_dist <= eps);
i0 *= (cur_R_dist - min_warp_dist <= eps);
}
} while (lane_mask);
}
Expand Down Expand Up @@ -831,16 +835,18 @@ RAFT_KERNEL __launch_bounds__(tpb)
i *= (cur_R_dist - min_warp_dist <= eps);
}

while (i >= WarpSize) {
uint32_t i0 = raft::shfl(i, 0);

while (i0 >= WarpSize) {
y_ptr -= WarpSize * dim;
i -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid];
i0 -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0];
const value_t dist = dfunc(local_x_ptr, y_ptr, dim);
const bool in_range = (dist <= eps2);
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
const uint32_t index = R_1nn_cols[R_start_offset + i];
const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid];
const uint32_t row_pos = __popc(mask & lid_mask);
adj_ja[row_pos] = index;
}
Expand All @@ -849,7 +855,7 @@ RAFT_KERNEL __launch_bounds__(tpb)
column_index_offset += (in_range);
}
// abort in case subsequent points cannot possibly be in reach
i *= (cur_R_dist - min_warp_dist <= eps);
i0 *= (cur_R_dist - min_warp_dist <= eps);
}
} while (lane_mask);
}
Expand Down Expand Up @@ -961,24 +967,26 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered,
i *= (cur_R_dist - min_warp_dist <= eps);
}

while (i >= WarpSize) {
uint32_t i0 = raft::shfl(i, 0);

while (i0 >= WarpSize) {
y_ptr -= WarpSize * n_cols;
i -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid];
i0 -= WarpSize;
const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0];
const value_t dist = dfunc(x_ptr, y_ptr, n_cols);
const bool in_range = (dist <= eps2);
const int mask = raft::ballot(in_range);
if (in_range) {
auto row_pos = column_count + __popc(mask & lid_mask);
// we still continue to look for more hits to return valid vd
if (row_pos < max_k) {
auto index = R_1nn_cols[R_start_offset + i];
auto index = R_1nn_cols[R_start_offset + i0 + lid];
tmp[row_pos] = index;
}
}
column_count += __popc(mask);
// abort in case subsequent points cannot possibly be in reach
i *= (cur_R_dist - min_warp_dist <= eps);
i0 *= (cur_R_dist - min_warp_dist <= eps);
}
} while (lane_mask);
}
Expand Down

0 comments on commit 2d06d0a

Please sign in to comment.