diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9aed6bf387..88be628e55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -99,7 +99,7 @@ repos: hooks: - id: check-json - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.3.1 + rev: v0.4.0 hooks: - id: verify-copyright files: | diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index ef8524dce3..462874a7e7 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - nvcc_linux-aarch64=11.8 - pre-commit diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 6ffb27bb29..cfd974a6a8 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - nvcc_linux-64=11.8 - pre-commit diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index fd0e380a1c..82e391e9ae 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - pre-commit - pydata-sphinx-theme diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index ad4ecb7ff2..0389427d13 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23,<2.0a0 +- numpy>=1.23,<3.0a0 - numpydoc - pre-commit - pydata-sphinx-theme diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index 4ef85fc0e5..9d91af712e 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -65,7 +65,7 @@ requirements: {% endif %} - libraft {{ version }} - libraft-headers {{ version }} - - numpy >=1.23,<2.0a0 + - numpy >=1.23,<3.0a0 - python x.x - rmm ={{ minor_version }} diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index 9c37ee146d..02610f9afb 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -344,7 +345,9 @@ struct GnndGraph { ~GnndGraph(); }; -template +template > class GNND { public: GNND(raft::resources const& res, const BuildConfig& build_config); @@ -356,9 +359,10 @@ class GNND { Index_t* output_graph, bool return_distances, DistData_t* output_distances, - epilogue_op distance_epilogue = raft::identity_op()); + epilogue_op distance_epilogue = DistEpilogue()); ~GNND() = default; using ID_t = InternalID_t; + void reset(raft::resources const& res); private: void add_reverse_edges(Index_t* graph_ptr, @@ -366,7 +370,8 @@ class GNND { Index_t* d_rev_graph_ptr, int2* list_sizes, cudaStream_t stream = 0); - void local_join(cudaStream_t stream = 0, epilogue_op distance_epilogue = raft::identity_op()); + void local_join(cudaStream_t stream = 0, + epilogue_op distance_epilogue = DistEpilogue()); raft::resources const& res; @@ -701,7 +706,7 @@ __device__ __forceinline__ void remove_duplicates( // is 1024 and 1536 respectively, which means the bounds don't work anymore template , - typename epilogue_op = raft::identity_op> + typename epilogue_op = DistEpilogue> RAFT_KERNEL #ifdef __CUDA_ARCH__ #if (__CUDA_ARCH__) == 750 || ((__CUDA_ARCH__) >= 860 && (__CUDA_ARCH__) <= 890) @@ -1183,18 +1188,23 @@ GNND::GNND(raft::resources const& res, d_list_sizes_old_{raft::make_device_vector(res, nrow_)} { static_assert(NUM_SAMPLES <= 32); - - thrust::fill(thrust::device, - dists_buffer_.data_handle(), - dists_buffer_.data_handle() + dists_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, - reinterpret_cast(graph_buffer_.data_handle()), - reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); + raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); + auto graph_buffer_view = raft::make_device_matrix_view( + reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); + raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); + raft::matrix::fill(res, d_locks_.view(), 0); }; +template +void GNND::reset(raft::resources const& res) +{ + raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); + auto graph_buffer_view = raft::make_device_matrix_view( + reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); + raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); + raft::matrix::fill(res, d_locks_.view(), 0); +} + template void GNND::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev_graph_ptr, @@ -1246,6 +1256,7 @@ void GNND::build(Data_t* data, cudaStream_t stream = raft::resource::get_cuda_stream(res); nrow_ = nrow; + graph_.nrow = nrow; graph_.h_graph = (InternalID_t*)output_graph; cudaPointerAttributes data_ptr_attr; @@ -1384,6 +1395,7 @@ void GNND::build(Data_t* data, static_cast(build_config_.output_graph_degree)}; raft::matrix::slice( res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords); + raft::resource::sync_stream(res); } Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); @@ -1414,14 +1426,14 @@ void GNND::build(Data_t* data, template , typename Accessor = host_device_accessor, memory_type::host>> void build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset, index& idx, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, "The dataset size for GNND should be less than %d", @@ -1491,13 +1503,13 @@ void build(raft::resources const& res, template , typename Accessor = host_device_accessor, memory_type::host>> index build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; diff --git a/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh b/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh new file mode 100644 index 0000000000..78467c9741 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh @@ -0,0 +1,701 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + +#include "../nn_descent_types.hpp" +#include "nn_descent.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent::detail { + +// +// Run balanced kmeans on a subsample of the dataset to get centroids +// +template , memory_type::host>> +void get_balanced_kmeans_centroids( + raft::resources const& res, + raft::distance::DistanceType metric, + mdspan, row_major, Accessor> dataset, + raft::device_matrix_view centroids) +{ + size_t num_rows = static_cast(dataset.extent(0)); + size_t num_cols = static_cast(dataset.extent(1)); + size_t n_clusters = centroids.extent(0); + size_t num_subsamples = + std::min(static_cast(num_rows / n_clusters), static_cast(num_rows * 0.1)); + + auto d_subsample_dataset = + raft::make_device_matrix(res, num_subsamples, num_cols); + raft::matrix::sample_rows( + res, raft::random::RngState{0}, dataset, d_subsample_dataset.view()); + + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.metric = metric; + + auto d_subsample_dataset_const_view = + raft::make_device_matrix_view( + d_subsample_dataset.data_handle(), num_subsamples, num_cols); + raft::cluster::kmeans_balanced::fit( + res, kmeans_params, d_subsample_dataset_const_view, centroids); +} + +// +// Get the top k closest centroid indices for each data point +// Loads the data in batches onto device if data is on host for memory efficiency +// +template +void get_global_nearest_k( + raft::resources const& res, + size_t k, + size_t num_rows, + size_t n_clusters, + const T* dataset, + raft::host_matrix_view global_nearest_cluster, + raft::device_matrix_view centroids, + raft::distance::DistanceType metric) +{ + size_t num_cols = centroids.extent(1); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, dataset)); + float* ptr = reinterpret_cast(attr.devicePointer); + + if (ptr == nullptr) { // data on host + size_t num_batches = n_clusters; + size_t batch_size = (num_rows + n_clusters) / n_clusters; + + auto d_dataset_batch = + raft::make_device_matrix(res, batch_size, num_cols); + + auto nearest_clusters_idx = + raft::make_device_matrix(res, batch_size, k); + auto nearest_clusters_dist = + raft::make_device_matrix(res, batch_size, k); + + for (size_t i = 0; i < num_batches; i++) { + size_t batch_size_ = batch_size; + + if (i == num_batches - 1) { batch_size_ = num_rows - batch_size * i; } + raft::copy(d_dataset_batch.data_handle(), + dataset + i * batch_size * num_cols, + batch_size_ * num_cols, + resource::get_cuda_stream(res)); + + raft::neighbors::brute_force::fused_l2_knn( + res, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(d_dataset_batch.view()), + nearest_clusters_idx.view(), + nearest_clusters_dist.view(), + metric); + raft::copy(global_nearest_cluster.data_handle() + i * batch_size * k, + nearest_clusters_idx.data_handle(), + batch_size_ * k, + resource::get_cuda_stream(res)); + } + } else { // data on device + auto nearest_clusters_idx = + raft::make_device_matrix(res, num_rows, k); + auto nearest_clusters_dist = + raft::make_device_matrix(res, num_rows, k); + + raft::neighbors::brute_force::fused_l2_knn( + res, + raft::make_const_mdspan(centroids), + raft::make_device_matrix_view(dataset, num_rows, num_cols), + nearest_clusters_idx.view(), + nearest_clusters_dist.view(), + metric); + + raft::copy(global_nearest_cluster.data_handle(), + nearest_clusters_idx.data_handle(), + num_rows * k, + resource::get_cuda_stream(res)); + } +} + +// +// global_nearest_cluster [num_rows X k=2] : top 2 closest clusters for each data point +// inverted_indices [num_rows x k vector] : sparse vector for data indices for each cluster +// cluster_size [n_cluster] : cluster size for each cluster +// offset [n_cluster] : offset in inverted_indices for each cluster +// Loads the data in batches onto device if data is on host for memory efficiency +// +template +void get_inverted_indices(raft::resources const& res, + size_t n_clusters, + size_t& max_cluster_size, + size_t& min_cluster_size, + raft::host_matrix_view global_nearest_cluster, + raft::host_vector_view inverted_indices, + raft::host_vector_view cluster_size, + raft::host_vector_view offset) +{ + // build sparse inverted indices and get number of data points for each cluster + size_t num_rows = global_nearest_cluster.extent(0); + size_t k = global_nearest_cluster.extent(1); + + auto local_offset = raft::make_host_vector(n_clusters); + + max_cluster_size = 0; + min_cluster_size = std::numeric_limits::max(); + + thrust::fill( + thrust::host, cluster_size.data_handle(), cluster_size.data_handle() + n_clusters, 0); + thrust::fill( + thrust::host, local_offset.data_handle(), local_offset.data_handle() + n_clusters, 0); + + // TODO: this part isn't really a bottleneck but maybe worth trying omp parallel + // for with atomic add + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < k; j++) { + IdxT cluster_id = global_nearest_cluster(i, j); + cluster_size(cluster_id) += 1; + } + } + + offset(0) = 0; + for (size_t i = 1; i < n_clusters; i++) { + offset(i) = offset(i - 1) + cluster_size(i - 1); + } + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < k; j++) { + IdxT cluster_id = global_nearest_cluster(i, j); + inverted_indices(offset(cluster_id) + local_offset(cluster_id)) = i; + local_offset(cluster_id) += 1; + } + } + + max_cluster_size = static_cast( + *std::max_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); + min_cluster_size = static_cast( + *std::min_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); +} + +template +struct KeyValuePair { + KeyType key; + ValueType value; +}; + +template +struct CustomKeyComparator { + __device__ bool operator()(const KeyValuePair& a, + const KeyValuePair& b) const + { + if (a.key == b.key) { return a.value < b.value; } + return a.key < b.key; + } +}; + +template +RAFT_KERNEL merge_subgraphs(IdxT* cluster_data_indices, + size_t graph_degree, + size_t num_cluster_in_batch, + float* global_distances, + float* batch_distances, + IdxT* global_indices, + IdxT* batch_indices) +{ + size_t batch_row = blockIdx.x; + typedef cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD> + BlockMergeSortType; + __shared__ typename cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD>:: + TempStorage tmpSmem; + + extern __shared__ char sharedMem[]; + float* blockKeys = reinterpret_cast(sharedMem); + IdxT* blockValues = reinterpret_cast(&sharedMem[graph_degree * 2 * sizeof(float)]); + int16_t* uniqueMask = + reinterpret_cast(&sharedMem[graph_degree * 2 * (sizeof(float) + sizeof(IdxT))]); + + if (batch_row < num_cluster_in_batch) { + // load batch or global depending on threadIdx + size_t global_row = cluster_data_indices[batch_row]; + + KeyValuePair threadKeyValuePair[ITEMS_PER_THREAD]; + + size_t halfway = BLOCK_SIZE / 2; + size_t do_global = threadIdx.x < halfway; + + float* distances; + IdxT* indices; + + if (do_global) { + distances = global_distances; + indices = global_indices; + } else { + distances = batch_distances; + indices = batch_indices; + } + + size_t idxBase = (threadIdx.x * do_global + (threadIdx.x - halfway) * (1lu - do_global)) * + static_cast(ITEMS_PER_THREAD); + size_t arrIdxBase = (global_row * do_global + batch_row * (1lu - do_global)) * graph_degree; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId < graph_degree) { + threadKeyValuePair[i].key = distances[arrIdxBase + colId]; + threadKeyValuePair[i].value = indices[arrIdxBase + colId]; + } else { + threadKeyValuePair[i].key = std::numeric_limits::max(); + threadKeyValuePair[i].value = std::numeric_limits::max(); + } + } + + __syncthreads(); + + BlockMergeSortType(tmpSmem).Sort(threadKeyValuePair, CustomKeyComparator{}); + + // load sorted result into shared memory to get unique values + idxBase = threadIdx.x * ITEMS_PER_THREAD; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId < 2 * graph_degree) { + blockKeys[colId] = threadKeyValuePair[i].key; + blockValues[colId] = threadKeyValuePair[i].value; + } + } + + __syncthreads(); + + // get unique mask + if (threadIdx.x == 0) { uniqueMask[0] = 1; } + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId > 0 && colId < 2 * graph_degree) { + uniqueMask[colId] = static_cast(blockValues[colId] != blockValues[colId - 1]); + } + } + + __syncthreads(); + + // prefix sum + if (threadIdx.x == 0) { + for (int i = 1; i < 2 * graph_degree; i++) { + uniqueMask[i] += uniqueMask[i - 1]; + } + } + + __syncthreads(); + // load unique values to global memory + if (threadIdx.x == 0) { + global_distances[global_row * graph_degree] = blockKeys[0]; + global_indices[global_row * graph_degree] = blockValues[0]; + } + + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId > 0 && colId < 2 * graph_degree) { + bool is_unique = uniqueMask[colId] != uniqueMask[colId - 1]; + int16_t global_colId = uniqueMask[colId] - 1; + if (is_unique && static_cast(global_colId) < graph_degree) { + global_distances[global_row * graph_degree + global_colId] = blockKeys[colId]; + global_indices[global_row * graph_degree + global_colId] = blockValues[colId]; + } + } + } + } +} + +// +// builds knn graph using NN Descent and merge with global graph +// +template , + typename Accessor = + host_device_accessor, memory_type::host>> +void build_and_merge(raft::resources const& res, + const index_params& params, + size_t num_data_in_cluster, + size_t graph_degree, + size_t int_graph_node_degree, + T* cluster_data, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_d, + float* global_distances_d, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + GNND& nnd, + epilogue_op distance_epilogue) +{ + nnd.build( + cluster_data, num_data_in_cluster, int_graph, true, batch_distances_d, distance_epilogue); + + // remap indices +#pragma omp parallel for + for (size_t i = 0; i < num_data_in_cluster; i++) { + for (size_t j = 0; j < graph_degree; j++) { + size_t local_idx = int_graph[i * int_graph_node_degree + j]; + batch_indices_h[i * graph_degree + j] = inverted_indices[local_idx]; + } + } + + raft::copy(batch_indices_d, + batch_indices_h, + num_data_in_cluster * graph_degree, + raft::resource::get_cuda_stream(res)); + + size_t num_elems = graph_degree * 2; + size_t sharedMemSize = num_elems * (sizeof(float) + sizeof(IdxT) + sizeof(int16_t)); + + if (num_elems <= 128) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 512) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 1024) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 2048) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else { + // this is as far as we can get due to the shared mem usage of cub::BlockMergeSort + RAFT_FAIL("The degree of knn is too large (%lu). It must be smaller than 1024", graph_degree); + } + raft::resource::sync_stream(res); +} + +// +// For each cluster, gather the data samples that belong to that cluster, and +// call build_and_merge +// +template > +void cluster_nnd(raft::resources const& res, + const index_params& params, + size_t graph_degree, + size_t extended_graph_degree, + size_t max_cluster_size, + raft::host_matrix_view dataset, + IdxT* offsets, + IdxT* cluster_size, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_h, + float* global_distances_h, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + const BuildConfig& build_config, + epilogue_op distance_epilogue) +{ + size_t num_rows = dataset.extent(0); + size_t num_cols = dataset.extent(1); + + GNND nnd(res, build_config); + + auto cluster_data_matrix = + raft::make_host_matrix(max_cluster_size, num_cols); + + for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { + RAFT_LOG_DEBUG( + "# Data on host. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); + size_t num_data_in_cluster = cluster_size[cluster_id]; + size_t offset = offsets[cluster_id]; + +#pragma omp parallel for + for (size_t i = 0; i < num_data_in_cluster; i++) { + for (size_t j = 0; j < num_cols; j++) { + size_t global_row = (inverted_indices + offset)[i]; + cluster_data_matrix(i, j) = dataset(global_row, j); + } + } + + distance_epilogue.preprocess_for_batch(cluster_data_indices + offset, num_data_in_cluster); + + build_and_merge(res, + params, + num_data_in_cluster, + graph_degree, + extended_graph_degree, + cluster_data_matrix.data_handle(), + cluster_data_indices + offset, + int_graph, + inverted_indices + offset, + global_indices_h, + global_distances_h, + batch_indices_h, + batch_indices_d, + batch_distances_d, + nnd, + distance_epilogue); + nnd.reset(res); + } +} + +template > +void cluster_nnd(raft::resources const& res, + const index_params& params, + size_t graph_degree, + size_t extended_graph_degree, + size_t max_cluster_size, + raft::device_matrix_view dataset, + IdxT* offsets, + IdxT* cluster_size, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_h, + float* global_distances_h, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + const BuildConfig& build_config, + epilogue_op distance_epilogue) +{ + size_t num_rows = dataset.extent(0); + size_t num_cols = dataset.extent(1); + + GNND nnd(res, build_config); + + auto cluster_data_matrix = + raft::make_device_matrix(res, max_cluster_size, num_cols); + + for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { + RAFT_LOG_DEBUG( + "# Data on device. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); + size_t num_data_in_cluster = cluster_size[cluster_id]; + size_t offset = offsets[cluster_id]; + + auto cluster_data_view = raft::make_device_matrix_view( + cluster_data_matrix.data_handle(), num_data_in_cluster, num_cols); + auto cluster_data_indices_view = raft::make_device_vector_view( + cluster_data_indices + offset, num_data_in_cluster); + distance_epilogue.preprocess_for_batch(cluster_data_indices + offset, num_data_in_cluster); + + auto dataset_IdxT = + raft::make_device_matrix_view(dataset.data_handle(), num_rows, num_cols); + raft::matrix::gather(res, dataset_IdxT, cluster_data_indices_view, cluster_data_view); + + build_and_merge(res, + params, + num_data_in_cluster, + graph_degree, + extended_graph_degree, + cluster_data_view.data_handle(), + cluster_data_indices + offset, + int_graph, + inverted_indices + offset, + global_indices_h, + global_distances_h, + batch_indices_h, + batch_indices_d, + batch_distances_d, + nnd, + distance_epilogue); + nnd.reset(res); + } +} + +template , + typename Accessor = + host_device_accessor, memory_type::host>> +index batch_build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset, + epilogue_op distance_epilogue = DistEpilogue()) +{ + size_t graph_degree = params.graph_degree; + size_t intermediate_degree = params.intermediate_graph_degree; + + size_t num_rows = static_cast(dataset.extent(0)); + size_t num_cols = static_cast(dataset.extent(1)); + + auto centroids = + raft::make_device_matrix(res, params.n_clusters, num_cols); + get_balanced_kmeans_centroids(res, params.metric, dataset, centroids.view()); + + size_t k = 2; + auto global_nearest_cluster = raft::make_host_matrix(num_rows, k); + get_global_nearest_k(res, + k, + num_rows, + params.n_clusters, + dataset.data_handle(), + global_nearest_cluster.view(), + centroids.view(), + params.metric); + + auto inverted_indices = raft::make_host_vector(num_rows * k); + auto cluster_size = raft::make_host_vector(params.n_clusters); + auto offset = raft::make_host_vector(params.n_clusters); + + size_t max_cluster_size, min_cluster_size; + get_inverted_indices(res, + params.n_clusters, + max_cluster_size, + min_cluster_size, + global_nearest_cluster.view(), + inverted_indices.view(), + cluster_size.view(), + offset.view()); + + if (intermediate_degree >= min_cluster_size) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than minimum cluster size, reducing it to %lu", + dataset.extent(0)); + intermediate_degree = min_cluster_size - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + size_t extended_graph_degree = + align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); + size_t extended_intermediate_degree = align32::roundUp( + static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); + + auto int_graph = raft::make_host_matrix( + max_cluster_size, static_cast(extended_graph_degree)); + + BuildConfig build_config{.max_dataset_size = max_cluster_size, + .dataset_dim = num_cols, + .node_degree = extended_graph_degree, + .internal_node_degree = extended_intermediate_degree, + .max_iterations = params.max_iterations, + .termination_threshold = params.termination_threshold, + .output_graph_degree = graph_degree}; + + auto global_indices_h = raft::make_managed_matrix(res, num_rows, graph_degree); + auto global_distances_h = raft::make_managed_matrix(res, num_rows, graph_degree); + + thrust::fill(thrust::host, + global_indices_h.data_handle(), + global_indices_h.data_handle() + num_rows * graph_degree, + std::numeric_limits::max()); + thrust::fill(thrust::host, + global_distances_h.data_handle(), + global_distances_h.data_handle() + num_rows * graph_degree, + std::numeric_limits::max()); + + auto batch_indices_h = + raft::make_host_matrix(max_cluster_size, graph_degree); + auto batch_indices_d = + raft::make_device_matrix(res, max_cluster_size, graph_degree); + auto batch_distances_d = + raft::make_device_matrix(res, max_cluster_size, graph_degree); + + auto cluster_data_indices = raft::make_device_vector(res, num_rows * k); + raft::copy(cluster_data_indices.data_handle(), + inverted_indices.data_handle(), + num_rows * k, + resource::get_cuda_stream(res)); + + cluster_nnd(res, + params, + graph_degree, + extended_graph_degree, + max_cluster_size, + dataset, + offset.data_handle(), + cluster_size.data_handle(), + cluster_data_indices.data_handle(), + int_graph.data_handle(), + inverted_indices.data_handle(), + global_indices_h.data_handle(), + global_distances_h.data_handle(), + batch_indices_h.data_handle(), + batch_indices_d.data_handle(), + batch_distances_d.data_handle(), + build_config, + distance_epilogue); + + index global_idx{ + res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; + + raft::copy(global_idx.graph().data_handle(), + global_indices_h.data_handle(), + num_rows * graph_degree, + raft::resource::get_cuda_stream(res)); + if (params.return_distances && global_idx.distances().has_value()) { + raft::copy(global_idx.distances().value().data_handle(), + global_distances_h.data_handle(), + num_rows * graph_degree, + raft::resource::get_cuda_stream(res)); + } + return global_idx; +} + +} // namespace raft::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh index a46a2006d6..6c08546d3f 100644 --- a/cpp/include/raft/neighbors/nn_descent.cuh +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -17,9 +17,11 @@ #pragma once #include "detail/nn_descent.cuh" +#include "detail/nn_descent_batch.cuh" #include #include +#include namespace raft::neighbors::experimental::nn_descent { @@ -57,13 +59,17 @@ namespace raft::neighbors::experimental::nn_descent { * @param[in] distance_epilogue epilogue operation for distances * @return index index containing all-neighbors knn graph in host memory */ -template +template > index build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { - return detail::build(res, params, dataset, distance_epilogue); + if (params.n_clusters > 1) { + return detail::batch_build(res, params, dataset, distance_epilogue); + } else { + return detail::build(res, params, dataset, distance_epilogue); + } } /** @@ -98,12 +104,12 @@ index build(raft::resources const& res, * in host memory * @param[in] distance_epilogue epilogue operation for distances */ -template +template > void build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset, index& idx, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { detail::build(res, params, dataset, idx, distance_epilogue); } @@ -137,13 +143,17 @@ void build(raft::resources const& res, * @param[in] distance_epilogue epilogue operation for distances * @return index index containing all-neighbors knn graph in host memory */ -template +template > index build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { - return detail::build(res, params, dataset, distance_epilogue); + if (params.n_clusters > 1) { + return detail::batch_build(res, params, dataset, distance_epilogue); + } else { + return detail::build(res, params, dataset, distance_epilogue); + } } /** @@ -178,12 +188,12 @@ index build(raft::resources const& res, * in host memory * @param[in] distance_epilogue epilogue operation for distances */ -template +template > void build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset, index& idx, - epilogue_op distance_epilogue = raft::identity_op()) + epilogue_op distance_epilogue = DistEpilogue()) { detail::build(res, params, dataset, idx, distance_epilogue); } diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 5d23ff2c2e..eb01a423be 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -48,6 +48,20 @@ namespace raft::neighbors::experimental::nn_descent { * `max_iterations`: The number of iterations that nn-descent will refine * the graph for. More iterations produce a better quality graph at cost of performance * `termination_threshold`: The delta at which nn-descent will terminate its iterations + * `return_distances`: boolean whether to return distances + * `n_clusters`: NN Descent offers batching a dataset to save GPU memory usage. + * Increase `n_clusters` to save GPU memory and run NN Descent with large datasets. + * Most effective when data is put on CPU memory. + * Setting this number too big may results in too much overhead of doing multiple + * iterations of graph building. Recommend starting at 4 and continue to increase + * depending on desired GPU memory usages. + * (Specifically, with n_clusters > 1, the NN Descent build algorithm will first + * find n_clusters number of cluster centroids of the dataset, then consider data + * points that belong to each cluster as a batch. + * Then we build knn subgraphs on each batch of the entire data. This is especially + * useful when the dataset is put on host, since only a subset of the data will + * be on GPU at once, enabling running NN Descent with large datasets that do not + * fit on the GPU as a whole.) * */ struct index_params : ann::index_params { @@ -56,6 +70,7 @@ struct index_params : ann::index_params { size_t max_iterations = 20; // Number of nn-descent iterations. float termination_threshold = 0.0001; // Termination threshold of nn-descent. bool return_distances = false; // return distances if true + size_t n_clusters = 1; // defaults to not using any batching }; /** @@ -178,6 +193,14 @@ struct index : ann::index { bool return_distances_; }; +template +struct DistEpilogue : raft::identity_op { + __host__ void preprocess_for_batch(value_idx* cluster_indices, size_t num_data_in_cluster) + { + return; + } +}; + /** @} */ } // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index e3af6ebb78..a497e6d3ba 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -96,17 +96,8 @@ endfunction() if(BUILD_TESTS) ConfigureTest( - NAME - CLUSTER_TEST - PATH - cluster/kmeans.cu - cluster/kmeans_balanced.cu - cluster/kmeans_find_k.cu - cluster/cluster_solvers.cu - cluster/linkage.cu - cluster/spectral.cu - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME CLUSTER_TEST PATH cluster/kmeans.cu cluster/kmeans_balanced.cu cluster/kmeans_find_k.cu + cluster/cluster_solvers.cu cluster/linkage.cu cluster/spectral.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -144,8 +135,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME CORE_TEST PATH core/stream_view.cpp core/mdspan_copy.cpp LIB - EXPLICIT_INSTANTIATE_ONLY NOCUDA + NAME CORE_TEST PATH core/stream_view.cpp core/mdspan_copy.cpp LIB EXPLICIT_INSTANTIATE_ONLY + NOCUDA ) ConfigureTest( @@ -301,8 +292,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SOLVERS_TEST PATH cluster/cluster_solvers_deprecated.cu linalg/eigen_solvers.cu - lap/lap.cu sparse/mst.cu LIB EXPLICIT_INSTANTIATE_ONLY + NAME SOLVERS_TEST PATH cluster/cluster_solvers_deprecated.cu linalg/eigen_solvers.cu lap/lap.cu + sparse/mst.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -331,19 +322,13 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SPARSE_DIST_TEST PATH sparse/dist_coo_spmv.cu sparse/distance.cu - sparse/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY + NAME SPARSE_DIST_TEST PATH sparse/dist_coo_spmv.cu sparse/distance.cu sparse/gram.cu LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( - NAME - SPARSE_NEIGHBORS_TEST - PATH - sparse/neighbors/cross_component_nn.cu - sparse/neighbors/brute_force.cu - sparse/neighbors/knn_graph.cu - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME SPARSE_NEIGHBORS_TEST PATH sparse/neighbors/cross_component_nn.cu + sparse/neighbors/brute_force.cu sparse/neighbors/knn_graph.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -455,6 +440,7 @@ if(BUILD_TESTS) neighbors/ann_nn_descent/test_float_uint32_t.cu neighbors/ann_nn_descent/test_int8_t_uint32_t.cu neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu + neighbors/ann_nn_descent/test_batch_float_uint32_t.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index f74cadb415..2f9d4e252b 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -42,6 +42,15 @@ struct AnnNNDescentInputs { double min_recall; }; +struct AnnNNDescentBatchInputs { + std::pair recall_cluster; + int n_rows; + int dim; + int graph_degree; + raft::distance::DistanceType metric; + bool host_dataset; +}; + inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& p) { os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree @@ -50,6 +59,14 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& return os; } +inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentBatchInputs& p) +{ + os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << ", clusters=" << p.recall_cluster.second << std::endl; + return os; +} + template class AnnNNDescentTest : public ::testing::TestWithParam { public: @@ -105,7 +122,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam { raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); auto database_host_view = raft::make_host_matrix_view( (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - auto index = nn_descent::build(handle_, index_params, database_host_view); + index index{handle_, ps.n_rows, static_cast(ps.graph_degree), true}; + nn_descent::build( + handle_, index_params, database_host_view, index, DistEpilogue()); raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); if (index.distances().has_value()) { @@ -116,7 +135,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam { } } else { - auto index = nn_descent::build(handle_, index_params, database_view); + index index{handle_, ps.n_rows, static_cast(ps.graph_degree), true}; + nn_descent::build( + handle_, index_params, database_view, index, DistEpilogue()); raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); if (index.distances().has_value()) { @@ -168,6 +189,127 @@ class AnnNNDescentTest : public ::testing::TestWithParam { rmm::device_uvector database; }; +template +class AnnNNDescentBatchTest : public ::testing::TestWithParam { + public: + AnnNNDescentBatchTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_) + { + } + + void testNNDescentBatch() + { + size_t queries_size = ps.n_rows * ps.graph_degree; + std::vector indices_NNDescent(queries_size); + std::vector distances_NNDescent(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + database.data(), + database.data(), + ps.n_rows, + ps.n_rows, + ps.dim, + ps.graph_degree, + ps.metric); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + { + nn_descent::index_params index_params; + index_params.metric = ps.metric; + index_params.graph_degree = ps.graph_degree; + index_params.intermediate_graph_degree = 2 * ps.graph_degree; + index_params.max_iterations = 10; + index_params.return_distances = true; + index_params.n_clusters = ps.recall_cluster.second; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + { + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + auto index = nn_descent::build( + handle_, index_params, database_host_view, DistEpilogue()); + raft::copy( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + + } else { + auto index = nn_descent::build( + handle_, index_params, database_view, DistEpilogue()); + raft::copy( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + }; + } + resource::sync_stream(handle_); + } + double min_recall = ps.recall_cluster.first; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_NNDescent, + distances_naive, + distances_NNDescent, + ps.n_rows, + ps.graph_degree, + 0.01, + min_recall, + true, + static_cast(ps.graph_degree * 0.1))); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnNNDescentBatchInputs ps; + rmm::device_uvector database; +}; + const std::vector inputs = raft::util::itertools::product( {1000, 2000}, // n_rows {3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim @@ -176,4 +318,13 @@ const std::vector inputs = raft::util::itertools::product inputsBatch = + raft::util::itertools::product( + {std::make_pair(0.9, 3lu), std::make_pair(0.9, 2lu)}, // min_recall, n_clusters + {4000, 5000}, // n_rows + {192, 512}, // dim + {32, 64}, // graph_degree + {raft::distance::DistanceType::L2Expanded}, + {false, true}); + } // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu new file mode 100644 index 0000000000..c6f56e8c39 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_nn_descent.cuh" + +#include + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentBatchTest AnnNNDescentBatchTestF_U32; +TEST_P(AnnNNDescentBatchTestF_U32, AnnNNDescentBatch) { this->testNNDescentBatch(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentBatchTest, + AnnNNDescentBatchTestF_U32, + ::testing::ValuesIn(inputsBatch)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 2139e97428..82e3ace9da 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -153,9 +153,13 @@ auto calc_recall(const std::vector& expected_idx, /** check uniqueness of indices */ template -auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +auto check_unique_indices(const std::vector& actual_idx, + size_t rows, + size_t cols, + size_t max_duplicates) { size_t max_count; + size_t dup_count = 0lu; std::set unique_indices; for (size_t i = 0; i < rows; ++i) { unique_indices.clear(); @@ -168,8 +172,11 @@ auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t } else if (unique_indices.find(act_idx) == unique_indices.end()) { unique_indices.insert(act_idx); } else { - return testing::AssertionFailure() - << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + dup_count++; + if (dup_count > max_duplicates) { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + } } } } @@ -252,7 +259,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t cols, double eps, double min_recall, - bool test_unique = true) -> testing::AssertionResult + bool test_unique = true, + size_t max_duplicates = 0) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -270,7 +278,7 @@ auto eval_neighbours(const std::vector& expected_idx, << min_recall << "); eps = " << eps << ". "; } if (test_unique) - return check_unique_indices(actual_idx, rows, cols); + return check_unique_indices(actual_idx, rows, cols, max_duplicates); else return testing::AssertionSuccess(); } diff --git a/dependencies.yaml b/dependencies.yaml index 92c6d98414..e4e361548f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -454,10 +454,6 @@ dependencies: specific: - output_types: conda matrices: - - matrix: - py: "3.9" - packages: - - python=3.9 - matrix: py: "3.10" packages: @@ -468,12 +464,12 @@ dependencies: - python=3.11 - matrix: packages: - - python>=3.9,<3.12 + - python>=3.10,<3.12 run_pylibraft: common: - output_types: [conda, pyproject] packages: - - &numpy numpy>=1.23,<2.0a0 + - numpy>=1.23,<3.0a0 - output_types: [conda] packages: - *rmm_unsuffixed @@ -513,7 +509,6 @@ dependencies: - dask-cuda==24.10.*,>=0.0.0a0 - joblib>=0.11 - numba>=0.57 - - *numpy - rapids-dask-dependency==24.10.*,>=0.0.0a0 - output_types: conda packages: diff --git a/pyproject.toml b/pyproject.toml index 1e4ba0b369..5042113388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 79 -target-version = ["py39"] +target-version = ["py310"] include = '\.py?$' force-exclude = ''' /( diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index 9a826e53c6..14f2ba7d2f 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -29,10 +29,10 @@ authors = [ { name = "NVIDIA Corporation" }, ] license = { text = "Apache 2.0" } -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "cuda-python", - "numpy>=1.23,<2.0a0", + "numpy>=1.23,<3.0a0", "nvidia-cublas", "nvidia-curand", "nvidia-cusolver", @@ -42,7 +42,6 @@ dependencies = [ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] diff --git a/python/raft-ann-bench/pyproject.toml b/python/raft-ann-bench/pyproject.toml index d22dd567fe..fa5781893b 100644 --- a/python/raft-ann-bench/pyproject.toml +++ b/python/raft-ann-bench/pyproject.toml @@ -16,7 +16,7 @@ authors = [ { name = "NVIDIA Corporation" }, ] license = { text = "Apache 2.0" } -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ @@ -25,7 +25,6 @@ classifiers = [ "Topic :: Scientific/Engineering", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index 4fadfa5c9f..44012b5f10 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -29,13 +29,12 @@ authors = [ { name = "NVIDIA Corporation" }, ] license = { text = "Apache 2.0" } -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "dask-cuda==24.10.*,>=0.0.0a0", "distributed-ucxx==0.40.*,>=0.0.0a0", "joblib>=0.11", "numba>=0.57", - "numpy>=1.23,<2.0a0", "pylibraft==24.10.*,>=0.0.0a0", "rapids-dask-dependency==24.10.*,>=0.0.0a0", "ucx-py==0.40.*,>=0.0.0a0", @@ -43,7 +42,6 @@ dependencies = [ classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ]