Skip to content

Commit

Permalink
review suggestion constexpr
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 5, 2024
1 parent 1782558 commit b3e998a
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered,
const value_t dist =
(i >= R_size) ? std::numeric_limits<value_idx>::max() : dfunc(x_ptr, y_ptr, n_cols);
const bool in_range = (dist <= eps2);
if (write_pass) {
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
auto index = R_1nn_cols[R_start_offset + i];
Expand All @@ -683,7 +683,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered,
const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i));
const value_t dist = dfunc(x_ptr, y_ptr, n_cols);
const bool in_range = (dist <= eps2);
if (write_pass) {
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
auto index = R_1nn_cols[R_start_offset + i];
Expand All @@ -700,7 +700,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered,
} while (k_offset < WarpSize);
}

if (!write_pass) {
if constexpr (!write_pass) {
value_idx row_sum = raft::warpReduce(column_index_offset);
if (lid == 0) adj_ia[query_id] = row_sum;
}
Expand Down Expand Up @@ -801,7 +801,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered,
const value_t dist =
(i >= R_size) ? std::numeric_limits<value_idx>::max() : dfunc(local_x_ptr, y_ptr, dim);
const bool in_range = (dist <= eps2);
if (write_pass) {
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
auto index = R_1nn_cols[R_start_offset + i];
Expand All @@ -823,7 +823,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered,
const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i));
const value_t dist = dfunc(local_x_ptr, y_ptr, dim);
const bool in_range = (dist <= eps2);
if (write_pass) {
if constexpr (write_pass) {
const int mask = raft::ballot(in_range);
if (in_range) {
auto index = R_1nn_cols[R_start_offset + i];
Expand All @@ -840,7 +840,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass_xd(const value_t* X_reordered,
} while (k_offset < WarpSize);
}

if (!write_pass) {
if constexpr (!write_pass) {
value_idx row_sum = raft::warpReduce(column_index_offset);
if (lid == 0) adj_ia[query_id] = row_sum;
}
Expand Down

0 comments on commit b3e998a

Please sign in to comment.