Skip to content

Commit

Permalink
mdspan feature for build and extend
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Jul 8, 2024
1 parent 666d47f commit 1a559a6
Showing 1 changed file with 43 additions and 38 deletions.
81 changes: 43 additions & 38 deletions cpp/include/raft/neighbors/detail/ann_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include <raft/neighbors/brute_force.cuh>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/cagra.cuh>
#define RAFT_EXPLICIT_INSTANTIATE_ONLY

#include <raft_runtime/neighbors/ivf_flat.hpp>
Expand All @@ -38,63 +41,65 @@
#include <raft/neighbors/cagra_serialize.cuh>


// Number of rows per batch (search on shards)
#define N_ROWS_PER_BATCH 2^24

namespace raft::neighbors::mg::detail {
using namespace raft::neighbors;

template <typename AnnIndexType, typename T, typename IdxT>
class ann_interface {
public:

template <typename Accessor>
void build(raft::resources const& handle,
const ann::index_params* index_params,
raft::host_matrix_view<const T, IdxT, row_major> h_index_dataset)
raft::mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> index_dataset)
{
IdxT n_rows = h_index_dataset.extent(0);
IdxT n_dims = h_index_dataset.extent(1);
auto d_index_dataset = raft::make_device_matrix<T, IdxT, row_major>(handle, n_rows, n_dims);
raft::copy(d_index_dataset.data_handle(), h_index_dataset.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle));
raft::device_matrix_view<const T, IdxT, row_major> d_index_dataset_view = raft::make_device_matrix_view<const T, IdxT, row_major>(d_index_dataset.data_handle(), n_rows, n_dims);

if constexpr (std::is_same<AnnIndexType, ivf_flat::index<T, IdxT>>::value) {
index_.emplace(std::move(raft::runtime::neighbors::ivf_flat::build(
handle, *static_cast<const ivf_flat::index_params*>(index_params), d_index_dataset_view)));
auto idx = raft::neighbors::ivf_flat::build(handle,
*static_cast<const ivf_flat::index_params*>(index_params),
index_dataset.data_handle(),
index_dataset.extent(0),
index_dataset.extent(1));
index_.emplace(std::move(idx));
} else if constexpr (std::is_same<AnnIndexType, ivf_pq::index<IdxT>>::value) {
index_.emplace(std::move(raft::runtime::neighbors::ivf_pq::build(
handle, *static_cast<const ivf_pq::index_params*>(index_params), d_index_dataset_view)));
auto idx = raft::neighbors::ivf_pq::build(handle,
*static_cast<const ivf_pq::index_params*>(index_params),
index_dataset.data_handle(),
index_dataset.extent(0),
index_dataset.extent(1));
index_.emplace(std::move(idx));
} else if constexpr (std::is_same<AnnIndexType, cagra::index<T, IdxT>>::value) {
index_.emplace(std::move(raft::runtime::neighbors::cagra::build(
handle, *static_cast<const cagra::index_params*>(index_params), d_index_dataset_view)));
auto extents = raft::make_extents<int64_t>(index_dataset.extent(0), index_dataset.extent(1));
const bool host_acc = decltype(index_dataset)::accessor_type::is_host_type::value;
const bool device_acc = decltype(index_dataset)::accessor_type::is_device_type::value;
auto dataset = raft::make_mdspan<const T, int64_t, row_major, host_acc, device_acc>(index_dataset.data_handle(), extents);
cagra::index<T, IdxT> idx(handle);
idx = raft::neighbors::cagra::build<T, IdxT>(handle,
*static_cast<const cagra::index_params*>(index_params),
dataset);
index_.emplace(std::move(idx));
}
resource::sync_stream(handle);
}

template <typename Accessor1, typename Accessor2>
void extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> h_new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> h_new_indices)
raft::mdspan<const T, matrix_extent<IdxT>, row_major, Accessor1> new_vectors,
std::optional<raft::mdspan<const IdxT, vector_extent<IdxT>, layout_c_contiguous, Accessor2>> new_indices)
{
IdxT n_rows = h_new_vectors.extent(0);
IdxT n_dims = h_new_vectors.extent(1);
auto d_new_vectors = raft::make_device_matrix<T, IdxT, row_major>(handle, n_rows, n_dims);
raft::copy(d_new_vectors.data_handle(), h_new_vectors.data_handle(), n_rows * n_dims, resource::get_cuda_stream(handle));
raft::device_matrix_view<const T, IdxT, row_major> d_new_vectors_view = \
raft::make_device_matrix_view<const T, IdxT, row_major>(d_new_vectors.data_handle(), n_rows, n_dims);

std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices_opt = std::nullopt;
if (h_new_indices.has_value()) {
auto d_new_indices = raft::make_device_vector<IdxT, IdxT>(handle, n_rows);
raft::copy(d_new_indices.data_handle(), h_new_indices.value().data_handle(), n_rows, resource::get_cuda_stream(handle));
auto d_new_indices_view = raft::device_vector_view<const IdxT, IdxT>(d_new_indices.data_handle(), n_rows);
new_indices_opt = std::move(d_new_indices_view);
}

if constexpr (std::is_same<AnnIndexType, ivf_flat::index<T, IdxT>>::value) {
index_.emplace(std::move(raft::runtime::neighbors::ivf_flat::extend(
handle, d_new_vectors_view, new_indices_opt, index_.value())));
auto idx = raft::neighbors::ivf_flat::extend(handle,
index_.value(),
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
new_vectors.extent(0));
index_.emplace(std::move(idx));
} else if constexpr (std::is_same<AnnIndexType, ivf_pq::index<IdxT>>::value) {
index_.emplace(std::move(raft::runtime::neighbors::ivf_pq::extend(
handle, d_new_vectors_view, new_indices_opt, index_.value())));
auto idx = raft::neighbors::ivf_pq::extend(handle,
index_.value(),
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
new_vectors.extent(0));
index_.emplace(std::move(idx));
} else if constexpr (std::is_same<AnnIndexType, cagra::index<T, IdxT>>::value) {
RAFT_FAIL("CAGRA does not implement the extend method");
}
Expand Down Expand Up @@ -406,7 +411,7 @@ class ann_mg_index {
root_handle, n_rows_per_batch, n_neighbors);

for (IdxT batch_idx = 0; batch_idx < n_batches; batch_idx++) {
IdxT offset = batch_idx * N_ROWS_PER_BATCH;
IdxT offset = batch_idx * n_rows_per_batch;
IdxT query_offset = offset * n_cols;
IdxT output_offset = offset * n_neighbors;
IdxT n_rows_of_current_batch = std::min((IdxT)n_rows_per_batch, n_rows - offset);
Expand Down

0 comments on commit 1a559a6

Please sign in to comment.