Skip to content

Commit

Permalink
Fix lauch_vpq_search_main_core
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Mar 13, 2024
1 parent 5d037b3 commit 239b29a
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,19 @@ template <class T,
class CagraSampleFilterT>
void lauch_vpq_search_main_core(
raft::resources const& res,
const vpq_dataset<DatasetT, DatasetIdxT>* dataset,
const vpq_dataset<DatasetT, DatasetIdxT>* vpq_dset,
search_params params,
raft::device_matrix_view<const InternalIdxT, int64_t, row_major> graph,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<InternalIdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter)
{
if (dataset->pq_bits() == 8) {
if (dataset->pq_len() == 2) {
const float vq_scale = 1.0f;
const float pq_scale = 1.0f;

if (vpq_dset->pq_bits() == 8) {
if (vpq_dset->pq_len() == 2) {
using dataset_desc_t = cagra_q_dataset_descriptor_t<T,
DatasetT,
8 /*PQ bit*/,
Expand All @@ -175,18 +178,18 @@ void lauch_vpq_search_main_core(
DistanceT,
InternalIdxT,
0>;
dataset_desc_t dataset_desc(dataset->data.data_handle(),
dataset->pq_dim(),
dataset->vq_n_centers(),
dataset->vq_code_book.data_handle(),
0,
dataset->pq_code_book.data_handle(),
0,
dataset->n_rows(),
dataset->dim());
dataset_desc_t dataset_desc(vpq_dset->data.data_handle(),
vpq_dset->encoded_row_length(),
vpq_dset->pq_dim(),
vpq_dset->vq_code_book.data_handle(),
vq_scale,
vpq_dset->pq_code_book.data_handle(),
pq_scale,
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
} else if (dataset->pq_len() == 4) {
} else if (vpq_dset->pq_len() == 4) {
using dataset_desc_t = cagra_q_dataset_descriptor_t<T,
DatasetT,
8 /*PQ bit*/,
Expand All @@ -195,15 +198,15 @@ void lauch_vpq_search_main_core(
DistanceT,
InternalIdxT,
0>;
dataset_desc_t dataset_desc(dataset->data.data_handle(),
dataset->pq_dim(),
dataset->vq_n_centers(),
dataset->vq_code_book.data_handle(),
0,
dataset->pq_code_book.data_handle(),
0,
dataset->n_rows(),
dataset->dim());
dataset_desc_t dataset_desc(vpq_dset->data.data_handle(),
vpq_dset->encoded_row_length(),
vpq_dset->pq_dim(),
vpq_dset->vq_code_book.data_handle(),
vq_scale,
vpq_dset->pq_code_book.data_handle(),
pq_scale,
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
} else {
Expand Down

0 comments on commit 239b29a

Please sign in to comment.