Skip to content

Commit

Permalink
add extend API
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Nov 22, 2024
1 parent 05ceb9c commit 3bd6180
Show file tree
Hide file tree
Showing 12 changed files with 379 additions and 8 deletions.
49 changes: 49 additions & 0 deletions cpp/include/cuvs/neighbors/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,41 @@ cuvsError_t cuvsHnswIndexCreate(cuvsHnswIndex_t* index);
*/
cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index);

/**
* @}
*/

/**
* @defgroup hnsw_c_extend_params Parameters for extending HNSW index
*/

struct cuvsHnswExtendParams {
int num_threads;
};

typedef struct cuvsHnswExtendParams* cuvsHnswExtendParams_t;

/**
* @brief Allocate HNSW extend params, and populate with default values
*
* @param[in] params cuvsHnswExtendParams_t to allocate
* @return cuvsError_t
*/
cuvsError_t cuvsHnswExtendParamsCreate(cuvsHnswExtendParams_t* params);

/**
* @brief De-allocate HNSW extend params
*
* @param[in] params cuvsHnswExtendParams_t to de-allocate
* @return cuvsError_t
*/

cuvsError_t cuvsHnswExtendParamsDestroy(cuvsHnswExtendParams_t params);

/**
* @}
*/

/**
* @defgroup hnsw_c_index_load Load CAGRA index as hnswlib index
* @{
Expand All @@ -122,6 +157,20 @@ cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
* @}
*/

/**
* @defgroup hnsw_c_index_extend Extend HNSW index with additional vectors
* @{
*/

cuvsError_t cuvsHnswExtend(cuvsResources_t res,
cuvsHnswExtendParams_t params,
DLManagedTensor* additional_dataset,
cuvsHnswIndex_t index);

/**
* @}
*/

/**
* @defgroup hnsw_c_search_params C API for hnswlib wrapper search params
* @{
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/cuvs/neighbors/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ struct index : cuvs::neighbors::index {

/**@}*/

/**
* @defgroup hnsw_cpp_extend_params HNSW index extend parameters
* @{
*/

struct extend_params {
/** Number of host threads to use to add additional vectors to the index.
Value of 0 automatically maximizes parallelism. */
int num_threads = 0;
};

/**
* @defgroup hnsw_cpp_index_load Load CAGRA index as hnswlib index
* @{
Expand Down Expand Up @@ -205,6 +216,28 @@ std::unique_ptr<index<int8_t>> from_cagra(

/**@}*/

/**
* @defgroup hnsw_cpp_index_extend Extend HNSW index with additional vectors
* @{
*/

void extend(raft::resources const& res,
const extend_params& params,
raft::host_matrix_view<const float, int64_t, raft::row_major> additional_dataset,
index<float>& idx);

void extend(raft::resources const& res,
const extend_params& params,
raft::host_matrix_view<const uint8_t, int64_t, raft::row_major> additional_dataset,
index<uint8_t>& idx);

void extend(raft::resources const& res,
const extend_params& params,
raft::host_matrix_view<const int8_t, int64_t, raft::row_major> additional_dataset,
index<int8_t>& idx);

/**@} */

/**
* @defgroup hnsw_cpp_search_params Build CAGRA index and search with hnswlib
* @{
Expand Down
27 changes: 25 additions & 2 deletions cpp/src/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ struct index_impl : index<T> {
index_impl(int dim, cuvs::distance::DistanceType metric, HnswHierarchy hierarchy)
: index<T>{dim, metric, hierarchy}
{
std::cout << "dim: " << dim << std::endl;
if constexpr (std::is_same_v<T, float>) {
if (metric == cuvs::distance::DistanceType::L2Expanded) {
space_ = std::make_unique<hnswlib::L2Space>(dim);
Expand Down Expand Up @@ -191,7 +190,6 @@ std::enable_if_t<hierarchy == HnswHierarchy::CPU, std::unique_ptr<index<T>>> fro
raft::host_matrix_view<const T, int64_t, raft::row_major> host_dataset_view(
host_dataset.data_handle(), host_dataset.extent(0), host_dataset.extent(1));
if (dataset.has_value()) {
std::cout << "Using dataset provided by user" << std::endl;
host_dataset_view = dataset.value();
} else {
// move dataset to host, remove padding
Expand Down Expand Up @@ -266,6 +264,31 @@ std::unique_ptr<index<T>> from_cagra(
}
}

template <typename T>
void extend(raft::resources const& res,
const extend_params& params,
raft::host_matrix_view<const T, int64_t, raft::row_major> additional_dataset,
index<T>& idx)
{
auto* hnswlib_index = reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>*>(
const_cast<void*>(idx.get_index()));
auto current_element_count = hnswlib_index->getCurrentElementCount();
auto new_element_count = additional_dataset.extent(0);
auto num_threads = params.num_threads == 0 ? std::thread::hardware_concurrency()
: static_cast<size_t>(params.num_threads);

hnswlib_index->resizeIndex(current_element_count + new_element_count);
ParallelFor(current_element_count,
current_element_count + new_element_count,
num_threads,
[&](size_t i, size_t threadId) {
hnswlib_index->addPoint(
(void*)(additional_dataset.data_handle() +
(i - current_element_count) * additional_dataset.extent(1)),
i);
});
}

template <typename T>
void get_search_knn_results(hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* idx,
const T* query,
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/neighbors/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ CUVS_INST_HNSW_FROM_CAGRA(int8_t);

#undef CUVS_INST_HNSW_FROM_CAGRA

#define CUVS_INST_HNSW_EXTEND(T) \
void extend(raft::resources const& res, \
const extend_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> additional_dataset, \
index<T>& idx) \
{ \
detail::extend<T>(res, params, additional_dataset, idx); \
}

CUVS_INST_HNSW_EXTEND(float);
CUVS_INST_HNSW_EXTEND(uint8_t);
CUVS_INST_HNSW_EXTEND(int8_t);

#undef CUVS_INST_HNSW_EXTEND

#define CUVS_INST_HNSW_SEARCH(T) \
void search(raft::resources const& res, \
const search_params& params, \
Expand Down
46 changes: 46 additions & 0 deletions cpp/src/neighbors/hnsw_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ void _from_cagra(cuvsResources_t res,
hnsw_index->addr = reinterpret_cast<uintptr_t>(hnsw_index_ptr);
}

template <typename T>
void _extend(cuvsResources_t res,
cuvsHnswExtendParams_t params,
DLManagedTensor* additional_dataset,
cuvsHnswIndex index)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::hnsw::index<T>*>(index.addr);
auto cpp_params = cuvs::neighbors::hnsw::extend_params();
cpp_params.num_threads = params->num_threads;

using additional_dataset_mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
auto additional_dataset_mds =
cuvs::core::from_dlpack<additional_dataset_mdspan_type>(additional_dataset);
cuvs::neighbors::hnsw::extend(*res_ptr, cpp_params, additional_dataset_mds, *index_ptr);
}

template <typename T>
void _search(cuvsResources_t res,
cuvsHnswSearchParams params,
Expand Down Expand Up @@ -138,6 +155,17 @@ extern "C" cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index_c_ptr)
});
}

extern "C" cuvsError_t cuvsHnswExtendParamsCreate(cuvsHnswExtendParams_t* params)
{
return cuvs::core::translate_exceptions(
[=] { *params = new cuvsHnswExtendParams{.num_threads = 0}; });
}

extern "C" cuvsError_t cuvsHnswExtendParamsDestroy(cuvsHnswExtendParams_t params)
{
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
Expand All @@ -158,6 +186,24 @@ extern "C" cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
});
}

extern "C" cuvsError_t cuvsHnswExtend(cuvsResources_t res,
cuvsHnswExtendParams_t params,
DLManagedTensor* additional_dataset,
cuvsHnswIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat) {
_extend<float>(res, params, additional_dataset, *index);
} else if (index->dtype.code == kDLUInt) {
_extend<uint8_t>(res, params, additional_dataset, *index);
} else if (index->dtype.code == kDLInt) {
_extend<int8_t>(res, params, additional_dataset, *index);
} else {
RAFT_FAIL("Unsupported dtype: %d", index->dtype.code);
}
});
}

extern "C" cuvsError_t cuvsHnswSearchParamsCreate(cuvsHnswSearchParams_t* params)
{
return cuvs::core::translate_exceptions(
Expand Down
15 changes: 15 additions & 0 deletions docs/source/c_api/neighbors_hnsw_c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ Index
:members:
:content-only:

Index extend parameters
-----------------------

.. doxygengroup:: hnsw_c_extend_params
:project: cuvs
:members:
:content-only:

Index extend
------------
.. doxygengroup:: hnsw_c_index_extend
:project: cuvs
:members:
:content-only:

Index load
----------
.. doxygengroup:: hnsw_c_index_load
Expand Down
19 changes: 17 additions & 2 deletions docs/source/cpp_api/neighbors_hnsw.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,25 @@ Index
:members:
:content-only:

Index load
Index extend parameters
-----------------------

.. doxygengroup:: hnsw_cpp_extend_params
:project: cuvs
:members:
:content-only:

Index extend
------------
.. doxygengroup:: hnsw_cpp_index_extend
:project: cuvs
:members:
:content-only:

.. doxygengroup:: hnsw_cpp_index_search
Index load
----------

.. doxygengroup:: hnsw_cpp_index_load
:project: cuvs
:members:
:content-only:
Expand Down
4 changes: 4 additions & 0 deletions python/cuvs/cuvs/neighbors/hnsw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@


from .hnsw import (
ExtendParams,
Index,
IndexParams,
SearchParams,
extend,
from_cagra,
load,
save,
Expand All @@ -26,6 +28,8 @@
__all__ = [
"IndexParams",
"Index",
"ExtendParams",
"extend",
"SearchParams",
"load",
"save",
Expand Down
16 changes: 15 additions & 1 deletion python/cuvs/cuvs/neighbors/hnsw/hnsw.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,24 @@ cdef extern from "cuvs/neighbors/hnsw.h" nogil:

cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index)

ctypedef struct cuvsHnswExtendParams:
int32_t num_threads

ctypedef cuvsHnswExtendParams* cuvsHnswExtendParams_t

cuvsError_t cuvsHnswExtendParamsCreate(cuvsHnswExtendParams_t* params)

cuvsError_t cuvsHnswExtendParamsDestroy(cuvsHnswExtendParams_t params)

cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index)
cuvsHnswIndex_t hnsw_index) except +

cuvsError_t cuvsHnswExtend(cuvsResources_t res,
cuvsHnswExtendParams_t params,
DLManagedTensor* data,
cuvsHnswIndex_t index) except +

ctypedef struct cuvsHnswSearchParams:
int32_t ef
Expand Down
Loading

0 comments on commit 3bd6180

Please sign in to comment.