diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index c54ed32b77..80abf1e6c4 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -27,48 +27,56 @@ namespace raft::runtime::neighbors::cagra { // Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors -#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - auto build(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset) \ - ->raft::neighbors::cagra::index; \ - \ - void build_device(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void build_host(raft::resources const& handle, \ - const raft::neighbors::cagra::index_params& params, \ - raft::host_matrix_view dataset, \ - raft::neighbors::cagra::index& idx); \ - \ - void search(raft::resources const& handle, \ - raft::neighbors::cagra::search_params const& params, \ - const raft::neighbors::cagra::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index); \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset = true); \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ +#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + void build_device(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void build_host(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra::search_params const& params, \ + const raft::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void serialize_to_hnswlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ + \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index); \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ raft::neighbors::cagra::index* index); RAFT_INST_CAGRA_FUNCS(float, uint32_t); diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index 69b48b93a4..4ede7ecc25 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -24,39 +24,55 @@ namespace raft::runtime::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - raft::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::cagra::deserialize(handle, is); \ +#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ + }; \ + \ + void serialize_to_hswnlib_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index) \ + { \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ + str = os.str(); \ + } \ + \ + void serialize_to_hnswlib(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::cagra::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, is); \ } RAFT_INST_CAGRA_SERIALIZE(float); diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 45cd9f74e6..d5b8a80761 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx brute_force.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx cagra_hnswlib.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index c11d933b27..977bdfaee0 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -142,8 +142,6 @@ cdef class IndexParams: cdef class Index: - cdef readonly bool trained - cdef str active_index_type def __cinit__(self): self.trained = False diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 7e22f274e9..3d61fd2203 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -174,6 +174,10 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ string& str, const index[float, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnwslib( + const device_resources& handle, + string& str, + const index[float, uint32_t]& index) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -184,6 +188,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnwslib( + const device_resources& handle, + string& str, + const index[uint8_t, uint32_t]& index) except + + cdef void deserialize(const device_resources& handle, const string& str, index[uint8_t, uint32_t]* index) except + @@ -193,6 +202,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[int8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnwslib( + const device_resources& handle, + string& str, + const index[int8_t, uint32_t]& index) except + + cdef void deserialize(const device_resources& handle, const string& str, index[int8_t, uint32_t]* index) except + @@ -202,6 +216,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[float, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[float, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[float, uint32_t]* index) except + @@ -211,6 +230,11 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[uint8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[uint8_t, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[uint8_t, uint32_t]* index) except + @@ -220,6 +244,24 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ const index[int8_t, uint32_t]& index, bool include_dataset) except + + cdef void serialize_to_hnswlib_file( + const device_resources& handle, + const string& filename, + const index[int8_t, uint32_t]& index) except + + cdef void deserialize_file(const device_resources& handle, const string& filename, index[int8_t, uint32_t]* index) except + + +cdef class Index: + cdef readonly bool trained + cdef str active_index_type + +cdef class IndexFloat(Index): + pass + +cdef class IndexInt8(Index): + pass + +cdef class IndexUint8(Index): + pass diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx new file mode 100644 index 0000000000..a69d27967d --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx @@ -0,0 +1,104 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, uint8_t, uint32_t +from libcpp.string cimport string + +cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra + +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + +from pylibraft.common import DeviceResources + + +@auto_sync_handle +def save(filename, c_cagra.Index index, handle=None): + """ + Saves the index to a file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained CAGRA index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import cagra_hnswlib + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> # Serialize and deserialize the cagra index built + >>> cagra_hnswlib.save("my_index.bin", index, handle=handle) + """ + if not index.trained: + raise ValueError("Index need to be built before saving it.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + + cdef c_cagra.IndexFloat idx_float + cdef c_cagra.IndexInt8 idx_int8 + cdef c_cagra.IndexUint8 idx_uint8 + + cdef c_cagra.index[float, uint32_t] * c_index_float + cdef c_cagra.index[int8_t, uint32_t] * c_index_int8 + cdef c_cagra.index[uint8_t, uint32_t] * c_index_uint8 + + if index.active_index_type == "float32": + idx_float = index + c_index_float = \ + idx_float.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_float)) + elif index.active_index_type == "byte": + idx_int8 = index + c_index_int8 = \ + idx_int8.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_int8)) + elif index.active_index_type == "ubyte": + idx_uint8 = index + c_index_uint8 = \ + idx_uint8.index + c_cagra.serialize_to_hnswlib_file( + deref(handle_), c_filename, deref(c_index_uint8)) + else: + raise ValueError( + "Index dtype %s not supported" % index.active_index_type)