Skip to content

Commit

Permalink
Merge branch 'branch-24.02' into fea-mdbuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks authored Nov 28, 2023
2 parents 3da9348 + 88e9a55 commit 3b1f245
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 85 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class RaftIvfFlatGpu : public ANN<T> {
AlgoProperty get_preference() const override
{
AlgoProperty property;
property.dataset_memory_type = MemoryType::Device;
property.dataset_memory_type = MemoryType::HostMmap;
property.query_memory_type = MemoryType::Device;
return property;
}
Expand Down
106 changes: 74 additions & 32 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
uint32_t* list_sizes_ptr,
IdxT n_rows,
uint32_t dim,
uint32_t veclen)
uint32_t veclen,
IdxT batch_offset = 0)
{
const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x;
if (i >= n_rows) { return; }
Expand All @@ -131,7 +132,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
auto* list_data = list_data_ptrs[list_id];

// Record the source vector id in the index
list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i];
list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i];

// The data is written in interleaved groups of `index::kGroupSize` vectors
using interleaved_group = Pow2<kIndexGroupSize>;
Expand Down Expand Up @@ -180,16 +181,33 @@ void extend(raft::resources const& handle,

auto new_labels = raft::make_device_vector<LabelT, IdxT>(handle, n_rows);
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.metric = index->metric();
auto new_vectors_view = raft::make_device_matrix_view<const T, IdxT>(new_vectors, n_rows, dim);
kmeans_params.metric = index->metric();
auto orig_centroids_view =
raft::make_device_matrix_view<const float, IdxT>(index->centers().data_handle(), n_lists, dim);
raft::cluster::kmeans_balanced::predict(handle,
kmeans_params,
new_vectors_view,
orig_centroids_view,
new_labels.view(),
utils::mapping<float>{});
// Calculate the batch size for the input data if it's not accessible directly from the device
constexpr size_t kReasonableMaxBatchSize = 65536;
size_t max_batch_size = std::min<size_t>(n_rows, kReasonableMaxBatchSize);

// Predict the cluster labels for the new data, in batches if necessary
utils::batch_load_iterator<T> vec_batches(new_vectors,
n_rows,
index->dim(),
max_batch_size,
stream,
resource::get_workspace_resource(handle));

for (const auto& batch : vec_batches) {
auto batch_data_view =
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
auto batch_labels_view = raft::make_device_vector_view<LabelT, IdxT>(
new_labels.data_handle() + batch.offset(), batch.size());
raft::cluster::kmeans_balanced::predict(handle,
kmeans_params,
batch_data_view,
orig_centroids_view,
batch_labels_view,
utils::mapping<float>{});
}

auto* list_sizes_ptr = index->list_sizes().data_handle();
auto old_list_sizes_dev = raft::make_device_vector<uint32_t, IdxT>(handle, n_lists);
Expand All @@ -202,14 +220,19 @@ void extend(raft::resources const& handle,
auto list_sizes_view =
raft::make_device_vector_view<std::remove_pointer_t<decltype(list_sizes_ptr)>, IdxT>(
list_sizes_ptr, n_lists);
auto const_labels_view = make_const_mdspan(new_labels.view());
raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle,
new_vectors_view,
const_labels_view,
centroids_view,
list_sizes_view,
false,
utils::mapping<float>{});
for (const auto& batch : vec_batches) {
auto batch_data_view =
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
auto batch_labels_view = raft::make_device_vector_view<const LabelT, IdxT>(
new_labels.data_handle() + batch.offset(), batch.size());
raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle,
batch_data_view,
batch_labels_view,
centroids_view,
list_sizes_view,
false,
utils::mapping<float>{});
}
} else {
raft::stats::histogram<uint32_t, IdxT>(raft::stats::HistTypeAuto,
reinterpret_cast<int32_t*>(list_sizes_ptr),
Expand Down Expand Up @@ -244,20 +267,39 @@ void extend(raft::resources const& handle,
// we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter.
raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream);

// Kernel to insert the new vectors
const dim3 block_dim(256);
const dim3 grid_dim(raft::ceildiv<IdxT>(n_rows, block_dim.x));
build_index_kernel<<<grid_dim, block_dim, 0, stream>>>(new_labels.data_handle(),
new_vectors,
new_indices,
index->data_ptrs().data_handle(),
index->inds_ptrs().data_handle(),
list_sizes_ptr,
n_rows,
dim,
index->veclen());
RAFT_CUDA_TRY(cudaPeekAtLastError());

utils::batch_load_iterator<IdxT> vec_indices(
new_indices, n_rows, 1, max_batch_size, stream, resource::get_workspace_resource(handle));
utils::batch_load_iterator<IdxT> idx_batch = vec_indices.begin();
size_t next_report_offset = 0;
size_t d_report_offset = n_rows * 5 / 100;
for (const auto& batch : vec_batches) {
auto batch_data_view =
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
// Kernel to insert the new vectors
const dim3 block_dim(256);
const dim3 grid_dim(raft::ceildiv<IdxT>(batch.size(), block_dim.x));
build_index_kernel<T, IdxT, LabelT>
<<<grid_dim, block_dim, 0, stream>>>(new_labels.data_handle() + batch.offset(),
batch_data_view.data_handle(),
idx_batch->data(),
index->data_ptrs().data_handle(),
index->inds_ptrs().data_handle(),
list_sizes_ptr,
batch.size(),
dim,
index->veclen(),
batch.offset());
RAFT_CUDA_TRY(cudaPeekAtLastError());

if (batch.offset() > next_report_offset) {
float progress = batch.offset() * 100.0f / n_rows;
RAFT_LOG_DEBUG("ivf_flat::extend added vectors %zu, %6.1f%% complete",
static_cast<size_t>(batch.offset()),
progress);
next_report_offset += d_report_offset;
}
++idx_batch;
}
// Precompute the centers vector norms for L2Expanded distance
if (!index->center_norms().has_value()) {
index->allocate_center_norms(handle);
Expand Down
52 changes: 51 additions & 1 deletion cpp/include/raft/neighbors/ivf_flat-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ void build(raft::resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> dataset,
raft::neighbors::ivf_flat::index<T, IdxT>& idx) RAFT_EXPLICIT;

template <typename T, typename IdxT>
auto build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset)
-> index<T, IdxT> RAFT_EXPLICIT;

template <typename T, typename IdxT>
void build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset,
raft::neighbors::ivf_flat::index<T, IdxT>& idx) RAFT_EXPLICIT;

template <typename T, typename IdxT>
auto extend(raft::resources const& handle,
const index<T, IdxT>& orig_index,
Expand All @@ -74,6 +86,19 @@ void extend(raft::resources const& handle,
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices,
index<T, IdxT>* index) RAFT_EXPLICIT;

template <typename T, typename IdxT>
auto extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index)
-> raft::neighbors::ivf_flat::index<T, IdxT> RAFT_EXPLICIT;

template <typename T, typename IdxT>
void extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
index<T, IdxT>* index) RAFT_EXPLICIT;

template <typename T, typename IdxT, typename IvfSampleFilterT>
void search_with_filtering(raft::resources const& handle,
const search_params& params,
Expand Down Expand Up @@ -137,6 +162,18 @@ void search(raft::resources const& handle,
raft::resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_flat::index<T, IdxT>& idx); \
\
extern template auto raft::neighbors::ivf_flat::build<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::host_matrix_view<const T, IdxT, row_major> dataset) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
extern template void raft::neighbors::ivf_flat::build<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::host_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_flat::index<T, IdxT>& idx);

instantiate_raft_neighbors_ivf_flat_build(float, int64_t);
Expand Down Expand Up @@ -171,7 +208,20 @@ instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t);
raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* index);
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
\
extern template void raft::neighbors::ivf_flat::extend<T, IdxT>( \
raft::resources const& handle, \
raft::host_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
\
extern template auto raft::neighbors::ivf_flat::extend<T, IdxT>( \
const raft::resources& handle, \
raft::host_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
const raft::neighbors::ivf_flat::index<T, IdxT>& idx) \
->raft::neighbors::ivf_flat::index<T, IdxT>;

instantiate_raft_neighbors_ivf_flat_extend(float, int64_t);
instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t);
Expand Down
75 changes: 69 additions & 6 deletions cpp/include/raft/neighbors/ivf_flat-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace raft::neighbors::ivf_flat {
*
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] dataset a host or device pointer to a row-major matrix [n_rows, dim]
* @param[in] n_rows the number of samples
* @param[in] dim the dimensionality of the data
*
Expand Down Expand Up @@ -102,7 +102,7 @@ auto build(raft::resources const& handle,
*
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
* @param[in] dataset a device matrix [n_rows, dim]
*
* @return the constructed ivf-flat index
*/
Expand All @@ -118,6 +118,20 @@ auto build(raft::resources const& handle,
static_cast<IdxT>(dataset.extent(1)));
}

/**
* @brief Build the index from a dataset in host memory.
*/
template <typename T, typename IdxT>
auto build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset) -> index<T, IdxT>
{
return raft::neighbors::ivf_flat::detail::build(handle,
params,
dataset.data_handle(),
static_cast<IdxT>(dataset.extent(0)),
static_cast<IdxT>(dataset.extent(1)));
}
/**
* @brief Build the index from the dataset for efficient search.
*
Expand Down Expand Up @@ -162,6 +176,21 @@ void build(raft::resources const& handle,
static_cast<IdxT>(dataset.extent(1)));
}

/**
* @brief Build the index from a dataset in host memory.
*/
template <typename T, typename IdxT>
void build(raft::resources const& handle,
const index_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset,
raft::neighbors::ivf_flat::index<T, IdxT>& idx)
{
idx = raft::neighbors::ivf_flat::detail::build(handle,
params,
dataset.data_handle(),
static_cast<IdxT>(dataset.extent(0)),
static_cast<IdxT>(dataset.extent(1)));
}
/** @} */

/**
Expand All @@ -188,8 +217,8 @@ void build(raft::resources const& handle,
*
* @param[in] handle
* @param[in] orig_index original index
* @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device pointer to a vector of indices [n_rows].
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param[in] n_rows number of rows in `new_vectors`
Expand Down Expand Up @@ -257,6 +286,23 @@ auto extend(raft::resources const& handle,
new_vectors.extent(0));
}

/**
* @brief Extend the index with additional vectors.
*
* This overloads takes input data in host memory.
*/
template <typename T, typename IdxT>
auto extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
const index<T, IdxT>& orig_index) -> index<T, IdxT>
{
return extend<T, IdxT>(handle,
orig_index,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
new_vectors.extent(0));
}
/** @} */

/**
Expand All @@ -279,8 +325,8 @@ auto extend(raft::resources const& handle,
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device pointer to a vector of indices [n_rows].
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param[in] n_rows the number of samples
Expand Down Expand Up @@ -339,6 +385,23 @@ void extend(raft::resources const& handle,
static_cast<IdxT>(new_vectors.extent(0)));
}

/**
* @brief Extend the index with additional vectors.
*
* This overloads takes input data in host memory.
*/
template <typename T, typename IdxT>
void extend(raft::resources const& handle,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
index<T, IdxT>* index)
{
extend(handle,
index,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
static_cast<IdxT>(new_vectors.extent(0)));
}
/** @} */

/**
Expand Down
Loading

0 comments on commit 3b1f245

Please sign in to comment.