diff --git a/cpp/include/raft/neighbors/ball_cover-inl.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh index 4af2a6772c..398e7d6e42 100644 --- a/cpp/include/raft/neighbors/ball_cover-inl.cuh +++ b/cpp/include/raft/neighbors/ball_cover-inl.cuh @@ -333,7 +333,7 @@ void eps_nn(raft::resources const& handle, query.extent(0), adj.data_handle(), vd.data_handle(), - spatial::knn::detail::EuclideanFunc()); + spatial::knn::detail::EuclideanSqFunc()); } /** @@ -392,7 +392,7 @@ void eps_nn(raft::resources const& handle, adj_ia.data_handle(), adj_ja.data_handle(), vd.data_handle(), - spatial::knn::detail::EuclideanFunc()); + spatial::knn::detail::EuclideanSqFunc()); } /** diff --git a/cpp/include/raft/neighbors/ball_cover_types.hpp b/cpp/include/raft/neighbors/ball_cover_types.hpp index a627a1d234..e3ea5f0005 100644 --- a/cpp/include/raft/neighbors/ball_cover_types.hpp +++ b/cpp/include/raft/neighbors/ball_cover_types.hpp @@ -69,6 +69,7 @@ class BallCoverIndex { R_1nn_dists(raft::make_device_vector(handle, m_)), R_closest_landmark_dists(raft::make_device_vector(handle, m_)), R(raft::make_device_matrix(handle, sqrt(m_), n_)), + X_reordered(raft::make_device_matrix(handle, m_, n_)), R_radius(raft::make_device_vector(handle, sqrt(m_))), index_trained(false) { @@ -93,6 +94,8 @@ class BallCoverIndex { R_1nn_dists(raft::make_device_vector(handle, X_.extent(0))), R_closest_landmark_dists(raft::make_device_vector(handle, X_.extent(0))), R(raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1))), + X_reordered( + raft::make_device_matrix(handle, X_.extent(0), X_.extent(1))), R_radius(raft::make_device_vector(handle, sqrt(X_.extent(0)))), index_trained(false) { @@ -122,6 +125,10 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } + auto get_X_reordered() const -> raft::device_matrix_view + { + return X_reordered.view(); + } raft::device_vector_view get_R_indptr() { return R_indptr.view(); } raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } @@ -132,6 +139,10 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } + raft::device_matrix_view get_X_reordered() + { + return X_reordered.view(); + } raft::device_matrix_view get_X() const { return X; } raft::distance::DistanceType get_metric() const { return metric; } @@ -162,6 +173,7 @@ class BallCoverIndex { raft::device_vector R_radius; raft::device_matrix R; + raft::device_matrix X_reordered; protected: bool index_trained; diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 5c91b1ad6d..c4ca2ffa61 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -165,6 +165,10 @@ void construct_landmark_1nn(raft::resources const& handle, index.get_R_indptr().data_handle(), index.n_landmarks + 1, resource::get_cuda_stream(handle)); + + // reorder X to allow aligned access + raft::matrix::copy_rows( + handle, index.get_X(), index.get_X_reordered(), index.get_R_1nn_cols()); } /** @@ -337,12 +341,6 @@ void perform_rbc_query(raft::resources const& handle, /** * Perform eps-select * - * a. Map 1 row to each warp/block - * b. Add closest k R points to heap - * c. Iterate through batches of R, having each thread in the warp load a set - * of distances y from R (only if d(q, r) < 3 * distance to closest r) and - * marking the distance to be computed between x, y only - * if knn[k].distance >= d(x_i, R_k) + d(R_k, y) */ template ( - handle, index, query, n_query_pts, eps, landmark_dists, dfunc, adj, vd); + handle, index, query, n_query_pts, eps, landmarks, dfunc, adj, vd); resource::sync_stream(handle); } @@ -384,14 +382,14 @@ void perform_rbc_eps_nn_query( value_int n_query_pts, value_t eps, value_int* max_k, - const value_t* landmark_dists, + const value_t* landmarks, dist_func dfunc, value_idx* adj_ia, value_idx* adj_ja, value_idx* vd) { rbc_eps_pass( - handle, index, query, n_query_pts, eps, max_k, landmark_dists, dfunc, adj_ia, adj_ja, vd); + handle, index, query, n_query_pts, eps, max_k, landmarks, dfunc, adj_ia, adj_ja, vd); resource::sync_stream(handle); } @@ -664,15 +662,9 @@ void rbc_eps_nn_query(raft::resources const& handle, { ASSERT(index.is_index_trained(), "index must be previously trained"); - auto R_dists = - raft::make_device_matrix(handle, index.n_landmarks, n_query_pts); - - // find all landmarks that might have points in range - compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); - // query all points and write to adj perform_rbc_eps_nn_query( - handle, index, query, n_query_pts, eps, R_dists.data_handle(), dfunc, adj, vd); + handle, index, query, n_query_pts, eps, index.get_R().data_handle(), dfunc, adj, vd); } template (handle, index.n_landmarks, n_query_pts); - - // find all landmarks that might have points in range - compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); - // query all points and write to adj perform_rbc_eps_nn_query(handle, index, @@ -706,7 +692,7 @@ void rbc_eps_nn_query(raft::resources const& handle, n_query_pts, eps, max_k, - R_dists.data_handle(), + index.get_R().data_handle(), dfunc, adj_ia, adj_ja, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh index 81500a0eae..70df6d0165 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -190,7 +190,7 @@ instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc); instantiate_raft_spatial_knn_detail_rbc_eps_pass( - std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index e4acd1dc4c..eda6d33293 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -157,7 +157,7 @@ template -RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, +RAFT_KERNEL compute_final_dists_registers(const value_t* X_reordered, const value_t* X, const value_int n_cols, bitset_type* bitset, @@ -238,7 +238,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -267,7 +267,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -313,7 +313,7 @@ template -RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_reordered, const value_t* X, value_int n_cols, // n_cols should be 2 or 3 dims const value_idx* R_knn_inds, @@ -408,7 +408,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -433,7 +433,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -456,15 +456,22 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, } } +template +__device__ value_t squared(const value_t& a) +{ + return a * a; +} + template -RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -476,70 +483,115 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_index, bool* adj, value_idx* vd) { - __shared__ int column_count_smem; - - // initialize - if (vd != nullptr) { - if (threadIdx.x == 0) { column_count_smem = 0; } - __syncthreads(); - } - - const value_t* x_ptr = X + (n_cols * blockIdx.x); - - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; - - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; - - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; - - value_idx R_size = R_stop_offset - R_start_offset; - - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - adj[blockIdx.x * m + cur_candidate_ind] = true; - if (vd != nullptr) atomicAdd(&column_count_smem, 1); + constexpr int num_warps = tpb / WarpSize; + + // process 1 query per warp + const uint32_t lid = raft::laneId(); + + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); + + // this is an early out for a full warp + if (query_id >= n_queries) return; + + value_idx column_count = 0; + + const value_t* x_ptr = X + (n_cols * query_id); + adj += query_id * m; + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + column_count++; + adj[index] = true; + } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } - - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - adj[blockIdx.x * m + cur_candidate_ind] = true; - if (vd != nullptr) atomicAdd(&column_count_smem, 1); + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i0 + lid]; + column_count++; + adj[index] = true; + } + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); } - } + } while (lane_mask); } if (vd != nullptr) { - __syncthreads(); - if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } + value_idx row_sum = raft::warpReduce(column_count); + if (lid == 0) vd[query_id] = row_sum; } } template -RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -551,58 +603,267 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index, value_idx* adj_ia, value_idx* adj_ja) { - const value_t* x_ptr = X + (n_cols * blockIdx.x); + constexpr int num_warps = tpb / WarpSize; - __shared__ unsigned long long int column_index_smem; + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; - bool pass2 = adj_ja != nullptr; + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); - // initialize - if (threadIdx.x == 0) { column_index_smem = pass2 ? adj_ia[blockIdx.x] : 0; } + // this is an early out for a full warp + if (query_id >= n_queries) return; - __syncthreads(); + uint32_t column_index_offset = 0; - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + if constexpr (write_pass) { + value_idx offset = adj_ia[query_id]; + // we have no neighbors to fill for this query + if (offset == adj_ia[query_id + 1]) return; + adj_ja += offset; + } - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; + const value_t* x_ptr = X + (n_cols * query_id); + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); + } - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); + } + } while (lane_mask); + } - value_idx R_size = R_stop_offset - R_start_offset; + if constexpr (!write_pass) { + value_idx row_sum = raft::warpReduce(column_index_offset); + if (lid == 0) adj_ia[query_id] = row_sum; + } +} - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; +template +RAFT_KERNEL __launch_bounds__(tpb) + block_rbc_kernel_eps_csr_pass_xd(const value_t* __restrict__ X_reordered, + const value_t* __restrict__ X, + const value_int n_queries, + const value_int n_cols, + const value_t* __restrict__ R, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* __restrict__ R_indptr, + const value_idx* __restrict__ R_1nn_cols, + const value_t* __restrict__ R_1nn_dists, + const value_t* __restrict__ R_radius, + distance_func dfunc, + value_idx* __restrict__ adj_ia, + value_idx* adj_ja) +{ + constexpr int num_warps = tpb / WarpSize; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - auto row_pos = atomicAdd(&column_index_smem, 1); - if (pass2) adj_ja[row_pos] = cur_candidate_ind; - } - } + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); + + // this is an early out for a full warp + if (query_id >= n_queries) return; + + uint32_t column_index_offset = 0; + + if constexpr (write_pass) { + value_idx offset = adj_ia[query_id]; + // we have no neighbors to fill for this query + if (offset == adj_ia[query_id + 1]) return; + adj_ja += offset; + } + + const value_t* x_ptr = X + (dim * query_id); + value_t local_x_ptr[dim]; +#pragma unroll + for (uint32_t i = 0; i < dim; ++i) { + local_x_ptr[i] = x_ptr[i]; + } + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(local_x_ptr, R + lane_k * dim, dim) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(local_x_ptr, y_ptr, dim) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); + } - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - auto row_pos = atomicAdd(&column_index_smem, 1); - if (pass2) adj_ja[row_pos] = cur_candidate_ind; + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * dim; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(local_x_ptr, y_ptr, dim); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); } - } + } while (lane_mask); } - __syncthreads(); - if (threadIdx.x == 0 && !pass2) { adj_ia[blockIdx.x] = (value_idx)column_index_smem; } + if constexpr (!write_pass) { + value_idx row_sum = raft::warpReduce(column_index_offset); + if (lid == 0) adj_ia[query_id] = row_sum; + } } template -RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -626,59 +888,110 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_index, const value_int max_k, value_idx* tmp) { - const value_t* x_ptr = X + (n_cols * blockIdx.x); - - __shared__ int column_count_smem; - - // initialize - if (threadIdx.x == 0) { column_count_smem = 0; } - - __syncthreads(); - - // we store all column indices in dense tmp store [blockDim.x * max_k] - value_int offset = blockIdx.x * max_k; - - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; - - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; - - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; - - value_idx R_size = R_stop_offset - R_start_offset; - - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - int row_pos = atomicAdd(&column_count_smem, 1); - if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; + constexpr int num_warps = tpb / WarpSize; + + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; + + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); + + // this is an early out for a full warp + if (query_id >= n_queries) return; + + value_idx column_count = 0; + + const value_t* x_ptr = X + (n_cols * query_id); + tmp += query_id * max_k; + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + const int mask = raft::ballot(in_range); + if (in_range) { + auto row_pos = column_count + __popc(mask & lid_mask); + // we still continue to look for more hits to return valid vd + if (row_pos < max_k) { + auto index = R_1nn_cols[R_start_offset + i]; + tmp[row_pos] = index; + } + } + column_count += __popc(mask); + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } - - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - int row_pos = atomicAdd(&column_count_smem, 1); - if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + const int mask = raft::ballot(in_range); + if (in_range) { + auto row_pos = column_count + __popc(mask & lid_mask); + // we still continue to look for more hits to return valid vd + if (row_pos < max_k) { + auto index = R_1nn_cols[R_start_offset + i0 + lid]; + tmp[row_pos] = index; + } + } + column_count += __popc(mask); + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); } - } + } while (lane_mask); } - __syncthreads(); - if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } + if (lid == 0) vd[query_id] = column_count; } template @@ -723,7 +1036,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, if (k <= 32) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -743,7 +1056,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 64) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -762,7 +1075,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 128) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -782,7 +1095,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 256) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -802,7 +1115,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 512) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -822,7 +1135,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 1024) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -892,7 +1205,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -918,7 +1231,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -944,7 +1257,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -970,7 +1283,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -995,7 +1308,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1020,7 +1333,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1047,17 +1360,18 @@ void rbc_eps_pass(raft::resources const& handle, const value_t* query, const value_int n_query_rows, value_t eps, - const value_t* R_dists, + const value_t* R, dist_func& dfunc, bool* adj, value_idx* vd) { block_rbc_kernel_eps_dense <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, + n_query_rows, index.n, - R_dists, + R, index.m, eps, index.n_landmarks, @@ -1093,7 +1407,7 @@ void rbc_eps_pass(raft::resources const& handle, const value_int n_query_rows, value_t eps, value_int* max_k, - const value_t* R_dists, + const value_t* R, dist_func& dfunc, value_idx* adj_ia, value_idx* adj_ja, @@ -1104,22 +1418,61 @@ void rbc_eps_pass(raft::resources const& handle, if (adj_ja == nullptr) { // pass 1 -> only compute adj_ia / vd value_idx* vd_ptr = (vd != nullptr) ? vd : adj_ia; - block_rbc_kernel_eps_csr_pass - <<>>( - index.get_X().data_handle(), - query, - index.n, - R_dists, - index.m, - eps, - index.n_landmarks, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - index.get_R_radius().data_handle(), - dfunc, - vd_ptr, - nullptr); + if (index.n == 2) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } else if (index.n == 3) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } else { + block_rbc_kernel_eps_csr_pass + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } thrust::exclusive_scan(resource::get_thrust_policy(handle), vd_ptr, @@ -1129,22 +1482,61 @@ void rbc_eps_pass(raft::resources const& handle, } else { // pass 2 -> fill in adj_ja - block_rbc_kernel_eps_csr_pass - <<>>( - index.get_X().data_handle(), - query, - index.n, - R_dists, - index.m, - eps, - index.n_landmarks, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - index.get_R_radius().data_handle(), - dfunc, - adj_ia, - adj_ja); + if (index.n == 2) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } else if (index.n == 3) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } else { + block_rbc_kernel_eps_csr_pass + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } } } else { value_int max_k_in = *max_k; @@ -1153,11 +1545,12 @@ void rbc_eps_pass(raft::resources const& handle, rmm::device_uvector tmp(n_query_rows * max_k_in, resource::get_cuda_stream(handle)); block_rbc_kernel_eps_max_k - <<>>( - index.get_X().data_handle(), + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), query, + n_query_rows, index.n, - R_dists, + R, index.m, eps, index.n_landmarks, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh index f38a3eeec9..eacd4aeff3 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh @@ -61,6 +61,21 @@ struct EuclideanFunc : public DistFunc { } }; +template +struct EuclideanSqFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + value_t sum_sq = 0; + for (value_int i = 0; i < n_dims; ++i) { + value_t diff = a[i] - b[i]; + sum_sq += diff * diff; + } + return sum_sq; + } +}; + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py index dff2e015a4..10d9c95ece 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py @@ -121,6 +121,8 @@ dist="raft::spatial::knn::detail::DistFunc", ) +euclideanSq="raft::spatial::knn::detail::EuclideanSqFunc", + types = dict( int64_float=("std::int64_t", "float"), #int64_double=("std::int64_t", "double"), @@ -156,7 +158,7 @@ f.write(macro_pass_eps) for type_path, (int_t, data_t) in types.items(): f.write(f"instantiate_raft_spatial_knn_detail_rbc_eps_pass(\n") - f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {distances['euclidean']});\n") + f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {euclideanSq});\n") f.write("#undef instantiate_raft_spatial_knn_detail_rbc_eps_pass\n") print(f"src/spatial/knn/detail/ball_cover/{path}") diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu index bbc1b55eb2..2a88862b2c 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu @@ -56,5 +56,5 @@ Mvalue_idx* vd) instantiate_raft_spatial_knn_detail_rbc_eps_pass( - std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc); #undef instantiate_raft_spatial_knn_detail_rbc_eps_pass