From dddb05a40f387f5224302e2cd00b1b298dc522cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Mon, 11 Mar 2024 12:33:34 +0100 Subject: [PATCH 1/3] Improve RBC eps-neighborhood query performance (#2211) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR significantly improves performance of epsilon neighborhood search via RBC. Scope: * the rbc index was modified to contain a reordered dataset allowing for better access patterns * all kernels for dense, sparse and hybrid have been improved Optimizations include: - improve pruning of complete landmarks by pre-fetching distances and minimizing overhead for skipping - specialized 2D and 3D kernels, allowing for register per-fetch of query points - improve inner loop iterating landmark neighborhood - pruning of points within selected landmark neighborhood using triangle inequality - reverse iterate landmark neighborhood to allow complete processing stop once subsequent points cannot be reached - removal of shared memory atomics in favor of warp voting - general prevention of branches within a warp (at least to some extent) CC @cjnolet , @tfeher Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2211 --- cpp/include/raft/neighbors/ball_cover-inl.cuh | 4 +- .../raft/neighbors/ball_cover_types.hpp | 12 + .../raft/spatial/knn/detail/ball_cover.cuh | 34 +- .../knn/detail/ball_cover/registers-ext.cuh | 2 +- .../knn/detail/ball_cover/registers-inl.cuh | 781 +++++++++++++----- .../knn/detail/ball_cover/registers_types.cuh | 15 + .../ball_cover/registers_00_generate.py | 4 +- .../registers_eps_pass_euclidean.cu | 2 +- 8 files changed, 631 insertions(+), 223 deletions(-) diff --git a/cpp/include/raft/neighbors/ball_cover-inl.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh index 4af2a6772c..398e7d6e42 100644 --- a/cpp/include/raft/neighbors/ball_cover-inl.cuh +++ b/cpp/include/raft/neighbors/ball_cover-inl.cuh @@ -333,7 +333,7 @@ void eps_nn(raft::resources const& handle, query.extent(0), adj.data_handle(), vd.data_handle(), - spatial::knn::detail::EuclideanFunc()); + spatial::knn::detail::EuclideanSqFunc()); } /** @@ -392,7 +392,7 @@ void eps_nn(raft::resources const& handle, adj_ia.data_handle(), adj_ja.data_handle(), vd.data_handle(), - spatial::knn::detail::EuclideanFunc()); + spatial::knn::detail::EuclideanSqFunc()); } /** diff --git a/cpp/include/raft/neighbors/ball_cover_types.hpp b/cpp/include/raft/neighbors/ball_cover_types.hpp index a627a1d234..e3ea5f0005 100644 --- a/cpp/include/raft/neighbors/ball_cover_types.hpp +++ b/cpp/include/raft/neighbors/ball_cover_types.hpp @@ -69,6 +69,7 @@ class BallCoverIndex { R_1nn_dists(raft::make_device_vector(handle, m_)), R_closest_landmark_dists(raft::make_device_vector(handle, m_)), R(raft::make_device_matrix(handle, sqrt(m_), n_)), + X_reordered(raft::make_device_matrix(handle, m_, n_)), R_radius(raft::make_device_vector(handle, sqrt(m_))), index_trained(false) { @@ -93,6 +94,8 @@ class BallCoverIndex { R_1nn_dists(raft::make_device_vector(handle, X_.extent(0))), R_closest_landmark_dists(raft::make_device_vector(handle, X_.extent(0))), R(raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1))), + X_reordered( + raft::make_device_matrix(handle, X_.extent(0), X_.extent(1))), R_radius(raft::make_device_vector(handle, sqrt(X_.extent(0)))), index_trained(false) { @@ -122,6 +125,10 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } + auto get_X_reordered() const -> raft::device_matrix_view + { + return X_reordered.view(); + } raft::device_vector_view get_R_indptr() { return R_indptr.view(); } raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } @@ -132,6 +139,10 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } + raft::device_matrix_view get_X_reordered() + { + return X_reordered.view(); + } raft::device_matrix_view get_X() const { return X; } raft::distance::DistanceType get_metric() const { return metric; } @@ -162,6 +173,7 @@ class BallCoverIndex { raft::device_vector R_radius; raft::device_matrix R; + raft::device_matrix X_reordered; protected: bool index_trained; diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 5c91b1ad6d..c4ca2ffa61 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -165,6 +165,10 @@ void construct_landmark_1nn(raft::resources const& handle, index.get_R_indptr().data_handle(), index.n_landmarks + 1, resource::get_cuda_stream(handle)); + + // reorder X to allow aligned access + raft::matrix::copy_rows( + handle, index.get_X(), index.get_X_reordered(), index.get_R_1nn_cols()); } /** @@ -337,12 +341,6 @@ void perform_rbc_query(raft::resources const& handle, /** * Perform eps-select * - * a. Map 1 row to each warp/block - * b. Add closest k R points to heap - * c. Iterate through batches of R, having each thread in the warp load a set - * of distances y from R (only if d(q, r) < 3 * distance to closest r) and - * marking the distance to be computed between x, y only - * if knn[k].distance >= d(x_i, R_k) + d(R_k, y) */ template ( - handle, index, query, n_query_pts, eps, landmark_dists, dfunc, adj, vd); + handle, index, query, n_query_pts, eps, landmarks, dfunc, adj, vd); resource::sync_stream(handle); } @@ -384,14 +382,14 @@ void perform_rbc_eps_nn_query( value_int n_query_pts, value_t eps, value_int* max_k, - const value_t* landmark_dists, + const value_t* landmarks, dist_func dfunc, value_idx* adj_ia, value_idx* adj_ja, value_idx* vd) { rbc_eps_pass( - handle, index, query, n_query_pts, eps, max_k, landmark_dists, dfunc, adj_ia, adj_ja, vd); + handle, index, query, n_query_pts, eps, max_k, landmarks, dfunc, adj_ia, adj_ja, vd); resource::sync_stream(handle); } @@ -664,15 +662,9 @@ void rbc_eps_nn_query(raft::resources const& handle, { ASSERT(index.is_index_trained(), "index must be previously trained"); - auto R_dists = - raft::make_device_matrix(handle, index.n_landmarks, n_query_pts); - - // find all landmarks that might have points in range - compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); - // query all points and write to adj perform_rbc_eps_nn_query( - handle, index, query, n_query_pts, eps, R_dists.data_handle(), dfunc, adj, vd); + handle, index, query, n_query_pts, eps, index.get_R().data_handle(), dfunc, adj, vd); } template (handle, index.n_landmarks, n_query_pts); - - // find all landmarks that might have points in range - compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); - // query all points and write to adj perform_rbc_eps_nn_query(handle, index, @@ -706,7 +692,7 @@ void rbc_eps_nn_query(raft::resources const& handle, n_query_pts, eps, max_k, - R_dists.data_handle(), + index.get_R().data_handle(), dfunc, adj_ia, adj_ja, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh index 81500a0eae..70df6d0165 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -190,7 +190,7 @@ instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc); instantiate_raft_spatial_knn_detail_rbc_eps_pass( - std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index e4acd1dc4c..eda6d33293 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -157,7 +157,7 @@ template -RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, +RAFT_KERNEL compute_final_dists_registers(const value_t* X_reordered, const value_t* X, const value_int n_cols, bitset_type* bitset, @@ -238,7 +238,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -267,7 +267,7 @@ RAFT_KERNEL compute_final_dists_registers(const value_t* X_index, // the closest k neighbors, compute it and add to k-select value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -313,7 +313,7 @@ template -RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_reordered, const value_t* X, value_int n_cols, // n_cols should be 2 or 3 dims const value_idx* R_knn_inds, @@ -408,7 +408,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -433,7 +433,7 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, value_t dist = std::numeric_limits::max(); if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); value_t local_y_ptr[col_q]; for (value_int j = 0; j < n_cols; ++j) { local_y_ptr[j] = y_ptr[j]; @@ -456,15 +456,22 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, } } +template +__device__ value_t squared(const value_t& a) +{ + return a * a; +} + template -RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -476,70 +483,115 @@ RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_index, bool* adj, value_idx* vd) { - __shared__ int column_count_smem; - - // initialize - if (vd != nullptr) { - if (threadIdx.x == 0) { column_count_smem = 0; } - __syncthreads(); - } - - const value_t* x_ptr = X + (n_cols * blockIdx.x); - - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; - - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; - - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; - - value_idx R_size = R_stop_offset - R_start_offset; - - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - adj[blockIdx.x * m + cur_candidate_ind] = true; - if (vd != nullptr) atomicAdd(&column_count_smem, 1); + constexpr int num_warps = tpb / WarpSize; + + // process 1 query per warp + const uint32_t lid = raft::laneId(); + + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); + + // this is an early out for a full warp + if (query_id >= n_queries) return; + + value_idx column_count = 0; + + const value_t* x_ptr = X + (n_cols * query_id); + adj += query_id * m; + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i]; + column_count++; + adj[index] = true; + } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } - - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - adj[blockIdx.x * m + cur_candidate_ind] = true; - if (vd != nullptr) atomicAdd(&column_count_smem, 1); + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if (in_range) { + auto index = R_1nn_cols[R_start_offset + i0 + lid]; + column_count++; + adj[index] = true; + } + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); } - } + } while (lane_mask); } if (vd != nullptr) { - __syncthreads(); - if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } + value_idx row_sum = raft::warpReduce(column_count); + if (lid == 0) vd[query_id] = row_sum; } } template -RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -551,58 +603,267 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index, value_idx* adj_ia, value_idx* adj_ja) { - const value_t* x_ptr = X + (n_cols * blockIdx.x); + constexpr int num_warps = tpb / WarpSize; - __shared__ unsigned long long int column_index_smem; + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; - bool pass2 = adj_ja != nullptr; + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); - // initialize - if (threadIdx.x == 0) { column_index_smem = pass2 ? adj_ia[blockIdx.x] : 0; } + // this is an early out for a full warp + if (query_id >= n_queries) return; - __syncthreads(); + uint32_t column_index_offset = 0; - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + if constexpr (write_pass) { + value_idx offset = adj_ia[query_id]; + // we have no neighbors to fill for this query + if (offset == adj_ia[query_id + 1]) return; + adj_ja += offset; + } - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; + const value_t* x_ptr = X + (n_cols * query_id); + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); + } - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); + } + } while (lane_mask); + } - value_idx R_size = R_stop_offset - R_start_offset; + if constexpr (!write_pass) { + value_idx row_sum = raft::warpReduce(column_index_offset); + if (lid == 0) adj_ia[query_id] = row_sum; + } +} - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; +template +RAFT_KERNEL __launch_bounds__(tpb) + block_rbc_kernel_eps_csr_pass_xd(const value_t* __restrict__ X_reordered, + const value_t* __restrict__ X, + const value_int n_queries, + const value_int n_cols, + const value_t* __restrict__ R, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* __restrict__ R_indptr, + const value_idx* __restrict__ R_1nn_cols, + const value_t* __restrict__ R_1nn_dists, + const value_t* __restrict__ R_radius, + distance_func dfunc, + value_idx* __restrict__ adj_ia, + value_idx* adj_ja) +{ + constexpr int num_warps = tpb / WarpSize; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - auto row_pos = atomicAdd(&column_index_smem, 1); - if (pass2) adj_ja[row_pos] = cur_candidate_ind; - } - } + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); + + // this is an early out for a full warp + if (query_id >= n_queries) return; + + uint32_t column_index_offset = 0; + + if constexpr (write_pass) { + value_idx offset = adj_ia[query_id]; + // we have no neighbors to fill for this query + if (offset == adj_ia[query_id + 1]) return; + adj_ja += offset; + } + + const value_t* x_ptr = X + (dim * query_id); + value_t local_x_ptr[dim]; +#pragma unroll + for (uint32_t i = 0; i < dim; ++i) { + local_x_ptr[i] = x_ptr[i]; + } + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(local_x_ptr, R + lane_k * dim, dim) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (dim * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(local_x_ptr, y_ptr, dim) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); + } - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - auto row_pos = atomicAdd(&column_index_smem, 1); - if (pass2) adj_ja[row_pos] = cur_candidate_ind; + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * dim; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(local_x_ptr, y_ptr, dim); + const bool in_range = (dist <= eps2); + if constexpr (write_pass) { + const int mask = raft::ballot(in_range); + if (in_range) { + const uint32_t index = R_1nn_cols[R_start_offset + i0 + lid]; + const uint32_t row_pos = __popc(mask & lid_mask); + adj_ja[row_pos] = index; + } + adj_ja += __popc(mask); + } else { + column_index_offset += (in_range); + } + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); } - } + } while (lane_mask); } - __syncthreads(); - if (threadIdx.x == 0 && !pass2) { adj_ia[blockIdx.x] = (value_idx)column_index_smem; } + if constexpr (!write_pass) { + value_idx row_sum = raft::warpReduce(column_index_offset); + if (lid == 0) adj_ia[query_id] = row_sum; + } } template -RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_index, +RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_reordered, const value_t* X, + const value_int n_queries, const value_int n_cols, - const value_t* R_dists, + const value_t* R, const value_int m, const value_t eps, const value_int n_landmarks, @@ -626,59 +888,110 @@ RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_index, const value_int max_k, value_idx* tmp) { - const value_t* x_ptr = X + (n_cols * blockIdx.x); - - __shared__ int column_count_smem; - - // initialize - if (threadIdx.x == 0) { column_count_smem = 0; } - - __syncthreads(); - - // we store all column indices in dense tmp store [blockDim.x * max_k] - value_int offset = blockIdx.x * max_k; - - for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { - // TODO: this might also be worth computing in-place here - value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; - - // prune all R's that can't be within eps - if (cur_R_dist - R_radius[cur_k] > eps) continue; - - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_k]; - value_idx R_stop_offset = R_indptr[cur_k + 1]; - - value_idx R_size = R_stop_offset - R_start_offset; - - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - int row_pos = atomicAdd(&column_count_smem, 1); - if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; + constexpr int num_warps = tpb / WarpSize; + + // process 1 query per warp + const uint32_t lid = raft::laneId(); + const uint32_t lid_mask = (1 << lid) - 1; + + // this should help the compiler to prevent branches + const int query_id = raft::shfl(blockIdx.x * num_warps + (threadIdx.x / WarpSize), 0); + + // this is an early out for a full warp + if (query_id >= n_queries) return; + + value_idx column_count = 0; + + const value_t* x_ptr = X + (n_cols * query_id); + tmp += query_id * max_k; + + // we omit the sqrt() in the inner distance compute + const value_t eps2 = eps * eps; + +#pragma nounroll + for (uint32_t cur_k0 = 0; cur_k0 < n_landmarks; cur_k0 += WarpSize) { + // Pre-compute landmark_dist & triangularization checks for 32 iterations + const uint32_t lane_k = cur_k0 + lid; + const value_t lane_R_dist_sq = lane_k < n_landmarks ? dfunc(x_ptr, R + lane_k * n_cols, n_cols) + : std::numeric_limits::max(); + const int lane_check = lane_k < n_landmarks + ? static_cast(lane_R_dist_sq <= squared(eps + R_radius[lane_k])) + : 0; + + int lane_mask = raft::ballot(lane_check); + if (lane_mask == 0) continue; + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + do { + // look for next k_offset + const uint32_t k_offset = __clz(lane_mask); + + const uint32_t cur_k = cur_k0 + k_offset; + + // The whole warp should iterate through the elements in the current R + const value_idx R_start_offset = R_indptr[cur_k]; + + // update lane_mask for next iteration - erase bits up to k_offset + lane_mask &= (0x7fffffff >> k_offset); + + const uint32_t R_size = R_indptr[cur_k + 1] - R_start_offset; + + // we have precomputed the query<->landmark distance + const value_t cur_R_dist = raft::sqrt(raft::shfl(lane_R_dist_sq, k_offset)); + + const uint32_t limit = Pow2::roundDown(R_size); + uint32_t i = limit + lid; + + // R_1nn_dists are sorted ascendingly for each landmark + // Iterating backwards, after pruning the first point w.r.t. triangle + // inequality all subsequent points can be pruned as well + const value_t* y_ptr = X_reordered + (n_cols * (R_start_offset + i)); + { + const value_t min_warp_dist = + limit < R_size ? R_1nn_dists[R_start_offset + limit] : cur_R_dist; + const value_t dist = + (i < R_size) ? dfunc(x_ptr, y_ptr, n_cols) : std::numeric_limits::max(); + const bool in_range = (dist <= eps2); + const int mask = raft::ballot(in_range); + if (in_range) { + auto row_pos = column_count + __popc(mask & lid_mask); + // we still continue to look for more hits to return valid vd + if (row_pos < max_k) { + auto index = R_1nn_cols[R_start_offset + i]; + tmp[row_pos] = index; + } + } + column_count += __popc(mask); + // abort in case subsequent points cannot possibly be in reach + i *= (cur_R_dist - min_warp_dist <= eps); } - } - - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { - int row_pos = atomicAdd(&column_count_smem, 1); - if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; + uint32_t i0 = raft::shfl(i, 0); + + while (i0 >= WarpSize) { + y_ptr -= WarpSize * n_cols; + i0 -= WarpSize; + const value_t min_warp_dist = R_1nn_dists[R_start_offset + i0]; + const value_t dist = dfunc(x_ptr, y_ptr, n_cols); + const bool in_range = (dist <= eps2); + const int mask = raft::ballot(in_range); + if (in_range) { + auto row_pos = column_count + __popc(mask & lid_mask); + // we still continue to look for more hits to return valid vd + if (row_pos < max_k) { + auto index = R_1nn_cols[R_start_offset + i0 + lid]; + tmp[row_pos] = index; + } + } + column_count += __popc(mask); + // abort in case subsequent points cannot possibly be in reach + i0 *= (cur_R_dist - min_warp_dist <= eps); } - } + } while (lane_mask); } - __syncthreads(); - if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } + if (lid == 0) vd[query_id] = column_count; } template @@ -723,7 +1036,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, if (k <= 32) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -743,7 +1056,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 64) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -762,7 +1075,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 128) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -782,7 +1095,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 256) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -802,7 +1115,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 512) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -822,7 +1135,7 @@ void rbc_low_dim_pass_one(raft::resources const& handle, else if (k <= 1024) block_rbc_kernel_registers <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, R_knn_inds, @@ -892,7 +1205,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -918,7 +1231,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -944,7 +1257,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -970,7 +1283,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 128, dims> <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -995,7 +1308,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1020,7 +1333,7 @@ void rbc_low_dim_pass_two(raft::resources const& handle, 8, 64, dims><<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, index.n, bitset.data(), @@ -1047,17 +1360,18 @@ void rbc_eps_pass(raft::resources const& handle, const value_t* query, const value_int n_query_rows, value_t eps, - const value_t* R_dists, + const value_t* R, dist_func& dfunc, bool* adj, value_idx* vd) { block_rbc_kernel_eps_dense <<>>( - index.get_X().data_handle(), + index.get_X_reordered().data_handle(), query, + n_query_rows, index.n, - R_dists, + R, index.m, eps, index.n_landmarks, @@ -1093,7 +1407,7 @@ void rbc_eps_pass(raft::resources const& handle, const value_int n_query_rows, value_t eps, value_int* max_k, - const value_t* R_dists, + const value_t* R, dist_func& dfunc, value_idx* adj_ia, value_idx* adj_ja, @@ -1104,22 +1418,61 @@ void rbc_eps_pass(raft::resources const& handle, if (adj_ja == nullptr) { // pass 1 -> only compute adj_ia / vd value_idx* vd_ptr = (vd != nullptr) ? vd : adj_ia; - block_rbc_kernel_eps_csr_pass - <<>>( - index.get_X().data_handle(), - query, - index.n, - R_dists, - index.m, - eps, - index.n_landmarks, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - index.get_R_radius().data_handle(), - dfunc, - vd_ptr, - nullptr); + if (index.n == 2) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } else if (index.n == 3) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } else { + block_rbc_kernel_eps_csr_pass + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + } thrust::exclusive_scan(resource::get_thrust_policy(handle), vd_ptr, @@ -1129,22 +1482,61 @@ void rbc_eps_pass(raft::resources const& handle, } else { // pass 2 -> fill in adj_ja - block_rbc_kernel_eps_csr_pass - <<>>( - index.get_X().data_handle(), - query, - index.n, - R_dists, - index.m, - eps, - index.n_landmarks, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - index.get_R_radius().data_handle(), - dfunc, - adj_ia, - adj_ja); + if (index.n == 2) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } else if (index.n == 3) { + block_rbc_kernel_eps_csr_pass_xd + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } else { + block_rbc_kernel_eps_csr_pass + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), + query, + n_query_rows, + index.n, + R, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } } } else { value_int max_k_in = *max_k; @@ -1153,11 +1545,12 @@ void rbc_eps_pass(raft::resources const& handle, rmm::device_uvector tmp(n_query_rows * max_k_in, resource::get_cuda_stream(handle)); block_rbc_kernel_eps_max_k - <<>>( - index.get_X().data_handle(), + <<(n_query_rows, 2), 64, 0, resource::get_cuda_stream(handle)>>>( + index.get_X_reordered().data_handle(), query, + n_query_rows, index.n, - R_dists, + R, index.m, eps, index.n_landmarks, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh index f38a3eeec9..eacd4aeff3 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh @@ -61,6 +61,21 @@ struct EuclideanFunc : public DistFunc { } }; +template +struct EuclideanSqFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + value_t sum_sq = 0; + for (value_int i = 0; i < n_dims; ++i) { + value_t diff = a[i] - b[i]; + sum_sq += diff * diff; + } + return sum_sq; + } +}; + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py index dff2e015a4..10d9c95ece 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py @@ -121,6 +121,8 @@ dist="raft::spatial::knn::detail::DistFunc", ) +euclideanSq="raft::spatial::knn::detail::EuclideanSqFunc", + types = dict( int64_float=("std::int64_t", "float"), #int64_double=("std::int64_t", "double"), @@ -156,7 +158,7 @@ f.write(macro_pass_eps) for type_path, (int_t, data_t) in types.items(): f.write(f"instantiate_raft_spatial_knn_detail_rbc_eps_pass(\n") - f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {distances['euclidean']});\n") + f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {euclideanSq});\n") f.write("#undef instantiate_raft_spatial_knn_detail_rbc_eps_pass\n") print(f"src/spatial/knn/detail/ball_cover/{path}") diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu index bbc1b55eb2..2a88862b2c 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu @@ -56,5 +56,5 @@ Mvalue_idx* vd) instantiate_raft_spatial_knn_detail_rbc_eps_pass( - std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanSqFunc); #undef instantiate_raft_spatial_knn_detail_rbc_eps_pass From 3b887cba346a009d5adcff562cf21076fd601b9c Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 12 Mar 2024 15:33:54 -0400 Subject: [PATCH 2/3] Remove hard-coding of RAPIDS version where possible (#2219) * Read `VERSION` file from CMake * Read `pylibraft.__version__` from docs build * Read `VERSION` file from shell scripts * Use environment variable substitution in Doxygen * Remove updates from `ci/release/update-version.sh` Issue: https://github.com/rapidsai/build-planning/issues/15 Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - Jake Awe (https://github.com/AyodeAwe) - Bradley Dice (https://github.com/bdice) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/2219 --- build.sh | 4 +++- ci/build_docs.sh | 6 ++++-- ci/release/update-version.sh | 17 +---------------- cpp/CMakeLists.txt | 7 ++----- cpp/doxygen/Doxyfile | 2 +- docs/source/conf.py | 10 +++++++--- fetch_rapids.cmake | 20 ------------------- python/pylibraft/CMakeLists.txt | 10 ++++------ python/raft-dask/CMakeLists.txt | 10 ++++------ rapids_config.cmake | 34 +++++++++++++++++++++++++++++++++ 10 files changed, 60 insertions(+), 60 deletions(-) delete mode 100644 fetch_rapids.cmake create mode 100644 rapids_config.cmake diff --git a/build.sh b/build.sh index e5df0af826..148d23c9c1 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # raft build scripts @@ -512,6 +512,8 @@ fi if hasArg docs; then set -x + export RAPIDS_VERSION="$(sed -E -e 's/^([0-9]{2})\.([0-9]{2})\.([0-9]{2}).*$/\1.\2.\3/' "${REPODIR}/VERSION")" + export RAPIDS_VERSION_MAJOR_MINOR="$(sed -E -e 's/^([0-9]{2})\.([0-9]{2})\.([0-9]{2}).*$/\1.\2/' "${REPODIR}/VERSION")" cd ${DOXYGEN_BUILD_DIR} doxygen Doxyfile cd ${SPHINX_BUILD_DIR} diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 4c07683642..3d72c815db 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail @@ -28,7 +28,9 @@ rapids-mamba-retry install \ pylibraft \ raft-dask -export RAPIDS_VERSION_NUMBER="24.04" +export RAPIDS_VERSION="$(rapids-version)" +export RAPIDS_VERSION_MAJOR_MINOR="$(rapids-version-major-minor)" +export RAPIDS_VERSION_NUMBER="$RAPIDS_VERSION_MAJOR_MINOR" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-logger "Build CPP docs" diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index d268c16e0a..636f637d0c 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. ######################## # RAFT Version Updater # ######################## @@ -36,23 +36,11 @@ function sed_runner() { sed -i.bak ''"$1"'' $2 && rm -f ${2}.bak } -sed_runner "s/set(RAPIDS_VERSION .*)/set(RAPIDS_VERSION \"${NEXT_SHORT_TAG}\")/g" cpp/CMakeLists.txt sed_runner "s/set(RAPIDS_VERSION .*)/set(RAPIDS_VERSION \"${NEXT_SHORT_TAG}\")/g" cpp/template/cmake/thirdparty/fetch_rapids.cmake -sed_runner "s/set(RAFT_VERSION .*)/set(RAFT_VERSION \"${NEXT_FULL_TAG}\")/g" cpp/CMakeLists.txt -sed_runner 's/'"pylibraft_version .*)"'/'"pylibraft_version ${NEXT_FULL_TAG})"'/g' python/pylibraft/CMakeLists.txt -sed_runner 's/'"raft_dask_version .*)"'/'"raft_dask_version ${NEXT_FULL_TAG})"'/g' python/raft-dask/CMakeLists.txt -sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cmake"'/g' fetch_rapids.cmake # Centralized version file update echo "${NEXT_FULL_TAG}" > VERSION -# Wheel testing script -sed_runner "s/branch-.*/branch-${NEXT_SHORT_TAG}/g" ci/test_wheel_raft_dask.sh - -# Docs update -sed_runner 's/version = .*/version = '"'${NEXT_SHORT_TAG}'"'/g' docs/source/conf.py -sed_runner 's/release = .*/release = '"'${NEXT_FULL_TAG}'"'/g' docs/source/conf.py - DEPENDENCIES=( dask-cuda pylibraft @@ -84,9 +72,6 @@ sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/ for FILE in .github/workflows/*.yaml; do sed_runner "/shared-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" done -sed_runner "s/RAPIDS_VERSION_NUMBER=\".*/RAPIDS_VERSION_NUMBER=\"${NEXT_SHORT_TAG}\"/g" ci/build_docs.sh - -sed_runner "/^PROJECT_NUMBER/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" cpp/doxygen/Doxyfile sed_runner "/^set(RAFT_VERSION/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" docs/source/build.md sed_runner "s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" docs/source/build.md diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 650bc1a059..638ceb3b45 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -10,11 +10,8 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION "24.04") -set(RAFT_VERSION "24.04.00") - cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -include(../fetch_rapids.cmake) +include(../rapids_config.cmake) include(rapids-cmake) include(rapids-cpm) include(rapids-export) @@ -34,7 +31,7 @@ endif() project( RAFT - VERSION ${RAFT_VERSION} + VERSION "${RAPIDS_VERSION}" LANGUAGES ${lang_list} ) diff --git a/cpp/doxygen/Doxyfile b/cpp/doxygen/Doxyfile index 779472d880..67566ac1f9 100644 --- a/cpp/doxygen/Doxyfile +++ b/cpp/doxygen/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "RAFT C++ API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = "24.04" +PROJECT_NUMBER = "$(RAPIDS_VERSION_MAJOR_MINOR)" # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/docs/source/conf.py b/docs/source/conf.py index 07dd4825fa..3dc9909b45 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,10 @@ -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. import os import sys +from packaging.version import Version + +import pylibraft # If extensions (or modules to document with autodoc) are in another # directory, add these directories to sys.path here. If the directory @@ -66,10 +69,11 @@ # |version| and |release|, also used in various other places throughout the # built documents. # +RAFT_VERSION = Version(pylibraft.__version__) # The short X.Y version. -version = '24.04' +version = f"{RAFT_VERSION.major:02}.{RAFT_VERSION.minor:02}" # The full version, including alpha/beta/rc tags. -release = '24.04.00' +release = f"{RAFT_VERSION.major:02}.{RAFT_VERSION.minor:02}.{RAFT_VERSION.micro:02}" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake deleted file mode 100644 index 1dca136c97..0000000000 --- a/fetch_rapids.cmake +++ /dev/null @@ -1,20 +0,0 @@ -# ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. -# ============================================================================= -if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.04/RAPIDS.cmake - ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake - ) -endif() - -include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index c17243728e..7a2d77041d 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -14,9 +14,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -include(../../fetch_rapids.cmake) - -set(pylibraft_version 24.04.00) +include(../../rapids_config.cmake) # We always need CUDA for pylibraft because the raft dependency brings in a header-only cuco # dependency that enables CUDA unconditionally. @@ -25,7 +23,7 @@ rapids_cuda_init_architectures(pylibraft) project( pylibraft - VERSION ${pylibraft_version} + VERSION "${RAPIDS_VERSION}" LANGUAGES CXX CUDA ) @@ -35,7 +33,7 @@ option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulti # If the user requested it we attempt to find RAFT. if(FIND_RAFT_CPP) - find_package(raft ${pylibraft_version} REQUIRED COMPONENTS compiled) + find_package(raft "${RAPIDS_VERSION}" REQUIRED COMPONENTS compiled) if(NOT TARGET raft::raft_lib) message( FATAL_ERROR diff --git a/python/raft-dask/CMakeLists.txt b/python/raft-dask/CMakeLists.txt index ff441e343e..58e5ae8104 100644 --- a/python/raft-dask/CMakeLists.txt +++ b/python/raft-dask/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -14,15 +14,13 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -set(raft_dask_version 24.04.00) - -include(../../fetch_rapids.cmake) +include(../../rapids_config.cmake) include(rapids-cuda) rapids_cuda_init_architectures(raft-dask-python) project( raft-dask-python - VERSION ${raft_dask_version} + VERSION "${RAPIDS_VERSION}" LANGUAGES CXX CUDA ) @@ -32,7 +30,7 @@ option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulti # If the user requested it we attempt to find RAFT. if(FIND_RAFT_CPP) - find_package(raft ${raft_dask_version} REQUIRED COMPONENTS distributed) + find_package(raft "${RAPIDS_VERSION}" REQUIRED COMPONENTS distributed) else() set(raft_FOUND OFF) endif() diff --git a/rapids_config.cmake b/rapids_config.cmake new file mode 100644 index 0000000000..c8077f7f4b --- /dev/null +++ b/rapids_config.cmake @@ -0,0 +1,34 @@ +# ============================================================================= +# Copyright (c) 2018-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= +file(READ "${CMAKE_CURRENT_LIST_DIR}/VERSION" _rapids_version) +if(_rapids_version MATCHES [[^([0-9][0-9])\.([0-9][0-9])\.([0-9][0-9])]]) + set(RAPIDS_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(RAPIDS_VERSION_MINOR "${CMAKE_MATCH_2}") + set(RAPIDS_VERSION_PATCH "${CMAKE_MATCH_3}") + set(RAPIDS_VERSION_MAJOR_MINOR "${RAPIDS_VERSION_MAJOR}.${RAPIDS_VERSION_MINOR}") + set(RAPIDS_VERSION "${RAPIDS_VERSION_MAJOR}.${RAPIDS_VERSION_MINOR}.${RAPIDS_VERSION_PATCH}") +else() + string(REPLACE "\n" "\n " _rapids_version_formatted " ${_rapids_version}") + message( + FATAL_ERROR + "Could not determine RAPIDS version. Contents of VERSION file:\n${_rapids_version_formatted}") +endif() + +if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake") + file( + DOWNLOAD + "https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION_MAJOR_MINOR}/RAPIDS.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake") +endif() +include("${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake") From 3e67a6c3bc5c73cd4a77a71bb5c2d911f8a7bcd4 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 13 Mar 2024 08:40:43 -0500 Subject: [PATCH 3/3] Add upper bound to prevent usage of NumPy 2 (#2222) NumPy 2 is expected to be released in the near future. For the RAPIDS 24.04 release, we will pin to `numpy>=1.23,<2.0a0`. This PR adds an upper bound to affected RAPIDS repositories. xref: https://github.com/rapidsai/build-planning/issues/29 Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Ray Douglass (https://github.com/raydouglass) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2222 --- conda/environments/all_cuda-118_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-122_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-122_arch-x86_64.yaml | 2 +- conda/recipes/pylibraft/meta.yaml | 2 +- dependencies.yaml | 2 +- python/pylibraft/pyproject.toml | 2 +- python/raft-dask/pyproject.toml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 40b031d677..634a6e7bcf 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - nvcc_linux-aarch64=11.8 - pre-commit diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 5485d09a37..da0aa74e16 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - nvcc_linux-64=11.8 - pre-commit diff --git a/conda/environments/all_cuda-122_arch-aarch64.yaml b/conda/environments/all_cuda-122_arch-aarch64.yaml index b688bf3952..f82f408759 100644 --- a/conda/environments/all_cuda-122_arch-aarch64.yaml +++ b/conda/environments/all_cuda-122_arch-aarch64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - pre-commit - pydata-sphinx-theme diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml index 013f852aee..06a6953ee5 100644 --- a/conda/environments/all_cuda-122_arch-x86_64.yaml +++ b/conda/environments/all_cuda-122_arch-x86_64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - pre-commit - pydata-sphinx-theme diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index 5c2829d297..e524a68f9e 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -65,7 +65,7 @@ requirements: {% endif %} - libraft {{ version }} - libraft-headers {{ version }} - - numpy >=1.23 + - numpy >=1.23,<2.0a0 - python x.x - rmm ={{ minor_version }} diff --git a/dependencies.yaml b/dependencies.yaml index 72aa3427d1..60f2306773 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -402,7 +402,7 @@ dependencies: common: - output_types: [conda, pyproject] packages: - - &numpy numpy>=1.23 + - &numpy numpy>=1.23,<2.0a0 - output_types: [conda] packages: - *rmm_conda diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index 6468220330..d687f70cf5 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -36,7 +36,7 @@ license = { text = "Apache 2.0" } requires-python = ">=3.9" dependencies = [ "cuda-python>=11.7.1,<12.0a0", - "numpy>=1.23", + "numpy>=1.23,<2.0a0", "rmm==24.4.*", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index b869290d5c..07e2463c5c 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "dask-cuda==24.4.*", "joblib>=0.11", "numba>=0.57", - "numpy>=1.23", + "numpy>=1.23,<2.0a0", "pylibraft==24.4.*", "rapids-dask-dependency==24.4.*", "ucx-py==0.37.*",