diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index acb77ec8c7..5b028a2456 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON) option(BUILD_TESTS "Build raft unit-tests" ON) option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF) option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF) +option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON) option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF) option(CUDA_ENABLE_LINEINFO "Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF @@ -195,6 +196,10 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH) rapids_cpm_gbench() endif() +if (BUILD_CAGRA_HNSWLIB) + include(cmake/thirdparty/get_hnswlib.cmake) +endif() + # ################################################################################################## # * raft --------------------------------------------------------------------- add_library(raft INTERFACE) @@ -202,6 +207,7 @@ add_library(raft::raft ALIAS raft) target_include_directories( raft INTERFACE "$" "$" + "$<$:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>" ) if(NOT BUILD_CPU_ONLY) diff --git a/cpp/include/raft/neighbors/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/cagra_hnswlib.hpp new file mode 100644 index 0000000000..cb02b10b4c --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_hnswlib.hpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cagra_hnswlib_types.hpp" +#include "detail/cagra_hnswlib.hpp" + +#include +#include +#include + +namespace raft::neighbors::cagra_hnswlib { + +/** + * @brief Search hnswlib base layer only index constructed from a CAGRA index + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * + * Usage example: + * @code{.cpp} + * // Build a CAGRA index + * using namespace raft::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * + * // Save CAGRA index as base layer only hnswlib index + * cagra::serialize_to_hnswlib(res, "my_index.bin", index); + * + * // Load CAGRA index as base layer only hnswlib index + * cagra_hnswlib::index(D, "my_index.bin", raft::distance::L2Expanded); + * + * // Search K nearest neighbors as an hnswlib index + * // using host threads for concurrency + * cagra_hnswlib::search_params search_params; + * search_params.ef = 50 // ef >= K; + * search_params.num_threads = 10; + * auto neighbors = raft::make_host_matrix(res, n_queries, k); + * auto distances = raft::make_host_matrix(res, n_queries, k); + * cagra_hnswlib::search(res, search_params, index, queries, neighbors, distances); + * @endcode + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + detail::search(res, params, idx, queries, neighbors, distances); +} + +} // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp new file mode 100644 index 0000000000..bf1725ea86 --- /dev/null +++ b/cpp/include/raft/neighbors/cagra_hnswlib_types.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include +#include + +#include +#include +#include +#include + +namespace raft::neighbors::cagra_hnswlib { + +template +struct hnsw_dist_t { + using type = void; +}; + +template <> +struct hnsw_dist_t { + using type = float; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +template <> +struct hnsw_dist_t { + using type = int; +}; + +struct search_params : ann::search_params { + int ef; // size of the candidate list + int num_threads = 1; // number of host threads to use for concurrent searches +}; + +template +struct index : ann::index { + public: + /** + * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index + * + * @param[in] filepath path to the index + * @param[in] dim dimensions of the training dataset + * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") + */ + index(std::string filepath, int dim, raft::distance::DistanceType metric) + : dim_{dim}, metric_{metric} + { + if constexpr (std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique(dim_); + } else if (metric == raft::distance::InnerProduct) { + space_ = std::make_unique(dim_); + } + } else if constexpr (std::is_same_v or std::is_same_v) { + if (metric == raft::distance::L2Expanded) { + space_ = std::make_unique(dim_); + } + } + + RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used"); + + appr_alg_ = std::make_unique::type>>( + space_.get(), filepath); + + appr_alg_->base_layer_only = true; + } + + /** + @brief Get hnswlib index + */ + auto get_index() const -> hnswlib::HierarchicalNSW::type> const* + { + return appr_alg_.get(); + } + + auto dim() const -> int const { return dim_; } + + auto metric() const -> raft::distance::DistanceType { return metric_; } + + private: + int dim_; + raft::distance::DistanceType metric_; + + std::unique_ptr::type>> appr_alg_; + std::unique_ptr::type>> space_; +}; + +} // namespace raft::neighbors::cagra_hnswlib diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 51c9475434..8ed01b062d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -204,7 +204,6 @@ void serialize_to_hnswlib(raft::resources const& res, auto zero = 0; os.write(reinterpret_cast(&zero), sizeof(int)); } - // delete [] host_graph; } template diff --git a/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp new file mode 100644 index 0000000000..ba826db0c5 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../cagra_hnswlib_types.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::cagra_hnswlib::detail { + +class FixedThreadPool { + public: + FixedThreadPool(int num_threads) + { + if (num_threads < 1) { + throw std::runtime_error("num_threads must >= 1"); + } else if (num_threads == 1) { + return; + } + + tasks_ = new Task_[num_threads]; + + threads_.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + threads_.emplace_back([&, i] { + auto& task = tasks_[i]; + while (true) { + std::unique_lock lock(task.mtx); + task.cv.wait(lock, + [&] { return task.has_task || finished_.load(std::memory_order_relaxed); }); + if (finished_.load(std::memory_order_relaxed)) { break; } + + task.task(); + task.has_task = false; + } + }); + } + } + + ~FixedThreadPool() + { + if (threads_.empty()) { return; } + + finished_.store(true, std::memory_order_relaxed); + for (unsigned i = 0; i < threads_.size(); ++i) { + auto& task = tasks_[i]; + std::lock_guard(task.mtx); + + task.cv.notify_one(); + threads_[i].join(); + } + + delete[] tasks_; + } + + template + void submit(Func f, IdxT len) + { + // Run functions in main thread if thread pool has no threads + if (threads_.empty()) { + for (IdxT i = 0; i < len; ++i) { + f(i); + } + return; + } + + const int num_threads = threads_.size(); + // one extra part for competition among threads + const IdxT items_per_thread = len / (num_threads + 1); + std::atomic cnt(items_per_thread * num_threads); + + // Wrap function + auto wrapped_f = [&](IdxT start, IdxT end) { + for (IdxT i = start; i < end; ++i) { + f(i); + } + + while (true) { + IdxT i = cnt.fetch_add(1, std::memory_order_relaxed); + if (i >= len) { break; } + f(i); + } + }; + + std::vector> futures; + futures.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + IdxT start = i * items_per_thread; + auto& task = tasks_[i]; + { + std::lock_guard lock(task.mtx); + (void)lock; // stop nvcc warning + task.task = std::packaged_task([=] { wrapped_f(start, start + items_per_thread); }); + futures.push_back(task.task.get_future()); + task.has_task = true; + } + task.cv.notify_one(); + } + + for (auto& fut : futures) { + fut.wait(); + } + return; + } + + private: + struct alignas(64) Task_ { + std::mutex mtx; + std::condition_variable cv; + bool has_task = false; + std::packaged_task task; + }; + + Task_* tasks_; + std::vector threads_; + std::atomic finished_{false}; +}; + +template +void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, + const T* query, + int k, + uint64_t* indices, + float* distances) +{ + auto result = idx->searchKnn(query, k); + assert(result.size() >= static_cast(k)); + + for (int i = k - 1; i >= 0; --i) { + indices[i] = result.top().second; + distances[i] = result.top().first; + result.pop(); + } +} + +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::host_matrix_view queries, + raft::host_matrix_view neighbors, + raft::host_matrix_view distances) +{ + auto const* hnswlib_index = idx.get_index(); + + // no-op when num_threads == 1, no synchronization overhead + FixedThreadPool thread_pool{params.num_threads}; + + auto f = [&](auto const& i) { + get_search_knn_results(hnswlib_index, + queries.data_handle() + i * queries.extent(1), + neighbors.extent(1), + neighbors.data_handle() + i * neighbors.extent(1), + distances.data_handle() + i * distances.extent(1)); + }; + + thread_pool.submit(f, queries.extent(0)); +} + +} // namespace raft::neighbors::cagra_hnswlib::detail diff --git a/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp b/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp new file mode 100644 index 0000000000..8a9edd03ce --- /dev/null +++ b/cpp/include/raft_runtime/neighbors/cagra_hnswlib.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::runtime::neighbors::cagra_hnswlib { + +#define RAFT_INST_CAGRA_HNSWLIB_FUNCS(T) \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra_hnswlib::search_params const& params, \ + raft::neighbors::cagra_hnswlib::index const& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances) \ + { \ + raft::neighbors::cagra_hnswlib::search(handle, params, index, queries, neighbors, distances); \ + } + +RAFT_INST_CAGRA_HNSWLIB_FUNCS(float); +RAFT_INST_CAGRA_HNSWLIB_FUNCS(int8_t); +RAFT_INST_CAGRA_HNSWLIB_FUNCS(uint8_t); + +} // namespace raft::runtime::neighbors::cagra_hnswlib diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 17dd2d8bfd..bdff331153 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -19,7 +19,7 @@ # cython: embedsignature = True # cython: language_level = 3 -from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t from libcpp.string cimport string from pylibraft.common.cpp.mdspan cimport ( @@ -79,6 +79,9 @@ cdef host_matrix_view[int64_t, int64_t, row_major] get_hmv_int64( cdef host_matrix_view[uint32_t, int64_t, row_major] get_hmv_uint32( array, check_shape) except * +cdef host_matrix_view[uint64_t, int64_t, row_major] get_hmv_uint64( + array, check_shape) except * + cdef host_matrix_view[const_float, int64_t, row_major] get_const_hmv_float( array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 7442a6bb89..d2b63ce549 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -290,7 +290,7 @@ cdef host_matrix_view[int64_t, int64_t, row_major] \ cdef host_matrix_view[uint32_t, int64_t, row_major] \ get_hmv_uint32(cai, check_shape) except *: - if cai.dtype != np.int64: + if cai.dtype != np.uint32: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) @@ -299,6 +299,17 @@ cdef host_matrix_view[uint32_t, int64_t, row_major] \ cai.data, shape[0], shape[1]) +cdef host_matrix_view[uint64_t, int64_t, row_major] \ + get_hmv_uint64(cai, check_shape) except *: + if cai.dtype != np.uint64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[uint64_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + cdef host_matrix_view[const_float, int64_t, row_major] \ get_const_hmv_float(cai, check_shape) except *: if cai.dtype != np.float32: diff --git a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx index 4faeb89ef5..db75524a9c 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra_hnswlib.pyx @@ -20,6 +20,7 @@ from cython.operator cimport dereference as deref from libc.stdint cimport int8_t, uint8_t, uint32_t +from libcpp cimport bool from libcpp.string cimport string cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra @@ -34,13 +35,126 @@ from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport device_resources -from pylibraft.common import DeviceResources +from pylibraft.common import DeviceResources, ai_wrapper, auto_convert_output + +cimport pylibraft.neighbors.cpp.cagra_hnswlib as c_cagra_hnswlib + +from pylibraft.neighbors.common import _check_input_array, _get_metric + +from pylibraft.common.mdspan cimport ( + get_hmv_float, + get_hmv_int8, + get_hmv_uint8, + get_hmv_uint64, +) +from pylibraft.neighbors.common cimport _get_metric_string + +import numpy as np + + +cdef class CagraHnswlibIndex: + cdef readonly bool trained + cdef str active_index_type + + def __cinit__(self): + self.trained = False + self.active_index_type = None + +cdef class CagraHnswlibIndexFloat(CagraHnswlibIndex): + cdef c_cagra_hnswlib.index[float] * index + + def __cinit__(self, filepath, dim, metric): + cdef string c_filepath = filepath.encode('utf-8') + + self.index = new c_cagra_hnswlib.index[float]( + c_filepath, + dim, + _get_metric(metric)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def metric(self): + return self.index[0].metric() + + def __dealloc__(self): + if self.index is not NULL: + del self.index + +cdef class CagraHnswlibIndexInt8(CagraHnswlibIndex): + cdef c_cagra_hnswlib.index[int8_t] * index + + def __cinit__(self, filepath, dim, metric): + cdef string c_filepath = filepath.encode('utf-8') + + self.index = new c_cagra_hnswlib.index[int8_t]( + c_filepath, + dim, + _get_metric(metric)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def metric(self): + return self.index[0].metric() + + def __dealloc__(self): + if self.index is not NULL: + del self.index + +cdef class CagraHnswlibIndexUint8(CagraHnswlibIndex): + cdef c_cagra_hnswlib.index[uint8_t] * index + + def __cinit__(self, filepath, dim, metric): + cdef string c_filepath = filepath.encode('utf-8') + + self.index = new c_cagra_hnswlib.index[uint8_t]( + c_filepath, + dim, + _get_metric(metric)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["dim"]] + attr_str = [m_str] + attr_str + return "Index(type=CAGRA_hnswlib, " + (", ".join(attr_str)) + ")" + + @property + def dim(self): + return self.index[0].dim() + + @property + def metric(self): + return self.index[0].metric() + + def __dealloc__(self): + if self.index is not NULL: + del self.index @auto_sync_handle def save(filename, Index index, handle=None): """ - Saves the index to a file. + Saves the CAGRA index as an hnswlib base layer only index to a file. Saving / loading the index is experimental. The serialization format is subject to change. @@ -66,7 +180,7 @@ def save(filename, Index index, handle=None): >>> # Build index >>> handle = DeviceResources() >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) - >>> # Serialize and deserialize the cagra index built + >>> # Serialize the CAGRA index to hnswlib base layer only index format >>> cagra_hnswlib.save("my_index.bin", index, handle=handle) """ if not index.trained: @@ -108,3 +222,224 @@ def save(filename, Index index, handle=None): else: raise ValueError( "Index dtype %s not supported" % index.active_index_type) + + +def load(filename, dim, dtype, metric="sqeuclidean"): + """ + Loads base layer only hnswlib index from file, which was originally + saved as a built CAGRA index. + + Saving / loading the index is experimental. The serialization format is + subject to change, therefore loading an index saved with a previous + version of raft is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + dim : int + Dimensions of the training dataest + dtype : np.dtype of the saved index + Valid values for dtype: ["float", "byte", "ubyte"] + metric : string denoting the metric type, default="sqeuclidean" + Valid values for metric: ["sqeuclidean", "inner_product"], where + - sqeuclidean is the euclidean distance without the square root + operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, + - inner product distance is defined as + distance(a, b) = \\sum_i a_i * b_i. + + Returns + ------- + index : CagraHnswlibIndex + + Examples + -------- + >>> from pylibraft.neighbors import cagra_hnswlib + >>> dim = 50 # Assuming training dataset has 50 dimensions + >>> index = cagra_hnswlib.load("my_index.bin", dim, "sqeuclidean") + """ + cdef string c_filename = filename.encode('utf-8') + cdef CagraHnswlibIndexFloat idx_float + cdef CagraHnswlibIndexInt8 idx_int8 + cdef CagraHnswlibIndexUint8 idx_uint8 + + if dtype == np.float32: + idx_float = CagraHnswlibIndexFloat(filename, dim, metric) + idx_float.trained = True + idx_float.active_index_type = 'float32' + return idx_float + elif dtype == np.byte: + idx_int8 = CagraHnswlibIndexInt8(filename, dim, metric) + idx_int8.trained = True + idx_int8.active_index_type = 'byte' + return idx_int8 + elif dtype == np.ubyte: + idx_uint8 = CagraHnswlibIndexUint8(filename, dim, metric) + idx_uint8.trained = True + idx_uint8.active_index_type = 'ubyte' + return idx_uint8 + else: + raise ValueError("Dataset dtype %s not supported" % dtype) + + +cdef class SearchParams: + """ + CAGRA search parameters + + Parameters + ---------- + ef: int + Size of list from which final neighbors k will be selected. + ef should be greater than or equal to k. + num_threads: int, default=1 + Number of host threads to use to search the hnswlib index + and increase concurrency + """ + cdef c_cagra_hnswlib.search_params params + + def __init__(self, ef, num_threads=1): + self.params.ef = ef + self.params.num_threads = num_threads + + def __repr__(self): + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in [ + "ef", "num_threads"]] + return "SearchParams(type=CAGRA_hnswlib, " + ( + ", ".join(attr_str)) + ")" + + @property + def ef(self): + return self.params.ef + + @property + def num_threads(self): + return self.params.num_threads + + +@auto_sync_handle +@auto_convert_output +def search(SearchParams search_params, + CagraHnswlibIndex index, + queries, + k, + neighbors=None, + distances=None, + handle=None): + """ + Find the k nearest neighbors for each query. + + Parameters + ---------- + search_params : SearchParams + index : Index + Trained CAGRA index. + queries : array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + k : int + The number of neighbors. + neighbors : Optional array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + distances : Optional array interface compliant matrix shape + (n_queries, k) If supplied, the distances to the + neighbors will be written here in-place. (default None) + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + >>> import numpy as np + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + >>> from pylibraft.neighbors import cagra_hnswlib + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> + >>> Save CAGRA built index as base layer only hnswlib index + >>> cagra_hnswlib.save("my_index.bin", index) + >>> + >>> Load saved base layer only hnswlib index + >>> index_hnswlib.load("my_index.bin", n_features, dataset.dtype) + >>> + >>> # Search hnswlib using the loaded index + >>> queries = np.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 10 + >>> search_params = cagra_hnswlib.SearchParams( + ... ef=20, + ... num_threads=5 + ... ) + >>> distances, neighbors = cagra_hnswlib.search(search_params, index, + ... queries, k, handle=handle) + """ + + if not index.trained: + raise ValueError("Index need to be built before calling search.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + queries_ai = ai_wrapper(queries) + queries_dt = queries_ai.dtype + cdef uint32_t n_queries = queries_ai.shape[0] + + _check_input_array(queries_ai, [np.dtype('float32'), np.dtype('byte'), + np.dtype('ubyte')], + exp_cols=index.dim) + + if neighbors is None: + neighbors = np.empty((n_queries, k), dtype='uint64') + + neighbors_ai = ai_wrapper(neighbors) + _check_input_array(neighbors_ai, [np.dtype('uint64')], + exp_rows=n_queries, exp_cols=k) + + if distances is None: + distances = np.empty((n_queries, k), dtype='float32') + + distances_ai = ai_wrapper(distances) + _check_input_array(distances_ai, [np.dtype('float32')], + exp_rows=n_queries, exp_cols=k) + + cdef c_cagra_hnswlib.search_params params = search_params.params + cdef CagraHnswlibIndexFloat idx_float + cdef CagraHnswlibIndexInt8 idx_int8 + cdef CagraHnswlibIndexUint8 idx_uint8 + + if queries_dt == np.float32: + idx_float = index + c_cagra_hnswlib.search(deref(handle_), + params, + deref(idx_float.index), + get_hmv_float(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + elif queries_dt == np.byte: + idx_int8 = index + c_cagra_hnswlib.search(deref(handle_), + params, + deref(idx_int8.index), + get_hmv_int8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + elif queries_dt == np.ubyte: + idx_uint8 = index + c_cagra_hnswlib.search(deref(handle_), + params, + deref(idx_uint8.index), + get_hmv_uint8(queries_ai, check_shape=True), + get_hmv_uint64(neighbors_ai, check_shape=True), + get_hmv_float(distances_ai, check_shape=True)) + else: + raise ValueError("query dtype %s not supported" % queries_dt) + + return (distances, neighbors) diff --git a/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd b/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd new file mode 100644 index 0000000000..6a4154e151 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/cagra_hnswlib.pxd @@ -0,0 +1,75 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t +from libcpp.string cimport string + +from pylibraft.common.cpp.mdspan cimport ( + device_vector_view, + host_matrix_view, + row_major, +) +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType +from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( + ann_index, + ann_search_params, +) + + +cdef extern from "raft/neighbors/cagra_hnswlib_types.hpp" \ + namespace "raft::neighbors::cagra_hnswlib" nogil: + + cpdef cppclass search_params(ann_search_params): + int ef + int num_threads + + cdef cppclass index[T](ann_index): + index(string filepath, int dim, DistanceType metric) + + int dim() + DistanceType metric() + + +cdef extern from "raft_runtime/neighbors/cagra_hnswlib.hpp" \ + namespace "raft::runtime::neighbors::cagra_hnswlib" nogil: + cdef void search( + const device_resources& handle, + const search_params& params, + const index[float]& index, + host_matrix_view[float, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[int8_t]& index, + host_matrix_view[int8_t, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint8_t]& index, + host_matrix_view[uint8_t, int64_t, row_major] queries, + host_matrix_view[uint64_t, int64_t, row_major] neighbors, + host_matrix_view[float, int64_t, row_major] distances) except +