Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libraft and pylibraft API for CAGRA build and HNSW search #2022

Merged
merged 42 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c60ae05
hnswlib serialize python API
divyegala Nov 23, 2023
7f04f97
fix typo, refactor cython
divyegala Nov 23, 2023
4b2b2c6
add cpp index and python load,search methods with runtime API
divyegala Nov 28, 2023
c62c8e3
Merge remote-tracking branch 'upstream/branch-24.02' into cagra_hnswl…
divyegala Nov 28, 2023
92e5717
add static_assert guard for (u)int hnswlib serialize
divyegala Nov 28, 2023
7f78851
passing float tests
divyegala Nov 28, 2023
29f0774
fix docs
divyegala Nov 28, 2023
e2e1fc3
try to write in native dtype
divyegala Nov 28, 2023
b8a4c50
update mypy, solve error
divyegala Nov 28, 2023
e02a0e3
attempt to use rapids_cpm_find for hnswlib
divyegala Nov 29, 2023
54106a1
rework to not expose hnswlib headers in runtime API
divyegala Dec 14, 2023
ebdce10
readd copyright check
divyegala Dec 14, 2023
c29313a
Merge remote-tracking branch 'upstream/branch-24.02' into cagra_hnswl…
divyegala Dec 14, 2023
1da21e4
print binary dir
divyegala Dec 14, 2023
e5cd5f6
address review
divyegala Dec 14, 2023
460bc98
address review, enable int8 in hnswlib
divyegala Dec 15, 2023
ce1108a
Merge remote-tracking branch 'upstream/branch-24.02' into cagra_hnswl…
divyegala Dec 15, 2023
fdc015f
missed template
divyegala Dec 15, 2023
211ba39
fix docs
divyegala Dec 15, 2023
466e141
move serialize back to cagra::
divyegala Dec 18, 2023
d3fad16
use unique_ptr for deserialization
divyegala Dec 18, 2023
e4e635f
add composite function from_cagra
divyegala Dec 18, 2023
ad0a25f
fix ann-bench compiler error
divyegala Dec 18, 2023
397176e
update docs from review
divyegala Dec 18, 2023
114d176
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
cjnolet Jan 6, 2024
691d254
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
vyasr Jan 19, 2024
c30671b
Fix some style
vyasr Jan 19, 2024
44c88e7
Get wheel builds to compile
vyasr Jan 19, 2024
304ae7a
fix docs
divyegala Jan 19, 2024
d10bb4f
Export hnswlib dependency in raft
vyasr Jan 19, 2024
1d8e804
fix doc again
divyegala Jan 19, 2024
05bc41b
Merge remote-tracking branch 'origin/cagra_hnswlib_pylibraft' into ca…
divyegala Jan 22, 2024
306de24
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
divyegala Jan 22, 2024
ae543e7
Specify the hnswlib version in the export
vyasr Jan 23, 2024
ad05438
more doc fixes
divyegala Jan 23, 2024
d8a3e06
Merge branch 'cagra_hnswlib_pylibraft' of github.com:divyegala/raft i…
divyegala Jan 23, 2024
94d08cd
hopefully final doc fix
divyegala Jan 23, 2024
54bf32c
style fix
divyegala Jan 23, 2024
8e8549b
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
divyegala Jan 23, 2024
aa16015
merging upstream
divyegala Jan 26, 2024
e62d3bc
Merge remote-tracking branch 'origin/cagra_hnswlib_pylibraft' into ca…
divyegala Jan 26, 2024
1897cf7
pre commit fixes
divyegala Jan 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
types_or: [python, cython]
additional_dependencies: ["flake8-force"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.971'
rev: 'v1.3.0'
hooks:
- id: mypy
additional_dependencies: [types-cachetools]
Expand Down
8 changes: 7 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON)
option(BUILD_TESTS "Build raft unit-tests" ON)
option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF)
option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF)
option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON)
option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF)
option(CUDA_ENABLE_LINEINFO
"Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF
Expand Down Expand Up @@ -195,18 +196,23 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH)
rapids_cpm_gbench()
endif()

if (BUILD_CAGRA_HNSWLIB)
include(cmake/thirdparty/get_hnswlib.cmake)
endif()

# ##################################################################################################
# * raft ---------------------------------------------------------------------
add_library(raft INTERFACE)
add_library(raft::raft ALIAS raft)

target_include_directories(
raft INTERFACE "$<BUILD_INTERFACE:${RAFT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:include>"
# "$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>"
divyegala marked this conversation as resolved.
Show resolved Hide resolved
)

if(NOT BUILD_CPU_ONLY)
# Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target.
target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass raft::Thrust)
target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass raft::Thrust "$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:hnswlib>")
endif()

target_compile_features(raft INTERFACE cxx_std_17 $<BUILD_INTERFACE:cuda_std_17>)
Expand Down
30 changes: 19 additions & 11 deletions cpp/cmake/thirdparty/get_hnswlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,27 @@ function(find_and_configure_hnswlib)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

set ( EXTERNAL_INCLUDES_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} )
if( NOT EXISTS ${EXTERNAL_INCLUDES_DIRECTORY}/_deps/hnswlib-src )
# set ( EXTERNAL_INCLUDES_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} )
# if( NOT EXISTS ${EXTERNAL_INCLUDES_DIRECTORY}/_deps/hnswlib-src )

execute_process (
COMMAND git clone --branch=v0.6.2 https://github.com/nmslib/hnswlib.git hnswlib-src
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps )
# execute_process (
# COMMAND git clone --branch=v0.6.2 https://github.com/nmslib/hnswlib.git hnswlib-src
# WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps )

message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}")
execute_process (
COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src
)
endif ()
# message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}")
# execute_process (
# COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch
# WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src
# )
# endif ()
rapids_cpm_find(hnswlib ${PKG_VERSION}
GLOBAL_TARGETS hnswlib
BUILD_EXPORT_SET raft-exports
INSTALL_EXPORT_SET raft-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/hnswlib.git
GIT_TAG ${PKG_PINNED_TAG}
EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL})

include(cmake/modules/FindAVX.cmake)

Expand Down
88 changes: 88 additions & 0 deletions cpp/include/raft/neighbors/cagra_hnswlib.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "cagra_hnswlib_types.hpp"
#include "detail/cagra_hnswlib.hpp"

#include <cstddef>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>

namespace raft::neighbors::cagra_hnswlib {

divyegala marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Search hnswlib base layer only index constructed from a CAGRA index
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a host matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a host matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a host matrix view to the distances to the selected neighbors [n_queries,
* k]
*
* Usage example:
* @code{.cpp}
* // Build a CAGRA index
* using namespace raft::neighbors;
* // use default index parameters
* cagra::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = cagra::build(res, index_params, dataset);
*
* // Save CAGRA index as base layer only hnswlib index
* cagra::serialize_to_hnswlib(res, "my_index.bin", index);
*
* // Load CAGRA index as base layer only hnswlib index
* cagra_hnswlib::index(D, "my_index.bin", raft::distance::L2Expanded);
*
* // Search K nearest neighbors as an hnswlib index
* // using host threads for concurrency
* cagra_hnswlib::search_params search_params;
* search_params.ef = 50 // ef >= K;
* search_params.num_threads = 10;
* auto neighbors = raft::make_host_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_host_matrix<float>(res, n_queries, k);
* cagra_hnswlib::search(res, search_params, index, queries, neighbors, distances);
* @endcode
*/
template <typename T>
void search(raft::resources const& res,
const search_params& params,
const index<T>& idx,
raft::host_matrix_view<const T, int64_t, row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<float, int64_t, row_major> distances)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

detail::search(res, params, idx, queries, neighbors, distances);
}

} // namespace raft::neighbors::cagra_hnswlib
108 changes: 108 additions & 0 deletions cpp/include/raft/neighbors/cagra_hnswlib_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "ann_types.hpp"
#include <memory>
#include <raft/distance/distance_types.hpp>

#include <cstdint>
#include <hnswlib.h>
#include <sys/types.h>
#include <type_traits>

namespace raft::neighbors::cagra_hnswlib {
divyegala marked this conversation as resolved.
Show resolved Hide resolved
divyegala marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
struct hnsw_dist_t {
using type = void;
};

template <>
struct hnsw_dist_t<float> {
using type = float;
};

template <>
struct hnsw_dist_t<std::uint8_t> {
using type = int;
};

template <>
struct hnsw_dist_t<std::int8_t> {
using type = int;
};

struct search_params : ann::search_params {
int ef; // size of the candidate list
int num_threads = 1; // number of host threads to use for concurrent searches
divyegala marked this conversation as resolved.
Show resolved Hide resolved
};

template <typename T>
struct index : ann::index {
public:
/**
* @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index
*
* @param[in] filepath path to the index
* @param[in] dim dimensions of the training dataset
* @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct")
*/
index(std::string filepath, int dim, raft::distance::DistanceType metric)
: dim_{dim}, metric_{metric}
{
if constexpr (std::is_same_v<T, float>) {
if (metric == raft::distance::L2Expanded) {
space_ = std::make_unique<hnswlib::L2Space>(dim_);
} else if (metric == raft::distance::InnerProduct) {
space_ = std::make_unique<hnswlib::InnerProductSpace>(dim_);
}
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
if (metric == raft::distance::L2Expanded) {
space_ = std::make_unique<hnswlib::L2SpaceI>(dim_);
}
}

RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used");

appr_alg_ = std::make_unique<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
space_.get(), filepath);

appr_alg_->base_layer_only = true;
}

/**
@brief Get hnswlib index
*/
auto get_index() const -> hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*
{
return appr_alg_.get();
}

auto dim() const -> int const { return dim_; }

auto metric() const -> raft::distance::DistanceType { return metric_; }

private:
int dim_;
raft::distance::DistanceType metric_;

std::unique_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::unique_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
};

} // namespace raft::neighbors::cagra_hnswlib
29 changes: 15 additions & 14 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ void serialize_to_hnswlib(raft::resources const& res,
std::ostream& os,
const index<T, IdxT>& index_)
{
// static_assert(std::is_same_v<IdxT, int> or std::is_same_v<IdxT, uint32_t>,
divyegala marked this conversation as resolved.
Show resolved Hide resolved
// "An hnswlib index can only be trained with int32 or uint32 IdxT");
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize_to_hnswlib");
RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u",
static_cast<size_t>(index_.size()),
Expand All @@ -120,14 +122,14 @@ void serialize_to_hnswlib(raft::resources const& res,
// Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t,
// labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) +
// dim * sizeof(data_t) + sizeof(labeltype)
auto size_data_per_element =
static_cast<std::size_t>(index_.graph_degree() * 4 + 4 + index_.dim() * 4 + 8);
auto size_data_per_element = static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 +
index_.dim() * sizeof(T) + 8);
os.write(reinterpret_cast<char*>(&size_data_per_element), sizeof(std::size_t));
// label_offset
std::size_t label_offset = size_data_per_element - 8;
os.write(reinterpret_cast<char*>(&label_offset), sizeof(std::size_t));
// offset_data
auto offset_data = static_cast<std::size_t>(index_.graph_degree() * 4 + 4);
auto offset_data = static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4);
os.write(reinterpret_cast<char*>(&offset_data), sizeof(std::size_t));
// max_level
int max_level = 1;
Expand Down Expand Up @@ -184,17 +186,17 @@ void serialize_to_hnswlib(raft::resources const& res,
}

auto data_row = host_dataset.data_handle() + (index_.dim() * i);
if constexpr (std::is_same_v<T, float>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = host_dataset(i, j);
os.write(reinterpret_cast<char*>(&data_elem), sizeof(T));
}
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<int>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(int));
}
// if constexpr (std::is_same_v<T, float>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = host_dataset(i, j);
os.write(reinterpret_cast<char*>(&data_elem), sizeof(T));
}
// } else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
// for (std::size_t j = 0; j < index_.dim(); ++j) {
// auto data_elem = static_cast<int>(host_dataset(i, j));
// os.write(reinterpret_cast<char*>(&data_elem), sizeof(int));
// }
// }

os.write(reinterpret_cast<char*>(&i), sizeof(std::size_t));
}
Expand All @@ -204,7 +206,6 @@ void serialize_to_hnswlib(raft::resources const& res,
auto zero = 0;
os.write(reinterpret_cast<char*>(&zero), sizeof(int));
}
// delete [] host_graph;
}

template <typename T, typename IdxT>
Expand Down
Loading
Loading