Skip to content

Commit

Permalink
hnswlib serialize python API
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Nov 23, 2023
1 parent ea8c3ed commit c60ae05
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 78 deletions.
92 changes: 50 additions & 42 deletions cpp/include/raft_runtime/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,48 +27,56 @@
namespace raft::runtime::neighbors::cagra {

// Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors
#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT>; \
\
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT>; \
\
void build_device(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx); \
\
void build_host(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx); \
\
void search(raft::resources const& handle, \
raft::neighbors::cagra::search_params const& params, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors, \
raft::device_matrix_view<float, int64_t, row_major> distances); \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT>; \
\
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT>; \
\
void build_device(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx); \
\
void build_host(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx); \
\
void search(raft::resources const& handle, \
raft::neighbors::cagra::search_params const& params, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors, \
raft::device_matrix_view<float, int64_t, row_major> distances); \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void serialize_to_hnswlib_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void serialize_to_hnswlib(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::cagra::index<T, IdxT>* index);

RAFT_INST_CAGRA_FUNCS(float, uint32_t);
Expand Down
82 changes: 49 additions & 33 deletions cpp/src/raft_runtime/neighbors/cagra_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,55 @@

namespace raft::runtime::neighbors::cagra {

#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \
}; \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::cagra::deserialize<DTYPE, uint32_t>(handle, filename); \
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
std::stringstream os; \
raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \
str = os.str(); \
} \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
std::istringstream is(str); \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::cagra::deserialize<DTYPE, uint32_t>(handle, is); \
#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \
}; \
\
void serialize_to_hswnlib_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \
}; \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::cagra::deserialize<DTYPE, uint32_t>(handle, filename); \
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
std::stringstream os; \
raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \
str = os.str(); \
} \
\
void serialize_to_hnswlib(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
std::stringstream os; \
raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \
str = os.str(); \
} \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
std::istringstream is(str); \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::cagra::deserialize<DTYPE, uint32_t>(handle, is); \
}

RAFT_INST_CAGRA_SERIALIZE(float);
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# =============================================================================

# Set the list of Cython files to build
set(cython_sources common.pyx refine.pyx brute_force.pyx)
set(cython_sources common.pyx refine.pyx brute_force.pyx cagra_hnswlib.pyx)
set(linked_libraries raft::raft raft::compiled)

# Build all of the Cython targets
Expand Down
2 changes: 0 additions & 2 deletions python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ cdef class IndexParams:


cdef class Index:
cdef readonly bool trained
cdef str active_index_type

def __cinit__(self):
self.trained = False
Expand Down
42 changes: 42 additions & 0 deletions python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \
string& str,
const index[float, uint32_t]& index,
bool include_dataset) except +
cdef void serialize_to_hnwslib(
const device_resources& handle,
string& str,
const index[float, uint32_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& str,
Expand All @@ -184,6 +188,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \
const index[uint8_t, uint32_t]& index,
bool include_dataset) except +

cdef void serialize_to_hnwslib(
const device_resources& handle,
string& str,
const index[uint8_t, uint32_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& str,
index[uint8_t, uint32_t]* index) except +
Expand All @@ -193,6 +202,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \
const index[int8_t, uint32_t]& index,
bool include_dataset) except +

cdef void serialize_to_hnwslib(
const device_resources& handle,
string& str,
const index[int8_t, uint32_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& str,
index[int8_t, uint32_t]* index) except +
Expand All @@ -202,6 +216,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \
const index[float, uint32_t]& index,
bool include_dataset) except +

cdef void serialize_to_hnswlib_file(
const device_resources& handle,
const string& filename,
const index[float, uint32_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[float, uint32_t]* index) except +
Expand All @@ -211,6 +230,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \
const index[uint8_t, uint32_t]& index,
bool include_dataset) except +

cdef void serialize_to_hnswlib_file(
const device_resources& handle,
const string& filename,
const index[uint8_t, uint32_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[uint8_t, uint32_t]* index) except +
Expand All @@ -220,6 +244,24 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \
const index[int8_t, uint32_t]& index,
bool include_dataset) except +

cdef void serialize_to_hnswlib_file(
const device_resources& handle,
const string& filename,
const index[int8_t, uint32_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[int8_t, uint32_t]* index) except +

cdef class Index:
cdef readonly bool trained
cdef str active_index_type

cdef class IndexFloat(Index):
pass

cdef class IndexInt8(Index):
pass

cdef class IndexUint8(Index):
pass
Loading

0 comments on commit c60ae05

Please sign in to comment.