Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libraft and pylibraft API for CAGRA build and HNSW search #2022

Merged
merged 42 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c60ae05
hnswlib serialize python API
divyegala Nov 23, 2023
7f04f97
fix typo, refactor cython
divyegala Nov 23, 2023
4b2b2c6
add cpp index and python load,search methods with runtime API
divyegala Nov 28, 2023
c62c8e3
Merge remote-tracking branch 'upstream/branch-24.02' into cagra_hnswl…
divyegala Nov 28, 2023
92e5717
add static_assert guard for (u)int hnswlib serialize
divyegala Nov 28, 2023
7f78851
passing float tests
divyegala Nov 28, 2023
29f0774
fix docs
divyegala Nov 28, 2023
e2e1fc3
try to write in native dtype
divyegala Nov 28, 2023
b8a4c50
update mypy, solve error
divyegala Nov 28, 2023
e02a0e3
attempt to use rapids_cpm_find for hnswlib
divyegala Nov 29, 2023
54106a1
rework to not expose hnswlib headers in runtime API
divyegala Dec 14, 2023
ebdce10
readd copyright check
divyegala Dec 14, 2023
c29313a
Merge remote-tracking branch 'upstream/branch-24.02' into cagra_hnswl…
divyegala Dec 14, 2023
1da21e4
print binary dir
divyegala Dec 14, 2023
e5cd5f6
address review
divyegala Dec 14, 2023
460bc98
address review, enable int8 in hnswlib
divyegala Dec 15, 2023
ce1108a
Merge remote-tracking branch 'upstream/branch-24.02' into cagra_hnswl…
divyegala Dec 15, 2023
fdc015f
missed template
divyegala Dec 15, 2023
211ba39
fix docs
divyegala Dec 15, 2023
466e141
move serialize back to cagra::
divyegala Dec 18, 2023
d3fad16
use unique_ptr for deserialization
divyegala Dec 18, 2023
e4e635f
add composite function from_cagra
divyegala Dec 18, 2023
ad0a25f
fix ann-bench compiler error
divyegala Dec 18, 2023
397176e
update docs from review
divyegala Dec 18, 2023
114d176
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
cjnolet Jan 6, 2024
691d254
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
vyasr Jan 19, 2024
c30671b
Fix some style
vyasr Jan 19, 2024
44c88e7
Get wheel builds to compile
vyasr Jan 19, 2024
304ae7a
fix docs
divyegala Jan 19, 2024
d10bb4f
Export hnswlib dependency in raft
vyasr Jan 19, 2024
1d8e804
fix doc again
divyegala Jan 19, 2024
05bc41b
Merge remote-tracking branch 'origin/cagra_hnswlib_pylibraft' into ca…
divyegala Jan 22, 2024
306de24
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
divyegala Jan 22, 2024
ae543e7
Specify the hnswlib version in the export
vyasr Jan 23, 2024
ad05438
more doc fixes
divyegala Jan 23, 2024
d8a3e06
Merge branch 'cagra_hnswlib_pylibraft' of github.com:divyegala/raft i…
divyegala Jan 23, 2024
94d08cd
hopefully final doc fix
divyegala Jan 23, 2024
54bf32c
style fix
divyegala Jan 23, 2024
8e8549b
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
divyegala Jan 23, 2024
aa16015
merging upstream
divyegala Jan 26, 2024
e62d3bc
Merge remote-tracking branch 'origin/cagra_hnswlib_pylibraft' into ca…
divyegala Jan 26, 2024
1897cf7
pre commit fixes
divyegala Jan 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
types_or: [python, cython]
additional_dependencies: ["flake8-force"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.971'
rev: 'v1.3.0'
hooks:
- id: mypy
additional_dependencies: [types-cachetools]
Expand Down
7 changes: 7 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON)
option(BUILD_TESTS "Build raft unit-tests" ON)
option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF)
option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF)
option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON)
option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF)
option(CUDA_ENABLE_LINEINFO
"Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF
Expand Down Expand Up @@ -195,13 +196,18 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH)
rapids_cpm_gbench()
endif()

if (BUILD_CAGRA_HNSWLIB)
include(cmake/thirdparty/get_hnswlib.cmake)
endif()

# ##################################################################################################
# * raft ---------------------------------------------------------------------
add_library(raft INTERFACE)
add_library(raft::raft ALIAS raft)

target_include_directories(
raft INTERFACE "$<BUILD_INTERFACE:${RAFT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:include>"
"$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>"
)

if(NOT BUILD_CPU_ONLY)
Expand Down Expand Up @@ -410,6 +416,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/neighbors/cagra_build.cu
src/raft_runtime/neighbors/cagra_search.cu
src/raft_runtime/neighbors/cagra_serialize.cu
src/raft_runtime/neighbors/hnsw.cpp
src/raft_runtime/neighbors/ivf_flat_build.cu
src/raft_runtime/neighbors/ivf_flat_search.cu
src/raft_runtime/neighbors/ivf_flat_serialize.cu
Expand Down
9 changes: 7 additions & 2 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ struct hnsw_dist_t<uint8_t> {
using type = int;
};

template <>
struct hnsw_dist_t<int8_t> {
using type = int;
};

template <typename T>
class HnswLib : public ANN<T> {
public:
Expand Down Expand Up @@ -135,7 +140,7 @@ void HnswLib<T>::build(const T* dataset, size_t nrow, cudaStream_t)
space_ = std::make_shared<hnswlib::L2Space>(dim_);
}
} else if constexpr (std::is_same_v<T, uint8_t>) {
space_ = std::make_shared<hnswlib::L2SpaceI>(dim_);
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}

appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
Expand Down Expand Up @@ -205,7 +210,7 @@ void HnswLib<T>::load(const std::string& path_to_index)
space_ = std::make_shared<hnswlib::L2Space>(dim_);
}
} else if constexpr (std::is_same_v<T, uint8_t>) {
space_ = std::make_shared<hnswlib::L2SpaceI>(dim_);
space_ = std::make_shared<hnswlib::L2SpaceI<T>>(dim_);
}

appr_alg_ = std::make_shared<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
Expand Down
57 changes: 57 additions & 0 deletions cpp/cmake/patches/hnswlib.patch
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,63 @@ index e95e0b5..f0fe50a 100644
}
}
}
diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h
index 4413537..c3240f3 100644
--- a/hnswlib/space_l2.h
+++ b/hnswlib/space_l2.h
@@ -252,13 +252,14 @@ namespace hnswlib {
~L2Space() {}
};

+ template <typename T>
static int
L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {

size_t qty = *((size_t *) qty_ptr);
int res = 0;
- unsigned char *a = (unsigned char *) pVect1;
- unsigned char *b = (unsigned char *) pVect2;
+ T *a = (T *) pVect1;
+ T *b = (T *) pVect2;

qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
@@ -279,11 +280,12 @@ namespace hnswlib {
return (res);
}

+ template <typename T>
static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
size_t qty = *((size_t*)qty_ptr);
int res = 0;
- unsigned char* a = (unsigned char*)pVect1;
- unsigned char* b = (unsigned char*)pVect2;
+ T* a = (T*)pVect1;
+ T* b = (T*)pVect2;

for(size_t i = 0; i < qty; i++)
{
@@ -294,6 +296,7 @@ namespace hnswlib {
return (res);
}

+ template <typename T>
class L2SpaceI : public SpaceInterface<int> {

DISTFUNC<int> fstdistfunc_;
@@ -302,10 +305,10 @@ namespace hnswlib {
public:
L2SpaceI(size_t dim) {
if(dim % 4 == 0) {
- fstdistfunc_ = L2SqrI4x;
+ fstdistfunc_ = L2SqrI4x<T>;
}
else {
- fstdistfunc_ = L2SqrI;
+ fstdistfunc_ = L2SqrI<T>;
}
dim_ = dim;
data_size_ = dim * sizeof(unsigned char);
diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h
index 5e1a4a5..4195ebd 100644
--- a/hnswlib/visited_list_pool.h
Expand Down
9 changes: 9 additions & 0 deletions cpp/cmake/thirdparty/get_hnswlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,20 @@ function(find_and_configure_hnswlib)
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps )

message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}")
message("WORKING DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}")
divyegala marked this conversation as resolved.
Show resolved Hide resolved
execute_process (
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src
)
endif ()
# rapids_cpm_find(hnswlib ${PKG_VERSION}
divyegala marked this conversation as resolved.
Show resolved Hide resolved
# GLOBAL_TARGETS hnswlib
# BUILD_EXPORT_SET raft-exports
# INSTALL_EXPORT_SET raft-exports
# CPM_ARGS
# GIT_REPOSITORY https://github.com/${PKG_FORK}/hnswlib.git
# GIT_TAG ${PKG_PINNED_TAG}
# EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL})

include(cmake/modules/FindAVX.cmake)

Expand Down
10 changes: 5 additions & 5 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void serialize(raft::resources const& handle,
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cagra::build(...);`
* raft::serialize_to_hnswlib(handle, os, index);
* raft::serialize(handle, os, index);
* @endcode
*
* @tparam T data element type
Expand All @@ -120,13 +120,13 @@ void serialize(raft::resources const& handle,
template <typename T, typename IdxT>
void serialize_to_hnswlib(raft::resources const& handle,
std::ostream& os,
const index<T, IdxT>& index)
const raft::neighbors::cagra::index<T, IdxT>& index)
{
detail::serialize_to_hnswlib<T, IdxT>(handle, os, index);
}

/**
* Write the CAGRA built index as a base layer HNSW index to file
* Save a CAGRA build index in hnswlib base-layer-only serialized format
*
* Experimental, both the API and the serialization format are subject to change.
*
Expand All @@ -138,7 +138,7 @@ void serialize_to_hnswlib(raft::resources const& handle,
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = cagra::build(...);`
* raft::serialize_to_hnswlib(handle, filename, index);
* raft::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
Expand All @@ -152,7 +152,7 @@ void serialize_to_hnswlib(raft::resources const& handle,
template <typename T, typename IdxT>
void serialize_to_hnswlib(raft::resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
const raft::neighbors::cagra::index<T, IdxT>& index)
{
detail::serialize_to_hnswlib<T, IdxT>(handle, filename, index);
}
Expand Down
35 changes: 18 additions & 17 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
void serialize_to_hnswlib(raft::resources const& res,
std::ostream& os,
const index<T, IdxT>& index_)
const raft::neighbors::cagra::index<T, IdxT>& index_)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize_to_hnswlib");
// static_assert(std::is_same_v<IdxT, int> or std::is_same_v<IdxT, uint32_t>,
divyegala marked this conversation as resolved.
Show resolved Hide resolved
// "An hnswlib index can only be trained with int32 or uint32 IdxT");
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize");
RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u",
static_cast<size_t>(index_.size()),
index_.dim());
Expand All @@ -120,14 +122,14 @@ void serialize_to_hnswlib(raft::resources const& res,
// Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t,
// labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) +
// dim * sizeof(data_t) + sizeof(labeltype)
auto size_data_per_element =
static_cast<std::size_t>(index_.graph_degree() * 4 + 4 + index_.dim() * 4 + 8);
auto size_data_per_element = static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 +
index_.dim() * sizeof(T) + 8);
os.write(reinterpret_cast<char*>(&size_data_per_element), sizeof(std::size_t));
// label_offset
std::size_t label_offset = size_data_per_element - 8;
os.write(reinterpret_cast<char*>(&label_offset), sizeof(std::size_t));
// offset_data
auto offset_data = static_cast<std::size_t>(index_.graph_degree() * 4 + 4);
auto offset_data = static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4);
os.write(reinterpret_cast<char*>(&offset_data), sizeof(std::size_t));
// max_level
int max_level = 1;
Expand Down Expand Up @@ -184,17 +186,17 @@ void serialize_to_hnswlib(raft::resources const& res,
}

auto data_row = host_dataset.data_handle() + (index_.dim() * i);
if constexpr (std::is_same_v<T, float>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = host_dataset(i, j);
os.write(reinterpret_cast<char*>(&data_elem), sizeof(T));
}
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<int>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(int));
}
// if constexpr (std::is_same_v<T, float>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = host_dataset(i, j);
os.write(reinterpret_cast<char*>(&data_elem), sizeof(T));
}
// } else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
// for (std::size_t j = 0; j < index_.dim(); ++j) {
// auto data_elem = static_cast<int>(host_dataset(i, j));
// os.write(reinterpret_cast<char*>(&data_elem), sizeof(int));
// }
// }

os.write(reinterpret_cast<char*>(&i), sizeof(std::size_t));
}
Expand All @@ -204,13 +206,12 @@ void serialize_to_hnswlib(raft::resources const& res,
auto zero = 0;
os.write(reinterpret_cast<char*>(&zero), sizeof(int));
}
// delete [] host_graph;
}

template <typename T, typename IdxT>
void serialize_to_hnswlib(raft::resources const& res,
const std::string& filename,
const index<T, IdxT>& index_)
const raft::neighbors::cagra::index<T, IdxT>& index_)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }
Expand Down
82 changes: 82 additions & 0 deletions cpp/include/raft/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "hnsw_types.hpp"

#include <cstdint>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>

#include <omp.h>

#include <hnswlib.h>

namespace raft::neighbors::hnsw::detail {

template <typename T>
void get_search_knn_results(hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* idx,
const T* query,
int k,
uint64_t* indices,
float* distances)
{
auto result = idx->searchKnn(query, k);
assert(result.size() >= static_cast<size_t>(k));

for (int i = k - 1; i >= 0; --i) {
indices[i] = result.top().second;
distances[i] = result.top().first;
result.pop();
}
}

template <typename T>
void search(raft::resources const& res,
const search_params& params,
const index<T>& idx,
raft::host_matrix_view<const T, int64_t, row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<float, int64_t, row_major> distances)
{
auto const* hnswlib_index =
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());

// when num_threads == 0, automatically maximize parallelism
if (params.num_threads) {
#pragma omp parallel for num_threads(params.num_threads)
for (int64_t i = 0; i < queries.extent(0); ++i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
}
} else {
#pragma omp parallel for
for (int64_t i = 0; i < queries.extent(0); ++i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
}
}
}

} // namespace raft::neighbors::hnsw::detail
Loading
Loading