diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b8258297e6..2446a95ea7 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -279,62 +280,6 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } -/** - * @brief Search ANN using the constructed index. - * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - raft::neighbors::filtering::none_cagra_sample_filter()); -} - /** * @brief Search ANN using the constructed index with the given sample filter. * @@ -401,8 +346,63 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); + // n_rows has the same type as the dataset index (the array extents type) + using ds_idx_type = decltype(idx.dataset().n_rows()); + // Dispatch search parameters based on the dataset kind. + if (auto* strided_dset = dynamic_cast*>(&idx.dataset()); + strided_dset != nullptr) { + // Search using a plain (strided) row-major dataset + return cagra::detail::search_main( + res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); + } else if (auto* vpq_dset = + dynamic_cast*>(&idx.dataset()); + vpq_dset != nullptr) { + // Search using a compressed dataset + RAFT_FAIL("FP32 VPQ dataset support is coming soon"); + } else if (auto* vpq_dset = + dynamic_cast*>(&idx.dataset()); + vpq_dset != nullptr) { + // Search using a compressed dataset + RAFT_FAIL("FP16 VPQ dataset support is coming soon"); + } else if (auto* empty_dset = dynamic_cast*>(&idx.dataset()); + empty_dset != nullptr) { + // Forgot to add a dataset. + RAFT_FAIL( + "Attempted to search without a dataset. Please call index.update_dataset(...) first."); + } else { + // This is a logic error. + RAFT_FAIL("Unrecognized dataset format"); + } +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter; + return cagra::search_with_filtering( + res, params, idx, queries, neighbors, distances, none_filter_type{}); } /** @} */ // end group cagra