From 2f587b13b62d2dd8b15bf6d2902c65d67e7afc7b Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Thu, 22 Aug 2024 17:17:30 -0700 Subject: [PATCH 1/4] [FEA] Batching NN Descent (#2403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements batching NN Descent. It will be helpful for reducing device memory usages for large datasets (specifically if the dataset is kept on host). `index_params` now has... - `n_clusters`: number of clusters to make. Larger clusters reduce device memory usage. Default is 1, in which case it doesn't do the batched NND. ### Notes - The batching approach may have duplicate indices in the knn graph (in rare cases) because sometimes distances calculated for the same pair may be slightly different. This results in putting the same index far apart after sorting by distances, making it difficult to get unique indices (which is done by looking at 2 indices before the current one). - handled by adding a `max_duplicates` for `check_unique_indices` in tests ### Benchmarks - Dataset for NND (no batch and batch) is on host - Dataset for brute force knn is on device (but still won't be able to run with large datasets even if the data is put on the host because it brings the entire dataset to device anyway) - The dataset is just a slice of the wiki-all dataset (88M, 768) to test for different sizes Screenshot 2024-08-02 at 8 35 58 AM Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2403 --- .../raft/neighbors/detail/nn_descent.cuh | 48 +- .../neighbors/detail/nn_descent_batch.cuh | 701 ++++++++++++++++++ cpp/include/raft/neighbors/nn_descent.cuh | 30 +- .../raft/neighbors/nn_descent_types.hpp | 23 + cpp/test/CMakeLists.txt | 36 +- cpp/test/neighbors/ann_nn_descent.cuh | 155 +++- .../test_batch_float_uint32_t.cu | 30 + cpp/test/neighbors/ann_utils.cuh | 18 +- 8 files changed, 981 insertions(+), 60 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/nn_descent_batch.cuh create mode 100644 cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 9c37ee146d..02610f9afb 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -344,7 +345,9 @@ struct GnndGraph { ~GnndGraph(); }; -template +template > class GNND { public: GNND(raft::resources const& res, const BuildConfig& build_config); @@ -356,9 +359,10 @@ class GNND { Index_t* output_graph, bool return_distances, DistData_t* output_distances, - epilogue_op distance_epilogue = raft::identity_op()); + epilogue_op distance_epilogue = DistEpilogue()); ~GNND() = default; using ID_t = InternalID_t; + void reset(raft::resources const& res); private: void add_reverse_edges(Index_t* graph_ptr, @@ -366,7 +370,8 @@ class GNND { Index_t* d_rev_graph_ptr, int2* list_sizes, cudaStream_t stream = 0); - void local_join(cudaStream_t stream = 0, epilogue_op distance_epilogue = raft::identity_op()); + void local_join(cudaStream_t stream = 0, + epilogue_op distance_epilogue = DistEpilogue()); raft::resources const& res; @@ -701,7 +706,7 @@ __device__ __forceinline__ void remove_duplicates( // is 1024 and 1536 respectively, which means the bounds don't work anymore template , - typename epilogue_op = raft::identity_op> + typename epilogue_op = DistEpilogue> RAFT_KERNEL #ifdef __CUDA_ARCH__ #if (__CUDA_ARCH__) == 750 || ((__CUDA_ARCH__) >= 860 && (__CUDA_ARCH__) <= 890) @@ -1183,18 +1188,23 @@ GNND::GNND(raft::resources const& res, d_list_sizes_old_{raft::make_device_vector(res, nrow_)} { static_assert(NUM_SAMPLES <= 32); - - thrust::fill(thrust::device, - dists_buffer_.data_handle(), - dists_buffer_.data_handle() + dists_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, - reinterpret_cast(graph_buffer_.data_handle()), - reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); + raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); + auto graph_buffer_view = raft::make_device_matrix_view( + reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); + raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); + raft::matrix::fill(res, d_locks_.view(), 0); }; +template +void GNND::reset(raft::resources const& res) +{ + raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); + auto graph_buffer_view = raft::make_device_matrix_view( + reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); + raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); + raft::matrix::fill(res, d_locks_.view(), 0); +} + template void GNND::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev_graph_ptr, @@ -1246,6 +1256,7 @@ void GNND::build(Data_t* data, cudaStream_t stream = raft::resource::get_cuda_stream(res); nrow_ = nrow; + graph_.nrow = nrow; graph_.h_graph = (InternalID_t*)output_graph; cudaPointerAttributes data_ptr_attr; @@ -1384,6 +1395,7 @@ void GNND::build(Data_t* data, static_cast(build_config_.output_graph_degree)}; raft::matrix::slice( res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords); + raft::resource::sync_stream(res); } Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); @@ -1414,14 +1426,14 @@ void GNND::build(Data_t* data, template , typename Accessor = host_device_accessor, memory_type::host>> void build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset, index& idx, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, "The dataset size for GNND should be less than %d", @@ -1491,13 +1503,13 @@ void build(raft::resources const& res, template , typename Accessor = host_device_accessor, memory_type::host>> index build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; diff --git a/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh b/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh new file mode 100644 index 0000000000..78467c9741 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh @@ -0,0 +1,701 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + +#include "../nn_descent_types.hpp" +#include "nn_descent.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent::detail { + +// +// Run balanced kmeans on a subsample of the dataset to get centroids +// +template , memory_type::host>> +void get_balanced_kmeans_centroids( + raft::resources const& res, + raft::distance::DistanceType metric, + mdspan, row_major, Accessor> dataset, + raft::device_matrix_view centroids) +{ + size_t num_rows = static_cast(dataset.extent(0)); + size_t num_cols = static_cast(dataset.extent(1)); + size_t n_clusters = centroids.extent(0); + size_t num_subsamples = + std::min(static_cast(num_rows / n_clusters), static_cast(num_rows * 0.1)); + + auto d_subsample_dataset = + raft::make_device_matrix(res, num_subsamples, num_cols); + raft::matrix::sample_rows( + res, raft::random::RngState{0}, dataset, d_subsample_dataset.view()); + + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.metric = metric; + + auto d_subsample_dataset_const_view = + raft::make_device_matrix_view( + d_subsample_dataset.data_handle(), num_subsamples, num_cols); + raft::cluster::kmeans_balanced::fit( + res, kmeans_params, d_subsample_dataset_const_view, centroids); +} + +// +// Get the top k closest centroid indices for each data point +// Loads the data in batches onto device if data is on host for memory efficiency +// +template +void get_global_nearest_k( + raft::resources const& res, + size_t k, + size_t num_rows, + size_t n_clusters, + const T* dataset, + raft::host_matrix_view global_nearest_cluster, + raft::device_matrix_view centroids, + raft::distance::DistanceType metric) +{ + size_t num_cols = centroids.extent(1); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, dataset)); + float* ptr = reinterpret_cast(attr.devicePointer); + + if (ptr == nullptr) { // data on host + size_t num_batches = n_clusters; + size_t batch_size = (num_rows + n_clusters) / n_clusters; + + auto d_dataset_batch = + raft::make_device_matrix(res, batch_size, num_cols); + + auto nearest_clusters_idx = + raft::make_device_matrix(res, batch_size, k); + auto nearest_clusters_dist = + raft::make_device_matrix(res, batch_size, k); + + for (size_t i = 0; i < num_batches; i++) { + size_t batch_size_ = batch_size; + + if (i == num_batches - 1) { batch_size_ = num_rows - batch_size * i; } + raft::copy(d_dataset_batch.data_handle(), + dataset + i * batch_size * num_cols, + batch_size_ * num_cols, + resource::get_cuda_stream(res)); + + raft::neighbors::brute_force::fused_l2_knn( + res, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(d_dataset_batch.view()), + nearest_clusters_idx.view(), + nearest_clusters_dist.view(), + metric); + raft::copy(global_nearest_cluster.data_handle() + i * batch_size * k, + nearest_clusters_idx.data_handle(), + batch_size_ * k, + resource::get_cuda_stream(res)); + } + } else { // data on device + auto nearest_clusters_idx = + raft::make_device_matrix(res, num_rows, k); + auto nearest_clusters_dist = + raft::make_device_matrix(res, num_rows, k); + + raft::neighbors::brute_force::fused_l2_knn( + res, + raft::make_const_mdspan(centroids), + raft::make_device_matrix_view(dataset, num_rows, num_cols), + nearest_clusters_idx.view(), + nearest_clusters_dist.view(), + metric); + + raft::copy(global_nearest_cluster.data_handle(), + nearest_clusters_idx.data_handle(), + num_rows * k, + resource::get_cuda_stream(res)); + } +} + +// +// global_nearest_cluster [num_rows X k=2] : top 2 closest clusters for each data point +// inverted_indices [num_rows x k vector] : sparse vector for data indices for each cluster +// cluster_size [n_cluster] : cluster size for each cluster +// offset [n_cluster] : offset in inverted_indices for each cluster +// Loads the data in batches onto device if data is on host for memory efficiency +// +template +void get_inverted_indices(raft::resources const& res, + size_t n_clusters, + size_t& max_cluster_size, + size_t& min_cluster_size, + raft::host_matrix_view global_nearest_cluster, + raft::host_vector_view inverted_indices, + raft::host_vector_view cluster_size, + raft::host_vector_view offset) +{ + // build sparse inverted indices and get number of data points for each cluster + size_t num_rows = global_nearest_cluster.extent(0); + size_t k = global_nearest_cluster.extent(1); + + auto local_offset = raft::make_host_vector(n_clusters); + + max_cluster_size = 0; + min_cluster_size = std::numeric_limits::max(); + + thrust::fill( + thrust::host, cluster_size.data_handle(), cluster_size.data_handle() + n_clusters, 0); + thrust::fill( + thrust::host, local_offset.data_handle(), local_offset.data_handle() + n_clusters, 0); + + // TODO: this part isn't really a bottleneck but maybe worth trying omp parallel + // for with atomic add + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < k; j++) { + IdxT cluster_id = global_nearest_cluster(i, j); + cluster_size(cluster_id) += 1; + } + } + + offset(0) = 0; + for (size_t i = 1; i < n_clusters; i++) { + offset(i) = offset(i - 1) + cluster_size(i - 1); + } + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < k; j++) { + IdxT cluster_id = global_nearest_cluster(i, j); + inverted_indices(offset(cluster_id) + local_offset(cluster_id)) = i; + local_offset(cluster_id) += 1; + } + } + + max_cluster_size = static_cast( + *std::max_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); + min_cluster_size = static_cast( + *std::min_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); +} + +template +struct KeyValuePair { + KeyType key; + ValueType value; +}; + +template +struct CustomKeyComparator { + __device__ bool operator()(const KeyValuePair& a, + const KeyValuePair& b) const + { + if (a.key == b.key) { return a.value < b.value; } + return a.key < b.key; + } +}; + +template +RAFT_KERNEL merge_subgraphs(IdxT* cluster_data_indices, + size_t graph_degree, + size_t num_cluster_in_batch, + float* global_distances, + float* batch_distances, + IdxT* global_indices, + IdxT* batch_indices) +{ + size_t batch_row = blockIdx.x; + typedef cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD> + BlockMergeSortType; + __shared__ typename cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD>:: + TempStorage tmpSmem; + + extern __shared__ char sharedMem[]; + float* blockKeys = reinterpret_cast(sharedMem); + IdxT* blockValues = reinterpret_cast(&sharedMem[graph_degree * 2 * sizeof(float)]); + int16_t* uniqueMask = + reinterpret_cast(&sharedMem[graph_degree * 2 * (sizeof(float) + sizeof(IdxT))]); + + if (batch_row < num_cluster_in_batch) { + // load batch or global depending on threadIdx + size_t global_row = cluster_data_indices[batch_row]; + + KeyValuePair threadKeyValuePair[ITEMS_PER_THREAD]; + + size_t halfway = BLOCK_SIZE / 2; + size_t do_global = threadIdx.x < halfway; + + float* distances; + IdxT* indices; + + if (do_global) { + distances = global_distances; + indices = global_indices; + } else { + distances = batch_distances; + indices = batch_indices; + } + + size_t idxBase = (threadIdx.x * do_global + (threadIdx.x - halfway) * (1lu - do_global)) * + static_cast(ITEMS_PER_THREAD); + size_t arrIdxBase = (global_row * do_global + batch_row * (1lu - do_global)) * graph_degree; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId < graph_degree) { + threadKeyValuePair[i].key = distances[arrIdxBase + colId]; + threadKeyValuePair[i].value = indices[arrIdxBase + colId]; + } else { + threadKeyValuePair[i].key = std::numeric_limits::max(); + threadKeyValuePair[i].value = std::numeric_limits::max(); + } + } + + __syncthreads(); + + BlockMergeSortType(tmpSmem).Sort(threadKeyValuePair, CustomKeyComparator{}); + + // load sorted result into shared memory to get unique values + idxBase = threadIdx.x * ITEMS_PER_THREAD; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId < 2 * graph_degree) { + blockKeys[colId] = threadKeyValuePair[i].key; + blockValues[colId] = threadKeyValuePair[i].value; + } + } + + __syncthreads(); + + // get unique mask + if (threadIdx.x == 0) { uniqueMask[0] = 1; } + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId > 0 && colId < 2 * graph_degree) { + uniqueMask[colId] = static_cast(blockValues[colId] != blockValues[colId - 1]); + } + } + + __syncthreads(); + + // prefix sum + if (threadIdx.x == 0) { + for (int i = 1; i < 2 * graph_degree; i++) { + uniqueMask[i] += uniqueMask[i - 1]; + } + } + + __syncthreads(); + // load unique values to global memory + if (threadIdx.x == 0) { + global_distances[global_row * graph_degree] = blockKeys[0]; + global_indices[global_row * graph_degree] = blockValues[0]; + } + + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId > 0 && colId < 2 * graph_degree) { + bool is_unique = uniqueMask[colId] != uniqueMask[colId - 1]; + int16_t global_colId = uniqueMask[colId] - 1; + if (is_unique && static_cast(global_colId) < graph_degree) { + global_distances[global_row * graph_degree + global_colId] = blockKeys[colId]; + global_indices[global_row * graph_degree + global_colId] = blockValues[colId]; + } + } + } + } +} + +// +// builds knn graph using NN Descent and merge with global graph +// +template , + typename Accessor = + host_device_accessor, memory_type::host>> +void build_and_merge(raft::resources const& res, + const index_params& params, + size_t num_data_in_cluster, + size_t graph_degree, + size_t int_graph_node_degree, + T* cluster_data, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_d, + float* global_distances_d, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + GNND& nnd, + epilogue_op distance_epilogue) +{ + nnd.build( + cluster_data, num_data_in_cluster, int_graph, true, batch_distances_d, distance_epilogue); + + // remap indices +#pragma omp parallel for + for (size_t i = 0; i < num_data_in_cluster; i++) { + for (size_t j = 0; j < graph_degree; j++) { + size_t local_idx = int_graph[i * int_graph_node_degree + j]; + batch_indices_h[i * graph_degree + j] = inverted_indices[local_idx]; + } + } + + raft::copy(batch_indices_d, + batch_indices_h, + num_data_in_cluster * graph_degree, + raft::resource::get_cuda_stream(res)); + + size_t num_elems = graph_degree * 2; + size_t sharedMemSize = num_elems * (sizeof(float) + sizeof(IdxT) + sizeof(int16_t)); + + if (num_elems <= 128) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 512) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 1024) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 2048) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else { + // this is as far as we can get due to the shared mem usage of cub::BlockMergeSort + RAFT_FAIL("The degree of knn is too large (%lu). It must be smaller than 1024", graph_degree); + } + raft::resource::sync_stream(res); +} + +// +// For each cluster, gather the data samples that belong to that cluster, and +// call build_and_merge +// +template > +void cluster_nnd(raft::resources const& res, + const index_params& params, + size_t graph_degree, + size_t extended_graph_degree, + size_t max_cluster_size, + raft::host_matrix_view dataset, + IdxT* offsets, + IdxT* cluster_size, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_h, + float* global_distances_h, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + const BuildConfig& build_config, + epilogue_op distance_epilogue) +{ + size_t num_rows = dataset.extent(0); + size_t num_cols = dataset.extent(1); + + GNND nnd(res, build_config); + + auto cluster_data_matrix = + raft::make_host_matrix(max_cluster_size, num_cols); + + for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { + RAFT_LOG_DEBUG( + "# Data on host. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); + size_t num_data_in_cluster = cluster_size[cluster_id]; + size_t offset = offsets[cluster_id]; + +#pragma omp parallel for + for (size_t i = 0; i < num_data_in_cluster; i++) { + for (size_t j = 0; j < num_cols; j++) { + size_t global_row = (inverted_indices + offset)[i]; + cluster_data_matrix(i, j) = dataset(global_row, j); + } + } + + distance_epilogue.preprocess_for_batch(cluster_data_indices + offset, num_data_in_cluster); + + build_and_merge(res, + params, + num_data_in_cluster, + graph_degree, + extended_graph_degree, + cluster_data_matrix.data_handle(), + cluster_data_indices + offset, + int_graph, + inverted_indices + offset, + global_indices_h, + global_distances_h, + batch_indices_h, + batch_indices_d, + batch_distances_d, + nnd, + distance_epilogue); + nnd.reset(res); + } +} + +template > +void cluster_nnd(raft::resources const& res, + const index_params& params, + size_t graph_degree, + size_t extended_graph_degree, + size_t max_cluster_size, + raft::device_matrix_view dataset, + IdxT* offsets, + IdxT* cluster_size, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_h, + float* global_distances_h, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + const BuildConfig& build_config, + epilogue_op distance_epilogue) +{ + size_t num_rows = dataset.extent(0); + size_t num_cols = dataset.extent(1); + + GNND nnd(res, build_config); + + auto cluster_data_matrix = + raft::make_device_matrix(res, max_cluster_size, num_cols); + + for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { + RAFT_LOG_DEBUG( + "# Data on device. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); + size_t num_data_in_cluster = cluster_size[cluster_id]; + size_t offset = offsets[cluster_id]; + + auto cluster_data_view = raft::make_device_matrix_view( + cluster_data_matrix.data_handle(), num_data_in_cluster, num_cols); + auto cluster_data_indices_view = raft::make_device_vector_view( + cluster_data_indices + offset, num_data_in_cluster); + distance_epilogue.preprocess_for_batch(cluster_data_indices + offset, num_data_in_cluster); + + auto dataset_IdxT = + raft::make_device_matrix_view(dataset.data_handle(), num_rows, num_cols); + raft::matrix::gather(res, dataset_IdxT, cluster_data_indices_view, cluster_data_view); + + build_and_merge(res, + params, + num_data_in_cluster, + graph_degree, + extended_graph_degree, + cluster_data_view.data_handle(), + cluster_data_indices + offset, + int_graph, + inverted_indices + offset, + global_indices_h, + global_distances_h, + batch_indices_h, + batch_indices_d, + batch_distances_d, + nnd, + distance_epilogue); + nnd.reset(res); + } +} + +template , + typename Accessor = + host_device_accessor, memory_type::host>> +index batch_build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset, + epilogue_op distance_epilogue = DistEpilogue()) +{ + size_t graph_degree = params.graph_degree; + size_t intermediate_degree = params.intermediate_graph_degree; + + size_t num_rows = static_cast(dataset.extent(0)); + size_t num_cols = static_cast(dataset.extent(1)); + + auto centroids = + raft::make_device_matrix(res, params.n_clusters, num_cols); + get_balanced_kmeans_centroids(res, params.metric, dataset, centroids.view()); + + size_t k = 2; + auto global_nearest_cluster = raft::make_host_matrix(num_rows, k); + get_global_nearest_k(res, + k, + num_rows, + params.n_clusters, + dataset.data_handle(), + global_nearest_cluster.view(), + centroids.view(), + params.metric); + + auto inverted_indices = raft::make_host_vector(num_rows * k); + auto cluster_size = raft::make_host_vector(params.n_clusters); + auto offset = raft::make_host_vector(params.n_clusters); + + size_t max_cluster_size, min_cluster_size; + get_inverted_indices(res, + params.n_clusters, + max_cluster_size, + min_cluster_size, + global_nearest_cluster.view(), + inverted_indices.view(), + cluster_size.view(), + offset.view()); + + if (intermediate_degree >= min_cluster_size) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than minimum cluster size, reducing it to %lu", + dataset.extent(0)); + intermediate_degree = min_cluster_size - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + size_t extended_graph_degree = + align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); + size_t extended_intermediate_degree = align32::roundUp( + static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); + + auto int_graph = raft::make_host_matrix( + max_cluster_size, static_cast(extended_graph_degree)); + + BuildConfig build_config{.max_dataset_size = max_cluster_size, + .dataset_dim = num_cols, + .node_degree = extended_graph_degree, + .internal_node_degree = extended_intermediate_degree, + .max_iterations = params.max_iterations, + .termination_threshold = params.termination_threshold, + .output_graph_degree = graph_degree}; + + auto global_indices_h = raft::make_managed_matrix(res, num_rows, graph_degree); + auto global_distances_h = raft::make_managed_matrix(res, num_rows, graph_degree); + + thrust::fill(thrust::host, + global_indices_h.data_handle(), + global_indices_h.data_handle() + num_rows * graph_degree, + std::numeric_limits::max()); + thrust::fill(thrust::host, + global_distances_h.data_handle(), + global_distances_h.data_handle() + num_rows * graph_degree, + std::numeric_limits::max()); + + auto batch_indices_h = + raft::make_host_matrix(max_cluster_size, graph_degree); + auto batch_indices_d = + raft::make_device_matrix(res, max_cluster_size, graph_degree); + auto batch_distances_d = + raft::make_device_matrix(res, max_cluster_size, graph_degree); + + auto cluster_data_indices = raft::make_device_vector(res, num_rows * k); + raft::copy(cluster_data_indices.data_handle(), + inverted_indices.data_handle(), + num_rows * k, + resource::get_cuda_stream(res)); + + cluster_nnd(res, + params, + graph_degree, + extended_graph_degree, + max_cluster_size, + dataset, + offset.data_handle(), + cluster_size.data_handle(), + cluster_data_indices.data_handle(), + int_graph.data_handle(), + inverted_indices.data_handle(), + global_indices_h.data_handle(), + global_distances_h.data_handle(), + batch_indices_h.data_handle(), + batch_indices_d.data_handle(), + batch_distances_d.data_handle(), + build_config, + distance_epilogue); + + index global_idx{ + res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; + + raft::copy(global_idx.graph().data_handle(), + global_indices_h.data_handle(), + num_rows * graph_degree, + raft::resource::get_cuda_stream(res)); + if (params.return_distances && global_idx.distances().has_value()) { + raft::copy(global_idx.distances().value().data_handle(), + global_distances_h.data_handle(), + num_rows * graph_degree, + raft::resource::get_cuda_stream(res)); + } + return global_idx; +} + +} // namespace raft::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index a46a2006d6..6c08546d3f 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -17,9 +17,11 @@ #pragma once #include "detail/nn_descent.cuh" +#include "detail/nn_descent_batch.cuh" #include #include +#include namespace raft::neighbors::experimental::nn_descent { @@ -57,13 +59,17 @@ namespace raft::neighbors::experimental::nn_descent { * @param[in] distance_epilogue epilogue operation for distances * @return index index containing all-neighbors knn graph in host memory */ -template +template > index build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { - return detail::build(res, params, dataset, distance_epilogue); + if (params.n_clusters > 1) { + return detail::batch_build(res, params, dataset, distance_epilogue); + } else { + return detail::build(res, params, dataset, distance_epilogue); + } } /** @@ -98,12 +104,12 @@ index build(raft::resources const& res, * in host memory * @param[in] distance_epilogue epilogue operation for distances */ -template +template > void build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset, index& idx, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { detail::build(res, params, dataset, idx, distance_epilogue); } @@ -137,13 +143,17 @@ void build(raft::resources const& res, * @param[in] distance_epilogue epilogue operation for distances * @return index index containing all-neighbors knn graph in host memory */ -template +template > index build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { - return detail::build(res, params, dataset, distance_epilogue); + if (params.n_clusters > 1) { + return detail::batch_build(res, params, dataset, distance_epilogue); + } else { + return detail::build(res, params, dataset, distance_epilogue); + } } /** @@ -178,12 +188,12 @@ index build(raft::resources const& res, * in host memory * @param[in] distance_epilogue epilogue operation for distances */ -template +template > void build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset, index& idx, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { detail::build(res, params, dataset, idx, distance_epilogue); } diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 5d23ff2c2e..eb01a423be 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -48,6 +48,20 @@ namespace raft::neighbors::experimental::nn_descent { * `max_iterations`: The number of iterations that nn-descent will refine * the graph for. More iterations produce a better quality graph at cost of performance * `termination_threshold`: The delta at which nn-descent will terminate its iterations + * `return_distances`: boolean whether to return distances + * `n_clusters`: NN Descent offers batching a dataset to save GPU memory usage. + * Increase `n_clusters` to save GPU memory and run NN Descent with large datasets. + * Most effective when data is put on CPU memory. + * Setting this number too big may results in too much overhead of doing multiple + * iterations of graph building. Recommend starting at 4 and continue to increase + * depending on desired GPU memory usages. + * (Specifically, with n_clusters > 1, the NN Descent build algorithm will first + * find n_clusters number of cluster centroids of the dataset, then consider data + * points that belong to each cluster as a batch. + * Then we build knn subgraphs on each batch of the entire data. This is especially + * useful when the dataset is put on host, since only a subset of the data will + * be on GPU at once, enabling running NN Descent with large datasets that do not + * fit on the GPU as a whole.) * */ struct index_params : ann::index_params { @@ -56,6 +70,7 @@ struct index_params : ann::index_params { size_t max_iterations = 20; // Number of nn-descent iterations. float termination_threshold = 0.0001; // Termination threshold of nn-descent. bool return_distances = false; // return distances if true + size_t n_clusters = 1; // defaults to not using any batching }; /** @@ -178,6 +193,14 @@ struct index : ann::index { bool return_distances_; }; +template +struct DistEpilogue : raft::identity_op { + __host__ void preprocess_for_batch(value_idx* cluster_indices, size_t num_data_in_cluster) + { + return; + } +}; + /** @} */ } // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index e3af6ebb78..a497e6d3ba 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -96,17 +96,8 @@ endfunction() if(BUILD_TESTS) ConfigureTest( - NAME - CLUSTER_TEST - PATH - cluster/kmeans.cu - cluster/kmeans_balanced.cu - cluster/kmeans_find_k.cu - cluster/cluster_solvers.cu - cluster/linkage.cu - cluster/spectral.cu - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME CLUSTER_TEST PATH cluster/kmeans.cu cluster/kmeans_balanced.cu cluster/kmeans_find_k.cu + cluster/cluster_solvers.cu cluster/linkage.cu cluster/spectral.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -144,8 +135,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME CORE_TEST PATH core/stream_view.cpp core/mdspan_copy.cpp LIB - EXPLICIT_INSTANTIATE_ONLY NOCUDA + NAME CORE_TEST PATH core/stream_view.cpp core/mdspan_copy.cpp LIB EXPLICIT_INSTANTIATE_ONLY + NOCUDA ) ConfigureTest( @@ -301,8 +292,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SOLVERS_TEST PATH cluster/cluster_solvers_deprecated.cu linalg/eigen_solvers.cu - lap/lap.cu sparse/mst.cu LIB EXPLICIT_INSTANTIATE_ONLY + NAME SOLVERS_TEST PATH cluster/cluster_solvers_deprecated.cu linalg/eigen_solvers.cu lap/lap.cu + sparse/mst.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -331,19 +322,13 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SPARSE_DIST_TEST PATH sparse/dist_coo_spmv.cu sparse/distance.cu - sparse/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY + NAME SPARSE_DIST_TEST PATH sparse/dist_coo_spmv.cu sparse/distance.cu sparse/gram.cu LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( - NAME - SPARSE_NEIGHBORS_TEST - PATH - sparse/neighbors/cross_component_nn.cu - sparse/neighbors/brute_force.cu - sparse/neighbors/knn_graph.cu - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME SPARSE_NEIGHBORS_TEST PATH sparse/neighbors/cross_component_nn.cu + sparse/neighbors/brute_force.cu sparse/neighbors/knn_graph.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -455,6 +440,7 @@ if(BUILD_TESTS) neighbors/ann_nn_descent/test_float_uint32_t.cu neighbors/ann_nn_descent/test_int8_t_uint32_t.cu neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu + neighbors/ann_nn_descent/test_batch_float_uint32_t.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index f74cadb415..2f9d4e252b 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -42,6 +42,15 @@ struct AnnNNDescentInputs { double min_recall; }; +struct AnnNNDescentBatchInputs { + std::pair recall_cluster; + int n_rows; + int dim; + int graph_degree; + raft::distance::DistanceType metric; + bool host_dataset; +}; + inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& p) { os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree @@ -50,6 +59,14 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& return os; } +inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentBatchInputs& p) +{ + os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << ", clusters=" << p.recall_cluster.second << std::endl; + return os; +} + template class AnnNNDescentTest : public ::testing::TestWithParam { public: @@ -105,7 +122,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam { raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); auto database_host_view = raft::make_host_matrix_view( (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - auto index = nn_descent::build(handle_, index_params, database_host_view); + index index{handle_, ps.n_rows, static_cast(ps.graph_degree), true}; + nn_descent::build( + handle_, index_params, database_host_view, index, DistEpilogue()); raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); if (index.distances().has_value()) { @@ -116,7 +135,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam { } } else { - auto index = nn_descent::build(handle_, index_params, database_view); + index index{handle_, ps.n_rows, static_cast(ps.graph_degree), true}; + nn_descent::build( + handle_, index_params, database_view, index, DistEpilogue()); raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); if (index.distances().has_value()) { @@ -168,6 +189,127 @@ class AnnNNDescentTest : public ::testing::TestWithParam { rmm::device_uvector database; }; +template +class AnnNNDescentBatchTest : public ::testing::TestWithParam { + public: + AnnNNDescentBatchTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_) + { + } + + void testNNDescentBatch() + { + size_t queries_size = ps.n_rows * ps.graph_degree; + std::vector indices_NNDescent(queries_size); + std::vector distances_NNDescent(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + database.data(), + database.data(), + ps.n_rows, + ps.n_rows, + ps.dim, + ps.graph_degree, + ps.metric); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + { + nn_descent::index_params index_params; + index_params.metric = ps.metric; + index_params.graph_degree = ps.graph_degree; + index_params.intermediate_graph_degree = 2 * ps.graph_degree; + index_params.max_iterations = 10; + index_params.return_distances = true; + index_params.n_clusters = ps.recall_cluster.second; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + { + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + auto index = nn_descent::build( + handle_, index_params, database_host_view, DistEpilogue()); + raft::copy( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + + } else { + auto index = nn_descent::build( + handle_, index_params, database_view, DistEpilogue()); + raft::copy( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + }; + } + resource::sync_stream(handle_); + } + double min_recall = ps.recall_cluster.first; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_NNDescent, + distances_naive, + distances_NNDescent, + ps.n_rows, + ps.graph_degree, + 0.01, + min_recall, + true, + static_cast(ps.graph_degree * 0.1))); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnNNDescentBatchInputs ps; + rmm::device_uvector database; +}; + const std::vector inputs = raft::util::itertools::product( {1000, 2000}, // n_rows {3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim @@ -176,4 +318,13 @@ const std::vector inputs = raft::util::itertools::product inputsBatch = + raft::util::itertools::product( + {std::make_pair(0.9, 3lu), std::make_pair(0.9, 2lu)}, // min_recall, n_clusters + {4000, 5000}, // n_rows + {192, 512}, // dim + {32, 64}, // graph_degree + {raft::distance::DistanceType::L2Expanded}, + {false, true}); + } // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu new file mode 100644 index 0000000000..c6f56e8c39 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_nn_descent.cuh" + +#include + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentBatchTest AnnNNDescentBatchTestF_U32; +TEST_P(AnnNNDescentBatchTestF_U32, AnnNNDescentBatch) { this->testNNDescentBatch(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentBatchTest, + AnnNNDescentBatchTestF_U32, + ::testing::ValuesIn(inputsBatch)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 2139e97428..82e3ace9da 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -153,9 +153,13 @@ auto calc_recall(const std::vector& expected_idx, /** check uniqueness of indices */ template -auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +auto check_unique_indices(const std::vector& actual_idx, + size_t rows, + size_t cols, + size_t max_duplicates) { size_t max_count; + size_t dup_count = 0lu; std::set unique_indices; for (size_t i = 0; i < rows; ++i) { unique_indices.clear(); @@ -168,8 +172,11 @@ auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t } else if (unique_indices.find(act_idx) == unique_indices.end()) { unique_indices.insert(act_idx); } else { - return testing::AssertionFailure() - << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + dup_count++; + if (dup_count > max_duplicates) { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + } } } } @@ -252,7 +259,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t cols, double eps, double min_recall, - bool test_unique = true) -> testing::AssertionResult + bool test_unique = true, + size_t max_duplicates = 0) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -270,7 +278,7 @@ auto eval_neighbours(const std::vector& expected_idx, << min_recall << "); eps = " << eps << ". "; } if (test_unique) - return check_unique_indices(actual_idx, rows, cols); + return check_unique_indices(actual_idx, rows, cols, max_duplicates); else return testing::AssertionSuccess(); } From 2dafe79af0216ad7189415ca3e2864277194c3e9 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Fri, 23 Aug 2024 07:49:30 +0200 Subject: [PATCH 2/4] Remove NumPy <2 pin (#2414) This PR removes the NumPy<2 pin which is expected to work for RAPIDS projects once CuPy 13.3.0 is released (CuPy 13.2.0 had some issues preventing the use with NumPy 2). Authors: - Sebastian Berg (https://github.com/seberg) - https://github.com/jakirkham Approvers: - James Lamb (https://github.com/jameslamb) - https://github.com/jakirkham URL: https://github.com/rapidsai/raft/pull/2414 --- conda/environments/all_cuda-118_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-118_arch-x86_64.yaml | 2 +- conda/environments/all_cuda-125_arch-aarch64.yaml | 2 +- conda/environments/all_cuda-125_arch-x86_64.yaml | 2 +- conda/recipes/pylibraft/meta.yaml | 2 +- dependencies.yaml | 3 +-- python/pylibraft/pyproject.toml | 2 +- python/raft-dask/pyproject.toml | 1 - 8 files changed, 7 insertions(+), 9 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index ef8524dce3..462874a7e7 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - nvcc_linux-aarch64=11.8 - pre-commit diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 6ffb27bb29..cfd974a6a8 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - nvcc_linux-64=11.8 - pre-commit diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index fd0e380a1c..82e391e9ae 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - pre-commit - pydata-sphinx-theme diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index ad4ecb7ff2..0389427d13 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - pre-commit - pydata-sphinx-theme diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index 4ef85fc0e5..9d91af712e 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -65,7 +65,7 @@ requirements: {% endif %} - libraft {{ version }} - libraft-headers {{ version }} - - numpy >=1.23,<2.0a0 + - numpy >=1.23,<3.0a0 - python x.x - rmm ={{ minor_version }} diff --git a/dependencies.yaml b/dependencies.yaml index 92c6d98414..a5b818e0f3 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -473,7 +473,7 @@ dependencies: common: - output_types: [conda, pyproject] packages: - - &numpy numpy>=1.23,<2.0a0 + - numpy>=1.23,<3.0a0 - output_types: [conda] packages: - *rmm_unsuffixed @@ -513,7 +513,6 @@ dependencies: - dask-cuda==24.10.*,>=0.0.0a0 - joblib>=0.11 - numba>=0.57 - - *numpy - rapids-dask-dependency==24.10.*,>=0.0.0a0 - output_types: conda packages: diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index 9a826e53c6..9a6e454e7b 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -32,7 +32,7 @@ license = { text = "Apache 2.0" } requires-python = ">=3.9" dependencies = [ "cuda-python", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "nvidia-cublas", "nvidia-curand", "nvidia-cusolver", diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index 4fadfa5c9f..4e2e45206f 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -35,7 +35,6 @@ dependencies = [ "distributed-ucxx==0.40.*,>=0.0.0a0", "joblib>=0.11", "numba>=0.57", - "numpy>=1.23,<2.0a0", "pylibraft==24.10.*,>=0.0.0a0", "rapids-dask-dependency==24.10.*,>=0.0.0a0", "ucx-py==0.40.*,>=0.0.0a0", From 73dc36c5304b338ee8d12d775522dbf4d75a014a Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 23 Aug 2024 00:51:31 -0500 Subject: [PATCH 3/4] Drop Python 3.9 support (#2417) Contributes to https://github.com/rapidsai/build-planning/issues/88 Finishes the work of dropping Python 3.9 support. This project stopped building / testing against Python 3.9 as of https://github.com/rapidsai/shared-workflows/pull/235. This PR updates configuration and docs to reflect that. ## Notes for Reviewers ### How I tested this Checked that there were no remaining uses like this: ```shell git grep -E '3\.9' git grep '39' git grep 'py39' ``` And similar for variations on Python 3.8 (to catch things that were missed the last time this was done). Authors: - James Lamb (https://github.com/jameslamb) Approvers: - https://github.com/jakirkham URL: https://github.com/rapidsai/raft/pull/2417 --- dependencies.yaml | 6 +----- pyproject.toml | 2 +- python/pylibraft/pyproject.toml | 3 +-- python/raft-ann-bench/pyproject.toml | 3 +-- python/raft-dask/pyproject.toml | 3 +-- 5 files changed, 5 insertions(+), 12 deletions(-) diff --git a/dependencies.yaml b/dependencies.yaml index a5b818e0f3..e4e361548f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -454,10 +454,6 @@ dependencies: specific: - output_types: conda matrices: - - matrix: - py: "3.9" - packages: - - python=3.9 - matrix: py: "3.10" packages: @@ -468,7 +464,7 @@ dependencies: - python=3.11 - matrix: packages: - - python>=3.9,<3.12 + - python>=3.10,<3.12 run_pylibraft: common: - output_types: [conda, pyproject] diff --git a/pyproject.toml b/pyproject.toml index 1e4ba0b369..5042113388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 79 -target-version = ["py39"] +target-version = ["py310"] include = '\.py?$' force-exclude = ''' /( diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index 9a6e454e7b..14f2ba7d2f 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -29,7 +29,7 @@ authors = [ { name = "NVIDIA Corporation" }, ] license = { text = "Apache 2.0" } -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "cuda-python", "numpy>=1.23,<3.0a0", @@ -42,7 +42,6 @@ dependencies = [ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] diff --git a/python/raft-ann-bench/pyproject.toml b/python/raft-ann-bench/pyproject.toml index d22dd567fe..fa5781893b 100644 --- a/python/raft-ann-bench/pyproject.toml +++ b/python/raft-ann-bench/pyproject.toml @@ -16,7 +16,7 @@ authors = [ { name = "NVIDIA Corporation" }, ] license = { text = "Apache 2.0" } -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ @@ -25,7 +25,6 @@ classifiers = [ "Topic :: Scientific/Engineering", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index 4e2e45206f..44012b5f10 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -29,7 +29,7 @@ authors = [ { name = "NVIDIA Corporation" }, ] license = { text = "Apache 2.0" } -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "dask-cuda==24.10.*,>=0.0.0a0", "distributed-ucxx==0.40.*,>=0.0.0a0", @@ -42,7 +42,6 @@ dependencies = [ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] From e17dcdfdbe6985d9a9dfb4bd76d047dc08481630 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Tue, 27 Aug 2024 15:50:10 -0400 Subject: [PATCH 4/4] Update rapidsai/pre-commit-hooks (#2420) This PR updates rapidsai/pre-commit-hooks to the version 0.4.0. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/raft/pull/2420 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9aed6bf387..88be628e55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -99,7 +99,7 @@ repos: hooks: - id: check-json - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.3.1 + rev: v0.4.0 hooks: - id: verify-copyright files: |