diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 2062e6e421..eda6d33293 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -556,19 +556,21 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * n_cols; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); if (in_range) { - auto index = R_1nn_cols[R_start_offset + i]; + auto index = R_1nn_cols[R_start_offset + i0 + lid]; column_count++; adj[index] = true; } // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); } @@ -687,16 +689,18 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * n_cols; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; const uint32_t row_pos = __popc(mask & lid_mask); adj_ja[row_pos] = index; } @@ -705,7 +709,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, column_index_offset += (in_range); } // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); } @@ -831,16 +835,18 @@ RAFT_KERNEL __launch_bounds__(tpb) i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * dim; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(local_x_ptr, y_ptr, dim); const bool in_range = (dist <= eps2); if constexpr (write_pass) { const int mask = raft::ballot(in_range); if (in_range) { - const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; const uint32_t row_pos = __popc(mask & lid_mask); adj_ja[row_pos] = index; } @@ -849,7 +855,7 @@ RAFT_KERNEL __launch_bounds__(tpb) column_index_offset += (in_range); } // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); } @@ -961,10 +967,12 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, i *= (cur_R_dist - min_warp_dist <= eps); } - while (i >= WarpSize) { + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { y_ptr -= WarpSize * n_cols; - i -= WarpSize; - const value_t min_warp_dist = R_1nn_dists[R_start_offset + i - lid]; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; const value_t dist = dfunc(x_ptr, y_ptr, n_cols); const bool in_range = (dist <= eps2); const int mask = raft::ballot(in_range); @@ -972,13 +980,13 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, auto row_pos = column_count + __popc(mask & lid_mask); // we still continue to look for more hits to return valid vd if (row_pos < max_k) { - auto index = R_1nn_cols[R_start_offset + i]; + auto index = R_1nn_cols[R_start_offset + i0 + lid]; tmp[row_pos] = index; } } column_count += __popc(mask); // abort in case subsequent points cannot possibly be in reach - i *= (cur_R_dist - min_warp_dist <= eps); + i0 *= (cur_R_dist - min_warp_dist <= eps); } } while (lane_mask); }