From 2822532f06a5879c2b92d338b6efb75a1b984dc9 Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Thu, 25 Jan 2024 21:53:47 -0500 Subject: [PATCH] `libraft` and `pylibraft` API for CAGRA build and HNSW search (#2022) Closes #1772 Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2022 --- .pre-commit-config.yaml | 2 +- cpp/CMakeLists.txt | 9 + cpp/bench/ann/CMakeLists.txt | 16 +- cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h | 9 +- cpp/cmake/patches/hnswlib.diff | 57 ++ cpp/cmake/thirdparty/get_hnswlib.cmake | 6 +- .../raft/neighbors/cagra_serialize.cuh | 34 +- .../detail/cagra/cagra_serialize.cuh | 37 +- cpp/include/raft/neighbors/detail/hnsw.hpp | 82 +++ .../raft/neighbors/detail/hnsw_serialize.hpp | 46 ++ .../raft/neighbors/detail/hnsw_types.hpp | 101 ++++ cpp/include/raft/neighbors/hnsw.hpp | 142 +++++ cpp/include/raft/neighbors/hnsw_serialize.hpp | 71 +++ cpp/include/raft/neighbors/hnsw_types.hpp | 70 +++ cpp/include/raft_runtime/neighbors/cagra.hpp | 91 ++-- cpp/include/raft_runtime/neighbors/hnsw.hpp | 52 ++ .../raft_runtime/neighbors/cagra_serialize.cu | 81 +-- cpp/src/raft_runtime/neighbors/hnsw.cpp | 73 +++ docs/source/cpp_api/neighbors_hnsw.rst | 29 ++ docs/source/pylibraft_api/neighbors.rst | 16 + pyproject.toml | 3 + python/pylibraft/pylibraft/common/mdspan.pxd | 5 +- python/pylibraft/pylibraft/common/mdspan.pyx | 13 +- .../pylibraft/neighbors/CMakeLists.txt | 2 +- .../pylibraft/pylibraft/neighbors/__init__.py | 8 +- .../pylibraft/neighbors/cagra/cagra.pxd | 39 ++ .../pylibraft/neighbors/cagra/cagra.pyx | 5 - .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 32 +- .../pylibraft/neighbors/cpp/hnsw.pxd | 94 ++++ python/pylibraft/pylibraft/neighbors/hnsw.pyx | 488 ++++++++++++++++++ python/pylibraft/pylibraft/test/__init__py | 0 python/pylibraft/pylibraft/test/ann_utils.py | 35 ++ python/pylibraft/pylibraft/test/test_cagra.py | 24 +- python/pylibraft/pylibraft/test/test_hnsw.py | 77 +++ 34 files changed, 1691 insertions(+), 158 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/hnsw.hpp create mode 100644 cpp/include/raft/neighbors/detail/hnsw_serialize.hpp create mode 100644 cpp/include/raft/neighbors/detail/hnsw_types.hpp create mode 100644 cpp/include/raft/neighbors/hnsw.hpp create mode 100644 cpp/include/raft/neighbors/hnsw_serialize.hpp create mode 100644 cpp/include/raft/neighbors/hnsw_types.hpp create mode 100644 cpp/include/raft_runtime/neighbors/hnsw.hpp create mode 100644 cpp/src/raft_runtime/neighbors/hnsw.cpp create mode 100644 docs/source/cpp_api/neighbors_hnsw.rst create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd create mode 100644 python/pylibraft/pylibraft/neighbors/hnsw.pyx create mode 100644 python/pylibraft/pylibraft/test/__init__py create mode 100644 python/pylibraft/pylibraft/test/ann_utils.py create mode 100644 python/pylibraft/pylibraft/test/test_hnsw.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c2e6d9fce4..0d6ab7ee54 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: types_or: [python, cython] additional_dependencies: ["flake8-force"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v0.971' + rev: 'v1.3.0' hooks: - id: mypy additional_dependencies: [types-cachetools] diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 517e6d3f49..27d25a64ee 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON) option(BUILD_TESTS "Build raft unit-tests" ON) option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF) option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF) +option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON) option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF) option(CUDA_ENABLE_LINEINFO "Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF @@ -195,6 +196,10 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH) rapids_cpm_gbench() endif() +if(BUILD_CAGRA_HNSWLIB) + include(cmake/thirdparty/get_hnswlib.cmake) +endif() + # ################################################################################################## # * raft --------------------------------------------------------------------- add_library(raft INTERFACE) @@ -203,6 +208,9 @@ add_library(raft::raft ALIAS raft) target_include_directories( raft INTERFACE "$" "$" ) +if(BUILD_CAGRA_HNSWLIB) + target_link_libraries(raft INTERFACE hnswlib::hnswlib) +endif() if(NOT BUILD_CPU_ONLY) # Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target. @@ -425,6 +433,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/neighbors/cagra_search.cu src/raft_runtime/neighbors/cagra_serialize.cu src/raft_runtime/neighbors/eps_neighborhood.cu + src/raft_runtime/neighbors/hnsw.cpp src/raft_runtime/neighbors/ivf_flat_build.cu src/raft_runtime/neighbors/ivf_flat_search.cu src/raft_runtime/neighbors/ivf_flat_serialize.cu diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index de980e8945..ee84f7515a 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -225,9 +225,7 @@ endfunction() if(RAFT_ANN_BENCH_USE_HNSWLIB) ConfigureAnnBench( - NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp - LINKS - hnswlib::hnswlib + NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib ) endif() @@ -276,12 +274,7 @@ endif() if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ConfigureAnnBench( - NAME - RAFT_CAGRA_HNSWLIB - PATH - bench/ann/src/raft/raft_cagra_hnswlib.cu - LINKS - raft::compiled + NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled hnswlib::hnswlib ) endif() @@ -336,10 +329,7 @@ endif() if(RAFT_ANN_BENCH_USE_GGNN) include(cmake/thirdparty/get_glog.cmake) - ConfigureAnnBench( - NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu - LINKS glog::glog ggnn::ggnn - ) + ConfigureAnnBench(NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn) endif() # ################################################################################################## diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 5ddfc58677..08b2f188c5 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -52,6 +52,11 @@ struct hnsw_dist_t { using type = int; }; +template <> +struct hnsw_dist_t { + using type = int; +}; + template class HnswLib : public ANN { public: @@ -135,7 +140,7 @@ void HnswLib::build(const T* dataset, size_t nrow, cudaStream_t) space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_shared(dim_); + space_ = std::make_shared>(dim_); } appr_alg_ = std::make_shared::type>>( @@ -205,7 +210,7 @@ void HnswLib::load(const std::string& path_to_index) space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_shared(dim_); + space_ = std::make_shared>(dim_); } appr_alg_ = std::make_shared::type>>( diff --git a/cpp/cmake/patches/hnswlib.diff b/cpp/cmake/patches/hnswlib.diff index 0007ed6425..e7f89a8cc9 100644 --- a/cpp/cmake/patches/hnswlib.diff +++ b/cpp/cmake/patches/hnswlib.diff @@ -105,6 +105,63 @@ } } } +diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h +index 4413537..c3240f3 100644 +--- a/hnswlib/space_l2.h ++++ b/hnswlib/space_l2.h +@@ -252,13 +252,14 @@ namespace hnswlib { + ~L2Space() {} + }; + ++ template + static int + L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + + size_t qty = *((size_t *) qty_ptr); + int res = 0; +- unsigned char *a = (unsigned char *) pVect1; +- unsigned char *b = (unsigned char *) pVect2; ++ T *a = (T *) pVect1; ++ T *b = (T *) pVect2; + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { +@@ -279,11 +280,12 @@ namespace hnswlib { + return (res); + } + ++ template + static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int res = 0; +- unsigned char* a = (unsigned char*)pVect1; +- unsigned char* b = (unsigned char*)pVect2; ++ T* a = (T*)pVect1; ++ T* b = (T*)pVect2; + + for(size_t i = 0; i < qty; i++) + { +@@ -294,6 +296,7 @@ namespace hnswlib { + return (res); + } + ++ template + class L2SpaceI : public SpaceInterface { + + DISTFUNC fstdistfunc_; +@@ -302,10 +305,10 @@ namespace hnswlib { + public: + L2SpaceI(size_t dim) { + if(dim % 4 == 0) { +- fstdistfunc_ = L2SqrI4x; ++ fstdistfunc_ = L2SqrI4x; + } + else { +- fstdistfunc_ = L2SqrI; ++ fstdistfunc_ = L2SqrI; + } + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h index 5e1a4a5..4195ebd 100644 --- a/hnswlib/visited_list_pool.h diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index 82e95803f3..f4fe777379 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -30,6 +30,8 @@ function(find_and_configure_hnswlib) rapids_cpm_find( hnswlib ${PKG_VERSION} GLOBAL_TARGETS hnswlib::hnswlib + BUILD_EXPORT_SET raft-exports + INSTALL_EXPORT_SET raft-exports CPM_ARGS GIT_REPOSITORY ${PKG_REPOSITORY} GIT_TAG ${PKG_PINNED_TAG} @@ -51,11 +53,13 @@ function(find_and_configure_hnswlib) # write export rules rapids_export( BUILD hnswlib + VERSION ${PKG_VERSION} EXPORT_SET hnswlib-exports GLOBAL_TARGETS hnswlib NAMESPACE hnswlib::) rapids_export( INSTALL hnswlib + VERSION ${PKG_VERSION} EXPORT_SET hnswlib-exports GLOBAL_TARGETS hnswlib NAMESPACE hnswlib::) @@ -74,5 +78,5 @@ endif() find_and_configure_hnswlib(VERSION 0.6.2 REPOSITORY ${RAFT_HNSWLIB_GIT_REPOSITORY} PINNED_TAG ${RAFT_HNSWLIB_GIT_TAG} - EXCLUDE_FROM_ALL ON + EXCLUDE_FROM_ALL OFF ) diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index c801bc9eda..83830c7457 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,13 +32,14 @@ namespace raft::neighbors::cagra { * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create an output stream * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, os, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize(handle, os, index); * @endcode * * @tparam T data element type @@ -66,13 +67,14 @@ void serialize(raft::resources const& handle, * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, filename, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize(handle, filename, index); * @endcode * * @tparam T data element type @@ -100,13 +102,14 @@ void serialize(raft::resources const& handle, * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create an output stream * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize_to_hnswlib(handle, os, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize_to_hnswlib(handle, os, index); * @endcode * * @tparam T data element type @@ -120,25 +123,26 @@ void serialize(raft::resources const& handle, template void serialize_to_hnswlib(raft::resources const& handle, std::ostream& os, - const index& index) + const raft::neighbors::cagra::index& index) { detail::serialize_to_hnswlib(handle, os, index); } /** - * Write the CAGRA built index as a base layer HNSW index to file + * Save a CAGRA build index in hnswlib base-layer-only serialized format * * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize_to_hnswlib(handle, filename, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize_to_hnswlib(handle, filename, index); * @endcode * * @tparam T data element type @@ -152,7 +156,7 @@ void serialize_to_hnswlib(raft::resources const& handle, template void serialize_to_hnswlib(raft::resources const& handle, const std::string& filename, - const index& index) + const raft::neighbors::cagra::index& index) { detail::serialize_to_hnswlib(handle, filename, index); } @@ -164,6 +168,7 @@ void serialize_to_hnswlib(raft::resources const& handle, * * @code{.cpp} * #include + * #include * * raft::resources handle; * @@ -171,7 +176,7 @@ void serialize_to_hnswlib(raft::resources const& handle, * std::istream is(std::cin.rdbuf()); * using T = float; // data element type * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, is); + * auto index = raft::cagra::deserialize(handle, is); * @endcode * * @tparam T data element type @@ -195,6 +200,7 @@ index deserialize(raft::resources const& handle, std::istream& is) * * @code{.cpp} * #include + * #include * * raft::resources handle; * @@ -202,7 +208,7 @@ index deserialize(raft::resources const& handle, std::istream& is) * std::string filename("/path/to/index"); * using T = float; // data element type * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, filename); + * auto index = raft::cagra::deserialize(handle, filename); * @endcode * * @tparam T data element type diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 51c9475434..42a979f059 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,9 +101,11 @@ void serialize(raft::resources const& res, template void serialize_to_hnswlib(raft::resources const& res, std::ostream& os, - const index& index_) + const raft::neighbors::cagra::index& index_) { - common::nvtx::range fun_scope("cagra::serialize_to_hnswlib"); + // static_assert(std::is_same_v or std::is_same_v, + // "An hnswlib index can only be trained with int32 or uint32 IdxT"); + common::nvtx::range fun_scope("cagra::serialize"); RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", static_cast(index_.size()), index_.dim()); @@ -120,14 +122,14 @@ void serialize_to_hnswlib(raft::resources const& res, // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + // dim * sizeof(data_t) + sizeof(labeltype) - auto size_data_per_element = - static_cast(index_.graph_degree() * 4 + 4 + index_.dim() * 4 + 8); + auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + + index_.dim() * sizeof(T) + 8); os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); // label_offset std::size_t label_offset = size_data_per_element - 8; os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); // offset_data - auto offset_data = static_cast(index_.graph_degree() * 4 + 4); + auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); // max_level int max_level = 1; @@ -184,17 +186,17 @@ void serialize_to_hnswlib(raft::resources const& res, } auto data_row = host_dataset.data_handle() + (index_.dim() * i); - if constexpr (std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = host_dataset(i, j); - os.write(reinterpret_cast(&data_elem), sizeof(T)); - } - } else if constexpr (std::is_same_v or std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = static_cast(host_dataset(i, j)); - os.write(reinterpret_cast(&data_elem), sizeof(int)); - } + // if constexpr (std::is_same_v) { + for (std::size_t j = 0; j < index_.dim(); ++j) { + auto data_elem = host_dataset(i, j); + os.write(reinterpret_cast(&data_elem), sizeof(T)); } + // } else if constexpr (std::is_same_v or std::is_same_v) { + // for (std::size_t j = 0; j < index_.dim(); ++j) { + // auto data_elem = static_cast(host_dataset(i, j)); + // os.write(reinterpret_cast(&data_elem), sizeof(int)); + // } + // } os.write(reinterpret_cast(&i), sizeof(std::size_t)); } @@ -204,13 +206,12 @@ void serialize_to_hnswlib(raft::resources const& res, auto zero = 0; os.write(reinterpret_cast(&zero), sizeof(int)); } - // delete [] host_graph; } template void serialize_to_hnswlib(raft::resources const& res, const std::string& filename, - const index& index_) + const raft::neighbors::cagra::index& index_) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } diff --git a/cpp/include/raft/neighbors/detail/hnsw.hpp b/cpp/include/raft/neighbors/detail/hnsw.hpp new file mode 100644 index 0000000000..69478205a9 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/hnsw.hpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "hnsw_types.hpp" + +#include +#include +#include + +#include + +#include + +namespace raft::neighbors::hnsw::detail { + +template +void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, + const T* query, + int k, + uint64_t* indices, + float* distances) +{ + auto result = idx->searchKnn(query, k); + assert(result.size() >= static_cast(k)); + + for (int i = k - 1; i >= 0; --i) { + indices[i] = result.top().second; + distances[i] = result.top().first; + result.pop(); + } +} + +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + auto const* hnswlib_index = + reinterpret_cast::type> const*>( + idx.get_index()); + + // when num_threads == 0, automatically maximize parallelism + if (params.num_threads) { +#pragma omp parallel for num_threads(params.num_threads) + for (int64_t i = 0; i < queries.extent(0); ++i) { + get_search_knn_results(hnswlib_index, + queries.data_handle() + i * queries.extent(1), + neighbors.extent(1), + neighbors.data_handle() + i * neighbors.extent(1), + distances.data_handle() + i * distances.extent(1)); + } + } else { +#pragma omp parallel for + for (int64_t i = 0; i < queries.extent(0); ++i) { + get_search_knn_results(hnswlib_index, + queries.data_handle() + i * queries.extent(1), + neighbors.extent(1), + neighbors.data_handle() + i * neighbors.extent(1), + distances.data_handle() + i * distances.extent(1)); + } + } +} + +} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp b/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp new file mode 100644 index 0000000000..8103ffc5ab --- /dev/null +++ b/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../hnsw_types.hpp" +#include "hnsw_types.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::neighbors::hnsw::detail { + +template +std::unique_ptr> deserialize(raft::resources const& handle, + const std::string& filename, + int dim, + raft::distance::DistanceType metric) +{ + return std::unique_ptr>(new index_impl(filename, dim, metric)); +} + +} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/detail/hnsw_types.hpp b/cpp/include/raft/neighbors/detail/hnsw_types.hpp new file mode 100644 index 0000000000..94ade95965 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/hnsw_types.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../hnsw_types.hpp" +#include +#include +#include + +#include +#include +#include +#include + +namespace raft::neighbors::hnsw::detail { + +/** + * @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib + * @{ + */ + +template +struct hnsw_dist_t { + using type = void; +}; + +template <> +struct hnsw_dist_t { + using type = float; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +template +struct index_impl : index { + public: + /** + * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index + * + * @param[in] filepath path to the index + * @param[in] dim dimensions of the training dataset + * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") + */ + index_impl(std::string filepath, int dim, raft::distance::DistanceType metric) + : index{dim, metric} + { + if constexpr (std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique(dim); + } else if (metric == raft::distance::InnerProduct) { + space_ = std::make_unique(dim); + } + } else if constexpr (std::is_same_v or std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique>(dim); + } + } + + RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used"); + + appr_alg_ = std::make_unique::type>>( + space_.get(), filepath); + + appr_alg_->base_layer_only = true; + } + + /** + @brief Get hnswlib index + */ + auto get_index() const -> void const* override { return appr_alg_.get(); } + + private: + std::unique_ptr::type>> appr_alg_; + std::unique_ptr::type>> space_; +}; + +/**@}*/ + +} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp new file mode 100644 index 0000000000..dceb98c5aa --- /dev/null +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/hnsw.hpp" +#include "hnsw.hpp" + +#include + +#include +#include +#include +#include + +namespace raft::neighbors::hnsw { + +/** + * @addtogroup hnsw Build CAGRA index and search with hnswlib + * @{ + */ + +/** + * @brief Construct an hnswlib base-layer-only index from a CAGRA index + * NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/cagra_index.bin` + * before reading it as an hnswlib index, then deleting the temporary file. + * 2. This function is only offered as a compiled symbol in `libraft.so` + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] cagra_index cagra index + * + * Usage example: + * @code{.cpp} + * // Build a CAGRA index + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * + * // Load CAGRA index as base-layer-only hnswlib index + * auto hnsw_index = hnsw::from_cagra(res, index); + * @endcode + */ +template +std::unique_ptr> from_cagra(raft::resources const& res, + raft::neighbors::cagra::index cagra_index); + +template <> +std::unique_ptr> from_cagra( + raft::resources const& res, raft::neighbors::cagra::index cagra_index); + +template <> +std::unique_ptr> from_cagra( + raft::resources const& res, raft::neighbors::cagra::index cagra_index); + +template <> +std::unique_ptr> from_cagra( + raft::resources const& res, raft::neighbors::cagra::index cagra_index); + +/** + * @brief Search hnswlib base-layer-only index constructed from a CAGRA index + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a host matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a host matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a host matrix view to the distances to the selected neighbors [n_queries, + * k] + * + * Usage example: + * @code{.cpp} + * // Build a CAGRA index + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * + * // Save CAGRA index as base layer only hnswlib index + * hnsw::serialize(res, "my_index.bin", index); + * + * // Load CAGRA index as base layer only hnswlib index + * raft::neighbors::hnsw::index* hnsw_index; + * auto hnsw_index = hnsw::deserialize(res, "my_index.bin", D, raft::distance::L2Expanded); + * + * // Search K nearest neighbors as an hnswlib index + * // using host threads for concurrency + * hnsw::search_params search_params; + * search_params.ef = 50 // ef >= K; + * search_params.num_threads = 10; + * auto neighbors = raft::make_host_matrix(res, n_queries, k); + * auto distances = raft::make_host_matrix(res, n_queries, k); + * hnsw::search(res, search_params, *index, queries, neighbors, distances); + * // de-allocate hnsw_index + * delete hnsw_index; + * @endcode + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + detail::search(res, params, idx, queries, neighbors, distances); +} + +/**@}*/ + +} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft/neighbors/hnsw_serialize.hpp b/cpp/include/raft/neighbors/hnsw_serialize.hpp new file mode 100644 index 0000000000..45819c8fb5 --- /dev/null +++ b/cpp/include/raft/neighbors/hnsw_serialize.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/hnsw_serialize.hpp" +#include "hnsw_types.hpp" +#include + +#include + +namespace raft::neighbors::hnsw { + +/** + * @defgroup hnsw_serialize HNSW Serialize + * @{ + */ + +/** + * Load an hnswlib index which was serialized from a CAGRA index + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an an unallocated pointer + * int dim = 10; + * raft::distance::DistanceType = raft::distance::L2Expanded + * auto index = raft::deserialize(handle, filename, dim, metric); + * @endcode + * + * @tparam T data element type + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] dim dimensionality of the index + * @param[in] metric metric used to build the index + * + * @return std::unique_ptr> + * + */ +template +std::unique_ptr> deserialize(raft::resources const& handle, + const std::string& filename, + int dim, + raft::distance::DistanceType metric) +{ + return detail::deserialize(handle, filename, dim, metric); +} + +/**@}*/ + +} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft/neighbors/hnsw_types.hpp b/cpp/include/raft/neighbors/hnsw_types.hpp new file mode 100644 index 0000000000..aa4cefbc30 --- /dev/null +++ b/cpp/include/raft/neighbors/hnsw_types.hpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include +#include + +#include +#include +#include + +namespace raft::neighbors::hnsw { + +/** + * @defgroup hnsw Build CAGRA index and search with hnswlib + * @{ + */ + +struct search_params : ann::search_params { + int ef; // size of the candidate list + int num_threads = 0; // number of host threads to use for concurrent searches. Value of 0 + // automatically maximizes parallelism +}; + +template +struct index : ann::index { + public: + /** + * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index. + * This is a virtual class and it cannot be used directly. To create an index, use the factory + * function `raft::neighbors::hnsw::from_cagra` from the header + * `raft/neighbors/hnsw.hpp` + * + * @param[in] dim dimensions of the training dataset + * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") + */ + index(int dim, raft::distance::DistanceType metric) : dim_{dim}, metric_{metric} {} + + /** + @brief Get underlying index + */ + virtual auto get_index() const -> void const* = 0; + + auto dim() const -> int const { return dim_; } + + auto metric() const -> raft::distance::DistanceType { return metric_; } + + private: + int dim_; + raft::distance::DistanceType metric_; +}; + +/**@}*/ + +} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index c54ed32b77..8389929b15 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,48 +27,53 @@ namespace raft::runtime::neighbors::cagra { // Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors -#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - void build_device(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void build_host(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra::search_params const& params, \ - const raft::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index); \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ +#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + void build_device(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void build_host(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra::search_params const& params, \ + const raft::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index); \ + void serialize_to_hnswlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index); \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ raft::neighbors::cagra::index* index); RAFT_INST_CAGRA_FUNCS(float, uint32_t); diff --git a/cpp/include/raft_runtime/neighbors/hnsw.hpp b/cpp/include/raft_runtime/neighbors/hnsw.hpp new file mode 100644 index 0000000000..e8b932d490 --- /dev/null +++ b/cpp/include/raft_runtime/neighbors/hnsw.hpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::runtime::neighbors::hnsw { + +#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ + std::unique_ptr> from_cagra( \ + raft::resources const& res, raft::neighbors::cagra::index); \ + void search(raft::resources const& handle, \ + raft::neighbors::hnsw::search_params const& params, \ + raft::neighbors::hnsw::index const& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + template \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ + raft::distance::DistanceType metric); \ + template <> \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ + raft::distance::DistanceType metric); + +RAFT_INST_HNSW_FUNCS(float, uint32_t); +RAFT_INST_HNSW_FUNCS(int8_t, uint32_t); +RAFT_INST_HNSW_FUNCS(uint8_t, uint32_t); + +} // namespace raft::runtime::neighbors::hnsw diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index adde8663f4..f386bcce8e 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -26,39 +26,54 @@ namespace raft::runtime::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - raft::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, is); \ +#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ + str = os.str(); \ + } \ + \ + void serialize_to_hnswlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index) \ + { \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ + }; \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::cagra::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, is); \ } RAFT_INST_CAGRA_SERIALIZE(float); diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cpp b/cpp/src/raft_runtime/neighbors/hnsw.cpp new file mode 100644 index 0000000000..1f9e6b0a0b --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +namespace raft::neighbors::hnsw { +#define RAFT_INST_HNSW(T) \ + template <> \ + std::unique_ptr> from_cagra( \ + raft::resources const& res, raft::neighbors::cagra::index cagra_index) \ + { \ + std::string filepath = "/tmp/cagra_index.bin"; \ + raft::runtime::neighbors::cagra::serialize_to_hnswlib(res, filepath, cagra_index); \ + auto hnsw_index = raft::runtime::neighbors::hnsw::deserialize_file( \ + res, filepath, cagra_index.dim(), cagra_index.metric()); \ + std::filesystem::remove(filepath); \ + return hnsw_index; \ + } + +RAFT_INST_HNSW(float); +RAFT_INST_HNSW(int8_t); +RAFT_INST_HNSW(uint8_t); +#undef RAFT_INST_HNSW +} // namespace raft::neighbors::hnsw + +namespace raft::runtime::neighbors::hnsw { + +#define RAFT_INST_HNSW(T) \ + void search(raft::resources const& handle, \ + raft::neighbors::hnsw::search_params const& params, \ + const raft::neighbors::hnsw::index& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + raft::neighbors::hnsw::search(handle, params, index, queries, neighbors, distances); \ + } \ + \ + template <> \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ + raft::distance::DistanceType metric) \ + { \ + return raft::neighbors::hnsw::deserialize(handle, filename, dim, metric); \ + } + +RAFT_INST_HNSW(float); +RAFT_INST_HNSW(int8_t); +RAFT_INST_HNSW(uint8_t); + +#undef RAFT_INST_HNSW + +} // namespace raft::runtime::neighbors::hnsw diff --git a/docs/source/cpp_api/neighbors_hnsw.rst b/docs/source/cpp_api/neighbors_hnsw.rst new file mode 100644 index 0000000000..86f9544c35 --- /dev/null +++ b/docs/source/cpp_api/neighbors_hnsw.rst @@ -0,0 +1,29 @@ +HNSW +===== + +HNSW is a graph-based nearest neighbors implementation for the CPU. +This implementation provides the ability to serialize a CAGRA graph and read it as a base-layer-only hnswlib graph. + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::neighbors::hnsw* + +.. doxygengroup:: hnsw + :project: RAFT + :members: + :content-only: + +Serializer Methods +------------------ +``#include `` + +namespace *raft::neighbors::hnsw* + +.. doxygengroup:: hnsw_serialize + :project: RAFT + :members: + :content-only: diff --git a/docs/source/pylibraft_api/neighbors.rst b/docs/source/pylibraft_api/neighbors.rst index 680a2982cb..e9e890fccb 100644 --- a/docs/source/pylibraft_api/neighbors.rst +++ b/docs/source/pylibraft_api/neighbors.rst @@ -33,6 +33,22 @@ Serializer Methods .. autofunction:: pylibraft.neighbors.cagra.load +HNSW +#### + +.. autoclass:: pylibraft.neighbors.hnsw.SearchParams + :members: + +.. autofunction:: pylibraft.neighbors.hnsw.from_cagra + +.. autofunction:: pylibraft.neighbors.hnsw.search + +Serializer Methods +------------------ +.. autofunction:: pylibraft.neighbors.hnsw.save + +.. autofunction:: pylibraft.neighbors.hnsw.load + IVF-Flat ######## diff --git a/pyproject.toml b/pyproject.toml index 2982db2a23..1e4ba0b369 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ ignore_missing_imports = true # If we don't specify this, then mypy will check excluded files if # they are imported by a checked file. follow_imports = "skip" +exclude = [ + "pylibraft/pylibraft/test", + ] [tool.codespell] # note: pre-commit passes explicit lists of files here, which this skip file list doesn't override - diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 2c488ef427..9da3957f03 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -19,7 +19,7 @@ # cython: embedsignature = True # cython: language_level = 3 -from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t from libcpp cimport bool from libcpp.string cimport string @@ -83,6 +83,9 @@ cdef host_matrix_view[int64_t, int64_t, row_major] get_hmv_int64( cdef host_matrix_view[uint32_t, int64_t, row_major] get_hmv_uint32( array, check_shape) except * +cdef host_matrix_view[uint64_t, int64_t, row_major] get_hmv_uint64( + array, check_shape) except * + cdef host_matrix_view[const_float, int64_t, row_major] get_const_hmv_float( array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 9a994e2ec9..c1a9188585 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -303,7 +303,7 @@ cdef host_matrix_view[int64_t, int64_t, row_major] \ cdef host_matrix_view[uint32_t, int64_t, row_major] \ get_hmv_uint32(cai, check_shape) except *: - if cai.dtype != np.int64: + if cai.dtype != np.uint32: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) @@ -312,6 +312,17 @@ cdef host_matrix_view[uint32_t, int64_t, row_major] \ cai.data, shape[0], shape[1]) +cdef host_matrix_view[uint64_t, int64_t, row_major] \ + get_hmv_uint64(cai, check_shape) except *: + if cai.dtype != np.uint64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[uint64_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + cdef host_matrix_view[const_float, int64_t, row_major] \ get_const_hmv_float(cai, check_shape) except *: if cai.dtype != np.float32: diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index e64032408a..069038a0e8 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx brute_force.pyx rbc.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx hnsw.pyx rbc.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index 972058aaee..86612b2fbb 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -13,7 +13,10 @@ # limitations under the License. # -from pylibraft.neighbors import brute_force, cagra, ivf_flat, ivf_pq, rbc +from pylibraft.neighbors import brute_force # type: ignore +from pylibraft.neighbors import hnsw # type: ignore +from pylibraft.neighbors import rbc # type: ignore +from pylibraft.neighbors import cagra, ivf_flat, ivf_pq from .refine import refine @@ -23,6 +26,7 @@ "brute_force", "ivf_flat", "ivf_pq", - "rbc", "cagra", + "hnsw", + "rbc", ] diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd new file mode 100644 index 0000000000..98537f8357 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd @@ -0,0 +1,39 @@ +# +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int8_t, uint8_t, uint32_t +from libcpp cimport bool +from libcpp.string cimport string + +cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra + + +cdef class Index: + cdef readonly bool trained + cdef str active_index_type + +cdef class IndexFloat(Index): + cdef c_cagra.index[float, uint32_t] * index + +cdef class IndexInt8(Index): + cdef c_cagra.index[int8_t, uint32_t] * index + +cdef class IndexUint8(Index): + cdef c_cagra.index[uint8_t, uint32_t] * index diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index c19faa826d..df31d2560b 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -142,8 +142,6 @@ cdef class IndexParams: cdef class Index: - cdef readonly bool trained - cdef str active_index_type def __cinit__(self): self.trained = False @@ -151,7 +149,6 @@ cdef class Index: cdef class IndexFloat(Index): - cdef c_cagra.index[float, uint32_t] * index def __cinit__(self, handle=None): if handle is None: @@ -216,7 +213,6 @@ cdef class IndexFloat(Index): cdef class IndexInt8(Index): - cdef c_cagra.index[int8_t, uint32_t] * index def __cinit__(self, handle=None): if handle is None: @@ -281,7 +277,6 @@ cdef class IndexInt8(Index): cdef class IndexUint8(Index): - cdef c_cagra.index[uint8_t, uint32_t] * index def __cinit__(self, handle=None): if handle is None: diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 7e22f274e9..1dffd40186 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -211,6 +211,36 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib( + const device_resources& handle, + string& str, + const index[float, uint32_t]& index) except + + + cdef void serialize_to_hnswlib( + const device_resources& handle, + string& str, + const index[uint8_t, uint32_t]& index) except + + + cdef void serialize_to_hnswlib( + const device_resources& handle, + string& str, + const index[int8_t, uint32_t]& index) except + + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[float, uint32_t]& index) except + + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[uint8_t, uint32_t]& index) except + + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[int8_t, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[uint8_t, uint32_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd new file mode 100644 index 0000000000..75c0c14aad --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd @@ -0,0 +1,94 @@ +# +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t +from libcpp.memory cimport unique_ptr +from libcpp.string cimport string + +from pylibraft.common.cpp.mdspan cimport ( + device_vector_view, + host_matrix_view, + row_major, +) +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType +from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( + ann_index, + ann_search_params, +) + + +cdef extern from "raft/neighbors/hnsw.hpp" \ + namespace "raft::neighbors::hnsw" nogil: + + cpdef cppclass search_params(ann_search_params): + int ef + int num_threads + + cdef cppclass index[T](ann_index): + index(int dim, DistanceType metric) + + int dim() + DistanceType metric() + + +cdef extern from "raft_runtime/neighbors/hnsw.hpp" \ + namespace "raft::runtime::neighbors::hnsw" nogil: + cdef void search( + const device_resources& handle, + const search_params& params, + const index[float]& index, + host_matrix_view[float, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[int8_t]& index, + host_matrix_view[int8_t, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint8_t]& index, + host_matrix_view[uint8_t, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + + + cdef unique_ptr[index[float]] deserialize_file[float]( + const device_resources& handle, + const string& filename, + int dim, + DistanceType metric) except + + + cdef unique_ptr[index[int8_t]] deserialize_file[int8_t]( + const device_resources& handle, + const string& filename, + int dim, + DistanceType metric) except + + + cdef unique_ptr[index[uint8_t]] deserialize_file[uint8_t]( + const device_resources& handle, + const string& filename, + int dim, + DistanceType metric) except + diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx new file mode 100644 index 0000000000..aa589ffb65 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -0,0 +1,488 @@ +# +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, uint8_t, uint32_t +from libcpp cimport bool +from libcpp.memory cimport unique_ptr +from libcpp.string cimport string + +cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra +from pylibraft.distance.distance_type cimport DistanceType +from pylibraft.neighbors.cagra.cagra cimport ( + Index, + IndexFloat, + IndexInt8, + IndexUint8, +) + +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + +from pylibraft.common import DeviceResources, ai_wrapper, auto_convert_output + +cimport pylibraft.neighbors.cpp.hnsw as c_hnsw + +from pylibraft.neighbors.common import _check_input_array, _get_metric + +from pylibraft.common.mdspan cimport ( + get_hmv_float, + get_hmv_int8, + get_hmv_uint8, + get_hmv_uint64, +) +from pylibraft.neighbors.common cimport _get_metric_string + +import os + +import numpy as np + + +cdef class HnswIndex: + cdef readonly bool trained + cdef str active_index_type + + def __cinit__(self): + self.trained = False + self.active_index_type = None + +cdef class HnswIndexFloat(HnswIndex): + cdef unique_ptr[c_hnsw.index[float]] index + + def __cinit__(self): + pass + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.metric) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=hnsw, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index.get()[0].dim() + + @property + def metric(self): + return self.index.get()[0].metric() + +cdef class HnswIndexInt8(HnswIndex): + cdef unique_ptr[c_hnsw.index[int8_t]] index + + def __cinit__(self): + pass + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.metric) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=hnsw, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index.get()[0].dim() + + @property + def metric(self): + return self.index.get()[0].metric() + +cdef class HnswIndexUint8(HnswIndex): + cdef unique_ptr[c_hnsw.index[uint8_t]] index + + def __cinit__(self): + pass + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.metric) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=hnsw, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index.get()[0].dim() + + @property + def metric(self): + return self.index.get()[0].metric() + + +@auto_sync_handle +def save(filename, Index index, handle=None): + """ + Saves the CAGRA index as an hnswlib base-layer-only index to a file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained CAGRA index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> # Serialize the CAGRA index to hnswlib base layer only index format + >>> hnsw.save("my_index.bin", index, handle=handle) + """ + if not index.trained: + raise ValueError("Index need to be built before saving it.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + cdef c_cagra.index[float, uint32_t] * c_index_float + cdef c_cagra.index[int8_t, uint32_t] * c_index_int8 + cdef c_cagra.index[uint8_t, uint32_t] * c_index_uint8 + + if index.active_index_type == "float32": + idx_float = index + c_index_float = \ + idx_float.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_float)) + elif index.active_index_type == "byte": + idx_int8 = index + c_index_int8 = \ + idx_int8.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_int8)) + elif index.active_index_type == "ubyte": + idx_uint8 = index + c_index_uint8 = \ + idx_uint8.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_uint8)) + else: + raise ValueError( + "Index dtype %s not supported" % index.active_index_type) + + +@auto_sync_handle +def load(filename, dim, dtype, metric="sqeuclidean", handle=None): + """ + Loads base-layer-only hnswlib index from file, which was originally + saved as a built CAGRA index. + + Saving / loading the index is experimental. The serialization format is + subject to change, therefore loading an index saved with a previous + version of raft is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + dim : int + Dimensions of the training dataest + dtype : np.dtype of the saved index + Valid values for dtype: [np.float32, np.byte, np.ubyte] + metric : string denoting the metric type, default="sqeuclidean" + Valid values for metric: ["sqeuclidean", "inner_product"], where + - sqeuclidean is the euclidean distance without the square root + operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, + - inner product distance is defined as + distance(a, b) = \\sum_i a_i * b_i. + {handle_docstring} + + Returns + ------- + index : HnswIndex + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> # Serialize the CAGRA index to hnswlib base layer only index format + >>> hnsw.save("my_index.bin", index, handle=handle) + >>> index = hnsw.load("my_index.bin", n_features, np.float32, + ... "sqeuclidean") + """ + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + cdef HnswIndexFloat idx_float + cdef HnswIndexInt8 idx_int8 + cdef HnswIndexUint8 idx_uint8 + + cdef DistanceType c_metric = _get_metric(metric) + + if dtype == np.float32: + idx_float = HnswIndexFloat() + idx_float.index = c_hnsw.deserialize_file[float]( + deref(handle_), c_filename, dim, c_metric) + idx_float.trained = True + idx_float.active_index_type = 'float32' + return idx_float + elif dtype == np.byte: + idx_int8 = HnswIndexInt8(dim, metric) + idx_int8.index = c_hnsw.deserialize_file[int8_t]( + deref(handle_), c_filename, dim, c_metric) + idx_int8.trained = True + idx_int8.active_index_type = 'byte' + return idx_int8 + elif dtype == np.ubyte: + idx_uint8 = HnswIndexUint8(dim, metric) + idx_uint8.index = c_hnsw.deserialize_file[uint8_t]( + deref(handle_), c_filename, dim, c_metric) + idx_uint8.trained = True + idx_uint8.active_index_type = 'ubyte' + return idx_uint8 + else: + raise ValueError("Dataset dtype %s not supported" % dtype) + + +@auto_sync_handle +def from_cagra(Index index, handle=None): + """ + Returns an hnswlib base-layer-only index from a CAGRA index. + + NOTE: This method uses the filesystem to write the CAGRA index in + `/tmp/cagra_index.bin` before reading it as an hnswlib index, + then deleting the temporary file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + index : Index + Trained CAGRA index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> # Serialize the CAGRA index to hnswlib base layer only index format + >>> hnsw_index = hnsw.from_cagra(index, handle=handle) + """ + filename = "/tmp/cagra_index.bin" + save(filename, index, handle=handle) + hnsw_index = load(filename, index.dim, np.dtype(index.active_index_type), + _get_metric_string(index.metric), handle=handle) + os.remove(filename) + return hnsw_index + + +cdef class SearchParams: + """ + Hnswlib search parameters + + Parameters + ---------- + ef: int, default=200 + Size of list from which final neighbors k will be selected. + ef should be greater than or equal to k. + num_threads: int, default=1 + Number of host threads to use to search the hnswlib index + and increase concurrency + """ + cdef c_hnsw.search_params params + + def __init__(self, ef=200, num_threads=1): + self.params.ef = ef + self.params.num_threads = num_threads + + def __repr__(self): + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in [ + "ef", "num_threads"]] + return "SearchParams(type=hnsw, " + ( + ", ".join(attr_str)) + ")" + + @property + def ef(self): + return self.params.ef + + @property + def num_threads(self): + return self.params.num_threads + + +@auto_sync_handle +@auto_convert_output +def search(SearchParams search_params, + HnswIndex index, + queries, + k, + neighbors=None, + distances=None, + handle=None): + """ + Find the k nearest neighbors for each query. + + Parameters + ---------- + search_params : SearchParams + index : HnswIndex + Trained CAGRA index saved as base-layer-only hnswlib index. + queries : array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + k : int + The number of neighbors. + neighbors : Optional array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + distances : Optional array interface compliant matrix shape + (n_queries, k) If supplied, the distances to the + neighbors will be written here in-place. (default None) + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> import numpy as np + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> + >>> # Load saved base-layer-only hnswlib index from CAGRA index + >>> hnsw_index = hnsw.from_cagra(index, handle=handle) + >>> + >>> # Search hnswlib using the loaded index + >>> queries = np.random.random_sample((n_queries, n_features)).astype( + ... np.float32) + >>> k = 10 + >>> search_params = hnsw.SearchParams( + ... ef=20, + ... num_threads=5 + ... ) + >>> distances, neighbors = hnsw.search(search_params, hnsw_index, + ... queries, k, handle=handle) + """ + + if not index.trained: + raise ValueError("Index need to be built before calling search.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + queries_ai = ai_wrapper(queries) + queries_dt = queries_ai.dtype + cdef uint32_t n_queries = queries_ai.shape[0] + + _check_input_array(queries_ai, [np.dtype('float32'), np.dtype('byte'), + np.dtype('ubyte')], + exp_cols=index.dim) + + if neighbors is None: + neighbors = np.empty((n_queries, k), dtype='uint64') + + neighbors_ai = ai_wrapper(neighbors) + _check_input_array(neighbors_ai, [np.dtype('uint64')], + exp_rows=n_queries, exp_cols=k) + + if distances is None: + distances = np.empty((n_queries, k), dtype='float32') + + distances_ai = ai_wrapper(distances) + _check_input_array(distances_ai, [np.dtype('float32')], + exp_rows=n_queries, exp_cols=k) + + cdef c_hnsw.search_params params = search_params.params + cdef HnswIndexFloat idx_float + cdef HnswIndexInt8 idx_int8 + cdef HnswIndexUint8 idx_uint8 + + if queries_dt == np.float32: + idx_float = index + c_hnsw.search(deref(handle_), + params, + deref(idx_float.index), + get_hmv_float(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + elif queries_dt == np.byte: + idx_int8 = index + c_hnsw.search(deref(handle_), + params, + deref(idx_int8.index), + get_hmv_int8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + elif queries_dt == np.ubyte: + idx_uint8 = index + c_hnsw.search(deref(handle_), + params, + deref(idx_uint8.index), + get_hmv_uint8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + else: + raise ValueError("query dtype %s not supported" % queries_dt) + + return (distances, neighbors) diff --git a/python/pylibraft/pylibraft/test/__init__py b/python/pylibraft/pylibraft/test/__init__py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/test/ann_utils.py b/python/pylibraft/pylibraft/test/ann_utils.py new file mode 100644 index 0000000000..60db7f3273 --- /dev/null +++ b/python/pylibraft/pylibraft/test/ann_utils.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def generate_data(shape, dtype): + if dtype == np.byte: + x = np.random.randint(-127, 128, size=shape, dtype=np.byte) + elif dtype == np.ubyte: + x = np.random.randint(0, 255, size=shape, dtype=np.ubyte) + else: + x = np.random.random_sample(shape).astype(dtype) + + return x + + +def calc_recall(ann_idx, true_nn_idx): + assert ann_idx.shape == true_nn_idx.shape + n = 0 + for i in range(ann_idx.shape[0]): + n += np.intersect1d(ann_idx[i, :], true_nn_idx[i, :]).size + recall = n / ann_idx.size + return recall diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 24126c0c5a..be53b33da3 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,27 +20,7 @@ from pylibraft.common import device_ndarray from pylibraft.neighbors import cagra - - -# todo (dantegd): consolidate helper utils of ann methods -def generate_data(shape, dtype): - if dtype == np.byte: - x = np.random.randint(-127, 128, size=shape, dtype=np.byte) - elif dtype == np.ubyte: - x = np.random.randint(0, 255, size=shape, dtype=np.ubyte) - else: - x = np.random.random_sample(shape).astype(dtype) - - return x - - -def calc_recall(ann_idx, true_nn_idx): - assert ann_idx.shape == true_nn_idx.shape - n = 0 - for i in range(ann_idx.shape[0]): - n += np.intersect1d(ann_idx[i, :], true_nn_idx[i, :]).size - recall = n / ann_idx.size - return recall +from pylibraft.test.ann_utils import calc_recall, generate_data def run_cagra_build_search_test( diff --git a/python/pylibraft/pylibraft/test/test_hnsw.py b/python/pylibraft/pylibraft/test/test_hnsw.py new file mode 100644 index 0000000000..487f190e4e --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_hnsw.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize + +from pylibraft.neighbors import cagra, hnsw +from pylibraft.test.ann_utils import calc_recall, generate_data + + +def run_hnsw_build_search_test( + n_rows=10000, + n_cols=10, + n_queries=100, + k=10, + dtype=np.float32, + metric="sqeuclidean", + intermediate_graph_degree=128, + graph_degree=64, + search_params={}, +): + dataset = generate_data((n_rows, n_cols), dtype) + if metric == "inner_product": + dataset = normalize(dataset, norm="l2", axis=1) + + build_params = cagra.IndexParams( + metric=metric, + intermediate_graph_degree=intermediate_graph_degree, + graph_degree=graph_degree, + ) + + index = cagra.build(build_params, dataset) + + assert index.trained + + hnsw_index = hnsw.from_cagra(index) + + queries = generate_data((n_queries, n_cols), dtype) + out_idx = np.zeros((n_queries, k), dtype=np.uint32) + + search_params = hnsw.SearchParams(**search_params) + + out_dist, out_idx = hnsw.search(search_params, hnsw_index, queries, k) + + # Calculate reference values with sklearn + nn_skl = NearestNeighbors(n_neighbors=k, algorithm="brute", metric=metric) + nn_skl.fit(dataset) + skl_idx = nn_skl.kneighbors(queries, return_distance=False) + + recall = calc_recall(out_idx, skl_idx) + assert recall > 0.95 + + +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("k", [10, 20]) +@pytest.mark.parametrize("ef", [30, 40]) +@pytest.mark.parametrize("num_threads", [2, 4]) +def test_hnsw(dtype, k, ef, num_threads): + # Note that inner_product tests use normalized input which we cannot + # represent in int8, therefore we test only sqeuclidean metric here. + run_hnsw_build_search_test( + dtype=dtype, k=k, search_params={"ef": ef, "num_threads": num_threads} + )