diff --git a/cpp/template/src/ivf_flat_example.cu b/cpp/template/src/ivf_flat_example.cu index b891c3cf4a..1e4b790c0c 100644 --- a/cpp/template/src/ivf_flat_example.cu +++ b/cpp/template/src/ivf_flat_example.cu @@ -23,11 +23,38 @@ #include #include #include +#include #include #include #include +// Copy the results to host and print a few samples +void print_results(raft::device_resources const& dev_resources, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + int64_t topk = neighbors.extent(1); + auto neighbors_host = raft::make_host_matrix(neighbors.extent(0), topk); + auto distances_host = raft::make_host_matrix(distances.extent(0), topk); + + cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources); + + raft::copy(neighbors_host.data_handle(), neighbors.data_handle(), neighbors.size(), stream); + raft::copy(distances_host.data_handle(), distances.data_handle(), distances.size(), stream); + + // The calls to ivf_flat::search and raft::copy is asyncronous. + // We need to sync the stream before accessing the data. + raft::resource::sync_stream(dev_resources, stream); + + for (int query_id = 0; query_id < 2; query_id++) { + std::cout << "Query " << query_id << " neighbor indices: "; + raft::print_host_vector("", &neighbors_host(query_id, 0), topk, std::cout); + std::cout << "Query " << query_id << " neighbor distances: "; + raft::print_host_vector("", &distances_host(query_id, 0), topk, std::cout); + } +} + void ivf_flat_build_search_simple(raft::device_resources const& dev_resources, raft::device_matrix_view dataset, raft::device_matrix_view queries) @@ -46,7 +73,7 @@ void ivf_flat_build_search_simple(raft::device_resources const& dev_resources, << index.size() << std::endl; // Create output arrays. - int64_t topk = 12; + int64_t topk = 10; int64_t n_queries = queries.extent(0); auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); @@ -58,12 +85,18 @@ void ivf_flat_build_search_simple(raft::device_resources const& dev_resources, // Search K nearest neighbors for each of the queries. ivf_flat::search( dev_resources, search_params, index, queries, neighbors.view(), distances.view()); + + // The call to ivf_flat::search is asyncronous. Before accessing the data, sync by calling + // raft::resource::sync_stream(dev_resources); + + print_results(dev_resources, neighbors.view(), distances.view()); } /** Subsample the dataset to create a training set*/ raft::device_matrix subsample( raft::device_resources const& dev_resources, raft::device_matrix_view dataset, + raft::device_vector_view data_indices, float fraction) { int64_t n_samples = dataset.extent(0); @@ -73,19 +106,10 @@ raft::device_matrix subsample( int seed = 137; raft::random::RngState rng(seed); - auto data_indices = raft::make_device_vector(dev_resources, n_samples); auto train_indices = raft::make_device_vector(dev_resources, n_train); - thrust::counting_iterator first(0); - thrust::device_ptr ptr(data_indices.data_handle()); - thrust::copy(raft::resource::get_thrust_policy(dev_resources), first, first + n_samples, ptr); - - raft::random::sample_without_replacement(dev_resources, - rng, - raft::make_const_mdspan(data_indices.view()), - std::nullopt, - train_indices.view(), - std::nullopt); + raft::random::sample_without_replacement( + dev_resources, rng, data_indices, std::nullopt, train_indices.view(), std::nullopt); raft::matrix::copy_rows( dev_resources, dataset, trainset.view(), raft::make_const_mdspan(train_indices.view())); @@ -99,8 +123,16 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources, { using namespace raft::neighbors; + // Define dataset indices. + auto data_indices = raft::make_device_vector(dev_resources, dataset.extent(0)); + thrust::counting_iterator first(0); + thrust::device_ptr ptr(data_indices.data_handle()); + thrust::copy( + raft::resource::get_thrust_policy(dev_resources), first, first + dataset.extent(0), ptr); + // Sub-sample the dataset to create a training set. - auto trainset = subsample(dev_resources, dataset, 0.1); + auto trainset = + subsample(dev_resources, dataset, raft::make_const_mdspan(data_indices.view()), 0.1); ivf_flat::index_params index_params; index_params.n_lists = 100; @@ -115,9 +147,7 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources, << index.size() << std::endl; std::cout << "Filling index with the dataset vectors" << std::endl; - - auto data_indices = raft::make_device_vector(dev_resources, dataset.extent(1)); - index = ivf_flat::extend(dev_resources, + index = ivf_flat::extend(dev_resources, dataset, std::make_optional(raft::make_const_mdspan(data_indices.view())), index); @@ -137,6 +167,11 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources, // Search K nearest neighbors for each queries. ivf_flat::search( dev_resources, search_params, index, queries, neighbors.view(), distances.view()); + + // The call to ivf_flat::search is asyncronous. Before accessing the data, sync using: + // raft::resource::sync_stream(dev_resources); + + print_results(dev_resources, neighbors.view(), distances.view()); } int main()