Skip to content

Commit

Permalink
add restrict, free 1 register
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 7, 2024
1 parent fde845b commit 663dabf
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -719,21 +719,22 @@ template <typename value_idx = std::int64_t,
int dim = 3,
typename value_int = std::uint32_t,
typename distance_func>
RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered,
const value_t* X,
const value_int n_queries,
const value_int n_cols,
const value_t* R,
const value_int m,
const value_t eps,
const value_int n_landmarks,
const value_idx* R_indptr,
const value_idx* R_1nn_cols,
const value_t* R_1nn_dists,
const value_t* R_radius,
distance_func dfunc,
value_idx* adj_ia,
value_idx* adj_ja)
RAFT_KERNEL __launch_bounds__(tpb)
block_rbc_kernel_eps_csr_pass_xd(const value_t* __restrict__ X_reordered,
const value_t* __restrict__ X,
const value_int n_queries,
const value_int n_cols,
const value_t* __restrict__ R,
const value_int m,
const value_t eps,
const value_int n_landmarks,
const value_idx* __restrict__ R_indptr,
const value_idx* __restrict__ R_1nn_cols,
const value_t* __restrict__ R_1nn_dists,
const value_t* __restrict__ R_radius,
distance_func dfunc,
value_idx* __restrict__ adj_ia,
value_idx* adj_ja)
{
constexpr int num_warps = tpb / WarpSize;
constexpr int max_lid = WarpSize - 1;
Expand All @@ -748,10 +749,14 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_
// this is an early out for a full warp
if (query_id >= n_queries) return;

value_idx column_index_offset = write_pass ? adj_ia[query_id] : 0;
uint32_t column_index_offset = 0;

// we have no neighbors to fill for this query
if (write_pass && adj_ia[query_id + 1] == column_index_offset) return;
if constexpr (write_pass) {
value_idx offset = adj_ia[query_id];
// we have no neighbors to fill for this query
if (offset == adj_ia[query_id + 1]) return;
adj_ja += offset;
}

const value_t* x_ptr = X + (dim * query_id);
value_t local_x_ptr[dim];
Expand Down Expand Up @@ -812,11 +817,11 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
const uint32_t index = R_1nn_cols[R_start_offset + i];
const value_idx row_pos = column_index_offset + __popc(mask & lid_mask);
adj_ja[row_pos] = index;
const uint32_t index = R_1nn_cols[R_start_offset + i];
const uint32_t row_pos = __popc(mask & lid_mask);
adj_ja[row_pos] = index;
}
column_index_offset += __popc(mask);
adj_ja += __popc(mask);
} else {
column_index_offset += (in_range);
}
Expand All @@ -833,11 +838,11 @@ RAFT_KERNEL __launch_bounds__(tpb) block_rbc_kernel_eps_csr_pass_xd(const value_
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
const uint32_t index = R_1nn_cols[R_start_offset + i];
const value_idx row_pos = column_index_offset + __popc(mask & lid_mask);
adj_ja[row_pos] = index;
const uint32_t index = R_1nn_cols[R_start_offset + i];
const uint32_t row_pos = __popc(mask & lid_mask);
adj_ja[row_pos] = index;
}
column_index_offset += __popc(mask);
adj_ja += __popc(mask);
} else {
column_index_offset += (in_range);
}
Expand Down

0 comments on commit 663dabf

Please sign in to comment.