From 2bf77dbb8b42f0c4b701b1eb2e1bb4e9815170fc Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 18 Mar 2024 10:07:35 +0100 Subject: [PATCH] Introduce internal_idx_t in ivf_pq::build() --- .../raft/neighbors/detail/ivf_pq_build.cuh | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index c85cbf9a98..91e523b94f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -61,6 +61,8 @@ namespace raft::neighbors::ivf_pq::detail { using namespace raft::spatial::knn::detail; // NOLINT +using internal_idx_t = int64_t; // The default mdspan extent type used internally. + template __launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel( T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) @@ -442,15 +444,15 @@ void train_per_subset(raft::resources const& handle, stream); // train PQ codebook for this subspace - auto sub_trainset_view = raft::make_device_matrix_view( + auto sub_trainset_view = raft::make_device_matrix_view( sub_trainset.data(), n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( + auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j, index.pq_book_size(), index.pq_len()); auto sub_labels_view = - raft::make_device_vector_view(sub_labels.data(), n_rows); - auto cluster_sizes_view = raft::make_device_vector_view( + raft::make_device_vector_view(sub_labels.data(), n_rows); + auto cluster_sizes_view = raft::make_device_vector_view( pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = kmeans_n_iters; @@ -526,16 +528,16 @@ void train_per_cluster(raft::resources const& handle, size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim()); auto pq_n_rows = uint32_t(std::min(big_enough, available_rows)); // train PQ codebook for this cluster - auto rot_vectors_view = raft::make_device_matrix_view( + auto rot_vectors_view = raft::make_device_matrix_view( rot_vectors.data(), pq_n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( + auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + static_cast(index.pq_book_size()) * static_cast(index.pq_len()) * static_cast(l), index.pq_book_size(), index.pq_len()); auto pq_labels_view = - raft::make_device_vector_view(pq_labels.data(), pq_n_rows); - auto pq_cluster_sizes_view = raft::make_device_vector_view( + raft::make_device_vector_view(pq_labels.data(), pq_n_rows); + auto pq_cluster_sizes_view = raft::make_device_vector_view( pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = kmeans_n_iters; @@ -1588,11 +1590,11 @@ void extend(raft::resources const& handle, cudaMemcpyDefault, stream)); for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( + auto batch_data_view = raft::make_device_matrix_view( + batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( new_data_labels.data() + batch.offset(), batch.size()); - auto centers_view = raft::make_device_matrix_view( + auto centers_view = raft::make_device_matrix_view( cluster_centers.data(), n_clusters, index->dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.metric = index->metric(); @@ -1768,10 +1770,10 @@ auto build(raft::resources const& handle, auto cluster_centers = cluster_centers_buf.data(); // Train balanced hierarchical kmeans clustering - auto trainset_const_view = raft::make_device_matrix_view( + auto trainset_const_view = raft::make_device_matrix_view( trainset.data(), n_rows_train, index.dim()); - auto centers_view = - raft::make_device_matrix_view(cluster_centers, index.n_lists(), index.dim()); + auto centers_view = raft::make_device_matrix_view( + cluster_centers, index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = index.metric(); @@ -1780,10 +1782,10 @@ auto build(raft::resources const& handle, // Trainset labels are needed for training PQ codebooks rmm::device_uvector labels(n_rows_train, stream, device_memory); - auto centers_const_view = raft::make_device_matrix_view( + auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); auto labels_view = - raft::make_device_vector_view(labels.data(), n_rows_train); + raft::make_device_vector_view(labels.data(), n_rows_train); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, trainset_const_view,