Skip to content

Commit

Permalink
libraft and pylibraft API for CAGRA build and HNSW search (#2022)
Browse files Browse the repository at this point in the history
Closes #1772

Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2022
  • Loading branch information
divyegala authored Jan 26, 2024
1 parent e272176 commit 2822532
Show file tree
Hide file tree
Showing 34 changed files with 1,691 additions and 158 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
types_or: [python, cython]
additional_dependencies: ["flake8-force"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.971'
rev: 'v1.3.0'
hooks:
- id: mypy
additional_dependencies: [types-cachetools]
Expand Down
9 changes: 9 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON)
option(BUILD_TESTS "Build raft unit-tests" ON)
option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF)
option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF)
option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON)
option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF)
option(CUDA_ENABLE_LINEINFO
"Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF
Expand Down Expand Up @@ -195,6 +196,10 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH)
rapids_cpm_gbench()
endif()

if(BUILD_CAGRA_HNSWLIB)
include(cmake/thirdparty/get_hnswlib.cmake)
endif()

# ##################################################################################################
# * raft ---------------------------------------------------------------------
add_library(raft INTERFACE)
Expand All @@ -203,6 +208,9 @@ add_library(raft::raft ALIAS raft)
target_include_directories(
raft INTERFACE "$<BUILD_INTERFACE:${RAFT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:include>"
)
if(BUILD_CAGRA_HNSWLIB)
target_link_libraries(raft INTERFACE hnswlib::hnswlib)
endif()

if(NOT BUILD_CPU_ONLY)
# Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target.
Expand Down Expand Up @@ -425,6 +433,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/neighbors/cagra_search.cu
src/raft_runtime/neighbors/cagra_serialize.cu
src/raft_runtime/neighbors/eps_neighborhood.cu
src/raft_runtime/neighbors/hnsw.cpp
src/raft_runtime/neighbors/ivf_flat_build.cu
src/raft_runtime/neighbors/ivf_flat_search.cu
src/raft_runtime/neighbors/ivf_flat_serialize.cu
Expand Down
16 changes: 3 additions & 13 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ endfunction()

if(RAFT_ANN_BENCH_USE_HNSWLIB)
ConfigureAnnBench(
NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp
LINKS
hnswlib::hnswlib
NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib
)

endif()
Expand Down Expand Up @@ -276,12 +274,7 @@ endif()

if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
ConfigureAnnBench(
NAME
RAFT_CAGRA_HNSWLIB
PATH
bench/ann/src/raft/raft_cagra_hnswlib.cu
LINKS
raft::compiled
NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled
hnswlib::hnswlib
)
endif()
Expand Down Expand Up @@ -336,10 +329,7 @@ endif()

if(RAFT_ANN_BENCH_USE_GGNN)
include(cmake/thirdparty/get_glog.cmake)
ConfigureAnnBench(
NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu
LINKS glog::glog ggnn::ggnn
)
ConfigureAnnBench(NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn)
endif()

# ##################################################################################################
Expand Down
9 changes: 7 additions & 2 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ struct hnsw_dist_t<uint8_t> {
using type = int;
};

template <>
struct hnsw_dist_t<int8_t> {
using type = int;
};

template <typename T>
class HnswLib : public ANN<T> {
public:
Expand Down Expand Up @@ -135,7 +140,7 @@ void HnswLib<T>::build(const T* dataset, size_t nrow, cudaStream_t)
space_ = std::make_shared<hnswlib::L2Space>(dim_);
}
} else if constexpr (std::is_same_v<T, uint8_t>) {
space_ = std::make_shared<hnswlib::L2SpaceI>(dim_);
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}

appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
Expand Down Expand Up @@ -205,7 +210,7 @@ void HnswLib<T>::load(const std::string& path_to_index)
space_ = std::make_shared<hnswlib::L2Space>(dim_);
}
} else if constexpr (std::is_same_v<T, uint8_t>) {
space_ = std::make_shared<hnswlib::L2SpaceI>(dim_);
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}

appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
Expand Down
57 changes: 57 additions & 0 deletions cpp/cmake/patches/hnswlib.diff
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,63 @@
}
}
}
diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h
index 4413537..c3240f3 100644
--- a/hnswlib/space_l2.h
+++ b/hnswlib/space_l2.h
@@ -252,13 +252,14 @@ namespace hnswlib {
~L2Space() {}
};

+ template <typename T>
static int
L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {

size_t qty = *((size_t *) qty_ptr);
int res = 0;
- unsigned char *a = (unsigned char *) pVect1;
- unsigned char *b = (unsigned char *) pVect2;
+ T *a = (T *) pVect1;
+ T *b = (T *) pVect2;

qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
@@ -279,11 +280,12 @@ namespace hnswlib {
return (res);
}

+ template <typename T>
static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
size_t qty = *((size_t*)qty_ptr);
int res = 0;
- unsigned char* a = (unsigned char*)pVect1;
- unsigned char* b = (unsigned char*)pVect2;
+ T* a = (T*)pVect1;
+ T* b = (T*)pVect2;

for(size_t i = 0; i < qty; i++)
{
@@ -294,6 +296,7 @@ namespace hnswlib {
return (res);
}

+ template <typename T>
class L2SpaceI : public SpaceInterface<int> {

DISTFUNC<int> fstdistfunc_;
@@ -302,10 +305,10 @@ namespace hnswlib {
public:
L2SpaceI(size_t dim) {
if(dim % 4 == 0) {
- fstdistfunc_ = L2SqrI4x;
+ fstdistfunc_ = L2SqrI4x<T>;
}
else {
- fstdistfunc_ = L2SqrI;
+ fstdistfunc_ = L2SqrI<T>;
}
dim_ = dim;
data_size_ = dim * sizeof(unsigned char);
diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h
index 5e1a4a5..4195ebd 100644
--- a/hnswlib/visited_list_pool.h
Expand Down
6 changes: 5 additions & 1 deletion cpp/cmake/thirdparty/get_hnswlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ function(find_and_configure_hnswlib)
rapids_cpm_find(
hnswlib ${PKG_VERSION}
GLOBAL_TARGETS hnswlib::hnswlib
BUILD_EXPORT_SET raft-exports
INSTALL_EXPORT_SET raft-exports
CPM_ARGS
GIT_REPOSITORY ${PKG_REPOSITORY}
GIT_TAG ${PKG_PINNED_TAG}
Expand All @@ -51,11 +53,13 @@ function(find_and_configure_hnswlib)
# write export rules
rapids_export(
BUILD hnswlib
VERSION ${PKG_VERSION}
EXPORT_SET hnswlib-exports
GLOBAL_TARGETS hnswlib
NAMESPACE hnswlib::)
rapids_export(
INSTALL hnswlib
VERSION ${PKG_VERSION}
EXPORT_SET hnswlib-exports
GLOBAL_TARGETS hnswlib
NAMESPACE hnswlib::)
Expand All @@ -74,5 +78,5 @@ endif()
find_and_configure_hnswlib(VERSION 0.6.2
REPOSITORY ${RAFT_HNSWLIB_GIT_REPOSITORY}
PINNED_TAG ${RAFT_HNSWLIB_GIT_TAG}
EXCLUDE_FROM_ALL ON
EXCLUDE_FROM_ALL OFF
)
34 changes: 20 additions & 14 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,13 +32,14 @@ namespace raft::neighbors::cagra {
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cagra::build(...);`
* raft::serialize(handle, os, index);
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize(handle, os, index);
* @endcode
*
* @tparam T data element type
Expand Down Expand Up @@ -66,13 +67,14 @@ void serialize(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = cagra::build(...);`
* raft::serialize(handle, filename, index);
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
Expand Down Expand Up @@ -100,13 +102,14 @@ void serialize(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cagra::build(...);`
* raft::serialize_to_hnswlib(handle, os, index);
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize_to_hnswlib(handle, os, index);
* @endcode
*
* @tparam T data element type
Expand All @@ -120,25 +123,26 @@ void serialize(raft::resources const& handle,
template <typename T, typename IdxT>
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const index<T, IdxT>& index)
const raft::neighbors::cagra::index<T, IdxT>& index)
{
detail::serialize_to_hnswlib<T, IdxT>(handle, os, index);
}

/**
* Write the CAGRA built index as a base layer HNSW index to file
* Save a CAGRA build index in hnswlib base-layer-only serialized format
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = cagra::build(...);`
* raft::serialize_to_hnswlib(handle, filename, index);
* // create an index with `auto index = raft::cagra::build(...);`
* raft::cagra::serialize_to_hnswlib(handle, filename, index);
* @endcode
*
* @tparam T data element type
Expand All @@ -152,7 +156,7 @@ void serialize_to_hnswlib(raft::resources const& handle,
template <typename T, typename IdxT>
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
const raft::neighbors::cagra::index<T, IdxT>& index)
{
detail::serialize_to_hnswlib<T, IdxT>(handle, filename, index);
}
Expand All @@ -164,14 +168,15 @@ void serialize_to_hnswlib(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::deserialize<T, IdxT>(handle, is);
* auto index = raft::cagra::deserialize<T, IdxT>(handle, is);
* @endcode
*
* @tparam T data element type
Expand All @@ -195,14 +200,15 @@ index<T, IdxT> deserialize(raft::resources const& handle, std::istream& is)
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/cagra_serialize.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::deserialize<T, IdxT>(handle, filename);
* auto index = raft::cagra::deserialize<T, IdxT>(handle, filename);
* @endcode
*
* @tparam T data element type
Expand Down
Loading

0 comments on commit 2822532

Please sign in to comment.