Skip to content

Commit

Permalink
Add a stub for the search function
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 11, 2024
1 parent 999d343 commit 4498a22
Showing 1 changed file with 58 additions and 58 deletions.
116 changes: 58 additions & 58 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

#include <rmm/cuda_stream_view.hpp>

Expand Down Expand Up @@ -279,62 +280,6 @@ index<T, IdxT> build(raft::resources const& res,
return detail::build<T, IdxT, Accessor>(res, params, dataset);
}

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
auto queries_internal = raft::make_device_matrix_view<const T, int64_t, row_major>(
queries.data_handle(), queries.extent(0), queries.extent(1));
auto neighbors_internal = raft::make_device_matrix_view<internal_IdxT, int64_t, row_major>(
reinterpret_cast<internal_IdxT*>(neighbors.data_handle()),
neighbors.extent(0),
neighbors.extent(1));
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T,
internal_IdxT,
decltype(raft::neighbors::filtering::none_cagra_sample_filter()),
IdxT>(res,
params,
idx,
queries_internal,
neighbors_internal,
distances_internal,
raft::neighbors::filtering::none_cagra_sample_filter());
}

/**
* @brief Search ANN using the constructed index with the given sample filter.
*
Expand Down Expand Up @@ -401,8 +346,63 @@ void search_with_filtering(raft::resources const& res,
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
// n_rows has the same type as the dataset index (the array extents type)
using ds_idx_type = decltype(idx.dataset().n_rows());
// Dispatch search parameters based on the dataset kind.
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&idx.dataset());
strided_dset != nullptr) {
// Search using a plain (strided) row-major dataset
return cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
} else if (auto* vpq_dset =
dynamic_cast<const strided_dataset<float, ds_idx_type>*>(&idx.dataset());
vpq_dset != nullptr) {
// Search using a compressed dataset
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset =
dynamic_cast<const strided_dataset<half, ds_idx_type>*>(&idx.dataset());
vpq_dset != nullptr) {
// Search using a compressed dataset
RAFT_FAIL("FP16 VPQ dataset support is coming soon");
} else if (auto* empty_dset = dynamic_cast<const empty_dataset<ds_idx_type>*>(&idx.dataset());
empty_dset != nullptr) {
// Forgot to add a dataset.
RAFT_FAIL(
"Attempted to search without a dataset. Please call index.update_dataset(...) first.");
} else {
// This is a logic error.
RAFT_FAIL("Unrecognized dataset format");
}
}

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
{
using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter;
return cagra::search_with_filtering<T, IdxT, none_filter_type>(
res, params, idx, queries, neighbors, distances, none_filter_type{});
}

/** @} */ // end group cagra
Expand Down

0 comments on commit 4498a22

Please sign in to comment.