Skip to content

Commit

Permalink
remove one more FLO per landmark visited
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 7, 2024
1 parent 2be4c17 commit 768eb48
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,11 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered,
if (lane_mask == 0) continue;

// reverse to use __clz instead of __ffs
lane_mask = __brev(lane_mask);
uint32_t k_offset = __clz(lane_mask);
lane_mask = __brev(lane_mask);
do {
// look for next k_offset
const uint32_t k_offset = __clz(lane_mask);

const uint32_t cur_k = cur_k0 + k_offset;

// The whole warp should iterate through the elements in the current R
Expand All @@ -536,9 +538,6 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered,
const uint32_t limit = Pow2<WarpSize>::roundDown(R_size);
uint32_t i = limit + lid;

// look ahead for next k_offset
k_offset = __clz(lane_mask);

// R_1nn_dists are sorted ascendingly for each landmark
// Iterating backwards, after pruning the first point w.r.t. triangle
// inequality all subsequent points can be pruned as well
Expand Down Expand Up @@ -644,9 +643,11 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered,
if (lane_mask == 0) continue;

// reverse to use __clz instead of __ffs
lane_mask = __brev(lane_mask);
uint32_t k_offset = __clz(lane_mask);
lane_mask = __brev(lane_mask);
do {
// look for next k_offset
const uint32_t k_offset = __clz(lane_mask);

const uint32_t cur_k = cur_k0 + k_offset;

// The whole warp should iterate through the elements in the current R
Expand All @@ -663,9 +664,6 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered,
const uint32_t limit = Pow2<WarpSize>::roundDown(R_size);
uint32_t i = limit + lid;

// look ahead for next k_offset
k_offset = __clz(lane_mask);

// R_1nn_dists are sorted ascendingly for each landmark
// Iterating backwards, after pruning the first point w.r.t. triangle
// inequality all subsequent points can be pruned as well
Expand Down Expand Up @@ -790,9 +788,11 @@ RAFT_KERNEL __launch_bounds__(tpb)
if (lane_mask == 0) continue;

// reverse to use __clz instead of __ffs
lane_mask = __brev(lane_mask);
uint32_t k_offset = __clz(lane_mask);
lane_mask = __brev(lane_mask);
do {
// look for next k_offset
const uint32_t k_offset = __clz(lane_mask);

const uint32_t cur_k = cur_k0 + k_offset;

// The whole warp should iterate through the elements in the current R
Expand All @@ -809,9 +809,6 @@ RAFT_KERNEL __launch_bounds__(tpb)
const uint32_t limit = Pow2<WarpSize>::roundDown(R_size);
uint32_t i = limit + lid;

// look ahead for next k_offset
k_offset = __clz(lane_mask);

// R_1nn_dists are sorted ascendingly for each landmark
// Iterating backwards, after pruning the first point w.r.t. triangle
// inequality all subsequent points can be pruned as well
Expand Down Expand Up @@ -923,9 +920,11 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered,
if (lane_mask == 0) continue;

// reverse to use __clz instead of __ffs
lane_mask = __brev(lane_mask);
uint32_t k_offset = __clz(lane_mask);
lane_mask = __brev(lane_mask);
do {
// look for next k_offset
const uint32_t k_offset = __clz(lane_mask);

const uint32_t cur_k = cur_k0 + k_offset;

// The whole warp should iterate through the elements in the current R
Expand All @@ -942,9 +941,6 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered,
const uint32_t limit = Pow2<WarpSize>::roundDown(R_size);
uint32_t i = limit + lid;

// look ahead for next k_offset
k_offset = __clz(lane_mask);

// R_1nn_dists are sorted ascendingly for each landmark
// Iterating backwards, after pruning the first point w.r.t. triangle
// inequality all subsequent points can be pruned as well
Expand Down

0 comments on commit 768eb48

Please sign in to comment.