diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index c65ac502ce..c5b52d8db2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -157,7 +157,7 @@ template void lauch_vpq_search_main_core( raft::resources const& res, - const vpq_dataset* dataset, + const vpq_dataset* vpq_dset, search_params params, raft::device_matrix_view graph, raft::device_matrix_view queries, @@ -165,8 +165,11 @@ void lauch_vpq_search_main_core( raft::device_matrix_view distances, CagraSampleFilterT sample_filter) { - if (dataset->pq_bits() == 8) { - if (dataset->pq_len() == 2) { + const float vq_scale = 1.0f; + const float pq_scale = 1.0f; + + if (vpq_dset->pq_bits() == 8) { + if (vpq_dset->pq_len() == 2) { using dataset_desc_t = cagra_q_dataset_descriptor_t; - dataset_desc_t dataset_desc(dataset->data.data_handle(), - dataset->pq_dim(), - dataset->vq_n_centers(), - dataset->vq_code_book.data_handle(), - 0, - dataset->pq_code_book.data_handle(), - 0, - dataset->n_rows(), - dataset->dim()); + dataset_desc_t dataset_desc(vpq_dset->data.data_handle(), + vpq_dset->encoded_row_length(), + vpq_dset->pq_dim(), + vpq_dset->vq_code_book.data_handle(), + vq_scale, + vpq_dset->pq_code_book.data_handle(), + pq_scale, + size_t(vpq_dset->n_rows()), + vpq_dset->dim()); search_main_core( res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter); - } else if (dataset->pq_len() == 4) { + } else if (vpq_dset->pq_len() == 4) { using dataset_desc_t = cagra_q_dataset_descriptor_t; - dataset_desc_t dataset_desc(dataset->data.data_handle(), - dataset->pq_dim(), - dataset->vq_n_centers(), - dataset->vq_code_book.data_handle(), - 0, - dataset->pq_code_book.data_handle(), - 0, - dataset->n_rows(), - dataset->dim()); + dataset_desc_t dataset_desc(vpq_dset->data.data_handle(), + vpq_dset->encoded_row_length(), + vpq_dset->pq_dim(), + vpq_dset->vq_code_book.data_handle(), + vq_scale, + vpq_dset->pq_code_book.data_handle(), + pq_scale, + size_t(vpq_dset->n_rows()), + vpq_dset->dim()); search_main_core( res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter); } else {