diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b7e362f704..2f011f2a9b 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -48,6 +49,7 @@ namespace raft::neighbors::cagra { * * The following distance metrics are supported: * - L2Expanded + * - InnerProduct * * Usage example: * @code{.cpp} @@ -79,6 +81,7 @@ template void build_knn_graph(raft::resources const& res, mdspan, row_major, accessor> dataset, raft::host_matrix_view knn_graph, + raft::distance::DistanceType metric, std::optional refine_rate = std::nullopt, std::optional build_params = std::nullopt, std::optional search_params = std::nullopt) @@ -93,7 +96,7 @@ void build_knn_graph(raft::resources const& res, dataset.data_handle(), dataset.extent(0), dataset.extent(1)); cagra::detail::build_knn_graph( - res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); + res, dataset_internal, knn_graph_internal, metric, refine_rate, build_params, search_params); } /** diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d91e45257e..c929840129 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -18,6 +18,7 @@ #include "../../cagra_types.hpp" #include "../../vpq_dataset.cuh" #include "graph_core.cuh" +#include "raft/util/cudart_utils.hpp" #include #include @@ -44,11 +45,14 @@ template void build_knn_graph(raft::resources const& res, mdspan, row_major, accessor> dataset, raft::host_matrix_view knn_graph, + raft::distance::DistanceType metric, std::optional refine_rate = std::nullopt, std::optional build_params = std::nullopt, std::optional search_params = std::nullopt) { - RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded, + std::cout << "metric from build_knn_graph" << metric<< std::endl; + RAFT_EXPECTS(!build_params || build_params->metric == metric, "Mismatch between index metric and IVF-PQ metric"); + RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || metric == distance::DistanceType::InnerProduct, "Currently only L2Expanded metric is supported"); uint32_t node_degree = knn_graph.extent(1); @@ -65,6 +69,7 @@ void build_knn_graph(raft::resources const& res, build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10; build_params->kmeans_n_iters = 25; build_params->add_data_on_build = true; + build_params->metric = metric; } // Make model name @@ -148,6 +153,8 @@ void build_knn_graph(raft::resources const& res, distances.data_handle(), batch.size(), distances.extent(1)); ivf_pq::search(res, *search_params, index, queries_view, neighbors_view, distances_view); + raft::resource::sync_stream(res); + raft::print_device_vector("distances vector", distances.data_handle(), distances.extent(1), std::cout); if constexpr (is_host_mdspan_v) { raft::copy(neighbors_host.data_handle(), neighbors.data_handle(), @@ -174,6 +181,7 @@ void build_knn_graph(raft::resources const& res, refined_neighbors_host_view, refined_distances_host_view, build_params->metric); + raft::print_host_vector("host_distances", refined_distances_host.data_handle(), top_k, std::cout); } else { auto neighbor_candidates_view = make_device_matrix_view( neighbors.data_handle(), batch.size(), gpu_top_k); @@ -197,6 +205,7 @@ void build_knn_graph(raft::resources const& res, refined_neighbors_view.size(), resource::get_cuda_stream(res)); resource::sync_stream(res); + raft::print_device_vector("device_distances", refined_distances.data_handle(), top_k, std::cout); } // omit itself & write out // TODO(tfeher): do this in parallel with GPU processing of next batch @@ -321,7 +330,7 @@ index build( raft::make_host_matrix(dataset.extent(0), intermediate_degree)); if (params.build_algo == graph_build_algo::IVF_PQ) { - build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params); + build_knn_graph(res, dataset, knn_graph->view(), params.metric, refine_rate, pq_build_params, search_params); } else { // Use nn-descent to build CAGRA knn graph diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 86a5b34e8a..c454f95d2f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -271,7 +271,7 @@ void search_main(raft::resources const& res, strided_dset->n_rows(), strided_dset->dim(), strided_dset->stride()); - + std::cout << "index.metric from search_main" << index.metric() << std::endl; search_main_core(res, params, dataset_desc, diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh index 60bcf7d007..8053074af6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh @@ -18,6 +18,7 @@ #include "compute_distance.hpp" +#include #include namespace raft::neighbors::cagra::detail { @@ -168,22 +169,13 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t( query_ptr))[device::swizzling(d / 2)]; - if (metric == raft::distance::L2Expanded) { - // Loading PQ code book in smem - diff2 -= *(reinterpret_cast(smem_pq_code_book_ptr + - (1 << PQ_BITS) * 2 * (m / 2) + - (2 * (pq_code & 0xff)))); - diff2 -= vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2]; - norm2 += diff2 * diff2; - } else { - half2 multiplier2 = *(reinterpret_cast(smem_pq_code_book_ptr + - (1 << PQ_BITS) * 2 * (m / 2) + - (2 * (pq_code & 0xff)))) + - vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2]; - norm2 -= diff2 * multiplier2; - } - pq_code >>= 8; + // Loading PQ code book in smem + diff2 -= *(reinterpret_cast( + smem_pq_code_book_ptr + (1 << PQ_BITS) * 2 * (m / 2) + (2 * (pq_code & 0xff)))); + diff2 -= vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2]; + norm2 += diff2 * diff2; } + pq_code >>= 8; } } norm += static_cast(norm2.x + norm2.y); @@ -219,17 +211,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t= dataset_dim) break; DISTANCE_T diff = query_ptr[d]; // (from smem) - if (metric == raft::distance::L2Expanded) { - // Loading PQ code book in smem - diff -= pq_scale * static_cast(pq_vals.data[m]); - diff -= vq_scale * static_cast(vq_vals[d1 / vlen].val.data[d1 % vlen]); - norm += diff * diff; - } else { - DISTANCE_T multiplier = - pq_scale * static_cast(pq_vals.data[m]) + - vq_scale * static_cast(vq_vals[d1 / vlen].val.data[d1 % vlen]); - norm -= diff * multiplier; - } + diff -= pq_scale * static_cast(pq_vals.data[m]); + diff -= vq_scale * static_cast(vq_vals[d1 / vlen].val.data[d1 % vlen]); + norm += diff * diff; } pq_code >>= 8; } @@ -245,4 +229,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { num_queries, stream); if (topk_distances_ptr) { - bool invert = this->metric == distance::InnerProduct; + bool invert = this->metric == distance::DistanceType::InnerProduct; batched_memcpy(topk_distances_ptr, topk, result_distances_ptr, @@ -982,6 +982,7 @@ struct search : search_plan_impl { num_queries, stream, invert); + raft::print_device_vector("result_distances_ptr", result_distances_ptr, topk, std::cout); } if (num_executed_iterations) { diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index a725bb9880..731dc23066 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -395,9 +395,9 @@ class AnnCagraSortTest : public ::testing::TestWithParam { if (ps.build_algo == graph_build_algo::IVF_PQ) { if (ps.host_dataset) { - cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + cagra::build_knn_graph(handle_, database_host_view, knn_graph.view(), ps.metric); } else { - cagra::build_knn_graph(handle_, database_view, knn_graph.view()); + cagra::build_knn_graph(handle_, database_view, knn_graph.view(), ps.metric); } } else { auto nn_descent_idx_params = experimental::nn_descent::index_params{};