From c60ae057c0bdba39c1a82fac64280619574605b3 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 Nov 2023 02:51:08 +0000 Subject: [PATCH 01/31] hnswlib serialize python API --- cpp/include/raft_runtime/neighbors/cagra.hpp | 92 +++++++++------- .../raft_runtime/neighbors/cagra_serialize.cu | 82 ++++++++------ .../pylibraft/neighbors/CMakeLists.txt | 2 +- .../pylibraft/neighbors/cagra/cagra.pyx | 2 - .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 42 +++++++ .../pylibraft/neighbors/cagra_hnswlib.pyx | 104 ++++++++++++++++++ 6 files changed, 246 insertions(+), 78 deletions(-) create mode 100644 python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index c54ed32b77..80abf1e6c4 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -27,48 +27,56 @@ namespace raft::runtime::neighbors::cagra { // Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors -#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - void build_device(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void build_host(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra::search_params const& params, \ - const raft::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index); \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ +#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + void build_device(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void build_host(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra::search_params const& params, \ + const raft::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void serialize_to_hnswlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index); \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ raft::neighbors::cagra::index* index); RAFT_INST_CAGRA_FUNCS(float, uint32_t); diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index 69b48b93a4..4ede7ecc25 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -24,39 +24,55 @@ namespace raft::runtime::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - raft::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, is); \ +#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ + }; \ + \ + void serialize_to_hswnlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index) \ + { \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ + str = os.str(); \ + } \ + \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::cagra::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, is); \ } RAFT_INST_CAGRA_SERIALIZE(float); diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 45cd9f74e6..d5b8a80761 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx brute_force.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx cagra_hnswlib.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index c11d933b27..977bdfaee0 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -142,8 +142,6 @@ cdef class IndexParams: cdef class Index: - cdef readonly bool trained - cdef str active_index_type def __cinit__(self): self.trained = False diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 7e22f274e9..3d61fd2203 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -174,6 +174,10 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ string& str, const index[float, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnwslib( + const device_resources& handle, + string& str, + const index[float, uint32_t]& index) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -184,6 +188,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnwslib( + const device_resources& handle, + string& str, + const index[uint8_t, uint32_t]& index) except + + cdef void deserialize(const device_resources& handle, const string& str, index[uint8_t, uint32_t]* index) except + @@ -193,6 +202,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[int8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnwslib( + const device_resources& handle, + string& str, + const index[int8_t, uint32_t]& index) except + + cdef void deserialize(const device_resources& handle, const string& str, index[int8_t, uint32_t]* index) except + @@ -202,6 +216,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[float, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[float, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[float, uint32_t]* index) except + @@ -211,6 +230,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[uint8_t, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[uint8_t, uint32_t]* index) except + @@ -220,6 +244,24 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[int8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[int8_t, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[int8_t, uint32_t]* index) except + + +cdef class Index: + cdef readonly bool trained + cdef str active_index_type + +cdef class IndexFloat(Index): + pass + +cdef class IndexInt8(Index): + pass + +cdef class IndexUint8(Index): + pass diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx new file mode 100644 index 0000000000..a69d27967d --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx @@ -0,0 +1,104 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, uint8_t, uint32_t +from libcpp.string cimport string + +cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra + +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + +from pylibraft.common import DeviceResources + + +@auto_sync_handle +def save(filename, c_cagra.Index index, handle=None): + """ + Saves the index to a file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained CAGRA index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import cagra_hnswlib + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> # Serialize and deserialize the cagra index built + >>> cagra_hnswlib.save("my_index.bin", index, handle=handle) + """ + if not index.trained: + raise ValueError("Index need to be built before saving it.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + + cdef c_cagra.IndexFloat idx_float + cdef c_cagra.IndexInt8 idx_int8 + cdef c_cagra.IndexUint8 idx_uint8 + + cdef c_cagra.index[float, uint32_t] * c_index_float + cdef c_cagra.index[int8_t, uint32_t] * c_index_int8 + cdef c_cagra.index[uint8_t, uint32_t] * c_index_uint8 + + if index.active_index_type == "float32": + idx_float = index + c_index_float = \ + idx_float.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_float)) + elif index.active_index_type == "byte": + idx_int8 = index + c_index_int8 = \ + idx_int8.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_int8)) + elif index.active_index_type == "ubyte": + idx_uint8 = index + c_index_uint8 = \ + idx_uint8.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_uint8)) + else: + raise ValueError( + "Index dtype %s not supported" % index.active_index_type) From 7f04f973dea3def3aeed3a3b473771010e6d253d Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 Nov 2023 03:48:38 +0000 Subject: [PATCH 02/31] fix typo, refactor cython --- .../raft_runtime/neighbors/cagra_serialize.cu | 2 +- .../pylibraft/pylibraft/neighbors/__init__.py | 9 ++++- .../pylibraft/neighbors/cagra/cagra.pxd | 39 +++++++++++++++++++ .../pylibraft/neighbors/cagra/cagra.pyx | 3 -- .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 13 ------- .../pylibraft/neighbors/cagra_hnswlib.pyx | 14 +++++-- 6 files changed, 58 insertions(+), 22 deletions(-) create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index 4ede7ecc25..c4599b560b 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::cagra { raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ }; \ \ - void serialize_to_hswnlib_file(raft::resources const& handle, \ + void serialize_to_hnswlib_file(raft::resources const& handle, \ const std::string& filename, \ const raft::neighbors::cagra::index& index) \ { \ diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index 325ea5842e..b4cac76019 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -17,4 +17,11 @@ from .refine import refine -__all__ = ["common", "refine", "brute_force", "ivf_flat", "ivf_pq", "cagra"] +__all__ = [ + "common", + "refine", + "brute_force", + "ivf_flat", + "ivf_pq", + "cagra", +] diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd new file mode 100644 index 0000000000..f5f6da93f9 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd @@ -0,0 +1,39 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int8_t, uint8_t, uint32_t +from libcpp cimport bool +from libcpp.string cimport string + +cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra + + +cdef class Index: + cdef readonly bool trained + cdef str active_index_type + +cdef class IndexFloat(Index): + cdef c_cagra.index[float, uint32_t] * index + +cdef class IndexInt8(Index): + cdef c_cagra.index[int8_t, uint32_t] * index + +cdef class IndexUint8(Index): + cdef c_cagra.index[uint8_t, uint32_t] * index diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index 977bdfaee0..6f7d6cc5b2 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -149,7 +149,6 @@ cdef class Index: cdef class IndexFloat(Index): - cdef c_cagra.index[float, uint32_t] * index def __cinit__(self, handle=None): if handle is None: @@ -214,7 +213,6 @@ cdef class IndexFloat(Index): cdef class IndexInt8(Index): - cdef c_cagra.index[int8_t, uint32_t] * index def __cinit__(self, handle=None): if handle is None: @@ -279,7 +277,6 @@ cdef class IndexInt8(Index): cdef class IndexUint8(Index): - cdef c_cagra.index[uint8_t, uint32_t] * index def __cinit__(self, handle=None): if handle is None: diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 3d61fd2203..8cd5cb0b44 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -252,16 +252,3 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void deserialize_file(const device_resources& handle, const string& filename, index[int8_t, uint32_t]* index) except + - -cdef class Index: - cdef readonly bool trained - cdef str active_index_type - -cdef class IndexFloat(Index): - pass - -cdef class IndexInt8(Index): - pass - -cdef class IndexUint8(Index): - pass diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx index a69d27967d..4faeb89ef5 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx @@ -23,6 +23,12 @@ from libc.stdint cimport int8_t, uint8_t, uint32_t from libcpp.string cimport string cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra +from pylibraft.neighbors.cagra.cagra cimport ( + Index, + IndexFloat, + IndexInt8, + IndexUint8, +) from pylibraft.common.handle import auto_sync_handle @@ -32,7 +38,7 @@ from pylibraft.common import DeviceResources @auto_sync_handle -def save(filename, c_cagra.Index index, handle=None): +def save(filename, Index index, handle=None): """ Saves the index to a file. @@ -73,9 +79,9 @@ def save(filename, c_cagra.Index index, handle=None): cdef string c_filename = filename.encode('utf-8') - cdef c_cagra.IndexFloat idx_float - cdef c_cagra.IndexInt8 idx_int8 - cdef c_cagra.IndexUint8 idx_uint8 + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 cdef c_cagra.index[float, uint32_t] * c_index_float cdef c_cagra.index[int8_t, uint32_t] * c_index_int8 From 4b2b2c603fbe2053b43622b8898b146c641fc242 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 28 Nov 2023 03:20:51 +0000 Subject: [PATCH 03/31] add cpp index and python load,search methods with runtime API --- cpp/CMakeLists.txt | 6 + cpp/include/raft/neighbors/cagra_hnswlib.hpp | 90 +++++ .../raft/neighbors/cagra_hnswlib_types.hpp | 108 ++++++ .../detail/cagra/cagra_serialize.cuh | 1 - .../raft/neighbors/detail/cagra_hnswlib.hpp | 186 ++++++++++ .../raft_runtime/neighbors/cagra_hnswlib.hpp | 39 ++ python/pylibraft/pylibraft/common/mdspan.pxd | 5 +- python/pylibraft/pylibraft/common/mdspan.pyx | 13 +- .../pylibraft/neighbors/cagra_hnswlib.pyx | 341 +++++++++++++++++- .../pylibraft/neighbors/cpp/cagra_hnswlib.pxd | 75 ++++ 10 files changed, 858 insertions(+), 6 deletions(-) create mode 100644 cpp/include/raft/neighbors/cagra_hnswlib.hpp create mode 100644 cpp/include/raft/neighbors/cagra_hnswlib_types.hpp create mode 100644 cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp create mode 100644 cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index acb77ec8c7..5b028a2456 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON) option(BUILD_TESTS "Build raft unit-tests" ON) option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF) option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF) +option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON) option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF) option(CUDA_ENABLE_LINEINFO "Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF @@ -195,6 +196,10 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH) rapids_cpm_gbench() endif() +if (BUILD_CAGRA_HNSWLIB) + include(cmake/thirdparty/get_hnswlib.cmake) +endif() + # ################################################################################################## # * raft --------------------------------------------------------------------- add_library(raft INTERFACE) @@ -202,6 +207,7 @@ add_library(raft::raft ALIAS raft) target_include_directories( raft INTERFACE "$" "$" + "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" ) if(NOT BUILD_CPU_ONLY) diff --git a/cpp/include/raft/neighbors/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/cagra_hnswlib.hpp new file mode 100644 index 0000000000..cb02b10b4c --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_hnswlib.hpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cagra_hnswlib_types.hpp" +#include "detail/cagra_hnswlib.hpp" + +#include +#include +#include + +namespace raft::neighbors::cagra_hnswlib { + +/** + * @brief Search hnswlib base layer only index constructed from a CAGRA index + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * + * Usage example: + * @code{.cpp} + * // Build a CAGRA index + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * + * // Save CAGRA index as base layer only hnswlib index + * cagra::serialize_to_hnswlib(res, "my_index.bin", index); + * + * // Load CAGRA index as base layer only hnswlib index + * cagra_hnswlib::index(D, "my_index.bin", raft::distance::L2Expanded); + * + * // Search K nearest neighbors as an hnswlib index + * // using host threads for concurrency + * cagra_hnswlib::search_params search_params; + * search_params.ef = 50 // ef >= K; + * search_params.num_threads = 10; + * auto neighbors = raft::make_host_matrix(res, n_queries, k); + * auto distances = raft::make_host_matrix(res, n_queries, k); + * cagra_hnswlib::search(res, search_params, index, queries, neighbors, distances); + * @endcode + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + detail::search(res, params, idx, queries, neighbors, distances); +} + +} // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp new file mode 100644 index 0000000000..bf1725ea86 --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include +#include + +#include +#include +#include +#include + +namespace raft::neighbors::cagra_hnswlib { + +template +struct hnsw_dist_t { + using type = void; +}; + +template <> +struct hnsw_dist_t { + using type = float; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +struct search_params : ann::search_params { + int ef; // size of the candidate list + int num_threads = 1; // number of host threads to use for concurrent searches +}; + +template +struct index : ann::index { + public: + /** + * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index + * + * @param[in] filepath path to the index + * @param[in] dim dimensions of the training dataset + * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") + */ + index(std::string filepath, int dim, raft::distance::DistanceType metric) + : dim_{dim}, metric_{metric} + { + if constexpr (std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique(dim_); + } else if (metric == raft::distance::InnerProduct) { + space_ = std::make_unique(dim_); + } + } else if constexpr (std::is_same_v or std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique(dim_); + } + } + + RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used"); + + appr_alg_ = std::make_unique::type>>( + space_.get(), filepath); + + appr_alg_->base_layer_only = true; + } + + /** + @brief Get hnswlib index + */ + auto get_index() const -> hnswlib::HierarchicalNSW::type> const* + { + return appr_alg_.get(); + } + + auto dim() const -> int const { return dim_; } + + auto metric() const -> raft::distance::DistanceType { return metric_; } + + private: + int dim_; + raft::distance::DistanceType metric_; + + std::unique_ptr::type>> appr_alg_; + std::unique_ptr::type>> space_; +}; + +} // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 51c9475434..8ed01b062d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -204,7 +204,6 @@ void serialize_to_hnswlib(raft::resources const& res, auto zero = 0; os.write(reinterpret_cast(&zero), sizeof(int)); } - // delete [] host_graph; } template diff --git a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp new file mode 100644 index 0000000000..ba826db0c5 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../cagra_hnswlib_types.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::cagra_hnswlib::detail { + +class FixedThreadPool { + public: + FixedThreadPool(int num_threads) + { + if (num_threads < 1) { + throw std::runtime_error("num_threads must >= 1"); + } else if (num_threads == 1) { + return; + } + + tasks_ = new Task_[num_threads]; + + threads_.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + threads_.emplace_back([&, i] { + auto& task = tasks_[i]; + while (true) { + std::unique_lock lock(task.mtx); + task.cv.wait(lock, + [&] { return task.has_task || finished_.load(std::memory_order_relaxed); }); + if (finished_.load(std::memory_order_relaxed)) { break; } + + task.task(); + task.has_task = false; + } + }); + } + } + + ~FixedThreadPool() + { + if (threads_.empty()) { return; } + + finished_.store(true, std::memory_order_relaxed); + for (unsigned i = 0; i < threads_.size(); ++i) { + auto& task = tasks_[i]; + std::lock_guard(task.mtx); + + task.cv.notify_one(); + threads_[i].join(); + } + + delete[] tasks_; + } + + template + void submit(Func f, IdxT len) + { + // Run functions in main thread if thread pool has no threads + if (threads_.empty()) { + for (IdxT i = 0; i < len; ++i) { + f(i); + } + return; + } + + const int num_threads = threads_.size(); + // one extra part for competition among threads + const IdxT items_per_thread = len / (num_threads + 1); + std::atomic cnt(items_per_thread * num_threads); + + // Wrap function + auto wrapped_f = [&](IdxT start, IdxT end) { + for (IdxT i = start; i < end; ++i) { + f(i); + } + + while (true) { + IdxT i = cnt.fetch_add(1, std::memory_order_relaxed); + if (i >= len) { break; } + f(i); + } + }; + + std::vector> futures; + futures.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + IdxT start = i * items_per_thread; + auto& task = tasks_[i]; + { + std::lock_guard lock(task.mtx); + (void)lock; // stop nvcc warning + task.task = std::packaged_task([=] { wrapped_f(start, start + items_per_thread); }); + futures.push_back(task.task.get_future()); + task.has_task = true; + } + task.cv.notify_one(); + } + + for (auto& fut : futures) { + fut.wait(); + } + return; + } + + private: + struct alignas(64) Task_ { + std::mutex mtx; + std::condition_variable cv; + bool has_task = false; + std::packaged_task task; + }; + + Task_* tasks_; + std::vector threads_; + std::atomic finished_{false}; +}; + +template +void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, + const T* query, + int k, + uint64_t* indices, + float* distances) +{ + auto result = idx->searchKnn(query, k); + assert(result.size() >= static_cast(k)); + + for (int i = k - 1; i >= 0; --i) { + indices[i] = result.top().second; + distances[i] = result.top().first; + result.pop(); + } +} + +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + auto const* hnswlib_index = idx.get_index(); + + // no-op when num_threads == 1, no synchronization overhead + FixedThreadPool thread_pool{params.num_threads}; + + auto f = [&](auto const& i) { + get_search_knn_results(hnswlib_index, + queries.data_handle() + i * queries.extent(1), + neighbors.extent(1), + neighbors.data_handle() + i * neighbors.extent(1), + distances.data_handle() + i * distances.extent(1)); + }; + + thread_pool.submit(f, queries.extent(0)); +} + +} // namespace raft::neighbors::cagra_hnswlib::detail diff --git a/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp b/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp new file mode 100644 index 0000000000..8a9edd03ce --- /dev/null +++ b/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::runtime::neighbors::cagra_hnswlib { + +#define RAFT_INST_CAGRA_HNSWLIB_FUNCS(T) \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra_hnswlib::search_params const& params, \ + raft::neighbors::cagra_hnswlib::index const& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + raft::neighbors::cagra_hnswlib::search(handle, params, index, queries, neighbors, distances); \ + } + +RAFT_INST_CAGRA_HNSWLIB_FUNCS(float); +RAFT_INST_CAGRA_HNSWLIB_FUNCS(int8_t); +RAFT_INST_CAGRA_HNSWLIB_FUNCS(uint8_t); + +} // namespace raft::runtime::neighbors::cagra_hnswlib diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 17dd2d8bfd..bdff331153 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -19,7 +19,7 @@ # cython: embedsignature = True # cython: language_level = 3 -from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t from libcpp.string cimport string from pylibraft.common.cpp.mdspan cimport ( @@ -79,6 +79,9 @@ cdef host_matrix_view[int64_t, int64_t, row_major] get_hmv_int64( cdef host_matrix_view[uint32_t, int64_t, row_major] get_hmv_uint32( array, check_shape) except * +cdef host_matrix_view[uint64_t, int64_t, row_major] get_hmv_uint64( + array, check_shape) except * + cdef host_matrix_view[const_float, int64_t, row_major] get_const_hmv_float( array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 7442a6bb89..d2b63ce549 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -290,7 +290,7 @@ cdef host_matrix_view[int64_t, int64_t, row_major] \ cdef host_matrix_view[uint32_t, int64_t, row_major] \ get_hmv_uint32(cai, check_shape) except *: - if cai.dtype != np.int64: + if cai.dtype != np.uint32: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) @@ -299,6 +299,17 @@ cdef host_matrix_view[uint32_t, int64_t, row_major] \ cai.data, shape[0], shape[1]) +cdef host_matrix_view[uint64_t, int64_t, row_major] \ + get_hmv_uint64(cai, check_shape) except *: + if cai.dtype != np.uint64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[uint64_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + cdef host_matrix_view[const_float, int64_t, row_major] \ get_const_hmv_float(cai, check_shape) except *: if cai.dtype != np.float32: diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx index 4faeb89ef5..db75524a9c 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx @@ -20,6 +20,7 @@ from cython.operator cimport dereference as deref from libc.stdint cimport int8_t, uint8_t, uint32_t +from libcpp cimport bool from libcpp.string cimport string cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra @@ -34,13 +35,126 @@ from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport device_resources -from pylibraft.common import DeviceResources +from pylibraft.common import DeviceResources, ai_wrapper, auto_convert_output + +cimport pylibraft.neighbors.cpp.cagra_hnswlib as c_cagra_hnswlib + +from pylibraft.neighbors.common import _check_input_array, _get_metric + +from pylibraft.common.mdspan cimport ( + get_hmv_float, + get_hmv_int8, + get_hmv_uint8, + get_hmv_uint64, +) +from pylibraft.neighbors.common cimport _get_metric_string + +import numpy as np + + +cdef class CagraHnswlibIndex: + cdef readonly bool trained + cdef str active_index_type + + def __cinit__(self): + self.trained = False + self.active_index_type = None + +cdef class CagraHnswlibIndexFloat(CagraHnswlibIndex): + cdef c_cagra_hnswlib.index[float] * index + + def __cinit__(self, filepath, dim, metric): + cdef string c_filepath = filepath.encode('utf-8') + + self.index = new c_cagra_hnswlib.index[float]( + c_filepath, + dim, + _get_metric(metric)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def metric(self): + return self.index[0].metric() + + def __dealloc__(self): + if self.index is not NULL: + del self.index + +cdef class CagraHnswlibIndexInt8(CagraHnswlibIndex): + cdef c_cagra_hnswlib.index[int8_t] * index + + def __cinit__(self, filepath, dim, metric): + cdef string c_filepath = filepath.encode('utf-8') + + self.index = new c_cagra_hnswlib.index[int8_t]( + c_filepath, + dim, + _get_metric(metric)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def metric(self): + return self.index[0].metric() + + def __dealloc__(self): + if self.index is not NULL: + del self.index + +cdef class CagraHnswlibIndexUint8(CagraHnswlibIndex): + cdef c_cagra_hnswlib.index[uint8_t] * index + + def __cinit__(self, filepath, dim, metric): + cdef string c_filepath = filepath.encode('utf-8') + + self.index = new c_cagra_hnswlib.index[uint8_t]( + c_filepath, + dim, + _get_metric(metric)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def metric(self): + return self.index[0].metric() + + def __dealloc__(self): + if self.index is not NULL: + del self.index @auto_sync_handle def save(filename, Index index, handle=None): """ - Saves the index to a file. + Saves the CAGRA index as an hnswlib base layer only index to a file. Saving / loading the index is experimental. The serialization format is subject to change. @@ -66,7 +180,7 @@ def save(filename, Index index, handle=None): >>> # Build index >>> handle = DeviceResources() >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) - >>> # Serialize and deserialize the cagra index built + >>> # Serialize the CAGRA index to hnswlib base layer only index format >>> cagra_hnswlib.save("my_index.bin", index, handle=handle) """ if not index.trained: @@ -108,3 +222,224 @@ def save(filename, Index index, handle=None): else: raise ValueError( "Index dtype %s not supported" % index.active_index_type) + + +def load(filename, dim, dtype, metric="sqeuclidean"): + """ + Loads base layer only hnswlib index from file, which was originally + saved as a built CAGRA index. + + Saving / loading the index is experimental. The serialization format is + subject to change, therefore loading an index saved with a previous + version of raft is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + dim : int + Dimensions of the training dataest + dtype : np.dtype of the saved index + Valid values for dtype: ["float", "byte", "ubyte"] + metric : string denoting the metric type, default="sqeuclidean" + Valid values for metric: ["sqeuclidean", "inner_product"], where + - sqeuclidean is the euclidean distance without the square root + operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, + - inner product distance is defined as + distance(a, b) = \\sum_i a_i * b_i. + + Returns + ------- + index : CagraHnswlibIndex + + Examples + -------- + >>> from pylibraft.neighbors import cagra_hnswlib + >>> dim = 50 # Assuming training dataset has 50 dimensions + >>> index = cagra_hnswlib.load("my_index.bin", dim, "sqeuclidean") + """ + cdef string c_filename = filename.encode('utf-8') + cdef CagraHnswlibIndexFloat idx_float + cdef CagraHnswlibIndexInt8 idx_int8 + cdef CagraHnswlibIndexUint8 idx_uint8 + + if dtype == np.float32: + idx_float = CagraHnswlibIndexFloat(filename, dim, metric) + idx_float.trained = True + idx_float.active_index_type = 'float32' + return idx_float + elif dtype == np.byte: + idx_int8 = CagraHnswlibIndexInt8(filename, dim, metric) + idx_int8.trained = True + idx_int8.active_index_type = 'byte' + return idx_int8 + elif dtype == np.ubyte: + idx_uint8 = CagraHnswlibIndexUint8(filename, dim, metric) + idx_uint8.trained = True + idx_uint8.active_index_type = 'ubyte' + return idx_uint8 + else: + raise ValueError("Dataset dtype %s not supported" % dtype) + + +cdef class SearchParams: + """ + CAGRA search parameters + + Parameters + ---------- + ef: int + Size of list from which final neighbors k will be selected. + ef should be greater than or equal to k. + num_threads: int, default=1 + Number of host threads to use to search the hnswlib index + and increase concurrency + """ + cdef c_cagra_hnswlib.search_params params + + def __init__(self, ef, num_threads=1): + self.params.ef = ef + self.params.num_threads = num_threads + + def __repr__(self): + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in [ + "ef", "num_threads"]] + return "SearchParams(type=CAGRA_hnswlib, " + ( + ", ".join(attr_str)) + ")" + + @property + def ef(self): + return self.params.ef + + @property + def num_threads(self): + return self.params.num_threads + + +@auto_sync_handle +@auto_convert_output +def search(SearchParams search_params, + CagraHnswlibIndex index, + queries, + k, + neighbors=None, + distances=None, + handle=None): + """ + Find the k nearest neighbors for each query. + + Parameters + ---------- + search_params : SearchParams + index : Index + Trained CAGRA index. + queries : array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + k : int + The number of neighbors. + neighbors : Optional array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + distances : Optional array interface compliant matrix shape + (n_queries, k) If supplied, the distances to the + neighbors will be written here in-place. (default None) + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> import numpy as np + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import cagra_hnswlib + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> + >>> Save CAGRA built index as base layer only hnswlib index + >>> cagra_hnswlib.save("my_index.bin", index) + >>> + >>> Load saved base layer only hnswlib index + >>> index_hnswlib.load("my_index.bin", n_features, dataset.dtype) + >>> + >>> # Search hnswlib using the loaded index + >>> queries = np.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 10 + >>> search_params = cagra_hnswlib.SearchParams( + ... ef=20, + ... num_threads=5 + ... ) + >>> distances, neighbors = cagra_hnswlib.search(search_params, index, + ... queries, k, handle=handle) + """ + + if not index.trained: + raise ValueError("Index need to be built before calling search.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + queries_ai = ai_wrapper(queries) + queries_dt = queries_ai.dtype + cdef uint32_t n_queries = queries_ai.shape[0] + + _check_input_array(queries_ai, [np.dtype('float32'), np.dtype('byte'), + np.dtype('ubyte')], + exp_cols=index.dim) + + if neighbors is None: + neighbors = np.empty((n_queries, k), dtype='uint64') + + neighbors_ai = ai_wrapper(neighbors) + _check_input_array(neighbors_ai, [np.dtype('uint64')], + exp_rows=n_queries, exp_cols=k) + + if distances is None: + distances = np.empty((n_queries, k), dtype='float32') + + distances_ai = ai_wrapper(distances) + _check_input_array(distances_ai, [np.dtype('float32')], + exp_rows=n_queries, exp_cols=k) + + cdef c_cagra_hnswlib.search_params params = search_params.params + cdef CagraHnswlibIndexFloat idx_float + cdef CagraHnswlibIndexInt8 idx_int8 + cdef CagraHnswlibIndexUint8 idx_uint8 + + if queries_dt == np.float32: + idx_float = index + c_cagra_hnswlib.search(deref(handle_), + params, + deref(idx_float.index), + get_hmv_float(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + elif queries_dt == np.byte: + idx_int8 = index + c_cagra_hnswlib.search(deref(handle_), + params, + deref(idx_int8.index), + get_hmv_int8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + elif queries_dt == np.ubyte: + idx_uint8 = index + c_cagra_hnswlib.search(deref(handle_), + params, + deref(idx_uint8.index), + get_hmv_uint8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + else: + raise ValueError("query dtype %s not supported" % queries_dt) + + return (distances, neighbors) diff --git a/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd b/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd new file mode 100644 index 0000000000..6a4154e151 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd @@ -0,0 +1,75 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t +from libcpp.string cimport string + +from pylibraft.common.cpp.mdspan cimport ( + device_vector_view, + host_matrix_view, + row_major, +) +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType +from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( + ann_index, + ann_search_params, +) + + +cdef extern from "raft/neighbors/cagra_hnswlib_types.hpp" \ + namespace "raft::neighbors::cagra_hnswlib" nogil: + + cpdef cppclass search_params(ann_search_params): + int ef + int num_threads + + cdef cppclass index[T](ann_index): + index(string filepath, int dim, DistanceType metric) + + int dim() + DistanceType metric() + + +cdef extern from "raft_runtime/neighbors/cagra_hnswlib.hpp" \ + namespace "raft::runtime::neighbors::cagra_hnswlib" nogil: + cdef void search( + const device_resources& handle, + const search_params& params, + const index[float]& index, + host_matrix_view[float, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[int8_t]& index, + host_matrix_view[int8_t, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint8_t]& index, + host_matrix_view[uint8_t, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + From 92e57177c5e1accc78500a49c6f4c3cdcb82df6e Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 28 Nov 2023 04:22:29 +0000 Subject: [PATCH 04/31] add static_assert guard for (u)int hnswlib serialize --- cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 8ed01b062d..a1775c2d4a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -103,6 +103,8 @@ void serialize_to_hnswlib(raft::resources const& res, std::ostream& os, const index& index_) { + static_assert(std::is_same_v or std::is_same_v, + "An hnswlib index can only be trained with int32 or uint32 IdxT"); common::nvtx::range fun_scope("cagra::serialize_to_hnswlib"); RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", static_cast(index_.size()), From 7f7885189156e84ee405cb556028831781009469 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 28 Nov 2023 04:27:21 +0000 Subject: [PATCH 05/31] passing float tests --- .../pylibraft/pylibraft/neighbors/__init__.py | 9 +- .../pylibraft/neighbors/cagra_hnswlib.pyx | 10 +-- python/pylibraft/pylibraft/test/__init__py | 0 python/pylibraft/pylibraft/test/ann_utils.py | 35 ++++++++ python/pylibraft/pylibraft/test/test_cagra.py | 22 +---- .../pylibraft/test/test_cagra_hnswlib.py | 85 +++++++++++++++++++ 6 files changed, 134 insertions(+), 27 deletions(-) create mode 100644 python/pylibraft/pylibraft/test/__init__py create mode 100644 python/pylibraft/pylibraft/test/ann_utils.py create mode 100644 python/pylibraft/pylibraft/test/test_cagra_hnswlib.py diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index b4cac76019..14f6f2e108 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. # -from pylibraft.neighbors import brute_force, cagra, ivf_flat, ivf_pq +from pylibraft.neighbors import ( + brute_force, + cagra, + cagra_hnswlib, + ivf_flat, + ivf_pq, +) from .refine import refine @@ -24,4 +30,5 @@ "ivf_flat", "ivf_pq", "cagra", + "cagra_hnswlib", ] diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx index db75524a9c..9979f11a59 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx @@ -284,11 +284,11 @@ def load(filename, dim, dtype, metric="sqeuclidean"): cdef class SearchParams: """ - CAGRA search parameters + Hnswlib search parameters Parameters ---------- - ef: int + ef: int, default=200 Size of list from which final neighbors k will be selected. ef should be greater than or equal to k. num_threads: int, default=1 @@ -297,7 +297,7 @@ cdef class SearchParams: """ cdef c_cagra_hnswlib.search_params params - def __init__(self, ef, num_threads=1): + def __init__(self, ef=200, num_threads=1): self.params.ef = ef self.params.num_threads = num_threads @@ -332,8 +332,8 @@ def search(SearchParams search_params, Parameters ---------- search_params : SearchParams - index : Index - Trained CAGRA index. + index : CagraHnswlibIndex + Trained CAGRA index saved as base layer only hnswlib index. queries : array interface compliant matrix shape (n_samples, dim) Supported dtype [float, int8, uint8] k : int diff --git a/python/pylibraft/pylibraft/test/__init__py b/python/pylibraft/pylibraft/test/__init__py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/test/ann_utils.py b/python/pylibraft/pylibraft/test/ann_utils.py new file mode 100644 index 0000000000..d9348fe100 --- /dev/null +++ b/python/pylibraft/pylibraft/test/ann_utils.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def generate_data(shape, dtype): + if dtype == np.byte: + x = np.random.randint(-127, 128, size=shape, dtype=np.byte) + elif dtype == np.ubyte: + x = np.random.randint(0, 255, size=shape, dtype=np.ubyte) + else: + x = np.random.random_sample(shape).astype(dtype) + + return x + + +def calc_recall(ann_idx, true_nn_idx): + assert ann_idx.shape == true_nn_idx.shape + n = 0 + for i in range(ann_idx.shape[0]): + n += np.intersect1d(ann_idx[i, :], true_nn_idx[i, :]).size + recall = n / ann_idx.size + return recall diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 24126c0c5a..65ad3d1fcf 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -20,27 +20,7 @@ from pylibraft.common import device_ndarray from pylibraft.neighbors import cagra - - -# todo (dantegd): consolidate helper utils of ann methods -def generate_data(shape, dtype): - if dtype == np.byte: - x = np.random.randint(-127, 128, size=shape, dtype=np.byte) - elif dtype == np.ubyte: - x = np.random.randint(0, 255, size=shape, dtype=np.ubyte) - else: - x = np.random.random_sample(shape).astype(dtype) - - return x - - -def calc_recall(ann_idx, true_nn_idx): - assert ann_idx.shape == true_nn_idx.shape - n = 0 - for i in range(ann_idx.shape[0]): - n += np.intersect1d(ann_idx[i, :], true_nn_idx[i, :]).size - recall = n / ann_idx.size - return recall +from pylibraft.test.ann_utils import calc_recall, generate_data def run_cagra_build_search_test( diff --git a/python/pylibraft/pylibraft/test/test_cagra_hnswlib.py b/python/pylibraft/pylibraft/test/test_cagra_hnswlib.py new file mode 100644 index 0000000000..c1cbbb8fb9 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_cagra_hnswlib.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize + +from pylibraft.neighbors import cagra, cagra_hnswlib +from pylibraft.test.ann_utils import calc_recall, generate_data + + +def run_cagra_hnswlib_build_search_test( + n_rows=10000, + n_cols=10, + n_queries=100, + k=10, + dtype=np.float32, + metric="sqeuclidean", + intermediate_graph_degree=128, + graph_degree=64, + search_params={}, +): + dataset = generate_data((n_rows, n_cols), dtype) + if metric == "inner_product": + dataset = normalize(dataset, norm="l2", axis=1) + + build_params = cagra.IndexParams( + metric=metric, + intermediate_graph_degree=intermediate_graph_degree, + graph_degree=graph_degree, + ) + + index = cagra.build(build_params, dataset) + + assert index.trained + + filename = "my_index.bin" + cagra_hnswlib.save(filename, index) + + index_hnswlib = cagra_hnswlib.load( + filename, n_cols, dataset.dtype, metric=metric + ) + + queries = generate_data((n_queries, n_cols), dtype) + out_idx = np.zeros((n_queries, k), dtype=np.uint32) + + search_params = cagra_hnswlib.SearchParams(**search_params) + + out_dist, out_idx = cagra_hnswlib.search( + search_params, index_hnswlib, queries, k + ) + + # Calculate reference values with sklearn + nn_skl = NearestNeighbors(n_neighbors=k, algorithm="brute", metric=metric) + nn_skl.fit(dataset) + skl_idx = nn_skl.kneighbors(queries, return_distance=False) + + recall = calc_recall(out_idx, skl_idx) + print(recall) + assert recall > 0.95 + + +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("k", [10, 20]) +@pytest.mark.parametrize("ef", [30, 40]) +@pytest.mark.parametrize("num_threads", [2, 4]) +def test_cagra_hnswlib(dtype, k, ef, num_threads): + # Note that inner_product tests use normalized input which we cannot + # represent in int8, therefore we test only sqeuclidean metric here. + run_cagra_hnswlib_build_search_test( + dtype=dtype, k=k, search_params={"ef": ef, "num_threads": num_threads} + ) From 29f077420977a922f988569432a31f87385c39d1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 28 Nov 2023 04:29:03 +0000 Subject: [PATCH 06/31] fix docs --- cpp/include/raft/neighbors/cagra_hnswlib.hpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/cagra_hnswlib.hpp index cb02b10b4c..e21de7686b 100644 --- a/cpp/include/raft/neighbors/cagra_hnswlib.hpp +++ b/cpp/include/raft/neighbors/cagra_hnswlib.hpp @@ -28,18 +28,16 @@ namespace raft::neighbors::cagra_hnswlib { /** * @brief Search hnswlib base layer only index constructed from a CAGRA index * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * * @tparam T data element type * @tparam IdxT type of the indices * * @param[in] res raft resources * @param[in] params configure the search * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * @param[in] queries a host matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a host matrix view to the indices of the neighbors in the source dataset * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * @param[out] distances a host matrix view to the distances to the selected neighbors [n_queries, * k] * * Usage example: From e2e1fc34c2f7a354bbde618da277c72e843399e1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 28 Nov 2023 15:52:44 +0000 Subject: [PATCH 07/31] try to write in native dtype --- .../detail/cagra/cagra_serialize.cuh | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index a1775c2d4a..e5d60acaa7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -103,8 +103,8 @@ void serialize_to_hnswlib(raft::resources const& res, std::ostream& os, const index& index_) { - static_assert(std::is_same_v or std::is_same_v, - "An hnswlib index can only be trained with int32 or uint32 IdxT"); + // static_assert(std::is_same_v or std::is_same_v, + // "An hnswlib index can only be trained with int32 or uint32 IdxT"); common::nvtx::range fun_scope("cagra::serialize_to_hnswlib"); RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", static_cast(index_.size()), @@ -122,14 +122,14 @@ void serialize_to_hnswlib(raft::resources const& res, // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + // dim * sizeof(data_t) + sizeof(labeltype) - auto size_data_per_element = - static_cast(index_.graph_degree() * 4 + 4 + index_.dim() * 4 + 8); + auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + + index_.dim() * sizeof(T) + 8); os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); // label_offset std::size_t label_offset = size_data_per_element - 8; os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); // offset_data - auto offset_data = static_cast(index_.graph_degree() * 4 + 4); + auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); // max_level int max_level = 1; @@ -186,17 +186,17 @@ void serialize_to_hnswlib(raft::resources const& res, } auto data_row = host_dataset.data_handle() + (index_.dim() * i); - if constexpr (std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = host_dataset(i, j); - os.write(reinterpret_cast(&data_elem), sizeof(T)); - } - } else if constexpr (std::is_same_v or std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = static_cast(host_dataset(i, j)); - os.write(reinterpret_cast(&data_elem), sizeof(int)); - } + // if constexpr (std::is_same_v) { + for (std::size_t j = 0; j < index_.dim(); ++j) { + auto data_elem = host_dataset(i, j); + os.write(reinterpret_cast(&data_elem), sizeof(T)); } + // } else if constexpr (std::is_same_v or std::is_same_v) { + // for (std::size_t j = 0; j < index_.dim(); ++j) { + // auto data_elem = static_cast(host_dataset(i, j)); + // os.write(reinterpret_cast(&data_elem), sizeof(int)); + // } + // } os.write(reinterpret_cast(&i), sizeof(std::size_t)); } From b8a4c503ffbdd615ceccfc70ab4815ea59c3692b Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 28 Nov 2023 19:45:14 +0000 Subject: [PATCH 08/31] update mypy, solve error --- .pre-commit-config.yaml | 2 +- pyproject.toml | 3 +++ python/pylibraft/pylibraft/neighbors/__init__.py | 10 +++------- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80ad3614bc..dc638c54e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: types_or: [python, cython] additional_dependencies: ["flake8-force"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v0.971' + rev: 'v1.3.0' hooks: - id: mypy additional_dependencies: [types-cachetools] diff --git a/pyproject.toml b/pyproject.toml index 2982db2a23..1e4ba0b369 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ ignore_missing_imports = true # If we don't specify this, then mypy will check excluded files if # they are imported by a checked file. follow_imports = "skip" +exclude = [ + "pylibraft/pylibraft/test", + ] [tool.codespell] # note: pre-commit passes explicit lists of files here, which this skip file list doesn't override - diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index 14f6f2e108..8ca4fb1769 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -13,13 +13,9 @@ # limitations under the License. # -from pylibraft.neighbors import ( - brute_force, - cagra, - cagra_hnswlib, - ivf_flat, - ivf_pq, -) +from pylibraft.neighbors import brute_force # type: ignore +from pylibraft.neighbors import cagra_hnswlib # type: ignore +from pylibraft.neighbors import cagra, ivf_flat, ivf_pq from .refine import refine From e02a0e3a95b01d68ec24d354bdd7da1321835cf1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 29 Nov 2023 20:27:51 +0000 Subject: [PATCH 09/31] attempt to use rapids_cpm_find for hnswlib --- cpp/CMakeLists.txt | 4 ++-- cpp/cmake/thirdparty/get_hnswlib.cmake | 30 ++++++++++++++++---------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5b028a2456..2029b83e43 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -207,12 +207,12 @@ add_library(raft::raft ALIAS raft) target_include_directories( raft INTERFACE "$" "$" - "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" + # "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" ) if(NOT BUILD_CPU_ONLY) # Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target. - target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass raft::Thrust) + target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass raft::Thrust "$<$:hnswlib>") endif() target_compile_features(raft INTERFACE cxx_std_17 $) diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index a4ceacae38..d49bb4ed29 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -19,19 +19,27 @@ function(find_and_configure_hnswlib) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - set ( EXTERNAL_INCLUDES_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) - if( NOT EXISTS ${EXTERNAL_INCLUDES_DIRECTORY}/_deps/hnswlib-src ) + # set ( EXTERNAL_INCLUDES_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) + # if( NOT EXISTS ${EXTERNAL_INCLUDES_DIRECTORY}/_deps/hnswlib-src ) - execute_process ( - COMMAND git clone --branch=v0.6.2 https://github.com/nmslib/hnswlib.git hnswlib-src - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps ) + # execute_process ( + # COMMAND git clone --branch=v0.6.2 https://github.com/nmslib/hnswlib.git hnswlib-src + # WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps ) - message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}") - execute_process ( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src - ) - endif () + # message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}") + # execute_process ( + # COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch + # WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src + # ) + # endif () + rapids_cpm_find(hnswlib ${PKG_VERSION} + GLOBAL_TARGETS hnswlib + BUILD_EXPORT_SET raft-exports + INSTALL_EXPORT_SET raft-exports + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/hnswlib.git + GIT_TAG ${PKG_PINNED_TAG} + EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL}) include(cmake/modules/FindAVX.cmake) From 54106a137478ddbd4c4cdc1b95655df2d60e9d63 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 14 Dec 2023 11:45:24 +0000 Subject: [PATCH 10/31] rework to not expose hnswlib headers in runtime API --- .pre-commit-config.yaml | 12 +-- cpp/CMakeLists.txt | 5 +- cpp/cmake/thirdparty/get_hnswlib.cmake | 38 ++++---- .../raft/neighbors/cagra_hnswlib_types.hpp | 60 ++---------- .../raft/neighbors/detail/cagra_hnswlib.hpp | 7 +- cpp/include/raft/neighbors/hnswlib_types.hpp | 94 +++++++++++++++++++ .../raft_runtime/neighbors/cagra_hnswlib.hpp | 26 ++--- .../raft_runtime/neighbors/cagra_hnswlib.cpp | 49 ++++++++++ .../pylibraft/neighbors/cagra_hnswlib.pyx | 50 +++++----- .../pylibraft/neighbors/cpp/cagra_hnswlib.pxd | 20 +++- 10 files changed, 241 insertions(+), 120 deletions(-) create mode 100644 cpp/include/raft/neighbors/hnswlib_types.hpp create mode 100644 cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc638c54e1..740df4ed5b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,12 +81,12 @@ repos: verbose: true require_serial: true exclude: .*/thirdparty/.* - - id: copyright-check - name: copyright-check - entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year - language: python - pass_filenames: false - additional_dependencies: [gitpython] + # - id: copyright-check + # name: copyright-check + # entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year + # language: python + # pass_filenames: false + # additional_dependencies: [gitpython] - id: include-check name: include-check entry: python ./cpp/scripts/include_checker.py cpp/bench cpp/include cpp/test diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2029b83e43..55cfaca271 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -207,12 +207,12 @@ add_library(raft::raft ALIAS raft) target_include_directories( raft INTERFACE "$" "$" - # "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" + "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" ) if(NOT BUILD_CPU_ONLY) # Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target. - target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass raft::Thrust "$<$:hnswlib>") + target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass raft::Thrust) endif() target_compile_features(raft INTERFACE cxx_std_17 $) @@ -424,6 +424,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/matrix/select_k_float_int64_t.cu src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu src/raft_runtime/neighbors/cagra_build.cu + src/raft_runtime/neighbors/cagra_hnswlib.cpp src/raft_runtime/neighbors/cagra_search.cu src/raft_runtime/neighbors/cagra_serialize.cu src/raft_runtime/neighbors/ivf_flat_build.cu diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index d49bb4ed29..2820bf3b69 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -19,27 +19,27 @@ function(find_and_configure_hnswlib) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - # set ( EXTERNAL_INCLUDES_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) - # if( NOT EXISTS ${EXTERNAL_INCLUDES_DIRECTORY}/_deps/hnswlib-src ) + set ( EXTERNAL_INCLUDES_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) + if( NOT EXISTS ${EXTERNAL_INCLUDES_DIRECTORY}/_deps/hnswlib-src ) - # execute_process ( - # COMMAND git clone --branch=v0.6.2 https://github.com/nmslib/hnswlib.git hnswlib-src - # WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps ) + execute_process ( + COMMAND git clone --branch=v0.6.2 https://github.com/nmslib/hnswlib.git hnswlib-src + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps ) - # message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}") - # execute_process ( - # COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch - # WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src - # ) - # endif () - rapids_cpm_find(hnswlib ${PKG_VERSION} - GLOBAL_TARGETS hnswlib - BUILD_EXPORT_SET raft-exports - INSTALL_EXPORT_SET raft-exports - CPM_ARGS - GIT_REPOSITORY https://github.com/${PKG_FORK}/hnswlib.git - GIT_TAG ${PKG_PINNED_TAG} - EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL}) + message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}") + execute_process ( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src + ) + endif () + # rapids_cpm_find(hnswlib ${PKG_VERSION} + # GLOBAL_TARGETS hnswlib + # BUILD_EXPORT_SET raft-exports + # INSTALL_EXPORT_SET raft-exports + # CPM_ARGS + # GIT_REPOSITORY https://github.com/${PKG_FORK}/hnswlib.git + # GIT_TAG ${PKG_PINNED_TAG} + # EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL}) include(cmake/modules/FindAVX.cmake) diff --git a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp index bf1725ea86..cd28b1ac42 100644 --- a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp +++ b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp @@ -21,32 +21,11 @@ #include #include -#include #include #include namespace raft::neighbors::cagra_hnswlib { -template -struct hnsw_dist_t { - using type = void; -}; - -template <> -struct hnsw_dist_t { - using type = float; -}; - -template <> -struct hnsw_dist_t { - using type = int; -}; - -template <> -struct hnsw_dist_t { - using type = int; -}; - struct search_params : ann::search_params { int ef; // size of the candidate list int num_threads = 1; // number of host threads to use for concurrent searches @@ -56,42 +35,20 @@ template struct index : ann::index { public: /** - * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index + * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index. + * This is a virtual class and it cannot be used directly. To create an index, construct + * an instance of `raft::neighbors::cagra_hnswlib::hnswlib_index` from the header + * `raft/neighbores/hnswlib_types.hpp` * - * @param[in] filepath path to the index * @param[in] dim dimensions of the training dataset * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") */ - index(std::string filepath, int dim, raft::distance::DistanceType metric) - : dim_{dim}, metric_{metric} - { - if constexpr (std::is_same_v) { - if (metric == raft::distance::L2Expanded) { - space_ = std::make_unique(dim_); - } else if (metric == raft::distance::InnerProduct) { - space_ = std::make_unique(dim_); - } - } else if constexpr (std::is_same_v or std::is_same_v) { - if (metric == raft::distance::L2Expanded) { - space_ = std::make_unique(dim_); - } - } - - RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used"); - - appr_alg_ = std::make_unique::type>>( - space_.get(), filepath); - - appr_alg_->base_layer_only = true; - } + index(int dim, raft::distance::DistanceType metric) : dim_{dim}, metric_{metric} {} /** - @brief Get hnswlib index + @brief Get underlying index */ - auto get_index() const -> hnswlib::HierarchicalNSW::type> const* - { - return appr_alg_.get(); - } + virtual auto get_index() const -> void const* = 0; auto dim() const -> int const { return dim_; } @@ -100,9 +57,6 @@ struct index : ann::index { private: int dim_; raft::distance::DistanceType metric_; - - std::unique_ptr::type>> appr_alg_; - std::unique_ptr::type>> space_; }; } // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp index ba826db0c5..7787bc7e58 100644 --- a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp +++ b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp @@ -16,7 +16,7 @@ #pragma once -#include "../cagra_hnswlib_types.hpp" +#include "../hnswlib_types.hpp" #include #include @@ -25,7 +25,6 @@ #include #include #include -#include #include #include #include @@ -167,7 +166,9 @@ void search(raft::resources const& res, raft::host_matrix_view neighbors, raft::host_matrix_view distances) { - auto const* hnswlib_index = idx.get_index(); + auto const* hnswlib_index = + reinterpret_cast::type> const*>( + idx.get_index()); // no-op when num_threads == 1, no synchronization overhead FixedThreadPool thread_pool{params.num_threads}; diff --git a/cpp/include/raft/neighbors/hnswlib_types.hpp b/cpp/include/raft/neighbors/hnswlib_types.hpp new file mode 100644 index 0000000000..2aa039a6c3 --- /dev/null +++ b/cpp/include/raft/neighbors/hnswlib_types.hpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cagra_hnswlib_types.hpp" +#include +#include +#include + +#include +#include +#include +#include + +namespace raft::neighbors::cagra_hnswlib { + +template +struct hnsw_dist_t { + using type = void; +}; + +template <> +struct hnsw_dist_t { + using type = float; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +template +struct hnswlib_index : index { + public: + /** + * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index + * + * @param[in] filepath path to the index + * @param[in] dim dimensions of the training dataset + * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") + */ + hnswlib_index(std::string filepath, int dim, raft::distance::DistanceType metric) + : index{dim, metric} + { + if constexpr (std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique(dim); + } else if (metric == raft::distance::InnerProduct) { + space_ = std::make_unique(dim); + } + } else if constexpr (std::is_same_v or std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique(dim); + } + } + + RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used"); + + appr_alg_ = std::make_unique::type>>( + space_.get(), filepath); + + appr_alg_->base_layer_only = true; + } + + /** + @brief Get hnswlib index + */ + auto get_index() const -> void const* override { return appr_alg_.get(); } + + private: + std::unique_ptr::type>> appr_alg_; + std::unique_ptr::type>> space_; +}; + +} // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp b/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp index 8a9edd03ce..66edaa2981 100644 --- a/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp @@ -16,21 +16,25 @@ #pragma once -#include +#include +#include #include namespace raft::runtime::neighbors::cagra_hnswlib { -#define RAFT_INST_CAGRA_HNSWLIB_FUNCS(T) \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra_hnswlib::search_params const& params, \ - raft::neighbors::cagra_hnswlib::index const& index, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - raft::neighbors::cagra_hnswlib::search(handle, params, index, queries, neighbors, distances); \ - } +#define RAFT_INST_CAGRA_HNSWLIB_FUNCS(T) \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra_hnswlib::search_params const& params, \ + raft::neighbors::cagra_hnswlib::index const& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra_hnswlib::index*& index, \ + int dim, \ + raft::distance::DistanceType metric); RAFT_INST_CAGRA_HNSWLIB_FUNCS(float); RAFT_INST_CAGRA_HNSWLIB_FUNCS(int8_t); diff --git a/cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp b/cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp new file mode 100644 index 0000000000..663429bddb --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors::cagra_hnswlib { + +#define RAFT_INST_CAGRA_HNSWLIB(T) \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra_hnswlib::search_params const& params, \ + const raft::neighbors::cagra_hnswlib::index& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + raft::neighbors::cagra_hnswlib::search( \ + handle, params, index, queries, neighbors, distances); \ + } \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra_hnswlib::index*& index, \ + int dim, \ + raft::distance::DistanceType metric) \ + { \ + index = new raft::neighbors::cagra_hnswlib::hnswlib_index(filename, dim, metric); \ + RAFT_EXPECTS(index, "Could not set index pointer"); \ + } + +RAFT_INST_CAGRA_HNSWLIB(float); +RAFT_INST_CAGRA_HNSWLIB(int8_t); +RAFT_INST_CAGRA_HNSWLIB(uint8_t); + +#undef RAFT_INST_CAGRA_HNSWLIB + +} // namespace raft::runtime::neighbors::cagra_hnswlib diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx index 9979f11a59..15f92f9ea1 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx @@ -24,6 +24,7 @@ from libcpp cimport bool from libcpp.string cimport string cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra +from pylibraft.distance.distance_type cimport DistanceType from pylibraft.neighbors.cagra.cagra cimport ( Index, IndexFloat, @@ -63,13 +64,8 @@ cdef class CagraHnswlibIndex: cdef class CagraHnswlibIndexFloat(CagraHnswlibIndex): cdef c_cagra_hnswlib.index[float] * index - def __cinit__(self, filepath, dim, metric): - cdef string c_filepath = filepath.encode('utf-8') - - self.index = new c_cagra_hnswlib.index[float]( - c_filepath, - dim, - _get_metric(metric)) + def __cinit__(self): + pass def __repr__(self): m_str = "metric=" + _get_metric_string(self.index.metric()) @@ -93,13 +89,8 @@ cdef class CagraHnswlibIndexFloat(CagraHnswlibIndex): cdef class CagraHnswlibIndexInt8(CagraHnswlibIndex): cdef c_cagra_hnswlib.index[int8_t] * index - def __cinit__(self, filepath, dim, metric): - cdef string c_filepath = filepath.encode('utf-8') - - self.index = new c_cagra_hnswlib.index[int8_t]( - c_filepath, - dim, - _get_metric(metric)) + def __cinit__(self): + pass def __repr__(self): m_str = "metric=" + _get_metric_string(self.index.metric()) @@ -123,13 +114,8 @@ cdef class CagraHnswlibIndexInt8(CagraHnswlibIndex): cdef class CagraHnswlibIndexUint8(CagraHnswlibIndex): cdef c_cagra_hnswlib.index[uint8_t] * index - def __cinit__(self, filepath, dim, metric): - cdef string c_filepath = filepath.encode('utf-8') - - self.index = new c_cagra_hnswlib.index[uint8_t]( - c_filepath, - dim, - _get_metric(metric)) + def __cinit__(self): + pass def __repr__(self): m_str = "metric=" + _get_metric_string(self.index.metric()) @@ -224,7 +210,7 @@ def save(filename, Index index, handle=None): "Index dtype %s not supported" % index.active_index_type) -def load(filename, dim, dtype, metric="sqeuclidean"): +def load(filename, dim, dtype, metric="sqeuclidean", handle=None): """ Loads base layer only hnswlib index from file, which was originally saved as a built CAGRA index. @@ -247,6 +233,7 @@ def load(filename, dim, dtype, metric="sqeuclidean"): operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, - inner product distance is defined as distance(a, b) = \\sum_i a_i * b_i. + {handle_docstring} Returns ------- @@ -258,23 +245,36 @@ def load(filename, dim, dtype, metric="sqeuclidean"): >>> dim = 50 # Assuming training dataset has 50 dimensions >>> index = cagra_hnswlib.load("my_index.bin", dim, "sqeuclidean") """ + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + cdef string c_filename = filename.encode('utf-8') cdef CagraHnswlibIndexFloat idx_float cdef CagraHnswlibIndexInt8 idx_int8 cdef CagraHnswlibIndexUint8 idx_uint8 + cdef DistanceType c_metric = _get_metric(metric) + if dtype == np.float32: - idx_float = CagraHnswlibIndexFloat(filename, dim, metric) + idx_float = CagraHnswlibIndexFloat() + c_cagra_hnswlib.deserialize_file( + deref(handle_), c_filename, idx_float.index, dim, c_metric) idx_float.trained = True idx_float.active_index_type = 'float32' return idx_float elif dtype == np.byte: - idx_int8 = CagraHnswlibIndexInt8(filename, dim, metric) + idx_int8 = CagraHnswlibIndexInt8(dim, metric) + c_cagra_hnswlib.deserialize_file( + deref(handle_), c_filename, idx_int8.index, dim, c_metric) idx_int8.trained = True idx_int8.active_index_type = 'byte' return idx_int8 elif dtype == np.ubyte: - idx_uint8 = CagraHnswlibIndexUint8(filename, dim, metric) + idx_uint8 = CagraHnswlibIndexUint8(dim, metric) + c_cagra_hnswlib.deserialize_file( + deref(handle_), c_filename, idx_uint8.index, dim, c_metric) idx_uint8.trained = True idx_uint8.active_index_type = 'ubyte' return idx_uint8 diff --git a/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd b/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd index 6a4154e151..942f1859e3 100644 --- a/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd +++ b/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd @@ -42,7 +42,7 @@ cdef extern from "raft/neighbors/cagra_hnswlib_types.hpp" \ int num_threads cdef cppclass index[T](ann_index): - index(string filepath, int dim, DistanceType metric) + index(int dim, DistanceType metric) int dim() DistanceType metric() @@ -73,3 +73,21 @@ cdef extern from "raft_runtime/neighbors/cagra_hnswlib.hpp" \ host_matrix_view[uint8_t, int64_t, row_major] queries, host_matrix_view[uint64_t, int64_t, row_major] neighbors, host_matrix_view[float, int64_t, row_major] distances) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[float]*& index, + int dim, + DistanceType metric) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[int8_t]*& index, + int dim, + DistanceType metric) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[uint8_t]*& index, + int dim, + DistanceType metric) except + From ebdce10274028070b999e9fd0f250af84fc5a0ed Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 14 Dec 2023 11:46:00 +0000 Subject: [PATCH 11/31] readd copyright check --- .pre-commit-config.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 740df4ed5b..dc638c54e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,12 +81,12 @@ repos: verbose: true require_serial: true exclude: .*/thirdparty/.* - # - id: copyright-check - # name: copyright-check - # entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year - # language: python - # pass_filenames: false - # additional_dependencies: [gitpython] + - id: copyright-check + name: copyright-check + entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year + language: python + pass_filenames: false + additional_dependencies: [gitpython] - id: include-check name: include-check entry: python ./cpp/scripts/include_checker.py cpp/bench cpp/include cpp/test From 1da21e45c9157182e6b6533e26237bc3f58328c5 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 14 Dec 2023 12:47:16 +0000 Subject: [PATCH 12/31] print binary dir --- cpp/cmake/thirdparty/get_hnswlib.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index 2820bf3b69..96446f75b3 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -27,6 +27,7 @@ function(find_and_configure_hnswlib) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps ) message("SOURCE ${CMAKE_CURRENT_SOURCE_DIR}") + message("WORKING DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}") execute_process ( COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/patches/hnswlib.patch WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src From e5cd5f6c90d9a167a781a107157bda5fb5e6d317 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 14 Dec 2023 13:07:45 +0000 Subject: [PATCH 13/31] address review --- cpp/include/raft/neighbors/cagra_hnswlib.hpp | 7 + .../raft/neighbors/cagra_hnswlib_types.hpp | 10 +- .../raft/neighbors/detail/cagra_hnswlib.hpp | 149 +++--------------- cpp/include/raft/neighbors/hnswlib_types.hpp | 7 + docs/source/cpp_api/neighbors_cagra.rst | 11 ++ docs/source/pylibraft_api/neighbors.rst | 14 ++ 6 files changed, 70 insertions(+), 128 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/cagra_hnswlib.hpp index e21de7686b..c3255e990a 100644 --- a/cpp/include/raft/neighbors/cagra_hnswlib.hpp +++ b/cpp/include/raft/neighbors/cagra_hnswlib.hpp @@ -25,6 +25,11 @@ namespace raft::neighbors::cagra_hnswlib { +/** + * @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib + * @{ + */ + /** * @brief Search hnswlib base layer only index constructed from a CAGRA index * @@ -85,4 +90,6 @@ void search(raft::resources const& res, detail::search(res, params, idx, queries, neighbors, distances); } +/**@}*/ + } // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp index cd28b1ac42..e2e23fb391 100644 --- a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp +++ b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp @@ -26,9 +26,15 @@ namespace raft::neighbors::cagra_hnswlib { +/** + * @defgroup cagra_hnswlib Build CAGRA index and search with hnswlib + * @{ + */ + struct search_params : ann::search_params { int ef; // size of the candidate list - int num_threads = 1; // number of host threads to use for concurrent searches + int num_threads = 0; // number of host threads to use for concurrent searches. Value of 0 + // automatically maximizes parallelism }; template @@ -59,4 +65,6 @@ struct index : ann::index { raft::distance::DistanceType metric_; }; +/**@}*/ + } // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp index 7787bc7e58..6f5a255ac3 100644 --- a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp +++ b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp @@ -18,129 +18,16 @@ #include "../hnswlib_types.hpp" +#include #include #include -#include -#include -#include -#include -#include -#include -#include +#include #include namespace raft::neighbors::cagra_hnswlib::detail { -class FixedThreadPool { - public: - FixedThreadPool(int num_threads) - { - if (num_threads < 1) { - throw std::runtime_error("num_threads must >= 1"); - } else if (num_threads == 1) { - return; - } - - tasks_ = new Task_[num_threads]; - - threads_.reserve(num_threads); - for (int i = 0; i < num_threads; ++i) { - threads_.emplace_back([&, i] { - auto& task = tasks_[i]; - while (true) { - std::unique_lock lock(task.mtx); - task.cv.wait(lock, - [&] { return task.has_task || finished_.load(std::memory_order_relaxed); }); - if (finished_.load(std::memory_order_relaxed)) { break; } - - task.task(); - task.has_task = false; - } - }); - } - } - - ~FixedThreadPool() - { - if (threads_.empty()) { return; } - - finished_.store(true, std::memory_order_relaxed); - for (unsigned i = 0; i < threads_.size(); ++i) { - auto& task = tasks_[i]; - std::lock_guard(task.mtx); - - task.cv.notify_one(); - threads_[i].join(); - } - - delete[] tasks_; - } - - template - void submit(Func f, IdxT len) - { - // Run functions in main thread if thread pool has no threads - if (threads_.empty()) { - for (IdxT i = 0; i < len; ++i) { - f(i); - } - return; - } - - const int num_threads = threads_.size(); - // one extra part for competition among threads - const IdxT items_per_thread = len / (num_threads + 1); - std::atomic cnt(items_per_thread * num_threads); - - // Wrap function - auto wrapped_f = [&](IdxT start, IdxT end) { - for (IdxT i = start; i < end; ++i) { - f(i); - } - - while (true) { - IdxT i = cnt.fetch_add(1, std::memory_order_relaxed); - if (i >= len) { break; } - f(i); - } - }; - - std::vector> futures; - futures.reserve(num_threads); - for (int i = 0; i < num_threads; ++i) { - IdxT start = i * items_per_thread; - auto& task = tasks_[i]; - { - std::lock_guard lock(task.mtx); - (void)lock; // stop nvcc warning - task.task = std::packaged_task([=] { wrapped_f(start, start + items_per_thread); }); - futures.push_back(task.task.get_future()); - task.has_task = true; - } - task.cv.notify_one(); - } - - for (auto& fut : futures) { - fut.wait(); - } - return; - } - - private: - struct alignas(64) Task_ { - std::mutex mtx; - std::condition_variable cv; - bool has_task = false; - std::packaged_task task; - }; - - Task_* tasks_; - std::vector threads_; - std::atomic finished_{false}; -}; - template void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, const T* query, @@ -170,18 +57,26 @@ void search(raft::resources const& res, reinterpret_cast::type> const*>( idx.get_index()); - // no-op when num_threads == 1, no synchronization overhead - FixedThreadPool thread_pool{params.num_threads}; - - auto f = [&](auto const& i) { - get_search_knn_results(hnswlib_index, - queries.data_handle() + i * queries.extent(1), - neighbors.extent(1), - neighbors.data_handle() + i * neighbors.extent(1), - distances.data_handle() + i * distances.extent(1)); - }; - - thread_pool.submit(f, queries.extent(0)); + // when num_threads == 0, automatically maximize parallelism + if (params.num_threads) { +#pragma omp parallel for num_threads(params.num_threads) + for (int64_t i = 0; i < queries.extent(0); ++i) { + get_search_knn_results(hnswlib_index, + queries.data_handle() + i * queries.extent(1), + neighbors.extent(1), + neighbors.data_handle() + i * neighbors.extent(1), + distances.data_handle() + i * distances.extent(1)); + } + } else { +#pragma omp parallel for + for (int64_t i = 0; i < queries.extent(0); ++i) { + get_search_knn_results(hnswlib_index, + queries.data_handle() + i * queries.extent(1), + neighbors.extent(1), + neighbors.data_handle() + i * neighbors.extent(1), + distances.data_handle() + i * distances.extent(1)); + } + } } } // namespace raft::neighbors::cagra_hnswlib::detail diff --git a/cpp/include/raft/neighbors/hnswlib_types.hpp b/cpp/include/raft/neighbors/hnswlib_types.hpp index 2aa039a6c3..6e47ca53c7 100644 --- a/cpp/include/raft/neighbors/hnswlib_types.hpp +++ b/cpp/include/raft/neighbors/hnswlib_types.hpp @@ -28,6 +28,11 @@ namespace raft::neighbors::cagra_hnswlib { +/** + * @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib + * @{ + */ + template struct hnsw_dist_t { using type = void; @@ -91,4 +96,6 @@ struct hnswlib_index : index { std::unique_ptr::type>> space_; }; +/**@}*/ + } // namespace raft::neighbors::cagra_hnswlib diff --git a/docs/source/cpp_api/neighbors_cagra.rst b/docs/source/cpp_api/neighbors_cagra.rst index 99ecd3a985..f09ad23798 100644 --- a/docs/source/cpp_api/neighbors_cagra.rst +++ b/docs/source/cpp_api/neighbors_cagra.rst @@ -29,3 +29,14 @@ namespace *raft::neighbors::cagra* :project: RAFT :members: :content-only: + +CAGRA index build and hnswlib search +------------------------------------ +``#include `` + +namespace *raft::neighbors::cagra_hnswlib* + +.. doxygengroup:: cagra_hnswlib + :project: RAFT + :members: + :content-only: diff --git a/docs/source/pylibraft_api/neighbors.rst b/docs/source/pylibraft_api/neighbors.rst index 680a2982cb..5da5d760df 100644 --- a/docs/source/pylibraft_api/neighbors.rst +++ b/docs/source/pylibraft_api/neighbors.rst @@ -33,6 +33,20 @@ Serializer Methods .. autofunction:: pylibraft.neighbors.cagra.load +CAGRA hnswlib +############# + +.. autoclass:: pylibraft.neighbors.cagra_hnswlib.SearchParams + :members: + +.. autofunction:: pylibraft.neighbors.cagra_hnswlib.search + +Serializer Methods +------------------ +.. autofunction:: pylibraft.neighbors.cagra_hnswlib.save + +.. autofunction:: pylibraft.neighbors.cagra_hnswlib.load + IVF-Flat ######## From 460bc9800591dc5d50889ed24641355e45b3585a Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 15 Dec 2023 08:45:11 +0000 Subject: [PATCH 14/31] address review, enable int8 in hnswlib --- cpp/CMakeLists.txt | 2 +- cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h | 7 +- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 3 +- cpp/cmake/patches/hnswlib.patch | 57 ++++++ .../raft/neighbors/cagra_serialize.cuh | 64 ------- .../detail/cagra/cagra_serialize.cuh | 124 ------------- .../detail/{cagra_hnswlib.hpp => hnsw.hpp} | 6 +- .../raft/neighbors/detail/hnsw_serialize.cuh | 172 ++++++++++++++++++ .../hnsw_types.hpp} | 12 +- .../neighbors/{cagra_hnswlib.hpp => hnsw.hpp} | 21 ++- cpp/include/raft/neighbors/hnsw_serialize.cuh | 140 ++++++++++++++ ...cagra_hnswlib_types.hpp => hnsw_types.hpp} | 6 +- cpp/include/raft_runtime/neighbors/cagra.hpp | 92 +++++----- .../raft_runtime/neighbors/cagra_hnswlib.hpp | 43 ----- cpp/include/raft_runtime/neighbors/hnsw.hpp | 49 +++++ .../raft_runtime/neighbors/cagra_hnswlib.cpp | 49 ----- .../raft_runtime/neighbors/cagra_serialize.cu | 82 ++++----- cpp/src/raft_runtime/neighbors/hnsw.cu | 64 +++++++ .../pylibraft/neighbors/CMakeLists.txt | 2 +- .../pylibraft/pylibraft/neighbors/__init__.py | 4 +- .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 29 --- .../cpp/{cagra_hnswlib.pxd => hnsw.pxd} | 41 ++++- .../neighbors/{cagra_hnswlib.pyx => hnsw.pyx} | 120 ++++++------ .../{test_cagra_hnswlib.py => test_hnsw.py} | 20 +- 24 files changed, 697 insertions(+), 512 deletions(-) rename cpp/include/raft/neighbors/detail/{cagra_hnswlib.hpp => hnsw.hpp} (95%) create mode 100644 cpp/include/raft/neighbors/detail/hnsw_serialize.cuh rename cpp/include/raft/neighbors/{hnswlib_types.hpp => detail/hnsw_types.hpp} (89%) rename cpp/include/raft/neighbors/{cagra_hnswlib.hpp => hnsw.hpp} (84%) create mode 100644 cpp/include/raft/neighbors/hnsw_serialize.cuh rename cpp/include/raft/neighbors/{cagra_hnswlib_types.hpp => hnsw_types.hpp} (92%) delete mode 100644 cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp create mode 100644 cpp/include/raft_runtime/neighbors/hnsw.hpp delete mode 100644 cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp create mode 100644 cpp/src/raft_runtime/neighbors/hnsw.cu rename python/pylibraft/pylibraft/neighbors/cpp/{cagra_hnswlib.pxd => hnsw.pxd} (70%) rename python/pylibraft/pylibraft/neighbors/{cagra_hnswlib.pyx => hnsw.pyx} (78%) rename python/pylibraft/pylibraft/test/{test_cagra_hnswlib.py => test_hnsw.py} (81%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b26d6c2a5b..1788c5da41 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -414,9 +414,9 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/matrix/select_k_float_int64_t.cu src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu src/raft_runtime/neighbors/cagra_build.cu - src/raft_runtime/neighbors/cagra_hnswlib.cpp src/raft_runtime/neighbors/cagra_search.cu src/raft_runtime/neighbors/cagra_serialize.cu + src/raft_runtime/neighbors/hnsw.cu src/raft_runtime/neighbors/ivf_flat_build.cu src/raft_runtime/neighbors/ivf_flat_search.cu src/raft_runtime/neighbors/ivf_flat_serialize.cu diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 2a5177d295..86b1ca2fd0 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -52,6 +52,11 @@ struct hnsw_dist_t { using type = int; }; +template <> +struct hnsw_dist_t { + using type = int; +}; + template class HnswLib : public ANN { public: @@ -135,7 +140,7 @@ void HnswLib::build(const T* dataset, size_t nrow, cudaStream_t) space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_shared(dim_); + space_ = std::make_shared>(dim_); } appr_alg_ = std::make_shared::type>>( diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index ec71de9cff..6fe48f539f 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -246,7 +247,7 @@ void RaftCagra::save(const std::string& file) const template void RaftCagra::save_to_hnswlib(const std::string& file) const { - raft::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_); + raft::neighbors::hnsw::serialize(handle_, file, *index_); } template diff --git a/cpp/cmake/patches/hnswlib.patch b/cpp/cmake/patches/hnswlib.patch index 32c1537c58..51dda73c15 100644 --- a/cpp/cmake/patches/hnswlib.patch +++ b/cpp/cmake/patches/hnswlib.patch @@ -107,6 +107,63 @@ index e95e0b5..f0fe50a 100644 } } } +diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h +index 4413537..c3240f3 100644 +--- a/hnswlib/space_l2.h ++++ b/hnswlib/space_l2.h +@@ -252,13 +252,14 @@ namespace hnswlib { + ~L2Space() {} + }; + ++ template + static int + L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + + size_t qty = *((size_t *) qty_ptr); + int res = 0; +- unsigned char *a = (unsigned char *) pVect1; +- unsigned char *b = (unsigned char *) pVect2; ++ T *a = (T *) pVect1; ++ T *b = (T *) pVect2; + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { +@@ -279,11 +280,12 @@ namespace hnswlib { + return (res); + } + ++ template + static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int res = 0; +- unsigned char* a = (unsigned char*)pVect1; +- unsigned char* b = (unsigned char*)pVect2; ++ T* a = (T*)pVect1; ++ T* b = (T*)pVect2; + + for(size_t i = 0; i < qty; i++) + { +@@ -294,6 +296,7 @@ namespace hnswlib { + return (res); + } + ++ template + class L2SpaceI : public SpaceInterface { + + DISTFUNC fstdistfunc_; +@@ -302,10 +305,10 @@ namespace hnswlib { + public: + L2SpaceI(size_t dim) { + if(dim % 4 == 0) { +- fstdistfunc_ = L2SqrI4x; ++ fstdistfunc_ = L2SqrI4x; + } + else { +- fstdistfunc_ = L2SqrI; ++ fstdistfunc_ = L2SqrI; + } + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h index 5e1a4a5..4195ebd 100644 --- a/hnswlib/visited_list_pool.h diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index c801bc9eda..0a806402d2 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -93,70 +93,6 @@ void serialize(raft::resources const& handle, detail::serialize(handle, filename, index, include_dataset); } -/** - * Write the CAGRA built index as a base layer HNSW index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize_to_hnswlib(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index CAGRA index - * - */ -template -void serialize_to_hnswlib(raft::resources const& handle, - std::ostream& os, - const index& index) -{ - detail::serialize_to_hnswlib(handle, os, index); -} - -/** - * Write the CAGRA built index as a base layer HNSW index to file - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize_to_hnswlib(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index CAGRA index - * - */ -template -void serialize_to_hnswlib(raft::resources const& handle, - const std::string& filename, - const index& index) -{ - detail::serialize_to_hnswlib(handle, filename, index); -} - /** * Load index from input stream * diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index e5d60acaa7..0d01b17a26 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -98,130 +98,6 @@ void serialize(raft::resources const& res, if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } -template -void serialize_to_hnswlib(raft::resources const& res, - std::ostream& os, - const index& index_) -{ - // static_assert(std::is_same_v or std::is_same_v, - // "An hnswlib index can only be trained with int32 or uint32 IdxT"); - common::nvtx::range fun_scope("cagra::serialize_to_hnswlib"); - RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", - static_cast(index_.size()), - index_.dim()); - - // offset_level_0 - std::size_t offset_level_0 = 0; - os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); - // max_element - std::size_t max_element = index_.size(); - os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); - // curr_element_count - std::size_t curr_element_count = index_.size(); - os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); - // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, - // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + - // dim * sizeof(data_t) + sizeof(labeltype) - auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + - index_.dim() * sizeof(T) + 8); - os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); - // label_offset - std::size_t label_offset = size_data_per_element - 8; - os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); - // offset_data - auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); - os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); - // max_level - int max_level = 1; - os.write(reinterpret_cast(&max_level), sizeof(int)); - // entrypoint_node - auto entrypoint_node = static_cast(index_.size() / 2); - os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); - // max_M - auto max_M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); - // max_M0 - std::size_t max_M0 = index_.graph_degree(); - os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); - // M - auto M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&M), sizeof(std::size_t)); - // mult, can be anything - double mult = 0.42424242; - os.write(reinterpret_cast(&mult), sizeof(double)); - // efConstruction, can be anything - std::size_t efConstruction = 500; - os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); - - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - - auto graph = index_.graph(); - auto host_graph = - raft::make_host_matrix(graph.extent(0), graph.extent(1)); - raft::copy(host_graph.data_handle(), - graph.data_handle(), - graph.size(), - raft::resource::get_cuda_stream(res)); - resource::sync_stream(res); - - // Write one dataset and graph row at a time - for (std::size_t i = 0; i < index_.size(); i++) { - auto graph_degree = static_cast(index_.graph_degree()); - os.write(reinterpret_cast(&graph_degree), sizeof(int)); - - for (std::size_t j = 0; j < index_.graph_degree(); ++j) { - auto graph_elem = host_graph(i, j); - os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); - } - - auto data_row = host_dataset.data_handle() + (index_.dim() * i); - // if constexpr (std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = host_dataset(i, j); - os.write(reinterpret_cast(&data_elem), sizeof(T)); - } - // } else if constexpr (std::is_same_v or std::is_same_v) { - // for (std::size_t j = 0; j < index_.dim(); ++j) { - // auto data_elem = static_cast(host_dataset(i, j)); - // os.write(reinterpret_cast(&data_elem), sizeof(int)); - // } - // } - - os.write(reinterpret_cast(&i), sizeof(std::size_t)); - } - - for (std::size_t i = 0; i < index_.size(); i++) { - // zeroes - auto zero = 0; - os.write(reinterpret_cast(&zero), sizeof(int)); - } -} - -template -void serialize_to_hnswlib(raft::resources const& res, - const std::string& filename, - const index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize_to_hnswlib(res, of, index_); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - /** Load an index from file. * * Experimental, both the API and the serialization format are subject to change. diff --git a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/detail/hnsw.hpp similarity index 95% rename from cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp rename to cpp/include/raft/neighbors/detail/hnsw.hpp index 6f5a255ac3..2033dfb888 100644 --- a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw.hpp @@ -16,7 +16,7 @@ #pragma once -#include "../hnswlib_types.hpp" +#include "hnsw_types.hpp" #include #include @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::cagra_hnswlib::detail { +namespace raft::neighbors::hnsw::detail { template void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, @@ -79,4 +79,4 @@ void search(raft::resources const& res, } } -} // namespace raft::neighbors::cagra_hnswlib::detail +} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh b/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh new file mode 100644 index 0000000000..c444c62364 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../hnsw_types.hpp" +#include "hnsw_types.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::neighbors::hnsw::detail { + +template +void serialize(raft::resources const& res, + std::ostream& os, + const raft::neighbors::cagra::index& index_) +{ + // static_assert(std::is_same_v or std::is_same_v, + // "An hnswlib index can only be trained with int32 or uint32 IdxT"); + common::nvtx::range fun_scope("cagra::serialize"); + RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", + static_cast(index_.size()), + index_.dim()); + + // offset_level_0 + std::size_t offset_level_0 = 0; + os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); + // max_element + std::size_t max_element = index_.size(); + os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); + // curr_element_count + std::size_t curr_element_count = index_.size(); + os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); + // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, + // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + + // dim * sizeof(data_t) + sizeof(labeltype) + auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + + index_.dim() * sizeof(T) + 8); + os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); + // label_offset + std::size_t label_offset = size_data_per_element - 8; + os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); + // offset_data + auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); + os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); + // max_level + int max_level = 1; + os.write(reinterpret_cast(&max_level), sizeof(int)); + // entrypoint_node + auto entrypoint_node = static_cast(index_.size() / 2); + os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); + // max_M + auto max_M = static_cast(index_.graph_degree() / 2); + os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); + // max_M0 + std::size_t max_M0 = index_.graph_degree(); + os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); + // M + auto M = static_cast(index_.graph_degree() / 2); + os.write(reinterpret_cast(&M), sizeof(std::size_t)); + // mult, can be anything + double mult = 0.42424242; + os.write(reinterpret_cast(&mult), sizeof(double)); + // efConstruction, can be anything + std::size_t efConstruction = 500; + os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); + + auto dataset = index_.dataset(); + // Remove padding before saving the dataset + auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), + sizeof(T) * host_dataset.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.stride(0), + sizeof(T) * host_dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + + auto graph = index_.graph(); + auto host_graph = + raft::make_host_matrix(graph.extent(0), graph.extent(1)); + raft::copy(host_graph.data_handle(), + graph.data_handle(), + graph.size(), + raft::resource::get_cuda_stream(res)); + resource::sync_stream(res); + + // Write one dataset and graph row at a time + for (std::size_t i = 0; i < index_.size(); i++) { + auto graph_degree = static_cast(index_.graph_degree()); + os.write(reinterpret_cast(&graph_degree), sizeof(int)); + + for (std::size_t j = 0; j < index_.graph_degree(); ++j) { + auto graph_elem = host_graph(i, j); + os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); + } + + auto data_row = host_dataset.data_handle() + (index_.dim() * i); + // if constexpr (std::is_same_v) { + for (std::size_t j = 0; j < index_.dim(); ++j) { + auto data_elem = host_dataset(i, j); + os.write(reinterpret_cast(&data_elem), sizeof(T)); + } + // } else if constexpr (std::is_same_v or std::is_same_v) { + // for (std::size_t j = 0; j < index_.dim(); ++j) { + // auto data_elem = static_cast(host_dataset(i, j)); + // os.write(reinterpret_cast(&data_elem), sizeof(int)); + // } + // } + + os.write(reinterpret_cast(&i), sizeof(std::size_t)); + } + + for (std::size_t i = 0; i < index_.size(); i++) { + // zeroes + auto zero = 0; + os.write(reinterpret_cast(&zero), sizeof(int)); + } +} + +template +void serialize(raft::resources const& res, + const std::string& filename, + const raft::neighbors::cagra::index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize(res, of, index_); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + +template +void deserialize(raft::resources const& handle, + const std::string& filename, + index*& index, + int dim, + raft::distance::DistanceType metric) +{ + index = new index_impl(filename, dim, metric); + RAFT_EXPECTS(index, "Could not set index pointer"); +} + +} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/hnswlib_types.hpp b/cpp/include/raft/neighbors/detail/hnsw_types.hpp similarity index 89% rename from cpp/include/raft/neighbors/hnswlib_types.hpp rename to cpp/include/raft/neighbors/detail/hnsw_types.hpp index 6e47ca53c7..a7ab0fc62f 100644 --- a/cpp/include/raft/neighbors/hnswlib_types.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw_types.hpp @@ -16,7 +16,7 @@ #pragma once -#include "cagra_hnswlib_types.hpp" +#include "../hnsw_types.hpp" #include #include #include @@ -26,7 +26,7 @@ #include #include -namespace raft::neighbors::cagra_hnswlib { +namespace raft::neighbors::hnsw::detail { /** * @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib @@ -54,7 +54,7 @@ struct hnsw_dist_t { }; template -struct hnswlib_index : index { +struct index_impl : index { public: /** * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index @@ -63,7 +63,7 @@ struct hnswlib_index : index { * @param[in] dim dimensions of the training dataset * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") */ - hnswlib_index(std::string filepath, int dim, raft::distance::DistanceType metric) + index_impl(std::string filepath, int dim, raft::distance::DistanceType metric) : index{dim, metric} { if constexpr (std::is_same_v) { @@ -74,7 +74,7 @@ struct hnswlib_index : index { } } else if constexpr (std::is_same_v or std::is_same_v) { if (metric == raft::distance::L2Expanded) { - space_ = std::make_unique(dim); + space_ = std::make_unique>(dim); } } @@ -98,4 +98,4 @@ struct hnswlib_index : index { /**@}*/ -} // namespace raft::neighbors::cagra_hnswlib +} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/hnsw.hpp similarity index 84% rename from cpp/include/raft/neighbors/cagra_hnswlib.hpp rename to cpp/include/raft/neighbors/hnsw.hpp index c3255e990a..8c876b8dec 100644 --- a/cpp/include/raft/neighbors/cagra_hnswlib.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -16,17 +16,17 @@ #pragma once -#include "cagra_hnswlib_types.hpp" -#include "detail/cagra_hnswlib.hpp" +#include "detail/hnsw.hpp" +#include "hnsw.hpp" #include #include #include -namespace raft::neighbors::cagra_hnswlib { +namespace raft::neighbors::hnsw { /** - * @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib + * @addtogroup hnsw Build CAGRA index and search with hnswlib * @{ */ @@ -55,19 +55,22 @@ namespace raft::neighbors::cagra_hnswlib { * auto index = cagra::build(res, index_params, dataset); * * // Save CAGRA index as base layer only hnswlib index - * cagra::serialize_to_hnswlib(res, "my_index.bin", index); + * hnsw::serialize(res, "my_index.bin", index); * * // Load CAGRA index as base layer only hnswlib index - * cagra_hnswlib::index(D, "my_index.bin", raft::distance::L2Expanded); + * raft::neighbors::hnsw::index* hnsw_index; + * hnsw::deserialize(D, "my_index.bin", hnsw_index, D,raft::distance::L2Expanded); * * // Search K nearest neighbors as an hnswlib index * // using host threads for concurrency - * cagra_hnswlib::search_params search_params; + * h::seanswrch_params search_params; * search_params.ef = 50 // ef >= K; * search_params.num_threads = 10; * auto neighbors = raft::make_host_matrix(res, n_queries, k); * auto distances = raft::make_host_matrix(res, n_queries, k); - * cagra_hnswlib::search(res, search_params, index, queries, neighbors, distances); + * hnsw::search(res, search_params, *index, queries, neighbors, distances); + * // de-allocate hnsw_index + * delete hnsw_index; * @endcode */ template @@ -92,4 +95,4 @@ void search(raft::resources const& res, /**@}*/ -} // namespace raft::neighbors::cagra_hnswlib +} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft/neighbors/hnsw_serialize.cuh b/cpp/include/raft/neighbors/hnsw_serialize.cuh new file mode 100644 index 0000000000..e75d01a717 --- /dev/null +++ b/cpp/include/raft/neighbors/hnsw_serialize.cuh @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/hnsw_serialize.cuh" +#include "hnsw_types.hpp" +#include + +#include + +namespace raft::neighbors::hnsw { + +/** + * @addtogroup hnsw Build CAGRA index and search with hnswlib + * @{ + */ + +/** + * Write the CAGRA built index as a base layer HNSW index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * + */ +template +void serialize(raft::resources const& handle, + std::ostream& os, + const raft::neighbors::cagra::index& index) +{ + detail::serialize(handle, os, index); +} + +/** + * Save a CAGRA build index in hnswlib base-layer-only serialized format + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * + */ +template +void serialize(raft::resources const& handle, + const std::string& filename, + const raft::neighbors::cagra::index& index) +{ + detail::serialize(handle, filename, index); +} + +/** + * Load an hnswlib index which was serialized from a CAGRA index + * + * NOTE: This function allocates the index on the heap, and it is + * the user's responsibility to de-allocate the index + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an an unallocated pointer + * raft::neighbors::hnsw* index; + * raft::deserialize(handle, filename, index); + * // use the index, then delete when done + * delete index; + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[out] index CAGRA index + * @param[in] dim dimensionality of the index + * @param[in] metric metric used to build the index + * + */ +template +void deserialize(raft::resources const& handle, + const std::string& filename, + index*& index, + int dim, + raft::distance::DistanceType metric) +{ + detail::deserialize(handle, filename, index, dim, metric); +} + +/**@}*/ + +} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp b/cpp/include/raft/neighbors/hnsw_types.hpp similarity index 92% rename from cpp/include/raft/neighbors/cagra_hnswlib_types.hpp rename to cpp/include/raft/neighbors/hnsw_types.hpp index e2e23fb391..06e96ab020 100644 --- a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp +++ b/cpp/include/raft/neighbors/hnsw_types.hpp @@ -24,10 +24,10 @@ #include #include -namespace raft::neighbors::cagra_hnswlib { +namespace raft::neighbors::hnsw { /** - * @defgroup cagra_hnswlib Build CAGRA index and search with hnswlib + * @defgroup hnsw Build CAGRA index and search with hnswlib * @{ */ @@ -67,4 +67,4 @@ struct index : ann::index { /**@}*/ -} // namespace raft::neighbors::cagra_hnswlib +} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index 80abf1e6c4..c54ed32b77 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -27,56 +27,48 @@ namespace raft::runtime::neighbors::cagra { // Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors -#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - void build_device(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void build_host(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra::search_params const& params, \ - const raft::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void serialize_to_hnswlib_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index); \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index); \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void serialize_to_hnswlib(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index); \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ +#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + void build_device(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void build_host(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra::search_params const& params, \ + const raft::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ raft::neighbors::cagra::index* index); RAFT_INST_CAGRA_FUNCS(float, uint32_t); diff --git a/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp b/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp deleted file mode 100644 index 66edaa2981..0000000000 --- a/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft::runtime::neighbors::cagra_hnswlib { - -#define RAFT_INST_CAGRA_HNSWLIB_FUNCS(T) \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra_hnswlib::search_params const& params, \ - raft::neighbors::cagra_hnswlib::index const& index, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra_hnswlib::index*& index, \ - int dim, \ - raft::distance::DistanceType metric); - -RAFT_INST_CAGRA_HNSWLIB_FUNCS(float); -RAFT_INST_CAGRA_HNSWLIB_FUNCS(int8_t); -RAFT_INST_CAGRA_HNSWLIB_FUNCS(uint8_t); - -} // namespace raft::runtime::neighbors::cagra_hnswlib diff --git a/cpp/include/raft_runtime/neighbors/hnsw.hpp b/cpp/include/raft_runtime/neighbors/hnsw.hpp new file mode 100644 index 0000000000..062636234f --- /dev/null +++ b/cpp/include/raft_runtime/neighbors/hnsw.hpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::runtime::neighbors::hnsw { + +#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ + void search(raft::resources const& handle, \ + raft::neighbors::hnsw::search_params const& params, \ + raft::neighbors::hnsw::index const& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index); \ + void serialize_to_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index); \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::hnsw::index*& index, \ + int dim, \ + raft::distance::DistanceType metric); + +RAFT_INST_HNSW_FUNCS(float, uint32_t); +RAFT_INST_HNSW_FUNCS(int8_t, uint32_t); +RAFT_INST_HNSW_FUNCS(uint8_t, uint32_t); + +} // namespace raft::runtime::neighbors::hnsw diff --git a/cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp b/cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp deleted file mode 100644 index 663429bddb..0000000000 --- a/cpp/src/raft_runtime/neighbors/cagra_hnswlib.cpp +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::runtime::neighbors::cagra_hnswlib { - -#define RAFT_INST_CAGRA_HNSWLIB(T) \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra_hnswlib::search_params const& params, \ - const raft::neighbors::cagra_hnswlib::index& index, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances) \ - { \ - raft::neighbors::cagra_hnswlib::search( \ - handle, params, index, queries, neighbors, distances); \ - } \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra_hnswlib::index*& index, \ - int dim, \ - raft::distance::DistanceType metric) \ - { \ - index = new raft::neighbors::cagra_hnswlib::hnswlib_index(filename, dim, metric); \ - RAFT_EXPECTS(index, "Could not set index pointer"); \ - } - -RAFT_INST_CAGRA_HNSWLIB(float); -RAFT_INST_CAGRA_HNSWLIB(int8_t); -RAFT_INST_CAGRA_HNSWLIB(uint8_t); - -#undef RAFT_INST_CAGRA_HNSWLIB - -} // namespace raft::runtime::neighbors::cagra_hnswlib diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index c4599b560b..69b48b93a4 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -24,55 +24,39 @@ namespace raft::runtime::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void serialize_to_hnswlib_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index) \ - { \ - raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void serialize_to_hnswlib(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index) \ - { \ - std::stringstream os; \ - raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - raft::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, is); \ +#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::cagra::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, is); \ } RAFT_INST_CAGRA_SERIALIZE(float); diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cu b/cpp/src/raft_runtime/neighbors/hnsw.cu new file mode 100644 index 0000000000..39bf380550 --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/hnsw.cu @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +namespace raft::runtime::neighbors::hnsw { + +#define RAFT_INST_HNSW(T) \ + void search(raft::resources const& handle, \ + raft::neighbors::hnsw::search_params const& params, \ + const raft::neighbors::hnsw::index& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + raft::neighbors::hnsw::search(handle, params, index, queries, neighbors, distances); \ + } \ + \ + void serialize_to_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index) \ + { \ + raft::neighbors::hnsw::serialize(handle, filename, index); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::hnsw::serialize(handle, os, index); \ + str = os.str(); \ + } \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::hnsw::index*& index, \ + int dim, \ + raft::distance::DistanceType metric) \ + { \ + raft::neighbors::hnsw::deserialize(handle, filename, index, dim, metric); \ + } + +RAFT_INST_HNSW(float); +RAFT_INST_HNSW(int8_t); +RAFT_INST_HNSW(uint8_t); + +#undef RAFT_INST_CAGRA_HNSWLIB + +} // namespace raft::runtime::neighbors::hnsw diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index d5b8a80761..9e45712b40 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx brute_force.pyx cagra_hnswlib.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx hnsw.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index 8ca4fb1769..d2f7092421 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -14,7 +14,7 @@ # from pylibraft.neighbors import brute_force # type: ignore -from pylibraft.neighbors import cagra_hnswlib # type: ignore +from pylibraft.neighbors import hnsw # type: ignore from pylibraft.neighbors import cagra, ivf_flat, ivf_pq from .refine import refine @@ -26,5 +26,5 @@ "ivf_flat", "ivf_pq", "cagra", - "cagra_hnswlib", + "hnsw", ] diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 8cd5cb0b44..7e22f274e9 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -174,10 +174,6 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ string& str, const index[float, uint32_t]& index, bool include_dataset) except + - cdef void serialize_to_hnwslib( - const device_resources& handle, - string& str, - const index[float, uint32_t]& index) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -188,11 +184,6 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + - cdef void serialize_to_hnwslib( - const device_resources& handle, - string& str, - const index[uint8_t, uint32_t]& index) except + - cdef void deserialize(const device_resources& handle, const string& str, index[uint8_t, uint32_t]* index) except + @@ -202,11 +193,6 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[int8_t, uint32_t]& index, bool include_dataset) except + - cdef void serialize_to_hnwslib( - const device_resources& handle, - string& str, - const index[int8_t, uint32_t]& index) except + - cdef void deserialize(const device_resources& handle, const string& str, index[int8_t, uint32_t]* index) except + @@ -216,11 +202,6 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[float, uint32_t]& index, bool include_dataset) except + - cdef void serialize_to_hnswlib_file( - const device_resources& handle, - const string& filename, - const index[float, uint32_t]& index) except + - cdef void deserialize_file(const device_resources& handle, const string& filename, index[float, uint32_t]* index) except + @@ -230,11 +211,6 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + - cdef void serialize_to_hnswlib_file( - const device_resources& handle, - const string& filename, - const index[uint8_t, uint32_t]& index) except + - cdef void deserialize_file(const device_resources& handle, const string& filename, index[uint8_t, uint32_t]* index) except + @@ -244,11 +220,6 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[int8_t, uint32_t]& index, bool include_dataset) except + - cdef void serialize_to_hnswlib_file( - const device_resources& handle, - const string& filename, - const index[int8_t, uint32_t]& index) except + - cdef void deserialize_file(const device_resources& handle, const string& filename, index[int8_t, uint32_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd similarity index 70% rename from python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd rename to python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd index 942f1859e3..fb322fc71b 100644 --- a/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd +++ b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd @@ -18,7 +18,7 @@ # cython: embedsignature = True # cython: language_level = 3 -from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t from libcpp.string cimport string from pylibraft.common.cpp.mdspan cimport ( @@ -28,14 +28,15 @@ from pylibraft.common.cpp.mdspan cimport ( ) from pylibraft.common.handle cimport device_resources from pylibraft.distance.distance_type cimport DistanceType +from pylibraft.neighbors.cagra.cpp.c_cagra cimport index as cagra_index from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( ann_index, ann_search_params, ) -cdef extern from "raft/neighbors/cagra_hnswlib_types.hpp" \ - namespace "raft::neighbors::cagra_hnswlib" nogil: +cdef extern from "raft/neighbors/hnsw.hpp" \ + namespace "raft::neighbors::hnsw" nogil: cpdef cppclass search_params(ann_search_params): int ef @@ -48,8 +49,8 @@ cdef extern from "raft/neighbors/cagra_hnswlib_types.hpp" \ DistanceType metric() -cdef extern from "raft_runtime/neighbors/cagra_hnswlib.hpp" \ - namespace "raft::runtime::neighbors::cagra_hnswlib" nogil: +cdef extern from "raft_runtime/neighbors/hnsw.hpp" \ + namespace "raft::runtime::neighbors::hnsw" nogil: cdef void search( const device_resources& handle, const search_params& params, @@ -74,6 +75,36 @@ cdef extern from "raft_runtime/neighbors/cagra_hnswlib.hpp" \ host_matrix_view[uint64_t, int64_t, row_major] neighbors, host_matrix_view[float, int64_t, row_major] distances) except + + cdef void serialize( + const device_resources& handle, + string& str, + const cagra_index[float, uint32_t]& index) except + + + cdef void serialize( + const device_resources& handle, + string& str, + const cagra_index[uint8_t, uint32_t]& index) except + + + cdef void serialize( + const device_resources& handle, + string& str, + const cagra_index[int8_t, uint32_t]& index) except + + + cdef void serialize_to_file( + const device_resources& handle, + const string& filename, + const cagra_index[float, uint32_t]& index) except + + + cdef void serialize_to_file( + const device_resources& handle, + const string& filename, + const cagra_index[uint8_t, uint32_t]& index) except + + + cdef void serialize_to_file( + const device_resources& handle, + const string& filename, + const cagra_index[int8_t, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[float]*& index, diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx similarity index 78% rename from python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx rename to python/pylibraft/pylibraft/neighbors/hnsw.pyx index 15f92f9ea1..2473fdefd5 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -38,7 +38,7 @@ from pylibraft.common.handle cimport device_resources from pylibraft.common import DeviceResources, ai_wrapper, auto_convert_output -cimport pylibraft.neighbors.cpp.cagra_hnswlib as c_cagra_hnswlib +cimport pylibraft.neighbors.cpp.hnsw as c_hnsw from pylibraft.neighbors.common import _check_input_array, _get_metric @@ -53,7 +53,7 @@ from pylibraft.neighbors.common cimport _get_metric_string import numpy as np -cdef class CagraHnswlibIndex: +cdef class HnswIndex: cdef readonly bool trained cdef str active_index_type @@ -61,8 +61,8 @@ cdef class CagraHnswlibIndex: self.trained = False self.active_index_type = None -cdef class CagraHnswlibIndexFloat(CagraHnswlibIndex): - cdef c_cagra_hnswlib.index[float] * index +cdef class HnswIndexFloat(HnswIndex): + cdef c_hnsw.index[float] * index def __cinit__(self): pass @@ -72,7 +72,7 @@ cdef class CagraHnswlibIndexFloat(CagraHnswlibIndex): attr_str = [attr + "=" + str(getattr(self, attr)) for attr in ["dim"]] attr_str = [m_str] + attr_str - return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + return "Index(type=hnsw, " + (", ".join(attr_str)) + ")" @property def dim(self): @@ -86,8 +86,8 @@ cdef class CagraHnswlibIndexFloat(CagraHnswlibIndex): if self.index is not NULL: del self.index -cdef class CagraHnswlibIndexInt8(CagraHnswlibIndex): - cdef c_cagra_hnswlib.index[int8_t] * index +cdef class HnswIndexInt8(HnswIndex): + cdef c_hnsw.index[int8_t] * index def __cinit__(self): pass @@ -97,7 +97,7 @@ cdef class CagraHnswlibIndexInt8(CagraHnswlibIndex): attr_str = [attr + "=" + str(getattr(self, attr)) for attr in ["dim"]] attr_str = [m_str] + attr_str - return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + return "Index(type=hnsw, " + (", ".join(attr_str)) + ")" @property def dim(self): @@ -111,8 +111,8 @@ cdef class CagraHnswlibIndexInt8(CagraHnswlibIndex): if self.index is not NULL: del self.index -cdef class CagraHnswlibIndexUint8(CagraHnswlibIndex): - cdef c_cagra_hnswlib.index[uint8_t] * index +cdef class HnswIndexUint8(HnswIndex): + cdef c_hnsw.index[uint8_t] * index def __cinit__(self): pass @@ -122,7 +122,7 @@ cdef class CagraHnswlibIndexUint8(CagraHnswlibIndex): attr_str = [attr + "=" + str(getattr(self, attr)) for attr in ["dim"]] attr_str = [m_str] + attr_str - return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + return "Index(type=hnsw, " + (", ".join(attr_str)) + ")" @property def dim(self): @@ -158,7 +158,7 @@ def save(filename, Index index, handle=None): >>> import cupy as cp >>> from pylibraft.common import DeviceResources >>> from pylibraft.neighbors import cagra - >>> from pylibraft.neighbors import cagra_hnswlib + >>> from pylibraft.neighbors import hnsw >>> n_samples = 50000 >>> n_features = 50 >>> dataset = cp.random.random_sample((n_samples, n_features), @@ -167,7 +167,7 @@ def save(filename, Index index, handle=None): >>> handle = DeviceResources() >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) >>> # Serialize the CAGRA index to hnswlib base layer only index format - >>> cagra_hnswlib.save("my_index.bin", index, handle=handle) + >>> hnsw.save("my_index.bin", index, handle=handle) """ if not index.trained: raise ValueError("Index need to be built before saving it.") @@ -191,19 +191,19 @@ def save(filename, Index index, handle=None): idx_float = index c_index_float = \ idx_float.index - c_cagra.serialize_to_hnswlib_file( + c_hnsw.serialize_to_file( deref(handle_), c_filename, deref(c_index_float)) elif index.active_index_type == "byte": idx_int8 = index c_index_int8 = \ idx_int8.index - c_cagra.serialize_to_hnswlib_file( + c_hnsw.serialize_to_file( deref(handle_), c_filename, deref(c_index_int8)) elif index.active_index_type == "ubyte": idx_uint8 = index c_index_uint8 = \ idx_uint8.index - c_cagra.serialize_to_hnswlib_file( + c_hnsw.serialize_to_file( deref(handle_), c_filename, deref(c_index_uint8)) else: raise ValueError( @@ -237,13 +237,13 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): Returns ------- - index : CagraHnswlibIndex + index : HnswIndex Examples -------- - >>> from pylibraft.neighbors import cagra_hnswlib + >>> from pylibraft.neighbors import hnsw >>> dim = 50 # Assuming training dataset has 50 dimensions - >>> index = cagra_hnswlib.load("my_index.bin", dim, "sqeuclidean") + >>> index = hnsw.load("my_index.bin", dim, "sqeuclidean") """ if handle is None: handle = DeviceResources() @@ -251,29 +251,29 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): handle.getHandle() cdef string c_filename = filename.encode('utf-8') - cdef CagraHnswlibIndexFloat idx_float - cdef CagraHnswlibIndexInt8 idx_int8 - cdef CagraHnswlibIndexUint8 idx_uint8 + cdef HnswIndexFloat idx_float + cdef HnswIndexInt8 idx_int8 + cdef HnswIndexUint8 idx_uint8 cdef DistanceType c_metric = _get_metric(metric) if dtype == np.float32: - idx_float = CagraHnswlibIndexFloat() - c_cagra_hnswlib.deserialize_file( + idx_float = HnswIndexFloat() + c_hnsw.deserialize_file( deref(handle_), c_filename, idx_float.index, dim, c_metric) idx_float.trained = True idx_float.active_index_type = 'float32' return idx_float elif dtype == np.byte: - idx_int8 = CagraHnswlibIndexInt8(dim, metric) - c_cagra_hnswlib.deserialize_file( + idx_int8 = HnswIndexInt8(dim, metric) + c_hnsw.deserialize_file( deref(handle_), c_filename, idx_int8.index, dim, c_metric) idx_int8.trained = True idx_int8.active_index_type = 'byte' return idx_int8 elif dtype == np.ubyte: - idx_uint8 = CagraHnswlibIndexUint8(dim, metric) - c_cagra_hnswlib.deserialize_file( + idx_uint8 = HnswIndexUint8(dim, metric) + c_hnsw.deserialize_file( deref(handle_), c_filename, idx_uint8.index, dim, c_metric) idx_uint8.trained = True idx_uint8.active_index_type = 'ubyte' @@ -295,7 +295,7 @@ cdef class SearchParams: Number of host threads to use to search the hnswlib index and increase concurrency """ - cdef c_cagra_hnswlib.search_params params + cdef c_hnsw.search_params params def __init__(self, ef=200, num_threads=1): self.params.ef = ef @@ -305,7 +305,7 @@ cdef class SearchParams: attr_str = [attr + "=" + str(getattr(self, attr)) for attr in [ "ef", "num_threads"]] - return "SearchParams(type=CAGRA_hnswlib, " + ( + return "SearchParams(type=hnsw, " + ( ", ".join(attr_str)) + ")" @property @@ -320,7 +320,7 @@ cdef class SearchParams: @auto_sync_handle @auto_convert_output def search(SearchParams search_params, - CagraHnswlibIndex index, + HnswIndex index, queries, k, neighbors=None, @@ -332,7 +332,7 @@ def search(SearchParams search_params, Parameters ---------- search_params : SearchParams - index : CagraHnswlibIndex + index : HnswIndex Trained CAGRA index saved as base layer only hnswlib index. queries : array interface compliant matrix shape (n_samples, dim) Supported dtype [float, int8, uint8] @@ -352,7 +352,7 @@ def search(SearchParams search_params, >>> import numpy as np >>> from pylibraft.common import DeviceResources >>> from pylibraft.neighbors import cagra - >>> from pylibraft.neighbors import cagra_hnswlib + >>> from pylibraft.neighbors import hnsw >>> n_samples = 50000 >>> n_features = 50 >>> n_queries = 1000 @@ -363,21 +363,21 @@ def search(SearchParams search_params, >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) >>> >>> Save CAGRA built index as base layer only hnswlib index - >>> cagra_hnswlib.save("my_index.bin", index) + >>> hnsw.save("my_index.bin", index) >>> >>> Load saved base layer only hnswlib index - >>> index_hnswlib.load("my_index.bin", n_features, dataset.dtype) + >>> hnsw_index = hnsw.load("my_index.bin", n_features, dataset.dtype) >>> >>> # Search hnswlib using the loaded index >>> queries = np.random.random_sample((n_queries, n_features), ... dtype=cp.float32) >>> k = 10 - >>> search_params = cagra_hnswlib.SearchParams( + >>> search_params = hnsw.SearchParams( ... ef=20, ... num_threads=5 ... ) - >>> distances, neighbors = cagra_hnswlib.search(search_params, index, - ... queries, k, handle=handle) + >>> distances, neighbors = hnsw.search(search_params, hnsw_index, + ... queries, k, handle=handle) """ if not index.trained: @@ -410,35 +410,35 @@ def search(SearchParams search_params, _check_input_array(distances_ai, [np.dtype('float32')], exp_rows=n_queries, exp_cols=k) - cdef c_cagra_hnswlib.search_params params = search_params.params - cdef CagraHnswlibIndexFloat idx_float - cdef CagraHnswlibIndexInt8 idx_int8 - cdef CagraHnswlibIndexUint8 idx_uint8 + cdef c_hnsw.search_params params = search_params.params + cdef HnswIndexFloat idx_float + cdef HnswIndexInt8 idx_int8 + cdef HnswIndexUint8 idx_uint8 if queries_dt == np.float32: idx_float = index - c_cagra_hnswlib.search(deref(handle_), - params, - deref(idx_float.index), - get_hmv_float(queries_ai, check_shape=True), - get_hmv_uint64(neighbors_ai, check_shape=True), - get_hmv_float(distances_ai, check_shape=True)) + c_hnsw.search(deref(handle_), + params, + deref(idx_float.index), + get_hmv_float(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) elif queries_dt == np.byte: idx_int8 = index - c_cagra_hnswlib.search(deref(handle_), - params, - deref(idx_int8.index), - get_hmv_int8(queries_ai, check_shape=True), - get_hmv_uint64(neighbors_ai, check_shape=True), - get_hmv_float(distances_ai, check_shape=True)) + c_hnsw.search(deref(handle_), + params, + deref(idx_int8.index), + get_hmv_int8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) elif queries_dt == np.ubyte: idx_uint8 = index - c_cagra_hnswlib.search(deref(handle_), - params, - deref(idx_uint8.index), - get_hmv_uint8(queries_ai, check_shape=True), - get_hmv_uint64(neighbors_ai, check_shape=True), - get_hmv_float(distances_ai, check_shape=True)) + c_hnsw.search(deref(handle_), + params, + deref(idx_uint8.index), + get_hmv_uint8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) else: raise ValueError("query dtype %s not supported" % queries_dt) diff --git a/python/pylibraft/pylibraft/test/test_cagra_hnswlib.py b/python/pylibraft/pylibraft/test/test_hnsw.py similarity index 81% rename from python/pylibraft/pylibraft/test/test_cagra_hnswlib.py rename to python/pylibraft/pylibraft/test/test_hnsw.py index c1cbbb8fb9..1cc16abaf6 100644 --- a/python/pylibraft/pylibraft/test/test_cagra_hnswlib.py +++ b/python/pylibraft/pylibraft/test/test_hnsw.py @@ -18,11 +18,11 @@ from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import normalize -from pylibraft.neighbors import cagra, cagra_hnswlib +from pylibraft.neighbors import cagra, hnsw from pylibraft.test.ann_utils import calc_recall, generate_data -def run_cagra_hnswlib_build_search_test( +def run_hnsw_build_search_test( n_rows=10000, n_cols=10, n_queries=100, @@ -48,20 +48,16 @@ def run_cagra_hnswlib_build_search_test( assert index.trained filename = "my_index.bin" - cagra_hnswlib.save(filename, index) + hnsw.save(filename, index) - index_hnswlib = cagra_hnswlib.load( - filename, n_cols, dataset.dtype, metric=metric - ) + hnsw_index = hnsw.load(filename, n_cols, dataset.dtype, metric=metric) queries = generate_data((n_queries, n_cols), dtype) out_idx = np.zeros((n_queries, k), dtype=np.uint32) - search_params = cagra_hnswlib.SearchParams(**search_params) + search_params = hnsw.SearchParams(**search_params) - out_dist, out_idx = cagra_hnswlib.search( - search_params, index_hnswlib, queries, k - ) + out_dist, out_idx = hnsw.search(search_params, hnsw_index, queries, k) # Calculate reference values with sklearn nn_skl = NearestNeighbors(n_neighbors=k, algorithm="brute", metric=metric) @@ -77,9 +73,9 @@ def run_cagra_hnswlib_build_search_test( @pytest.mark.parametrize("k", [10, 20]) @pytest.mark.parametrize("ef", [30, 40]) @pytest.mark.parametrize("num_threads", [2, 4]) -def test_cagra_hnswlib(dtype, k, ef, num_threads): +def test_hnsw(dtype, k, ef, num_threads): # Note that inner_product tests use normalized input which we cannot # represent in int8, therefore we test only sqeuclidean metric here. - run_cagra_hnswlib_build_search_test( + run_hnsw_build_search_test( dtype=dtype, k=k, search_params={"ef": ef, "num_threads": num_threads} ) From fdc015f7b2e8e66c453a2993d9e289d50c7ec7b8 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 15 Dec 2023 09:22:25 +0000 Subject: [PATCH 15/31] missed template --- cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 86b1ca2fd0..f23aca11c5 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -210,7 +210,7 @@ void HnswLib::load(const std::string& path_to_index) space_ = std::make_shared(dim_); } } else if constexpr (std::is_same_v) { - space_ = std::make_shared(dim_); + space_ = std::make_shared>(dim_); } appr_alg_ = std::make_shared::type>>( From 211ba394098ec04555cd0093280451ea4c583f53 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 15 Dec 2023 09:37:28 +0000 Subject: [PATCH 16/31] fix docs --- cpp/include/raft/neighbors/hnsw.hpp | 2 +- cpp/include/raft/neighbors/hnsw_serialize.cuh | 3 +- docs/source/cpp_api/neighbors_cagra.rst | 11 ------- docs/source/cpp_api/neighbors_hnsw.rst | 29 +++++++++++++++++++ 4 files changed, 31 insertions(+), 14 deletions(-) create mode 100644 docs/source/cpp_api/neighbors_hnsw.rst diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp index 8c876b8dec..9819b2bf26 100644 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -59,7 +59,7 @@ namespace raft::neighbors::hnsw { * * // Load CAGRA index as base layer only hnswlib index * raft::neighbors::hnsw::index* hnsw_index; - * hnsw::deserialize(D, "my_index.bin", hnsw_index, D,raft::distance::L2Expanded); + * hnsw::deserialize(res, "my_index.bin", hnsw_index, D,raft::distance::L2Expanded); * * // Search K nearest neighbors as an hnswlib index * // using host threads for concurrency diff --git a/cpp/include/raft/neighbors/hnsw_serialize.cuh b/cpp/include/raft/neighbors/hnsw_serialize.cuh index e75d01a717..33cb8d0320 100644 --- a/cpp/include/raft/neighbors/hnsw_serialize.cuh +++ b/cpp/include/raft/neighbors/hnsw_serialize.cuh @@ -25,7 +25,7 @@ namespace raft::neighbors::hnsw { /** - * @addtogroup hnsw Build CAGRA index and search with hnswlib + * @defgroup hnsw_serialize HNSW Serialize * @{ */ @@ -116,7 +116,6 @@ void serialize(raft::resources const& handle, * @endcode * * @tparam T data element type - * @tparam IdxT type of the indices * * @param[in] handle the raft handle * @param[in] filename the file name for saving the index diff --git a/docs/source/cpp_api/neighbors_cagra.rst b/docs/source/cpp_api/neighbors_cagra.rst index f09ad23798..99ecd3a985 100644 --- a/docs/source/cpp_api/neighbors_cagra.rst +++ b/docs/source/cpp_api/neighbors_cagra.rst @@ -29,14 +29,3 @@ namespace *raft::neighbors::cagra* :project: RAFT :members: :content-only: - -CAGRA index build and hnswlib search ------------------------------------- -``#include `` - -namespace *raft::neighbors::cagra_hnswlib* - -.. doxygengroup:: cagra_hnswlib - :project: RAFT - :members: - :content-only: diff --git a/docs/source/cpp_api/neighbors_hnsw.rst b/docs/source/cpp_api/neighbors_hnsw.rst new file mode 100644 index 0000000000..86f9544c35 --- /dev/null +++ b/docs/source/cpp_api/neighbors_hnsw.rst @@ -0,0 +1,29 @@ +HNSW +===== + +HNSW is a graph-based nearest neighbors implementation for the CPU. +This implementation provides the ability to serialize a CAGRA graph and read it as a base-layer-only hnswlib graph. + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::neighbors::hnsw* + +.. doxygengroup:: hnsw + :project: RAFT + :members: + :content-only: + +Serializer Methods +------------------ +``#include `` + +namespace *raft::neighbors::hnsw* + +.. doxygengroup:: hnsw_serialize + :project: RAFT + :members: + :content-only: From 466e141796d0b2b0aaa189c9ed998c23cbd36965 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 18 Dec 2023 08:22:05 +0000 Subject: [PATCH 17/31] move serialize back to cagra:: --- cpp/CMakeLists.txt | 2 +- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 2 +- .../raft/neighbors/cagra_serialize.cuh | 64 +++++++++ .../detail/cagra/cagra_serialize.cuh | 124 ++++++++++++++++++ .../raft/neighbors/detail/hnsw_serialize.cuh | 124 ------------------ cpp/include/raft/neighbors/hnsw_serialize.cuh | 64 --------- cpp/include/raft_runtime/neighbors/cagra.hpp | 89 +++++++------ cpp/include/raft_runtime/neighbors/hnsw.hpp | 28 ++-- .../raft_runtime/neighbors/cagra_serialize.cu | 81 +++++++----- .../neighbors/{hnsw.cu => hnsw.cpp} | 14 -- .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 30 +++++ .../pylibraft/neighbors/cpp/hnsw.pxd | 31 ----- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 6 +- python/pylibraft/pylibraft/test/test_hnsw.py | 1 - 14 files changed, 329 insertions(+), 331 deletions(-) rename cpp/src/raft_runtime/neighbors/{hnsw.cu => hnsw.cpp} (66%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 1788c5da41..5da3fff8bf 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -416,7 +416,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/neighbors/cagra_build.cu src/raft_runtime/neighbors/cagra_search.cu src/raft_runtime/neighbors/cagra_serialize.cu - src/raft_runtime/neighbors/hnsw.cu + src/raft_runtime/neighbors/hnsw.cpp src/raft_runtime/neighbors/ivf_flat_build.cu src/raft_runtime/neighbors/ivf_flat_search.cu src/raft_runtime/neighbors/ivf_flat_serialize.cu diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 6fe48f539f..a923756ff4 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -247,7 +247,7 @@ void RaftCagra::save(const std::string& file) const template void RaftCagra::save_to_hnswlib(const std::string& file) const { - raft::neighbors::hnsw::serialize(handle_, file, *index_); + raft::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_); } template diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index 0a806402d2..fc88926077 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -93,6 +93,70 @@ void serialize(raft::resources const& handle, detail::serialize(handle, filename, index, include_dataset); } +/** + * Write the CAGRA built index as a base layer HNSW index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize(handle, os, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * + */ +template +void serialize_to_hnswlib(raft::resources const& handle, + std::ostream& os, + const raft::neighbors::cagra::index& index) +{ + detail::serialize_to_hnswlib(handle, os, index); +} + +/** + * Save a CAGRA build index in hnswlib base-layer-only serialized format + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cagra::build(...);` + * raft::serialize(handle, filename, index); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * + */ +template +void serialize_to_hnswlib(raft::resources const& handle, + const std::string& filename, + const raft::neighbors::cagra::index& index) +{ + detail::serialize_to_hnswlib(handle, filename, index); +} + /** * Load index from input stream * diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 0d01b17a26..5fecadbd63 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -98,6 +98,130 @@ void serialize(raft::resources const& res, if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } } +template +void serialize_to_hnswlib(raft::resources const& res, + std::ostream& os, + const raft::neighbors::cagra::index& index_) +{ + // static_assert(std::is_same_v or std::is_same_v, + // "An hnswlib index can only be trained with int32 or uint32 IdxT"); + common::nvtx::range fun_scope("cagra::serialize"); + RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", + static_cast(index_.size()), + index_.dim()); + + // offset_level_0 + std::size_t offset_level_0 = 0; + os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); + // max_element + std::size_t max_element = index_.size(); + os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); + // curr_element_count + std::size_t curr_element_count = index_.size(); + os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); + // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, + // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + + // dim * sizeof(data_t) + sizeof(labeltype) + auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + + index_.dim() * sizeof(T) + 8); + os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); + // label_offset + std::size_t label_offset = size_data_per_element - 8; + os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); + // offset_data + auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); + os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); + // max_level + int max_level = 1; + os.write(reinterpret_cast(&max_level), sizeof(int)); + // entrypoint_node + auto entrypoint_node = static_cast(index_.size() / 2); + os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); + // max_M + auto max_M = static_cast(index_.graph_degree() / 2); + os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); + // max_M0 + std::size_t max_M0 = index_.graph_degree(); + os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); + // M + auto M = static_cast(index_.graph_degree() / 2); + os.write(reinterpret_cast(&M), sizeof(std::size_t)); + // mult, can be anything + double mult = 0.42424242; + os.write(reinterpret_cast(&mult), sizeof(double)); + // efConstruction, can be anything + std::size_t efConstruction = 500; + os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); + + auto dataset = index_.dataset(); + // Remove padding before saving the dataset + auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), + sizeof(T) * host_dataset.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.stride(0), + sizeof(T) * host_dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + + auto graph = index_.graph(); + auto host_graph = + raft::make_host_matrix(graph.extent(0), graph.extent(1)); + raft::copy(host_graph.data_handle(), + graph.data_handle(), + graph.size(), + raft::resource::get_cuda_stream(res)); + resource::sync_stream(res); + + // Write one dataset and graph row at a time + for (std::size_t i = 0; i < index_.size(); i++) { + auto graph_degree = static_cast(index_.graph_degree()); + os.write(reinterpret_cast(&graph_degree), sizeof(int)); + + for (std::size_t j = 0; j < index_.graph_degree(); ++j) { + auto graph_elem = host_graph(i, j); + os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); + } + + auto data_row = host_dataset.data_handle() + (index_.dim() * i); + // if constexpr (std::is_same_v) { + for (std::size_t j = 0; j < index_.dim(); ++j) { + auto data_elem = host_dataset(i, j); + os.write(reinterpret_cast(&data_elem), sizeof(T)); + } + // } else if constexpr (std::is_same_v or std::is_same_v) { + // for (std::size_t j = 0; j < index_.dim(); ++j) { + // auto data_elem = static_cast(host_dataset(i, j)); + // os.write(reinterpret_cast(&data_elem), sizeof(int)); + // } + // } + + os.write(reinterpret_cast(&i), sizeof(std::size_t)); + } + + for (std::size_t i = 0; i < index_.size(); i++) { + // zeroes + auto zero = 0; + os.write(reinterpret_cast(&zero), sizeof(int)); + } +} + +template +void serialize_to_hnswlib(raft::resources const& res, + const std::string& filename, + const raft::neighbors::cagra::index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + detail::serialize_to_hnswlib(res, of, index_); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + /** Load an index from file. * * Experimental, both the API and the serialization format are subject to change. diff --git a/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh b/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh index c444c62364..ed760c4f2c 100644 --- a/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh @@ -34,130 +34,6 @@ namespace raft::neighbors::hnsw::detail { -template -void serialize(raft::resources const& res, - std::ostream& os, - const raft::neighbors::cagra::index& index_) -{ - // static_assert(std::is_same_v or std::is_same_v, - // "An hnswlib index can only be trained with int32 or uint32 IdxT"); - common::nvtx::range fun_scope("cagra::serialize"); - RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", - static_cast(index_.size()), - index_.dim()); - - // offset_level_0 - std::size_t offset_level_0 = 0; - os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); - // max_element - std::size_t max_element = index_.size(); - os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); - // curr_element_count - std::size_t curr_element_count = index_.size(); - os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); - // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, - // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + - // dim * sizeof(data_t) + sizeof(labeltype) - auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + - index_.dim() * sizeof(T) + 8); - os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); - // label_offset - std::size_t label_offset = size_data_per_element - 8; - os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); - // offset_data - auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); - os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); - // max_level - int max_level = 1; - os.write(reinterpret_cast(&max_level), sizeof(int)); - // entrypoint_node - auto entrypoint_node = static_cast(index_.size() / 2); - os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); - // max_M - auto max_M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); - // max_M0 - std::size_t max_M0 = index_.graph_degree(); - os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); - // M - auto M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&M), sizeof(std::size_t)); - // mult, can be anything - double mult = 0.42424242; - os.write(reinterpret_cast(&mult), sizeof(double)); - // efConstruction, can be anything - std::size_t efConstruction = 500; - os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); - - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - - auto graph = index_.graph(); - auto host_graph = - raft::make_host_matrix(graph.extent(0), graph.extent(1)); - raft::copy(host_graph.data_handle(), - graph.data_handle(), - graph.size(), - raft::resource::get_cuda_stream(res)); - resource::sync_stream(res); - - // Write one dataset and graph row at a time - for (std::size_t i = 0; i < index_.size(); i++) { - auto graph_degree = static_cast(index_.graph_degree()); - os.write(reinterpret_cast(&graph_degree), sizeof(int)); - - for (std::size_t j = 0; j < index_.graph_degree(); ++j) { - auto graph_elem = host_graph(i, j); - os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); - } - - auto data_row = host_dataset.data_handle() + (index_.dim() * i); - // if constexpr (std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = host_dataset(i, j); - os.write(reinterpret_cast(&data_elem), sizeof(T)); - } - // } else if constexpr (std::is_same_v or std::is_same_v) { - // for (std::size_t j = 0; j < index_.dim(); ++j) { - // auto data_elem = static_cast(host_dataset(i, j)); - // os.write(reinterpret_cast(&data_elem), sizeof(int)); - // } - // } - - os.write(reinterpret_cast(&i), sizeof(std::size_t)); - } - - for (std::size_t i = 0; i < index_.size(); i++) { - // zeroes - auto zero = 0; - os.write(reinterpret_cast(&zero), sizeof(int)); - } -} - -template -void serialize(raft::resources const& res, - const std::string& filename, - const raft::neighbors::cagra::index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize(res, of, index_); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - template void deserialize(raft::resources const& handle, const std::string& filename, diff --git a/cpp/include/raft/neighbors/hnsw_serialize.cuh b/cpp/include/raft/neighbors/hnsw_serialize.cuh index 33cb8d0320..0f992e3035 100644 --- a/cpp/include/raft/neighbors/hnsw_serialize.cuh +++ b/cpp/include/raft/neighbors/hnsw_serialize.cuh @@ -29,70 +29,6 @@ namespace raft::neighbors::hnsw { * @{ */ -/** - * Write the CAGRA built index as a base layer HNSW index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index CAGRA index - * - */ -template -void serialize(raft::resources const& handle, - std::ostream& os, - const raft::neighbors::cagra::index& index) -{ - detail::serialize(handle, os, index); -} - -/** - * Save a CAGRA build index in hnswlib base-layer-only serialized format - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index CAGRA index - * - */ -template -void serialize(raft::resources const& handle, - const std::string& filename, - const raft::neighbors::cagra::index& index) -{ - detail::serialize(handle, filename, index); -} - /** * Load an hnswlib index which was serialized from a CAGRA index * diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index c54ed32b77..549acff60b 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -27,48 +27,53 @@ namespace raft::runtime::neighbors::cagra { // Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors -#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - void build_device(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void build_host(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra::search_params const& params, \ - const raft::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index); \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ +#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + void build_device(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void build_host(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra::search_params const& params, \ + const raft::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index); \ + void serialize_to_hnswlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index); \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ raft::neighbors::cagra::index* index); RAFT_INST_CAGRA_FUNCS(float, uint32_t); diff --git a/cpp/include/raft_runtime/neighbors/hnsw.hpp b/cpp/include/raft_runtime/neighbors/hnsw.hpp index 062636234f..eabbc9d3d3 100644 --- a/cpp/include/raft_runtime/neighbors/hnsw.hpp +++ b/cpp/include/raft_runtime/neighbors/hnsw.hpp @@ -23,23 +23,17 @@ namespace raft::runtime::neighbors::hnsw { -#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ - void search(raft::resources const& handle, \ - raft::neighbors::hnsw::search_params const& params, \ - raft::neighbors::hnsw::index const& index, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index); \ - void serialize_to_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index); \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::hnsw::index*& index, \ - int dim, \ +#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ + void search(raft::resources const& handle, \ + raft::neighbors::hnsw::search_params const& params, \ + raft::neighbors::hnsw::index const& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::hnsw::index*& index, \ + int dim, \ raft::distance::DistanceType metric); RAFT_INST_HNSW_FUNCS(float, uint32_t); diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index 69b48b93a4..2ebbad5084 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -24,39 +24,54 @@ namespace raft::runtime::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - raft::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, is); \ +#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ + str = os.str(); \ + } \ + \ + void serialize_to_hnswlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index) \ + { \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ + }; \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::cagra::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, is); \ } RAFT_INST_CAGRA_SERIALIZE(float); diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cu b/cpp/src/raft_runtime/neighbors/hnsw.cpp similarity index 66% rename from cpp/src/raft_runtime/neighbors/hnsw.cu rename to cpp/src/raft_runtime/neighbors/hnsw.cpp index 39bf380550..86fb5cd119 100644 --- a/cpp/src/raft_runtime/neighbors/hnsw.cu +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -32,20 +32,6 @@ namespace raft::runtime::neighbors::hnsw { raft::neighbors::hnsw::search(handle, params, index, queries, neighbors, distances); \ } \ \ - void serialize_to_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index) \ - { \ - raft::neighbors::hnsw::serialize(handle, filename, index); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index) \ - { \ - std::stringstream os; \ - raft::neighbors::hnsw::serialize(handle, os, index); \ - str = os.str(); \ - } \ void deserialize_file(raft::resources const& handle, \ const std::string& filename, \ raft::neighbors::hnsw::index*& index, \ diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 7e22f274e9..5f659e4754 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -211,6 +211,36 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib( + const device_resources& handle, + string& str, + const index[float, uint32_t]& index) except + + + cdef void serialize_to_hnswlib( + const device_resources& handle, + string& str, + const index[uint8_t, uint32_t]& index) except + + + cdef void serialize_to_hnswlib( + const device_resources& handle, + string& str, + const index[int8_t, uint32_t]& index) except + + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[float, uint32_t]& index) except + + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[uint8_t, uint32_t]& index) except + + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[int8_t, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[uint8_t, uint32_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd index fb322fc71b..abc4f47c89 100644 --- a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd +++ b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd @@ -28,7 +28,6 @@ from pylibraft.common.cpp.mdspan cimport ( ) from pylibraft.common.handle cimport device_resources from pylibraft.distance.distance_type cimport DistanceType -from pylibraft.neighbors.cagra.cpp.c_cagra cimport index as cagra_index from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( ann_index, ann_search_params, @@ -75,36 +74,6 @@ cdef extern from "raft_runtime/neighbors/hnsw.hpp" \ host_matrix_view[uint64_t, int64_t, row_major] neighbors, host_matrix_view[float, int64_t, row_major] distances) except + - cdef void serialize( - const device_resources& handle, - string& str, - const cagra_index[float, uint32_t]& index) except + - - cdef void serialize( - const device_resources& handle, - string& str, - const cagra_index[uint8_t, uint32_t]& index) except + - - cdef void serialize( - const device_resources& handle, - string& str, - const cagra_index[int8_t, uint32_t]& index) except + - - cdef void serialize_to_file( - const device_resources& handle, - const string& filename, - const cagra_index[float, uint32_t]& index) except + - - cdef void serialize_to_file( - const device_resources& handle, - const string& filename, - const cagra_index[uint8_t, uint32_t]& index) except + - - cdef void serialize_to_file( - const device_resources& handle, - const string& filename, - const cagra_index[int8_t, uint32_t]& index) except + - cdef void deserialize_file(const device_resources& handle, const string& filename, index[float]*& index, diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index 2473fdefd5..43612f476b 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -191,19 +191,19 @@ def save(filename, Index index, handle=None): idx_float = index c_index_float = \ idx_float.index - c_hnsw.serialize_to_file( + c_cagra.serialize_to_hnswlib_file( deref(handle_), c_filename, deref(c_index_float)) elif index.active_index_type == "byte": idx_int8 = index c_index_int8 = \ idx_int8.index - c_hnsw.serialize_to_file( + c_cagra.serialize_to_hnswlib_file( deref(handle_), c_filename, deref(c_index_int8)) elif index.active_index_type == "ubyte": idx_uint8 = index c_index_uint8 = \ idx_uint8.index - c_hnsw.serialize_to_file( + c_cagra.serialize_to_hnswlib_file( deref(handle_), c_filename, deref(c_index_uint8)) else: raise ValueError( diff --git a/python/pylibraft/pylibraft/test/test_hnsw.py b/python/pylibraft/pylibraft/test/test_hnsw.py index 1cc16abaf6..c3c0445bc1 100644 --- a/python/pylibraft/pylibraft/test/test_hnsw.py +++ b/python/pylibraft/pylibraft/test/test_hnsw.py @@ -65,7 +65,6 @@ def run_hnsw_build_search_test( skl_idx = nn_skl.kneighbors(queries, return_distance=False) recall = calc_recall(out_idx, skl_idx) - print(recall) assert recall > 0.95 From d3fad16f1230d3dd1dbc3775ed1fe3ea336b1aef Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 18 Dec 2023 08:54:58 +0000 Subject: [PATCH 18/31] use unique_ptr for deserialization --- .../raft/neighbors/detail/hnsw_serialize.cuh | 12 ++--- cpp/include/raft/neighbors/hnsw.hpp | 2 +- cpp/include/raft/neighbors/hnsw_serialize.cuh | 24 ++++----- cpp/include/raft_runtime/neighbors/hnsw.hpp | 25 ++++++---- cpp/src/raft_runtime/neighbors/hnsw.cpp | 15 +++--- .../pylibraft/neighbors/cpp/hnsw.pxd | 31 ++++++------ python/pylibraft/pylibraft/neighbors/hnsw.pyx | 49 +++++++------------ 7 files changed, 75 insertions(+), 83 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh b/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh index ed760c4f2c..a997ee064f 100644 --- a/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh @@ -35,14 +35,12 @@ namespace raft::neighbors::hnsw::detail { template -void deserialize(raft::resources const& handle, - const std::string& filename, - index*& index, - int dim, - raft::distance::DistanceType metric) +std::unique_ptr> deserialize(raft::resources const& handle, + const std::string& filename, + int dim, + raft::distance::DistanceType metric) { - index = new index_impl(filename, dim, metric); - RAFT_EXPECTS(index, "Could not set index pointer"); + return std::unique_ptr>(new index_impl(filename, dim, metric)); } } // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp index 9819b2bf26..63fdde3243 100644 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -59,7 +59,7 @@ namespace raft::neighbors::hnsw { * * // Load CAGRA index as base layer only hnswlib index * raft::neighbors::hnsw::index* hnsw_index; - * hnsw::deserialize(res, "my_index.bin", hnsw_index, D,raft::distance::L2Expanded); + * auto hnsw_index = hnsw::deserialize(res, "my_index.bin", D, raft::distance::L2Expanded); * * // Search K nearest neighbors as an hnswlib index * // using host threads for concurrency diff --git a/cpp/include/raft/neighbors/hnsw_serialize.cuh b/cpp/include/raft/neighbors/hnsw_serialize.cuh index 0f992e3035..695e88f9c3 100644 --- a/cpp/include/raft/neighbors/hnsw_serialize.cuh +++ b/cpp/include/raft/neighbors/hnsw_serialize.cuh @@ -32,9 +32,6 @@ namespace raft::neighbors::hnsw { /** * Load an hnswlib index which was serialized from a CAGRA index * - * NOTE: This function allocates the index on the heap, and it is - * the user's responsibility to de-allocate the index - * * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} @@ -45,29 +42,28 @@ namespace raft::neighbors::hnsw { * // create a string with a filepath * std::string filename("/path/to/index"); * // create an an unallocated pointer - * raft::neighbors::hnsw* index; - * raft::deserialize(handle, filename, index); - * // use the index, then delete when done - * delete index; + * int dim = 10; + * raft::distance::DistanceType = raft::distance::L2Expanded + * auto index = raft::deserialize(handle, filename, dim, metric); * @endcode * * @tparam T data element type * * @param[in] handle the raft handle * @param[in] filename the file name for saving the index - * @param[out] index CAGRA index * @param[in] dim dimensionality of the index * @param[in] metric metric used to build the index * + * @return std::unique_ptr> + * */ template -void deserialize(raft::resources const& handle, - const std::string& filename, - index*& index, - int dim, - raft::distance::DistanceType metric) +std::unique_ptr> deserialize(raft::resources const& handle, + const std::string& filename, + int dim, + raft::distance::DistanceType metric) { - detail::deserialize(handle, filename, index, dim, metric); + return detail::deserialize(handle, filename, dim, metric); } /**@}*/ diff --git a/cpp/include/raft_runtime/neighbors/hnsw.hpp b/cpp/include/raft_runtime/neighbors/hnsw.hpp index eabbc9d3d3..018bc17cb1 100644 --- a/cpp/include/raft_runtime/neighbors/hnsw.hpp +++ b/cpp/include/raft_runtime/neighbors/hnsw.hpp @@ -23,21 +23,28 @@ namespace raft::runtime::neighbors::hnsw { -#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ +#define RAFT_INST_HNSW_FUNCS(T) \ void search(raft::resources const& handle, \ raft::neighbors::hnsw::search_params const& params, \ raft::neighbors::hnsw::index const& index, \ raft::host_matrix_view queries, \ raft::host_matrix_view neighbors, \ raft::host_matrix_view distances); \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::hnsw::index*& index, \ - int dim, \ - raft::distance::DistanceType metric); + template \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ + raft::distance::DistanceType metric); \ + template <> \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ + raft::distance::DistanceType metric); -RAFT_INST_HNSW_FUNCS(float, uint32_t); -RAFT_INST_HNSW_FUNCS(int8_t, uint32_t); -RAFT_INST_HNSW_FUNCS(uint8_t, uint32_t); +RAFT_INST_HNSW_FUNCS(float); +RAFT_INST_HNSW_FUNCS(int8_t); +RAFT_INST_HNSW_FUNCS(uint8_t); } // namespace raft::runtime::neighbors::hnsw diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cpp b/cpp/src/raft_runtime/neighbors/hnsw.cpp index 86fb5cd119..f606f592ff 100644 --- a/cpp/src/raft_runtime/neighbors/hnsw.cpp +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -32,19 +32,20 @@ namespace raft::runtime::neighbors::hnsw { raft::neighbors::hnsw::search(handle, params, index, queries, neighbors, distances); \ } \ \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::hnsw::index*& index, \ - int dim, \ - raft::distance::DistanceType metric) \ + template <> \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ + raft::distance::DistanceType metric) \ { \ - raft::neighbors::hnsw::deserialize(handle, filename, index, dim, metric); \ + return raft::neighbors::hnsw::deserialize(handle, filename, dim, metric); \ } RAFT_INST_HNSW(float); RAFT_INST_HNSW(int8_t); RAFT_INST_HNSW(uint8_t); -#undef RAFT_INST_CAGRA_HNSWLIB +#undef RAFT_INST_HNSW } // namespace raft::runtime::neighbors::hnsw diff --git a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd index abc4f47c89..e9f3d7d2b7 100644 --- a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd +++ b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd @@ -19,6 +19,7 @@ # cython: language_level = 3 from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t +from libcpp.memory cimport unique_ptr from libcpp.string cimport string from pylibraft.common.cpp.mdspan cimport ( @@ -74,20 +75,20 @@ cdef extern from "raft_runtime/neighbors/hnsw.hpp" \ host_matrix_view[uint64_t, int64_t, row_major] neighbors, host_matrix_view[float, int64_t, row_major] distances) except + - cdef void deserialize_file(const device_resources& handle, - const string& filename, - index[float]*& index, - int dim, - DistanceType metric) except + + cdef unique_ptr[index[float]] deserialize_file[float]( + const device_resources& handle, + const string& filename, + int dim, + DistanceType metric) except + - cdef void deserialize_file(const device_resources& handle, - const string& filename, - index[int8_t]*& index, - int dim, - DistanceType metric) except + + cdef unique_ptr[index[int8_t]] deserialize_file[int8_t]( + const device_resources& handle, + const string& filename, + int dim, + DistanceType metric) except + - cdef void deserialize_file(const device_resources& handle, - const string& filename, - index[uint8_t]*& index, - int dim, - DistanceType metric) except + + cdef unique_ptr[index[uint8_t]] deserialize_file[uint8_t]( + const device_resources& handle, + const string& filename, + int dim, + DistanceType metric) except + diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index 43612f476b..266d56f7b8 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -21,6 +21,7 @@ from cython.operator cimport dereference as deref from libc.stdint cimport int8_t, uint8_t, uint32_t from libcpp cimport bool +from libcpp.memory cimport unique_ptr from libcpp.string cimport string cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra @@ -62,13 +63,13 @@ cdef class HnswIndex: self.active_index_type = None cdef class HnswIndexFloat(HnswIndex): - cdef c_hnsw.index[float] * index + cdef unique_ptr[c_hnsw.index[float]] index def __cinit__(self): pass def __repr__(self): - m_str = "metric=" + _get_metric_string(self.index.metric()) + m_str = "metric=" + _get_metric_string(self.metric) attr_str = [attr + "=" + str(getattr(self, attr)) for attr in ["dim"]] attr_str = [m_str] + attr_str @@ -76,24 +77,20 @@ cdef class HnswIndexFloat(HnswIndex): @property def dim(self): - return self.index[0].dim() + return self.index.get()[0].dim() @property def metric(self): - return self.index[0].metric() - - def __dealloc__(self): - if self.index is not NULL: - del self.index + return self.index.get()[0].metric() cdef class HnswIndexInt8(HnswIndex): - cdef c_hnsw.index[int8_t] * index + cdef unique_ptr[c_hnsw.index[int8_t]] index def __cinit__(self): pass def __repr__(self): - m_str = "metric=" + _get_metric_string(self.index.metric()) + m_str = "metric=" + _get_metric_string(self.metric) attr_str = [attr + "=" + str(getattr(self, attr)) for attr in ["dim"]] attr_str = [m_str] + attr_str @@ -101,24 +98,20 @@ cdef class HnswIndexInt8(HnswIndex): @property def dim(self): - return self.index[0].dim() + return self.index.get()[0].dim() @property def metric(self): - return self.index[0].metric() - - def __dealloc__(self): - if self.index is not NULL: - del self.index + return self.index.get()[0].metric() cdef class HnswIndexUint8(HnswIndex): - cdef c_hnsw.index[uint8_t] * index + cdef unique_ptr[c_hnsw.index[uint8_t]] index def __cinit__(self): pass def __repr__(self): - m_str = "metric=" + _get_metric_string(self.index.metric()) + m_str = "metric=" + _get_metric_string(self.metric) attr_str = [attr + "=" + str(getattr(self, attr)) for attr in ["dim"]] attr_str = [m_str] + attr_str @@ -126,15 +119,11 @@ cdef class HnswIndexUint8(HnswIndex): @property def dim(self): - return self.index[0].dim() + return self.index.get()[0].dim() @property def metric(self): - return self.index[0].metric() - - def __dealloc__(self): - if self.index is not NULL: - del self.index + return self.index.get()[0].metric() @auto_sync_handle @@ -259,22 +248,22 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): if dtype == np.float32: idx_float = HnswIndexFloat() - c_hnsw.deserialize_file( - deref(handle_), c_filename, idx_float.index, dim, c_metric) + idx_float.index = c_hnsw.deserialize_file[float]( + deref(handle_), c_filename, dim, c_metric) idx_float.trained = True idx_float.active_index_type = 'float32' return idx_float elif dtype == np.byte: idx_int8 = HnswIndexInt8(dim, metric) - c_hnsw.deserialize_file( - deref(handle_), c_filename, idx_int8.index, dim, c_metric) + idx_int8.index = c_hnsw.deserialize_file[int8_t]( + deref(handle_), c_filename, dim, c_metric) idx_int8.trained = True idx_int8.active_index_type = 'byte' return idx_int8 elif dtype == np.ubyte: idx_uint8 = HnswIndexUint8(dim, metric) - c_hnsw.deserialize_file( - deref(handle_), c_filename, idx_uint8.index, dim, c_metric) + idx_uint8.index = c_hnsw.deserialize_file[uint8_t]( + deref(handle_), c_filename, dim, c_metric) idx_uint8.trained = True idx_uint8.active_index_type = 'ubyte' return idx_uint8 From e4e635f8444a84ba74ae82f06e58c07709195780 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 18 Dec 2023 09:47:41 +0000 Subject: [PATCH 19/31] add composite function from_cagra --- cpp/include/raft/neighbors/hnsw.hpp | 46 ++++++++++++++- cpp/include/raft_runtime/neighbors/hnsw.hpp | 10 ++-- cpp/src/raft_runtime/neighbors/hnsw.cpp | 22 +++++++ python/pylibraft/pylibraft/neighbors/hnsw.pyx | 58 ++++++++++++++++--- python/pylibraft/pylibraft/test/test_hnsw.py | 5 +- 5 files changed, 124 insertions(+), 17 deletions(-) diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp index 63fdde3243..1649627f69 100644 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -20,8 +20,11 @@ #include "hnsw.hpp" #include + +#include #include #include +#include namespace raft::neighbors::hnsw { @@ -31,7 +34,48 @@ namespace raft::neighbors::hnsw { */ /** - * @brief Search hnswlib base layer only index constructed from a CAGRA index + * @brief Construct an hnswlib base-layer-only index from a CAGRA index + * NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/cagra_index.bin` + * before reading it as an hnswlib index, then deleting the temporary file. + * 2. This function is only offered as a compiled symbol in `libraft.so` + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] cagra_index cagra index + * + * Usage example: + * @code{.cpp} + * // Build a CAGRA index + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * + * // Load CAGRA index as base-layer-only hnswlib index + * auto hnsw_index = hnsw::from_cagra(res, index); + * @endcode + */ +template +std::unique_ptr> from_cagra(raft::resources const& res, + raft::neighbors::cagra::index cagra_index); + +template <> +std::unique_ptr> from_cagra( + raft::resources const& res, raft::neighbors::cagra::index cagra_index); + +template <> +std::unique_ptr> from_cagra( + raft::resources const& res, raft::neighbors::cagra::index cagra_index); + +template <> +std::unique_ptr> from_cagra( + raft::resources const& res, raft::neighbors::cagra::index cagra_index); + +/** + * @brief Search hnswlib base-layer-only index constructed from a CAGRA index * * @tparam T data element type * @tparam IdxT type of the indices diff --git a/cpp/include/raft_runtime/neighbors/hnsw.hpp b/cpp/include/raft_runtime/neighbors/hnsw.hpp index 018bc17cb1..70adee3ad0 100644 --- a/cpp/include/raft_runtime/neighbors/hnsw.hpp +++ b/cpp/include/raft_runtime/neighbors/hnsw.hpp @@ -23,7 +23,9 @@ namespace raft::runtime::neighbors::hnsw { -#define RAFT_INST_HNSW_FUNCS(T) \ +#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ + std::unique_ptr> from_cagra( \ + raft::resources const& res, raft::neighbors::cagra::index); \ void search(raft::resources const& handle, \ raft::neighbors::hnsw::search_params const& params, \ raft::neighbors::hnsw::index const& index, \ @@ -43,8 +45,8 @@ namespace raft::runtime::neighbors::hnsw { int dim, \ raft::distance::DistanceType metric); -RAFT_INST_HNSW_FUNCS(float); -RAFT_INST_HNSW_FUNCS(int8_t); -RAFT_INST_HNSW_FUNCS(uint8_t); +RAFT_INST_HNSW_FUNCS(float, uint32_t); +RAFT_INST_HNSW_FUNCS(int8_t, uint32_t); +RAFT_INST_HNSW_FUNCS(uint8_t, uint32_t); } // namespace raft::runtime::neighbors::hnsw diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cpp b/cpp/src/raft_runtime/neighbors/hnsw.cpp index f606f592ff..a082b06ed8 100644 --- a/cpp/src/raft_runtime/neighbors/hnsw.cpp +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -14,11 +14,33 @@ * limitations under the License. */ +#include #include #include +#include #include +namespace raft::neighbors::hnsw { +#define RAFT_INST_HNSW(T) \ + template <> \ + std::unique_ptr> from_cagra( \ + raft::resources const& res, raft::neighbors::cagra::index cagra_index) \ + { \ + std::string filepath = "/tmp/cagra_index.bin"; \ + raft::runtime::neighbors::cagra::serialize_to_hnswlib(res, filepath, cagra_index); \ + auto hnsw_index = raft::runtime::neighbors::hnsw::deserialize_file( \ + res, filepath, cagra_index.dim(), cagra_index.metric()); \ + std::filesystem::remove(filepath); \ + return hnsw_index; \ + } + +RAFT_INST_HNSW(float); +RAFT_INST_HNSW(int8_t); +RAFT_INST_HNSW(uint8_t); +#undef RAFT_INST_HNSW +} // namespace raft::neighbors::hnsw + namespace raft::runtime::neighbors::hnsw { #define RAFT_INST_HNSW(T) \ diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index 266d56f7b8..c07b58ad78 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -51,6 +51,8 @@ from pylibraft.common.mdspan cimport ( ) from pylibraft.neighbors.common cimport _get_metric_string +import os + import numpy as np @@ -129,7 +131,7 @@ cdef class HnswIndexUint8(HnswIndex): @auto_sync_handle def save(filename, Index index, handle=None): """ - Saves the CAGRA index as an hnswlib base layer only index to a file. + Saves the CAGRA index as an hnswlib base-layer-only index to a file. Saving / loading the index is experimental. The serialization format is subject to change. @@ -199,9 +201,10 @@ def save(filename, Index index, handle=None): "Index dtype %s not supported" % index.active_index_type) +@auto_sync_handle def load(filename, dim, dtype, metric="sqeuclidean", handle=None): """ - Loads base layer only hnswlib index from file, which was originally + Loads base-layer-only hnswlib index from file, which was originally saved as a built CAGRA index. Saving / loading the index is experimental. The serialization format is @@ -271,6 +274,48 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): raise ValueError("Dataset dtype %s not supported" % dtype) +@auto_sync_handle +def from_cagra(Index index, handle=None): + """ + Returns an hnswlib base-layer-only index from a CAGRA index. + + NOTE: This method uses the filesystem to write the CAGRA index in + `/tmp/cagra_index.bin` before reading it as an hnswlib index, + then deleting the temporary file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + index : Index + Trained CAGRA index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import hnsw + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> # Serialize the CAGRA index to hnswlib base layer only index format + >>> hnsw_index = hnsw.from_cagra(index, handle=handle) + """ + filename = "/tmp/cagra_index.bin" + save(filename, index, handle=handle) + hnsw_index = load(filename, index.dim, np.dtype(index.active_index_type), + _get_metric_string(index.metric), handle=handle) + os.remove(filename) + return hnsw_index + + cdef class SearchParams: """ Hnswlib search parameters @@ -322,7 +367,7 @@ def search(SearchParams search_params, ---------- search_params : SearchParams index : HnswIndex - Trained CAGRA index saved as base layer only hnswlib index. + Trained CAGRA index saved as base-layer-only hnswlib index. queries : array interface compliant matrix shape (n_samples, dim) Supported dtype [float, int8, uint8] k : int @@ -351,11 +396,8 @@ def search(SearchParams search_params, >>> handle = DeviceResources() >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) >>> - >>> Save CAGRA built index as base layer only hnswlib index - >>> hnsw.save("my_index.bin", index) - >>> - >>> Load saved base layer only hnswlib index - >>> hnsw_index = hnsw.load("my_index.bin", n_features, dataset.dtype) + >>> Load saved base-layer-only hnswlib index from CAGRA index + >>> hnsw_index = hnsw.from_cagra(index, handle=handle) >>> >>> # Search hnswlib using the loaded index >>> queries = np.random.random_sample((n_queries, n_features), diff --git a/python/pylibraft/pylibraft/test/test_hnsw.py b/python/pylibraft/pylibraft/test/test_hnsw.py index c3c0445bc1..6b4b649bdd 100644 --- a/python/pylibraft/pylibraft/test/test_hnsw.py +++ b/python/pylibraft/pylibraft/test/test_hnsw.py @@ -47,10 +47,7 @@ def run_hnsw_build_search_test( assert index.trained - filename = "my_index.bin" - hnsw.save(filename, index) - - hnsw_index = hnsw.load(filename, n_cols, dataset.dtype, metric=metric) + hnsw_index = hnsw.from_cagra(index) queries = generate_data((n_queries, n_cols), dtype) out_idx = np.zeros((n_queries, k), dtype=np.uint32) From ad0a25fd118c0e18bcd9fc92162fb35152da771f Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 18 Dec 2023 10:14:07 +0000 Subject: [PATCH 20/31] fix ann-bench compiler error --- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 1 - .../neighbors/detail/{hnsw_serialize.cuh => hnsw_serialize.hpp} | 0 .../raft/neighbors/{hnsw_serialize.cuh => hnsw_serialize.hpp} | 2 +- cpp/src/raft_runtime/neighbors/hnsw.cpp | 2 +- 4 files changed, 2 insertions(+), 3 deletions(-) rename cpp/include/raft/neighbors/detail/{hnsw_serialize.cuh => hnsw_serialize.hpp} (100%) rename cpp/include/raft/neighbors/{hnsw_serialize.cuh => hnsw_serialize.hpp} (98%) diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index a923756ff4..ec71de9cff 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/neighbors/detail/hnsw_serialize.cuh b/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp similarity index 100% rename from cpp/include/raft/neighbors/detail/hnsw_serialize.cuh rename to cpp/include/raft/neighbors/detail/hnsw_serialize.hpp diff --git a/cpp/include/raft/neighbors/hnsw_serialize.cuh b/cpp/include/raft/neighbors/hnsw_serialize.hpp similarity index 98% rename from cpp/include/raft/neighbors/hnsw_serialize.cuh rename to cpp/include/raft/neighbors/hnsw_serialize.hpp index 695e88f9c3..ce5f7de1e3 100644 --- a/cpp/include/raft/neighbors/hnsw_serialize.cuh +++ b/cpp/include/raft/neighbors/hnsw_serialize.hpp @@ -16,7 +16,7 @@ #pragma once -#include "detail/hnsw_serialize.cuh" +#include "detail/hnsw_serialize.hpp" #include "hnsw_types.hpp" #include diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cpp b/cpp/src/raft_runtime/neighbors/hnsw.cpp index a082b06ed8..f750f828a2 100644 --- a/cpp/src/raft_runtime/neighbors/hnsw.cpp +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include From 397176efc28a5ef737885bbfb00bf937dcc96fa5 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 18 Dec 2023 12:46:39 +0000 Subject: [PATCH 21/31] update docs from review --- cpp/include/raft/neighbors/hnsw.hpp | 2 +- cpp/include/raft/neighbors/hnsw_types.hpp | 6 +++--- docs/source/pylibraft_api/neighbors.rst | 14 ++++++++------ 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp index 1649627f69..62ecf4af99 100644 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -107,7 +107,7 @@ std::unique_ptr> from_cagra( * * // Search K nearest neighbors as an hnswlib index * // using host threads for concurrency - * h::seanswrch_params search_params; + * hnsw::search_params search_params; * search_params.ef = 50 // ef >= K; * search_params.num_threads = 10; * auto neighbors = raft::make_host_matrix(res, n_queries, k); diff --git a/cpp/include/raft/neighbors/hnsw_types.hpp b/cpp/include/raft/neighbors/hnsw_types.hpp index 06e96ab020..d65a7b9a4c 100644 --- a/cpp/include/raft/neighbors/hnsw_types.hpp +++ b/cpp/include/raft/neighbors/hnsw_types.hpp @@ -42,9 +42,9 @@ struct index : ann::index { public: /** * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index. - * This is a virtual class and it cannot be used directly. To create an index, construct - * an instance of `raft::neighbors::cagra_hnswlib::hnswlib_index` from the header - * `raft/neighbores/hnswlib_types.hpp` + * This is a virtual class and it cannot be used directly. To create an index, use the factory + * function `raft::neighbors::hnsw::from_cagra` from the header + * `raft/neighbors/hnsw.hpp` * * @param[in] dim dimensions of the training dataset * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") diff --git a/docs/source/pylibraft_api/neighbors.rst b/docs/source/pylibraft_api/neighbors.rst index 5da5d760df..e9e890fccb 100644 --- a/docs/source/pylibraft_api/neighbors.rst +++ b/docs/source/pylibraft_api/neighbors.rst @@ -33,19 +33,21 @@ Serializer Methods .. autofunction:: pylibraft.neighbors.cagra.load -CAGRA hnswlib -############# +HNSW +#### -.. autoclass:: pylibraft.neighbors.cagra_hnswlib.SearchParams +.. autoclass:: pylibraft.neighbors.hnsw.SearchParams :members: -.. autofunction:: pylibraft.neighbors.cagra_hnswlib.search +.. autofunction:: pylibraft.neighbors.hnsw.from_cagra + +.. autofunction:: pylibraft.neighbors.hnsw.search Serializer Methods ------------------ -.. autofunction:: pylibraft.neighbors.cagra_hnswlib.save +.. autofunction:: pylibraft.neighbors.hnsw.save -.. autofunction:: pylibraft.neighbors.cagra_hnswlib.load +.. autofunction:: pylibraft.neighbors.hnsw.load IVF-Flat ######## From c30671b6f8db1745dd0696ad9c823b865815eee2 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Fri, 19 Jan 2024 04:45:05 +0000 Subject: [PATCH 22/31] Fix some style --- cpp/CMakeLists.txt | 10 ++++++---- cpp/bench/ann/CMakeLists.txt | 16 +++------------- cpp/include/raft/neighbors/cagra_serialize.cuh | 2 +- .../neighbors/detail/cagra/cagra_serialize.cuh | 2 +- cpp/include/raft/neighbors/detail/hnsw.hpp | 2 +- .../raft/neighbors/detail/hnsw_serialize.hpp | 2 +- cpp/include/raft/neighbors/detail/hnsw_types.hpp | 2 +- cpp/include/raft/neighbors/hnsw.hpp | 2 +- cpp/include/raft/neighbors/hnsw_serialize.hpp | 2 +- cpp/include/raft/neighbors/hnsw_types.hpp | 2 +- cpp/include/raft_runtime/neighbors/cagra.hpp | 2 +- cpp/include/raft_runtime/neighbors/hnsw.hpp | 2 +- .../raft_runtime/neighbors/cagra_serialize.cu | 2 +- cpp/src/raft_runtime/neighbors/hnsw.cpp | 2 +- python/pylibraft/pylibraft/common/mdspan.pxd | 2 +- python/pylibraft/pylibraft/common/mdspan.pyx | 2 +- .../pylibraft/pylibraft/neighbors/CMakeLists.txt | 2 +- python/pylibraft/pylibraft/neighbors/__init__.py | 2 +- .../pylibraft/neighbors/cagra/cagra.pxd | 2 +- .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 2 +- .../pylibraft/pylibraft/neighbors/cpp/hnsw.pxd | 2 +- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 2 +- python/pylibraft/pylibraft/test/ann_utils.py | 2 +- python/pylibraft/pylibraft/test/test_cagra.py | 2 +- python/pylibraft/pylibraft/test/test_hnsw.py | 2 +- 25 files changed, 32 insertions(+), 40 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 407e1dd792..783a6171c0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -196,7 +196,7 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH) rapids_cpm_gbench() endif() -if (BUILD_CAGRA_HNSWLIB) +if(BUILD_CAGRA_HNSWLIB) include(cmake/thirdparty/get_hnswlib.cmake) endif() @@ -206,8 +206,10 @@ add_library(raft INTERFACE) add_library(raft::raft ALIAS raft) target_include_directories( - raft INTERFACE "$" "$" - "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" + raft + INTERFACE + "$" "$" + "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" ) if(NOT BUILD_CPU_ONLY) diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index de980e8945..ee84f7515a 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -225,9 +225,7 @@ endfunction() if(RAFT_ANN_BENCH_USE_HNSWLIB) ConfigureAnnBench( - NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp - LINKS - hnswlib::hnswlib + NAME HNSWLIB PATH bench/ann/src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib ) endif() @@ -276,12 +274,7 @@ endif() if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ConfigureAnnBench( - NAME - RAFT_CAGRA_HNSWLIB - PATH - bench/ann/src/raft/raft_cagra_hnswlib.cu - LINKS - raft::compiled + NAME RAFT_CAGRA_HNSWLIB PATH bench/ann/src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled hnswlib::hnswlib ) endif() @@ -336,10 +329,7 @@ endif() if(RAFT_ANN_BENCH_USE_GGNN) include(cmake/thirdparty/get_glog.cmake) - ConfigureAnnBench( - NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu - LINKS glog::glog ggnn::ggnn - ) + ConfigureAnnBench(NAME GGNN PATH bench/ann/src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn) endif() # ################################################################################################## diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index fc88926077..62928ef7da 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 5fecadbd63..42a979f059 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/detail/hnsw.hpp b/cpp/include/raft/neighbors/detail/hnsw.hpp index 2033dfb888..4d462d9bf1 100644 --- a/cpp/include/raft/neighbors/detail/hnsw.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp b/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp index a997ee064f..8103ffc5ab 100644 --- a/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/detail/hnsw_types.hpp b/cpp/include/raft/neighbors/detail/hnsw_types.hpp index a7ab0fc62f..ff1ef43bc7 100644 --- a/cpp/include/raft/neighbors/detail/hnsw_types.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp index 62ecf4af99..dceb98c5aa 100644 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/hnsw_serialize.hpp b/cpp/include/raft/neighbors/hnsw_serialize.hpp index ce5f7de1e3..45819c8fb5 100644 --- a/cpp/include/raft/neighbors/hnsw_serialize.hpp +++ b/cpp/include/raft/neighbors/hnsw_serialize.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/hnsw_types.hpp b/cpp/include/raft/neighbors/hnsw_types.hpp index d65a7b9a4c..aa4cefbc30 100644 --- a/cpp/include/raft/neighbors/hnsw_types.hpp +++ b/cpp/include/raft/neighbors/hnsw_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index 549acff60b..8389929b15 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft_runtime/neighbors/hnsw.hpp b/cpp/include/raft_runtime/neighbors/hnsw.hpp index 70adee3ad0..e8b932d490 100644 --- a/cpp/include/raft_runtime/neighbors/hnsw.hpp +++ b/cpp/include/raft_runtime/neighbors/hnsw.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index 2ebbad5084..bf8e7bf6ee 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cpp b/cpp/src/raft_runtime/neighbors/hnsw.cpp index f750f828a2..1f9e6b0a0b 100644 --- a/cpp/src/raft_runtime/neighbors/hnsw.cpp +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index bdff331153..dc3d6a5b45 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index d2b63ce549..71266e62ed 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 9e45712b40..db772c6ff1 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index d2f7092421..699a69ff3a 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd index f5f6da93f9..98537f8357 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 5f659e4754..1dffd40186 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd index e9f3d7d2b7..75c0c14aad 100644 --- a/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd +++ b/python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index c07b58ad78..8b8ecd48b7 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/test/ann_utils.py b/python/pylibraft/pylibraft/test/ann_utils.py index d9348fe100..60db7f3273 100644 --- a/python/pylibraft/pylibraft/test/ann_utils.py +++ b/python/pylibraft/pylibraft/test/ann_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 65ad3d1fcf..be53b33da3 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/test/test_hnsw.py b/python/pylibraft/pylibraft/test/test_hnsw.py index 6b4b649bdd..487f190e4e 100644 --- a/python/pylibraft/pylibraft/test/test_hnsw.py +++ b/python/pylibraft/pylibraft/test/test_hnsw.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 44c88e77ca9fbf55bd64020c3ed2b9378da51400 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Fri, 19 Jan 2024 18:32:21 +0000 Subject: [PATCH 23/31] Get wheel builds to compile --- cpp/CMakeLists.txt | 8 ++++---- cpp/cmake/thirdparty/get_hnswlib.cmake | 2 +- cpp/include/raft/neighbors/detail/hnsw.hpp | 2 +- cpp/include/raft/neighbors/detail/hnsw_types.hpp | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 783a6171c0..76d82192bf 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -206,11 +206,11 @@ add_library(raft INTERFACE) add_library(raft::raft ALIAS raft) target_include_directories( - raft - INTERFACE - "$" "$" - "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" + raft INTERFACE "$" "$" ) +if(BUILD_CAGRA_HNSWLIB) + target_link_libraries(raft INTERFACE hnswlib::hnswlib) +endif() if(NOT BUILD_CPU_ONLY) # Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target. diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index 82e95803f3..7ffcbd015a 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -74,5 +74,5 @@ endif() find_and_configure_hnswlib(VERSION 0.6.2 REPOSITORY ${RAFT_HNSWLIB_GIT_REPOSITORY} PINNED_TAG ${RAFT_HNSWLIB_GIT_TAG} - EXCLUDE_FROM_ALL ON + EXCLUDE_FROM_ALL OFF ) diff --git a/cpp/include/raft/neighbors/detail/hnsw.hpp b/cpp/include/raft/neighbors/detail/hnsw.hpp index 4d462d9bf1..69478205a9 100644 --- a/cpp/include/raft/neighbors/detail/hnsw.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw.hpp @@ -24,7 +24,7 @@ #include -#include +#include namespace raft::neighbors::hnsw::detail { diff --git a/cpp/include/raft/neighbors/detail/hnsw_types.hpp b/cpp/include/raft/neighbors/detail/hnsw_types.hpp index ff1ef43bc7..94ade95965 100644 --- a/cpp/include/raft/neighbors/detail/hnsw_types.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw_types.hpp @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include From 304ae7a498adedd914e41a923298aa39aa8d3603 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 19 Jan 2024 19:42:39 +0000 Subject: [PATCH 24/31] fix docs --- .../raft/neighbors/cagra_serialize.cuh | 26 ++++++++++++------- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 8 +++--- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index 62928ef7da..83830c7457 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -32,13 +32,14 @@ namespace raft::neighbors::cagra { * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create an output stream * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, os, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize(handle, os, index); * @endcode * * @tparam T data element type @@ -66,13 +67,14 @@ void serialize(raft::resources const& handle, * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, filename, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize(handle, filename, index); * @endcode * * @tparam T data element type @@ -100,13 +102,14 @@ void serialize(raft::resources const& handle, * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create an output stream * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, os, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize_to_hnswlib(handle, os, index); * @endcode * * @tparam T data element type @@ -132,13 +135,14 @@ void serialize_to_hnswlib(raft::resources const& handle, * * @code{.cpp} * #include + * #include * * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); - * // create an index with `auto index = cagra::build(...);` - * raft::serialize(handle, filename, index); + * // create an index with `auto index = raft::cagra::build(...);` + * raft::cagra::serialize_to_hnswlib(handle, filename, index); * @endcode * * @tparam T data element type @@ -164,6 +168,7 @@ void serialize_to_hnswlib(raft::resources const& handle, * * @code{.cpp} * #include + * #include * * raft::resources handle; * @@ -171,7 +176,7 @@ void serialize_to_hnswlib(raft::resources const& handle, * std::istream is(std::cin.rdbuf()); * using T = float; // data element type * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, is); + * auto index = raft::cagra::deserialize(handle, is); * @endcode * * @tparam T data element type @@ -195,6 +200,7 @@ index deserialize(raft::resources const& handle, std::istream& is) * * @code{.cpp} * #include + * #include * * raft::resources handle; * @@ -202,7 +208,7 @@ index deserialize(raft::resources const& handle, std::istream& is) * std::string filename("/path/to/index"); * using T = float; // data element type * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, filename); + * auto index = raft::cagra::deserialize(handle, filename); * @endcode * * @tparam T data element type diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index 8b8ecd48b7..964f16b0e6 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -218,7 +218,7 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): dim : int Dimensions of the training dataest dtype : np.dtype of the saved index - Valid values for dtype: ["float", "byte", "ubyte"] + Valid values for dtype: [np.float32, np.byte, np.ubyte] metric : string denoting the metric type, default="sqeuclidean" Valid values for metric: ["sqeuclidean", "inner_product"], where - sqeuclidean is the euclidean distance without the square root @@ -235,7 +235,7 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): -------- >>> from pylibraft.neighbors import hnsw >>> dim = 50 # Assuming training dataset has 50 dimensions - >>> index = hnsw.load("my_index.bin", dim, "sqeuclidean") + >>> index = hnsw.load("my_index.bin", dim, np.float32, "sqeuclidean") """ if handle is None: handle = DeviceResources() @@ -400,8 +400,8 @@ def search(SearchParams search_params, >>> hnsw_index = hnsw.from_cagra(index, handle=handle) >>> >>> # Search hnswlib using the loaded index - >>> queries = np.random.random_sample((n_queries, n_features), - ... dtype=cp.float32) + >>> queries = np.random.random_sample((n_queries, n_features)). + ... astype(np.float32) >>> k = 10 >>> search_params = hnsw.SearchParams( ... ef=20, From d10bb4fa3e6c53fd47517071576d5a87444c8c11 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Fri, 19 Jan 2024 20:00:46 +0000 Subject: [PATCH 25/31] Export hnswlib dependency in raft --- cpp/cmake/thirdparty/get_hnswlib.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index 7ffcbd015a..060ce2ba5d 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -30,6 +30,8 @@ function(find_and_configure_hnswlib) rapids_cpm_find( hnswlib ${PKG_VERSION} GLOBAL_TARGETS hnswlib::hnswlib + BUILD_EXPORT_SET raft-exports + INSTALL_EXPORT_SET raft-exports CPM_ARGS GIT_REPOSITORY ${PKG_REPOSITORY} GIT_TAG ${PKG_PINNED_TAG} From 1d8e80429d65c39a82b7170eca3b7a8bf074715e Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 19 Jan 2024 21:51:49 +0000 Subject: [PATCH 26/31] fix doc again --- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index 964f16b0e6..ebf7c6df89 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -400,8 +400,8 @@ def search(SearchParams search_params, >>> hnsw_index = hnsw.from_cagra(index, handle=handle) >>> >>> # Search hnswlib using the loaded index - >>> queries = np.random.random_sample((n_queries, n_features)). - ... astype(np.float32) + >>> queries = np.random.random_sample((n_queries, n_features)).astype( + ... np.float32) >>> k = 10 >>> search_params = hnsw.SearchParams( ... ef=20, From ae543e7637da27d2468aca830283abb6b67b6dd8 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Tue, 23 Jan 2024 00:12:40 +0000 Subject: [PATCH 27/31] Specify the hnswlib version in the export --- cpp/cmake/thirdparty/get_hnswlib.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/cmake/thirdparty/get_hnswlib.cmake b/cpp/cmake/thirdparty/get_hnswlib.cmake index 060ce2ba5d..f4fe777379 100644 --- a/cpp/cmake/thirdparty/get_hnswlib.cmake +++ b/cpp/cmake/thirdparty/get_hnswlib.cmake @@ -53,11 +53,13 @@ function(find_and_configure_hnswlib) # write export rules rapids_export( BUILD hnswlib + VERSION ${PKG_VERSION} EXPORT_SET hnswlib-exports GLOBAL_TARGETS hnswlib NAMESPACE hnswlib::) rapids_export( INSTALL hnswlib + VERSION ${PKG_VERSION} EXPORT_SET hnswlib-exports GLOBAL_TARGETS hnswlib NAMESPACE hnswlib::) From ad05438282b0866e878d33b6fdca2139fa51bce6 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 23 Jan 2024 02:09:36 +0000 Subject: [PATCH 28/31] more doc fixes --- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index ebf7c6df89..514f67e8a6 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -233,8 +233,19 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): Examples -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra >>> from pylibraft.neighbors import hnsw - >>> dim = 50 # Assuming training dataset has 50 dimensions + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> # Serialize the CAGRA index to hnswlib base layer only index format + >>> hnsw.save("my_index.bin", index, handle=handle) >>> index = hnsw.load("my_index.bin", dim, np.float32, "sqeuclidean") """ if handle is None: @@ -396,7 +407,7 @@ def search(SearchParams search_params, >>> handle = DeviceResources() >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) >>> - >>> Load saved base-layer-only hnswlib index from CAGRA index + >>> # Load saved base-layer-only hnswlib index from CAGRA index >>> hnsw_index = hnsw.from_cagra(index, handle=handle) >>> >>> # Search hnswlib using the loaded index From 94d08cd6925f8af009107ac82fa66290fecf8fb2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 23 Jan 2024 17:25:05 +0000 Subject: [PATCH 29/31] hopefully final doc fix --- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index 514f67e8a6..9f8be25ae6 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -246,7 +246,7 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) >>> # Serialize the CAGRA index to hnswlib base layer only index format >>> hnsw.save("my_index.bin", index, handle=handle) - >>> index = hnsw.load("my_index.bin", dim, np.float32, "sqeuclidean") + >>> index = hnsw.load("my_index.bin", n_features, np.float32, "sqeuclidean") """ if handle is None: handle = DeviceResources() From 54bf32c6c302c406188382ac9a7b8f63db6173f0 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 23 Jan 2024 17:32:32 +0000 Subject: [PATCH 30/31] style fix --- python/pylibraft/pylibraft/neighbors/hnsw.pyx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/neighbors/hnsw.pyx b/python/pylibraft/pylibraft/neighbors/hnsw.pyx index 9f8be25ae6..aa589ffb65 100644 --- a/python/pylibraft/pylibraft/neighbors/hnsw.pyx +++ b/python/pylibraft/pylibraft/neighbors/hnsw.pyx @@ -246,7 +246,8 @@ def load(filename, dim, dtype, metric="sqeuclidean", handle=None): >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) >>> # Serialize the CAGRA index to hnswlib base layer only index format >>> hnsw.save("my_index.bin", index, handle=handle) - >>> index = hnsw.load("my_index.bin", n_features, np.float32, "sqeuclidean") + >>> index = hnsw.load("my_index.bin", n_features, np.float32, + ... "sqeuclidean") """ if handle is None: handle = DeviceResources() From 1897cf73c589898fad5a56a58875b12879c19a98 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 25 Jan 2024 16:49:47 -0800 Subject: [PATCH 31/31] pre commit fixes --- python/pylibraft/pylibraft/neighbors/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index daf7f182f5..86612b2fbb 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -15,7 +15,8 @@ from pylibraft.neighbors import brute_force # type: ignore from pylibraft.neighbors import hnsw # type: ignore -from pylibraft.neighbors import cagra, ivf_flat, ivf_pq, rbc +from pylibraft.neighbors import rbc # type: ignore +from pylibraft.neighbors import cagra, ivf_flat, ivf_pq from .refine import refine @@ -27,5 +28,5 @@ "ivf_pq", "cagra", "hnsw", - "rbc" + "rbc", ]