Skip to content

Commit

Permalink
Fix dataset indices, Sync and print results
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Sep 18, 2023
1 parent 074353e commit c44d751
Showing 1 changed file with 51 additions and 16 deletions.
67 changes: 51 additions & 16 deletions cpp/template/src/ivf_flat_example.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,38 @@
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/random/sample_without_replacement.cuh>
#include <raft/util/cudart_utils.hpp>

#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>

// Copy the results to host and print a few samples
void print_results(raft::device_resources const& dev_resources,
raft::device_matrix_view<int64_t, int64_t> neighbors,
raft::device_matrix_view<float, int64_t> distances)
{
int64_t topk = neighbors.extent(1);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(neighbors.extent(0), topk);
auto distances_host = raft::make_host_matrix<float, int64_t>(distances.extent(0), topk);

cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources);

raft::copy(neighbors_host.data_handle(), neighbors.data_handle(), neighbors.size(), stream);
raft::copy(distances_host.data_handle(), distances.data_handle(), distances.size(), stream);

// The calls to ivf_flat::search and raft::copy is asyncronous.
// We need to sync the stream before accessing the data.
raft::resource::sync_stream(dev_resources, stream);

for (int query_id = 0; query_id < 2; query_id++) {
std::cout << "Query " << query_id << " neighbor indices: ";
raft::print_host_vector("", &neighbors_host(query_id, 0), topk, std::cout);
std::cout << "Query " << query_id << " neighbor distances: ";
raft::print_host_vector("", &distances_host(query_id, 0), topk, std::cout);
}
}

void ivf_flat_build_search_simple(raft::device_resources const& dev_resources,
raft::device_matrix_view<const float, int64_t> dataset,
raft::device_matrix_view<const float, int64_t> queries)
Expand All @@ -46,7 +73,7 @@ void ivf_flat_build_search_simple(raft::device_resources const& dev_resources,
<< index.size() << std::endl;

// Create output arrays.
int64_t topk = 12;
int64_t topk = 10;
int64_t n_queries = queries.extent(0);
auto neighbors = raft::make_device_matrix<int64_t>(dev_resources, n_queries, topk);
auto distances = raft::make_device_matrix<float>(dev_resources, n_queries, topk);
Expand All @@ -58,12 +85,18 @@ void ivf_flat_build_search_simple(raft::device_resources const& dev_resources,
// Search K nearest neighbors for each of the queries.
ivf_flat::search(
dev_resources, search_params, index, queries, neighbors.view(), distances.view());

// The call to ivf_flat::search is asyncronous. Before accessing the data, sync by calling
// raft::resource::sync_stream(dev_resources);

print_results(dev_resources, neighbors.view(), distances.view());
}

/** Subsample the dataset to create a training set*/
raft::device_matrix<float, int64_t> subsample(
raft::device_resources const& dev_resources,
raft::device_matrix_view<const float, int64_t> dataset,
raft::device_vector_view<const int64_t, int64_t> data_indices,
float fraction)
{
int64_t n_samples = dataset.extent(0);
Expand All @@ -73,19 +106,10 @@ raft::device_matrix<float, int64_t> subsample(

int seed = 137;
raft::random::RngState rng(seed);
auto data_indices = raft::make_device_vector<int64_t>(dev_resources, n_samples);
auto train_indices = raft::make_device_vector<int64_t>(dev_resources, n_train);

thrust::counting_iterator<int64_t> first(0);
thrust::device_ptr<int64_t> ptr(data_indices.data_handle());
thrust::copy(raft::resource::get_thrust_policy(dev_resources), first, first + n_samples, ptr);

raft::random::sample_without_replacement(dev_resources,
rng,
raft::make_const_mdspan(data_indices.view()),
std::nullopt,
train_indices.view(),
std::nullopt);
raft::random::sample_without_replacement(
dev_resources, rng, data_indices, std::nullopt, train_indices.view(), std::nullopt);

raft::matrix::copy_rows(
dev_resources, dataset, trainset.view(), raft::make_const_mdspan(train_indices.view()));
Expand All @@ -99,8 +123,16 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources,
{
using namespace raft::neighbors;

// Define dataset indices.
auto data_indices = raft::make_device_vector<int64_t, int64_t>(dev_resources, dataset.extent(0));
thrust::counting_iterator<int64_t> first(0);
thrust::device_ptr<int64_t> ptr(data_indices.data_handle());
thrust::copy(
raft::resource::get_thrust_policy(dev_resources), first, first + dataset.extent(0), ptr);

// Sub-sample the dataset to create a training set.
auto trainset = subsample(dev_resources, dataset, 0.1);
auto trainset =
subsample(dev_resources, dataset, raft::make_const_mdspan(data_indices.view()), 0.1);

ivf_flat::index_params index_params;
index_params.n_lists = 100;
Expand All @@ -115,9 +147,7 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources,
<< index.size() << std::endl;

std::cout << "Filling index with the dataset vectors" << std::endl;

auto data_indices = raft::make_device_vector<int64_t, int64_t>(dev_resources, dataset.extent(1));
index = ivf_flat::extend(dev_resources,
index = ivf_flat::extend(dev_resources,
dataset,
std::make_optional(raft::make_const_mdspan(data_indices.view())),
index);
Expand All @@ -137,6 +167,11 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources,
// Search K nearest neighbors for each queries.
ivf_flat::search(
dev_resources, search_params, index, queries, neighbors.view(), distances.view());

// The call to ivf_flat::search is asyncronous. Before accessing the data, sync using:
// raft::resource::sync_stream(dev_resources);

print_results(dev_resources, neighbors.view(), distances.view());
}

int main()
Expand Down

0 comments on commit c44d751

Please sign in to comment.