diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b7e362f704..5263ef73e7 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -48,13 +49,14 @@ namespace raft::neighbors::cagra { * * The following distance metrics are supported: * - L2Expanded + * - InnerProduct * * Usage example: * @code{.cpp} * using namespace raft::neighbors; - * // use default index parameters - * ivf_pq::index_params build_params; - * ivf_pq::search_params search_params + * // use default index parameters based on shape of the dataset + * ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset); + * ivf_pq::search_params search_params; * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d63f865c39..40dcf68e68 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -50,8 +51,9 @@ void build_knn_graph(raft::resources const& res, std::optional build_params = std::nullopt, std::optional search_params = std::nullopt) { - RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded, - "Currently only L2Expanded metric is supported"); + RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded || + build_params->metric == distance::DistanceType::InnerProduct, + "Currently only L2Expanded or InnerProduct metric are supported"); uint32_t node_degree = knn_graph.extent(1); common::nvtx::range fun_scope("cagra::build_graph(%zu, %zu, %u)", @@ -59,15 +61,7 @@ void build_knn_graph(raft::resources const& res, size_t(dataset.extent(1)), node_degree); - if (!build_params) { - build_params = ivf_pq::index_params{}; - build_params->n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500); - build_params->pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2); - build_params->pq_bits = 8; - build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10; - build_params->kmeans_n_iters = 25; - build_params->add_data_on_build = true; - } + if (!build_params) { build_params = ivf_pq::index_params::from_dataset(dataset); } // Make model name const std::string model_name = [&]() { @@ -324,8 +318,10 @@ index build( if (params.build_algo == graph_build_algo::IVF_PQ) { build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params); - } else { + RAFT_EXPECTS( + params.metric == raft::distance::DistanceType::L2Expanded, + "L2Expanded is the only distance metrics supported for CAGRA build with nn_descent"); // Use nn-descent to build CAGRA knn graph if (!nn_descent_params) { nn_descent_params = experimental::nn_descent::index_params(); @@ -348,6 +344,8 @@ index build( // Construct an index from dataset and optimized knn graph. if (construct_index_with_dataset) { if (params.compression.has_value()) { + RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded, + "VPQ compression is only supported with L2Expanded distance mertric"); index idx(res, params.metric); idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); idx.update_dataset( diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index b9edbbfc4a..67fad2e46a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -87,7 +88,8 @@ void search_main_core( raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) + CagraSampleFilterT sample_filter = CagraSampleFilterT(), + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) { RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", static_cast(dataset_desc.size), @@ -112,7 +114,7 @@ void search_main_core( using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; std::unique_ptr> plan = factory::create( - res, params, dataset_desc.dim, graph.extent(1), topk); + res, params, dataset_desc.dim, graph.extent(1), topk, metric); plan->check(topk); @@ -163,7 +165,8 @@ void launch_vpq_search_main_core( raft::device_matrix_view queries, raft::device_matrix_view neighbors, raft::device_matrix_view distances, - CagraSampleFilterT sample_filter) + CagraSampleFilterT sample_filter, + const raft::distance::DistanceType metric) { RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now"); RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4, @@ -192,7 +195,7 @@ void launch_vpq_search_main_core( size_t(vpq_dset->n_rows()), vpq_dset->dim()); search_main_core( - res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter); + res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric); } else if (vpq_dset->pq_len() == 4) { using dataset_desc_t = cagra_q_dataset_descriptor_tn_rows()), vpq_dset->dim()); search_main_core( - res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter); + res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric); } else { RAFT_FAIL("Subspace dimension must be 2 or 4"); } @@ -268,9 +271,15 @@ void search_main(raft::resources const& res, strided_dset->n_rows(), strided_dset->dim(), strided_dset->stride()); - - search_main_core( - res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter); + search_main_core(res, + params, + dataset_desc, + graph_internal, + queries, + neighbors, + distances, + sample_filter, + index.metric()); } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { // Search using a compressed dataset @@ -278,7 +287,15 @@ void search_main(raft::resources const& res, } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { launch_vpq_search_main_core( - res, vpq_dset, params, graph_internal, queries, neighbors, distances, sample_filter); + res, + vpq_dset, + params, + graph_internal, + queries, + neighbors, + distances, + sample_filter, + index.metric()); } else if (auto* empty_dset = dynamic_cast*>(&index.data()); empty_dset != nullptr) { // Forgot to add a dataset. diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 49e14be73d..80ee7a36f1 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -19,6 +19,8 @@ #include "hashmap.hpp" #include "utils.hpp" +#include +#include #include #include @@ -54,6 +56,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( const uint32_t num_seeds, INDEX_T* const visited_hash_ptr, const uint32_t hash_bitlen, + const raft::distance::DistanceType metric, const uint32_t block_id = 0, const uint32_t num_blocks = 1) { @@ -78,8 +81,22 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( } } - const auto norm2 = dataset_desc.template compute_similarity( - query_buffer, seed_index, valid_i); + DISTANCE_T norm2; + switch (metric) { + case raft::distance::L2Expanded: + norm2 = dataset_desc.template compute_similarity( + query_buffer, seed_index, valid_i); + break; + case raft::distance::InnerProduct: + norm2 = dataset_desc.template compute_similarity( + query_buffer, seed_index, valid_i); + break; + default: break; + } if (valid_i && (norm2 < best_norm2_team_local)) { best_norm2_team_local = norm2; @@ -121,7 +138,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes( const std::uint32_t hash_bitlen, const INDEX_T* const parent_indices, const INDEX_T* const internal_topk_list, - const std::uint32_t search_width) + const std::uint32_t search_width, + const raft::distance::DistanceType metric) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const INDEX_T invalid_index = utils::get_max_value(); @@ -153,8 +171,22 @@ _RAFT_DEVICE void compute_distance_to_child_nodes( INDEX_T child_id = invalid_index; if (valid_i) { child_id = result_child_indices_ptr[i]; } - const auto norm2 = dataset_desc.template compute_similarity( - query_buffer, child_id, child_id != invalid_index); + DISTANCE_T norm2; + switch (metric) { + case raft::distance::L2Expanded: + norm2 = + dataset_desc + .template compute_similarity( + query_buffer, child_id, child_id != invalid_index); + break; + case raft::distance::InnerProduct: + norm2 = dataset_desc.template compute_similarity( + query_buffer, child_id, child_id != invalid_index); + break; + default: break; + } // Store the distance const unsigned lane_id = threadIdx.x % TEAM_SIZE; @@ -220,7 +252,22 @@ struct standard_dataset_descriptor_t } } - template + template + std::enable_if_t __device__ + dist_op(T a, T b) const + { + T diff = a - b; + return diff * diff; + } + + template + std::enable_if_t __device__ + dist_op(T a, T b) const + { + return -a * b; + } + + template __device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr, const INDEX_T dataset_i, const bool valid) const @@ -252,9 +299,9 @@ struct standard_dataset_descriptor_t // because: // - Above the last element (dataset_dim-1), the query array is filled with zeros. // - The data buffer has to be also padded with zeros. - DISTANCE_T diff = query_ptr[device::swizzling(kv)]; - diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].val.data[v]); - norm2 += diff * diff; + DISTANCE_T d = query_ptr[device::swizzling(kv)]; + norm2 += dist_op( + d, spatial::knn::detail::utils::mapping{}(dl_buff[e].val.data[v])); } } } diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh index e73d24bfb6..c922a0d7f4 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh @@ -18,6 +18,7 @@ #include "compute_distance.hpp" +#include #include namespace raft::neighbors::cagra::detail { @@ -112,7 +113,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t + template __device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr, const INDEX_T node_id, const bool valid) const @@ -227,4 +228,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t>( new single_cta_search:: search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric)); } else if (plan.algo == search_algo::MULTI_CTA) { return std::unique_ptr>( new multi_cta_search:: search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric)); } else { return std::unique_ptr>( new multi_kernel_search:: search( - res, plan, plan.dim, plan.graph_degree, plan.topk)); + res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric)); } } }; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 8192b1ae51..4b979bcae8 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -24,11 +24,14 @@ #include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible #include "utils.hpp" +#include #include #include #include #include #include +#include +#include #include #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp @@ -96,8 +99,10 @@ struct search : public search_plan_impl { search_params params, int64_t dim, int64_t graph_degree, - uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk), + uint32_t topk, + raft::distance::DistanceType metric) + : search_plan_impl( + res, params, dim, graph_degree, topk, metric), intermediate_indices(0, resource::get_cuda_stream(res)), intermediate_distances(0, resource::get_cuda_stream(res)), topk_workspace(0, resource::get_cuda_stream(res)) @@ -235,6 +240,7 @@ struct search : public search_plan_impl { min_iterations, max_iterations, sample_filter, + this->metric, stream); RAFT_CUDA_TRY(cudaPeekAtLastError()); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index 50f9e69593..35f4f0e1c9 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -54,6 +54,7 @@ void select_and_run( size_t min_iterations, size_t max_iterations, SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric, cudaStream_t stream) RAFT_EXPLICIT; #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -88,6 +89,7 @@ void select_and_run( size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); instantiate_kernel_selection( @@ -172,6 +174,7 @@ instantiate_kernel_selection( size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); instantiate_q_kernel_selection( diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 48c22d9d14..cfbb1e100c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -149,7 +150,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( const uint32_t min_iteration, const uint32_t max_iteration, uint32_t* const num_executed_iterations, /* stats */ - SAMPLE_FILTER_T sample_filter) + SAMPLE_FILTER_T sample_filter, + const raft::distance::DistanceType metric) { using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; @@ -227,6 +229,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; uint32_t block_id = cta_id + (num_cta_per_query * query_id); uint32_t num_blocks = num_cta_per_query * num_queries; + device::compute_distance_to_random_nodes(result_indices_buffer, result_distances_buffer, query_buffer, @@ -238,6 +241,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( num_seeds, local_visited_hashmap_ptr, hash_bitlen, + metric, block_id, num_blocks); __syncthreads(); @@ -282,7 +286,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( hash_bitlen, parent_indices_buffer, result_indices_buffer, - search_width); + search_width, + metric); _CLK_REC(clk_compute_distance); __syncthreads(); @@ -459,6 +464,7 @@ void select_and_run( size_t min_iterations, size_t max_iterations, SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric, cudaStream_t stream) { auto kernel = @@ -484,6 +490,7 @@ void select_and_run( num_cta_per_query, num_queries, smem_size); + kernel<<>>(topk_indices_ptr, topk_distances_ptr, dataset_desc, @@ -501,7 +508,8 @@ void select_and_run( min_iterations, max_iterations, num_executed_iterations, - sample_filter); + sample_filter, + metric); } } // namespace multi_cta_search diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 10788da432..31c4bc5dca 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -100,7 +101,8 @@ RAFT_KERNEL random_pickup_kernel( typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] const std::uint32_t ldr, // (*) ldr >= num_pickup typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen) + const std::uint32_t hash_bitlen, + const raft::distance::DistanceType metric) { using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; @@ -137,8 +139,22 @@ RAFT_KERNEL random_pickup_kernel( device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_desc.size; } - const auto norm2 = dataset_desc.template compute_similarity( - query_buffer, seed_index, true); + DISTANCE_T norm2; + switch (metric) { + case distance::DistanceType::L2Expanded: + norm2 = dataset_desc.template compute_similarity( + query_buffer, seed_index, true); + break; + case distance::DistanceType::InnerProduct: + norm2 = dataset_desc.template compute_similarity( + query_buffer, seed_index, true); + break; + default: break; + } if (norm2 < best_norm2_team_local) { best_norm2_team_local = norm2; @@ -175,6 +191,7 @@ void random_pickup( const std::size_t ldr, // (*) ldr >= num_pickup typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] const std::uint32_t hash_bitlen, + const raft::distance::DistanceType metric, cudaStream_t const cuda_stream = 0) { const auto block_size = 256u; @@ -198,7 +215,8 @@ void random_pickup( result_distances_ptr, ldr, visited_hashmap_ptr, - hash_bitlen); + hash_bitlen, + metric); } template @@ -325,7 +343,8 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldd] typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - SAMPLE_FILTER_T sample_filter) + SAMPLE_FILTER_T sample_filter, + const raft::distance::DistanceType metric) { using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; @@ -371,8 +390,22 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( const auto compute_distance_flag = hashmap::insert( visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); - const auto norm2 = dataset_desc.template compute_similarity( - query_buffer, child_id, compute_distance_flag); + DISTANCE_T norm2; + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + norm2 = dataset_desc.template compute_similarity( + query_buffer, child_id, compute_distance_flag); + break; + case raft::distance::DistanceType::InnerProduct: + norm2 = dataset_desc.template compute_similarity( + query_buffer, child_id, compute_distance_flag); + break; + default: break; + } if (compute_distance_flag) { if (threadIdx.x % TEAM_SIZE == 0) { @@ -421,6 +454,7 @@ void compute_distance_to_child_nodes( typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree SAMPLE_FILTER_T sample_filter, + const raft::distance::DistanceType metric, cudaStream_t cuda_stream = 0) { const auto block_size = 128; @@ -452,7 +486,8 @@ void compute_distance_to_child_nodes( result_indices_ptr, result_distances_ptr, ldd, - sample_filter); + sample_filter, + metric); } template @@ -660,8 +695,10 @@ struct search : search_plan_impl { search_params params, int64_t dim, int64_t graph_degree, - uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk), + uint32_t topk, + raft::distance::DistanceType metric) + : search_plan_impl( + res, params, dim, graph_degree, topk, metric), result_indices(0, resource::get_cuda_stream(res)), result_distances(0, resource::get_cuda_stream(res)), parent_node_list(0, resource::get_cuda_stream(res)), @@ -835,6 +872,7 @@ struct search : search_plan_impl { result_buffer_allocation_size, hashmap.data(), hash_bitlen, + this->metric, stream); unsigned iter = 0; @@ -904,6 +942,7 @@ struct search : search_plan_impl { result_distances.data() + itopk_size, result_buffer_allocation_size, sample_filter, + this->metric, stream); iter++; @@ -1020,8 +1059,10 @@ struct search(res, params, dim, graph_degree, topk) + uint32_t topk, + raft::distance::DistanceType metric) + : search_plan_impl( + res, params, dim, graph_degree, topk, metric) { THROW("The multi-kernel mode does not support VPQ"); } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index be5ac0554f..b35d96e9f5 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -35,8 +36,13 @@ struct search_plan_impl_base : public search_params { int64_t dim; int64_t graph_degree; uint32_t topk; - search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) - : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk) + raft::distance::DistanceType metric; + search_plan_impl_base(search_params params, + int64_t dim, + int64_t graph_degree, + uint32_t topk, + raft::distance::DistanceType metric) + : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk), metric(metric) { set_dataset_block_and_team_size(dim); if (algo == search_algo::AUTO) { @@ -97,8 +103,9 @@ struct search_plan_impl : public search_plan_impl_base { search_params params, int64_t dim, int64_t graph_degree, - uint32_t topk) - : search_plan_impl_base(params, dim, graph_degree, topk), + uint32_t topk, + raft::distance::DistanceType metric) + : search_plan_impl_base(params, dim, graph_degree, topk, metric), hashmap(0, resource::get_cuda_stream(res)), num_executed_iterations(0, resource::get_cuda_stream(res)), dev_seed(0, resource::get_cuda_stream(res)), diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 4430b929fb..0771652787 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -94,8 +94,10 @@ struct search : search_plan_impl { search_params params, int64_t dim, int64_t graph_degree, - uint32_t topk) - : search_plan_impl(res, params, dim, graph_degree, topk) + uint32_t topk, + raft::distance::DistanceType metric) + : search_plan_impl( + res, params, dim, graph_degree, topk, metric) { set_params(res); } @@ -244,6 +246,7 @@ struct search : search_plan_impl { min_iterations, max_iterations, sample_filter, + this->metric, stream); } }; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index a836334667..510219ab5d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -54,6 +54,7 @@ void select_and_run( // raft::resources const& res, size_t min_iterations, size_t max_iterations, SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric, cudaStream_t stream) RAFT_EXPLICIT; #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -90,6 +91,7 @@ void select_and_run( // raft::resources const& res, size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); instantiate_single_cta_select_and_run( @@ -175,6 +177,7 @@ instantiate_single_cta_select_and_run( size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); instantiate_q_single_cta_select_and_run( diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index a697f9512c..e8104bd6f6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -485,7 +486,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( const std::uint32_t hash_bitlen, const std::uint32_t small_hash_bitlen, const std::uint32_t small_hash_reset_interval, - SAMPLE_FILTER_T sample_filter) + SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric) { using LOAD_T = device::LOAD_128BIT_T; @@ -581,7 +583,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( local_seed_ptr, num_seeds, local_visited_hashmap_ptr, - hash_bitlen); + hash_bitlen, + metric); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -718,7 +721,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( hash_bitlen, parent_list_buffer, result_indices_buffer, - search_width); + search_width, + metric); __syncthreads(); _CLK_REC(clk_compute_distance); @@ -930,6 +934,7 @@ void select_and_run( size_t min_iterations, size_t max_iterations, SAMPLE_FILTER_T sample_filter, + raft::distance::DistanceType metric, cudaStream_t stream) { auto kernel = @@ -962,7 +967,8 @@ void select_and_run( hash_bitlen, small_hash_bitlen, small_hash_reset_interval, - sample_filter); + sample_filter, + metric); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace single_cta_search diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 81e2886b18..3ee350c6fb 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -104,6 +104,36 @@ struct index_params : ann::index_params { * flag to `true` if you prefer to use as little GPU memory for the database as possible. */ bool conservative_memory_allocation = false; + + /** + * Creates index_params based on shape of the input dataset. + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * raft::resources res; + * // create index_params for a [N. D] dataset and have InnerProduct as the distance metric + * auto dataset = raft::make_device_matrix(res, N, D); + * ivf_pq::index_params index_params = + * ivf_pq::index_params::from_dataset(dataset.view(), raft::distance::InnerProduct); + * // modify/update index_params as needed + * index_params.add_data_on_build = true; + * @endcode + */ + template + static index_params from_dataset( + mdspan, row_major, Accessor> dataset, + raft::distance::DistanceType metric = raft::distance::L2Expanded) + { + index_params params; + params.n_lists = + dataset.extent(0) < 4 * 2500 ? 4 : static_cast(std::sqrt(dataset.extent(0))); + params.pq_dim = + round_up_safe(static_cast(dataset.extent(1) / 4), static_cast(8)); + params.pq_bits = 8; + params.kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 0.1; + params.metric = metric; + return params; + } }; struct search_params : ann::search_params { diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 179bf8f20f..542fdaad1f 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -44,6 +44,7 @@ namespace raft::neighbors::cagra::detail::multi_cta_search { size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); #define COMMA , diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 7fb705a2d2..855b104670 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -45,6 +45,7 @@ namespace raft::neighbors::cagra::detail::single_cta_search { size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); #define COMMA , diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 7278f71a24..715a94403f 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -85,25 +86,49 @@ void RandomSuffle(raft::host_matrix_view index) template testing::AssertionResult CheckOrder(raft::host_matrix_view index_test, - raft::host_matrix_view dataset) + raft::host_matrix_view dataset, + raft::distance::DistanceType metric) { for (IdxT i = 0; i < index_test.extent(0); i++) { const DatatT* const base_vec = dataset.data_handle() + i * dataset.extent(1); const IdxT* const index_row = index_test.data_handle() + i * index_test.extent(1); - DistanceT prev_distance = 0; + DistanceT prev_distance = metric == raft::distance::DistanceType::L2Expanded + ? 0 + : std::numeric_limits::max(); for (unsigned j = 0; j < index_test.extent(1) - 1; j++) { const DatatT* const target_vec = dataset.data_handle() + index_row[j] * dataset.extent(1); DistanceT distance = 0; - for (unsigned l = 0; l < dataset.extent(1); l++) { - const auto diff = - static_cast(target_vec[l]) - static_cast(base_vec[l]); - distance += diff * diff; - } - if (prev_distance > distance) { - return testing::AssertionFailure() - << "Wrong index order (row = " << i << ", neighbor_id = " << j - << "). (distance[neighbor_id-1] = " << prev_distance - << "should be larger than distance[neighbor_id] = " << distance << ")"; + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + for (unsigned l = 0; l < dataset.extent(1); l++) { + const auto diff = + static_cast(target_vec[l]) - static_cast(base_vec[l]); + distance += diff * diff; + } + if (prev_distance > distance) { + return testing::AssertionFailure() + << "Wrong index order (row = " << i << ", neighbor_id = " << j + << "). (distance[neighbor_id-1] = " << prev_distance + << "should be lesser than distance[neighbor_id] = " << distance << ")"; + } + break; + case raft::distance::DistanceType::InnerProduct: + for (unsigned l = 0; l < dataset.extent(1); l++) { + const auto prod = + static_cast(target_vec[l]) * static_cast(base_vec[l]); + distance += prod; + } + if (prev_distance < distance) { + return testing::AssertionFailure() + << "Wrong index order (row = " << i << ", neighbor_id = " << j + << "). (distance[neighbor_id-1] = " << prev_distance + << "should be greater than distance[neighbor_id] = " << distance << ")"; + } + break; + default: + return testing::AssertionFailure() + << "Distance metric " << metric + << " not supported. Only L2Expanded and InnerProduct are supported"; } prev_distance = distance; } @@ -221,6 +246,11 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { + // TODO (tarang-jain): remove when NN Descent index building support InnerProduct. Reference + // issue: https://github.com/rapidsai/raft/issues/2276 + if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) + GTEST_SKIP(); + size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size); @@ -301,6 +331,7 @@ class AnnCagraTest : public ::testing::TestWithParam { // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); // } + double min_recall = ps.min_recall; EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, @@ -368,6 +399,9 @@ class AnnCagraSortTest : public ::testing::TestWithParam { protected: void testCagraSort() { + if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) + GTEST_SKIP(); + { // Step 1: Build a sorted KNN graph by CAGRA knn build auto database_view = raft::make_device_matrix_view( @@ -383,10 +417,13 @@ class AnnCagraSortTest : public ::testing::TestWithParam { raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); if (ps.build_algo == graph_build_algo::IVF_PQ) { + auto build_params = ivf_pq::index_params::from_dataset(database_view, ps.metric); if (ps.host_dataset) { - cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + cagra::build_knn_graph( + handle_, database_host_view, knn_graph.view(), 2, build_params); } else { - cagra::build_knn_graph(handle_, database_view, knn_graph.view()); + cagra::build_knn_graph( + handle_, database_view, knn_graph.view(), 2, build_params); } } else { auto nn_descent_idx_params = experimental::nn_descent::index_params{}; @@ -403,14 +440,16 @@ class AnnCagraSortTest : public ::testing::TestWithParam { } handle_.sync_stream(); - ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view())); + ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view(), ps.metric)); - RandomSuffle(knn_graph.view()); + if (ps.metric != raft::distance::DistanceType::InnerProduct) { + RandomSuffle(knn_graph.view()); - cagra::sort_knn_graph(handle_, database_view, knn_graph.view()); - handle_.sync_stream(); + cagra::sort_knn_graph(handle_, database_view, knn_graph.view()); + handle_.sync_stream(); - ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view())); + ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view(), ps.metric)); + } } } @@ -453,6 +492,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { protected: void testCagraFilter() { + if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) + GTEST_SKIP(); + size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size); @@ -575,6 +617,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { void testCagraRemoved() { + if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) + GTEST_SKIP(); + size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size); @@ -741,7 +786,7 @@ inline std::vector generate_inputs() {0}, {256}, {1}, - {raft::distance::DistanceType::L2Expanded}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false}, {true}, {0.995}); @@ -757,7 +802,7 @@ inline std::vector generate_inputs() {0}, {256}, {1}, - {raft::distance::DistanceType::L2Expanded}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false}, {true}, {99. / 100} @@ -776,7 +821,7 @@ inline std::vector generate_inputs() {0}, {64}, {1}, - {raft::distance::DistanceType::L2Expanded}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false}, {true}, {0.995}); @@ -792,7 +837,7 @@ inline std::vector generate_inputs() {0, 4, 8, 16, 32}, // team_size {64}, {1}, - {raft::distance::DistanceType::L2Expanded}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false}, {false}, {0.995}); @@ -809,7 +854,7 @@ inline std::vector generate_inputs() {0}, // team_size {32, 64, 128, 256, 512, 768}, {1}, - {raft::distance::DistanceType::L2Expanded}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false}, {true}, {0.995}); @@ -826,27 +871,27 @@ inline std::vector generate_inputs() {0}, // team_size {64}, {1}, - {raft::distance::DistanceType::L2Expanded}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}, {false}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {20000}, - {32}, - {2048}, // k - {graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, - {4096}, // itopk_size - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {false}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {20000}, + {32}, + {2048}, // k + {graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, + {4096}, // itopk_size + {1}, + {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, + {false}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); return inputs; diff --git a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh index 5cca6d561a..412e71bff1 100644 --- a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh +++ b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh @@ -51,6 +51,7 @@ namespace multi_cta_search { size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); instantiate_kernel_selection(standard_dataset_descriptor_t, @@ -118,6 +119,7 @@ namespace single_cta_search { size_t min_iterations, \ size_t max_iterations, \ SAMPLE_FILTER_T sample_filter, \ + raft::distance::DistanceType metric, \ cudaStream_t stream); instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index df31d2560b..0e488a51ca 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -97,9 +97,11 @@ cdef class IndexParams: Parameters ---------- metric : string denoting the metric type, default="sqeuclidean" - Valid values for metric: ["sqeuclidean"], where + Valid values for metric: ["sqeuclidean", "inner_product"], where - sqeuclidean is the euclidean distance without the square root operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2 + - inner_product is the dot product between two vectors i.e.: + distance(a, b) = \\sum_i (a_i * b_i) intermediate_graph_degree : int, default = 128 graph_degree : int, default = 64 @@ -355,6 +357,7 @@ def build(IndexParams index_params, dataset, handle=None): The following distance metrics are supported: - L2 + - inner_product Parameters ---------- diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index be53b33da3..ef8e54917a 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -29,7 +29,7 @@ def run_cagra_build_search_test( n_queries=100, k=10, dtype=np.float32, - metric="euclidean", + metric="sqeuclidean", intermediate_graph_degree=128, graph_degree=64, build_algo="ivf_pq", @@ -143,7 +143,7 @@ def test_cagra_dataset_dtype_host_device( "graph_degree": 32, "add_data_on_build": True, "k": 1, - "metric": "euclidean", + "metric": "sqeuclidean", "build_algo": "ivf_pq", }, { @@ -159,7 +159,7 @@ def test_cagra_dataset_dtype_host_device( "graph_degree": 32, "add_data_on_build": True, "k": 10, - "metric": "inner_product", + "metric": "sqeuclidean", "build_algo": "nn_descent", }, ],