diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 48bf1d70d8..77adfb0623 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -88,6 +88,9 @@ void parse_build_param(const nlohmann::json& conf, "', should be either 'cluster' or 'subspace'"); } } + if (conf.contains("max_train_points_per_pq_code")) { + param.max_train_points_per_pq_code = conf.at("max_train_points_per_pq_code"); + } } template diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index 6d3f430e88..509638bbd4 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -881,10 +881,10 @@ auto build_fine_clusters(const raft::resources& handle, if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } } if (k != static_cast(mesocluster_sizes[i])) - RAFT_LOG_WARN("Incorrect mesocluster size at %d. %zu vs %zu", - static_cast(i), - static_cast(k), - static_cast(mesocluster_sizes[i])); + RAFT_LOG_DEBUG("Incorrect mesocluster size at %d. %zu vs %zu", + static_cast(i), + static_cast(k), + static_cast(mesocluster_sizes[i])); if (k == 0) { RAFT_LOG_DEBUG("Empty cluster %d", i); RAFT_EXPECTS(fine_clusters_nums[i] == 0, @@ -1030,7 +1030,7 @@ void build_hierarchical(const raft::resources& handle, const IdxT mesocluster_size_max_balanced = div_rounding_up_safe( 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu)); if (mesocluster_size_max > mesocluster_size_max_balanced) { - RAFT_LOG_WARN( + RAFT_LOG_DEBUG( "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " "At most %u points will be used for training within each mesocluster. " "Consider increasing the number of training iterations `n_iters`.", diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index 55184cc615..5c7cbb431b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -26,11 +26,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -364,28 +366,23 @@ inline auto build(raft::resources const& handle, // Train the kmeans clustering { + raft::random::RngState random_state{137}; auto trainset_ratio = std::max( 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); auto n_rows_train = n_rows / trainset_ratio; - rmm::device_uvector trainset(n_rows_train * index.dim(), stream); - // TODO: a proper sampling - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); + auto trainset = make_device_matrix(handle, n_rows_train, index.dim()); + raft::matrix::detail::sample_rows(handle, random_state, dataset, n_rows, trainset.view()); + auto centers_view = raft::make_device_matrix_view( index.centers().data_handle(), index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = index.metric(); - raft::cluster::kmeans_balanced::fit( - handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); + raft::cluster::kmeans_balanced::fit(handle, + kmeans_params, + make_const_mdspan(trainset.view()), + centers_view, + utils::mapping{}); } // add the data if necessary diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 8e3f7dbaf3..ec4225ae16 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -31,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -63,51 +65,6 @@ using namespace raft::spatial::knn::detail; // NOLINT using internal_extents_t = int64_t; // The default mdspan extent type used internally. -template -__launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel( - T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) -{ - using warp = Pow2; - size_t row_ix = warp::div(size_t(threadIdx.x) + size_t(BlockDim) * size_t(blockIdx.x)); - uint32_t i = warp::mod(threadIdx.x); - if (row_ix >= n_rows) return; - out += row_ix * ld_out; - in += row_ix * ld_in; - auto f = utils::mapping{}; - for (uint32_t col_ix = i; col_ix < n_cols; col_ix += warp::Value) { - auto x = f(in[col_ix]); - __syncwarp(); - out[col_ix] = x; - } -} - -/** - * Copy the data one warp-per-row: - * - * 1. load the data per-warp - * 2. apply the `utils::mapping{}` - * 3. sync within warp - * 4. store the data. - * - * Assuming sizeof(T) >= sizeof(S) and the data is properly aligned (see the usage in `build`), this - * allows to re-structure the data within rows in-place. - */ -template -void copy_warped(T* out, - uint32_t ld_out, - const S* in, - uint32_t ld_in, - uint32_t n_cols, - size_t n_rows, - rmm::cuda_stream_view stream) -{ - constexpr uint32_t kBlockDim = 128; - dim3 threads(kBlockDim, 1, 1); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockDim / WarpSize), 1, 1); - copy_warped_kernel - <<>>(out, ld_out, in, ld_in, n_cols, n_rows); -} - /** * @brief Fill-in a random orthogonal transformation matrix. * @@ -397,14 +354,19 @@ void train_per_subset(raft::resources const& handle, const float* trainset, // [n_rows, dim] const uint32_t* labels, // [n_rows] uint32_t kmeans_n_iters, + uint32_t max_train_points_per_pq_code, rmm::mr::device_memory_resource* managed_memory) { auto stream = resource::get_cuda_stream(handle); auto device_memory = resource::get_workspace_resource(handle); rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); - rmm::device_uvector sub_trainset(n_rows * size_t(index.pq_len()), stream, device_memory); - rmm::device_uvector sub_labels(n_rows, stream, device_memory); + // Subsampling the train set for codebook generation based on max_train_points_per_pq_code. + size_t big_enough = max_train_points_per_pq_code * size_t(index.pq_book_size()); + auto pq_n_rows = uint32_t(std::min(big_enough, n_rows)); + rmm::device_uvector sub_trainset( + pq_n_rows * size_t(index.pq_len()), stream, device_memory); + rmm::device_uvector sub_labels(pq_n_rows, stream, device_memory); rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); @@ -415,7 +377,7 @@ void train_per_subset(raft::resources const& handle, // Get the rotated cluster centers for each training vector. // This will be subtracted from the input vectors afterwards. utils::copy_selected( - n_rows, + pq_n_rows, index.pq_len(), index.centers_rot().data_handle() + index.pq_len() * j, labels, @@ -431,7 +393,7 @@ void train_per_subset(raft::resources const& handle, true, false, index.pq_len(), - n_rows, + pq_n_rows, index.dim(), &alpha, index.rotation_matrix().data_handle() + index.dim() * index.pq_len() * j, @@ -445,13 +407,13 @@ void train_per_subset(raft::resources const& handle, // train PQ codebook for this subspace auto sub_trainset_view = raft::make_device_matrix_view( - sub_trainset.data(), n_rows, index.pq_len()); + sub_trainset.data(), pq_n_rows, index.pq_len()); auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j, index.pq_book_size(), index.pq_len()); auto sub_labels_view = - raft::make_device_vector_view(sub_labels.data(), n_rows); + raft::make_device_vector_view(sub_labels.data(), pq_n_rows); auto cluster_sizes_view = raft::make_device_vector_view( pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params; @@ -475,6 +437,7 @@ void train_per_cluster(raft::resources const& handle, const float* trainset, // [n_rows, dim] const uint32_t* labels, // [n_rows] uint32_t kmeans_n_iters, + uint32_t max_train_points_per_pq_code, rmm::mr::device_memory_resource* managed_memory) { auto stream = resource::get_cuda_stream(handle); @@ -522,9 +485,11 @@ void train_per_cluster(raft::resources const& handle, indices + cluster_offsets[l], device_memory); - // limit the cluster size to bound the training time. + // limit the cluster size to bound the training time based on max_train_points_per_pq_code + // If pq_book_size is less than pq_dim, use max_train_points_per_pq_code per pq_dim instead // [sic] we interpret the data as pq_len-dimensional - size_t big_enough = 256ul * std::max(index.pq_book_size(), index.pq_dim()); + size_t big_enough = + max_train_points_per_pq_code * std::max(index.pq_book_size(), index.pq_dim()); size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim()); auto pq_n_rows = uint32_t(std::min(big_enough, available_rows)); // train PQ codebook for this cluster @@ -1702,77 +1667,45 @@ auto build(raft::resources const& handle, utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); { + raft::random::RngState random_state{137}; auto trainset_ratio = std::max( 1, size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); size_t n_rows_train = n_rows / trainset_ratio; - auto* device_memory = resource::get_workspace_resource(handle); - rmm::mr::managed_memory_resource managed_memory_upstream; + auto* device_mr = resource::get_workspace_resource(handle); + rmm::mr::managed_memory_resource managed_mr; // Besides just sampling, we transform the input dataset into floats to make it easier // to use gemm operations from cublas. - rmm::device_uvector trainset(n_rows_train * index.dim(), stream, device_memory); - // TODO: a proper sampling + auto trainset = make_device_matrix(handle, n_rows_train, dim); + if constexpr (std::is_same_v) { - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); + raft::matrix::detail::sample_rows( + handle, random_state, dataset, n_rows, trainset.view()); } else { - size_t dim = index.dim(); - cudaPointerAttributes dataset_attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&dataset_attr, dataset)); - if (dataset_attr.devicePointer != nullptr) { - // data is available on device: just run the kernel to copy and map the data - auto p = reinterpret_cast(dataset_attr.devicePointer); - auto trainset_view = - raft::make_device_vector_view(trainset.data(), dim * n_rows_train); - linalg::map_offset(handle, trainset_view, [p, trainset_ratio, dim] __device__(size_t i) { - auto col = i % dim; - return utils::mapping{}(p[(i - col) * size_t(trainset_ratio) + col]); - }); - } else { - // data is not available: first copy, then map inplace - auto trainset_tmp = reinterpret_cast(reinterpret_cast(trainset.data()) + - (sizeof(float) - sizeof(T)) * index.dim()); - // We copy the data in strides, one row at a time, and place the smaller rows of type T - // at the end of float rows. - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset_tmp, - sizeof(float) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - // Transform the input `{T -> float}`, one row per warp. - // The threads in each warp copy the data synchronously; this and the layout of the data - // (content is aligned to the end of the rows) together allow doing the transform in-place. - copy_warped(trainset.data(), - index.dim(), - trainset_tmp, - index.dim() * sizeof(float) / sizeof(T), - index.dim(), - n_rows_train, - stream); - } + // TODO(tfeher): Enable codebook generation with any type T, and then remove trainset tmp. + auto trainset_tmp = make_device_mdarray( + handle, &managed_mr, make_extents(n_rows_train, dim)); + raft::matrix::detail::sample_rows( + handle, random_state, dataset, n_rows, trainset_tmp.view()); + + raft::linalg::unaryOp(trainset.data_handle(), + trainset_tmp.data_handle(), + trainset.size(), + utils::mapping{}, + raft::resource::get_cuda_stream(handle)); } // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, // dim_ext]! rmm::device_uvector cluster_centers_buf( - index.n_lists() * index.dim(), stream, device_memory); + index.n_lists() * index.dim(), stream, device_mr); auto cluster_centers = cluster_centers_buf.data(); // Train balanced hierarchical kmeans clustering - auto trainset_const_view = raft::make_device_matrix_view( - trainset.data(), n_rows_train, index.dim()); - auto centers_view = raft::make_device_matrix_view( + auto trainset_const_view = raft::make_const_mdspan(trainset.view()); + auto centers_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; @@ -1781,7 +1714,7 @@ auto build(raft::resources const& handle, handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); // Trainset labels are needed for training PQ codebooks - rmm::device_uvector labels(n_rows_train, stream, device_memory); + rmm::device_uvector labels(n_rows_train, stream, device_mr); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); auto labels_view = @@ -1808,19 +1741,21 @@ auto build(raft::resources const& handle, train_per_subset(handle, index, n_rows_train, - trainset.data(), + trainset.data_handle(), labels.data(), params.kmeans_n_iters, - &managed_memory_upstream); + params.max_train_points_per_pq_code, + &managed_mr); break; case codebook_gen::PER_CLUSTER: train_per_cluster(handle, index, n_rows_train, - trainset.data(), + trainset.data_handle(), labels.data(), params.kmeans_n_iters, - &managed_memory_upstream); + params.max_train_points_per_pq_code, + &managed_mr); break; default: RAFT_FAIL("Unreachable code"); } diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 81e2886b18..d7d685f1e0 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -104,6 +104,14 @@ struct index_params : ann::index_params { * flag to `true` if you prefer to use as little GPU memory for the database as possible. */ bool conservative_memory_allocation = false; + /** + * The max number of data points to use per PQ code during PQ codebook training. Using more data + * points per PQ code may increase the quality of PQ codebook but may also increase the build + * time. The parameter is applied to both PQ codebook generation methods, i.e., PER_SUBSPACE and + * PER_CLUSTER. In both cases, we will use `pq_book_size * max_train_points_per_pq_code` training + * points to train each codebook. + */ + uint32_t max_train_points_per_pq_code = 256; }; struct search_params : ann::search_params { diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 4ebe02027f..3cf24e4fbf 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -68,7 +68,7 @@ struct ivf_pq_inputs { ivf_pq_inputs() { index_params.n_lists = max(32u, min(1024u, num_db_vecs / 128u)); - index_params.kmeans_trainset_fraction = 1.0; + index_params.kmeans_trainset_fraction = 0.95; } }; diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index afb4ed18ea..e003aa879c 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -38,6 +38,7 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of | `pq_bits` | `build` | N | Positive Integer. [4-8] | 8 | Bit length of the vector element after quantization. | | `codebook_kind` | `build` | N | ["cluster", "subspace"] | "subspace" | Type of codebook. See the [API docs](https://docs.rapids.ai/api/raft/nightly/cpp_api/neighbors_ivf_pq/#_CPPv412codebook_gen) for more detail | | `dataset_memory_type` | `build` | N | ["device", "host", "mmap"] | "host" | What memory type should the dataset reside? | +| `max_train_points_per_pq_code` | `build` | N | Positive Number >=1 | 256 | Max number of data points per PQ code used for PQ code book creation. Depending on input dataset size, the data points could be less than what user specifies. | | `query_memory_type` | `search` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? | | `nprobe` | `search` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | | `internalDistanceDtype` | `search` | N | [`float`, `half`] | `half` | The precision to use for the distance computations. Lower precision can increase performance at the cost of accuracy. | diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index 895abbadca..930c3245f1 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -78,6 +78,7 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \ codebook_gen codebook_kind bool force_random_rotation bool conservative_memory_allocation + uint32_t max_train_points_per_pq_code cdef cppclass index[IdxT](ann_index): index(const device_resources& handle, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 5b89f0d9a5..7081b65ce3 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -156,6 +156,14 @@ cdef class IndexParams: repeated calls to `extend` (extending the database). To disable this behavior and use as little GPU memory for the database as possible, set this flat to `True`. + max_train_points_per_pq_code : int, default = 256 + The max number of data points to use per PQ code during PQ codebook + training. Using more data points per PQ code may increase the + quality of PQ codebook but may also increase the build time. The + parameter is applied to both PQ codebook generation methods, i.e., + PER_SUBSPACE and PER_CLUSTER. In both cases, we will use + pq_book_size * max_train_points_per_pq_code training points to + train each codebook. """ def __init__(self, *, n_lists=1024, @@ -167,7 +175,8 @@ cdef class IndexParams: codebook_kind="subspace", force_random_rotation=False, add_data_on_build=True, - conservative_memory_allocation=False): + conservative_memory_allocation=False, + max_train_points_per_pq_code=256): self.params.n_lists = n_lists self.params.metric = _get_metric(metric) self.params.metric_arg = 0 @@ -185,6 +194,8 @@ cdef class IndexParams: self.params.add_data_on_build = add_data_on_build self.params.conservative_memory_allocation = \ conservative_memory_allocation + self.params.max_train_points_per_pq_code = \ + max_train_points_per_pq_code @property def n_lists(self): @@ -226,6 +237,9 @@ cdef class IndexParams: def conservative_memory_allocation(self): return self.params.conservative_memory_allocation + @property + def max_train_points_per_pq_code(self): + return self.params.max_train_points_per_pq_code cdef class Index: # We store a pointer to the index because it dose not have a trivial