-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add cpp index and python load,search methods with runtime API
- Loading branch information
Showing
10 changed files
with
858 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "cagra_hnswlib_types.hpp" | ||
#include "detail/cagra_hnswlib.hpp" | ||
|
||
#include <cstddef> | ||
#include <raft/core/host_mdspan.hpp> | ||
#include <raft/core/resources.hpp> | ||
|
||
namespace raft::neighbors::cagra_hnswlib { | ||
|
||
/** | ||
* @brief Search hnswlib base layer only index constructed from a CAGRA index | ||
* | ||
* See the [cagra::build](#cagra::build) documentation for a usage example. | ||
* | ||
* @tparam T data element type | ||
* @tparam IdxT type of the indices | ||
* | ||
* @param[in] res raft resources | ||
* @param[in] params configure the search | ||
* @param[in] idx cagra index | ||
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] | ||
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset | ||
* [n_queries, k] | ||
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, | ||
* k] | ||
* | ||
* Usage example: | ||
* @code{.cpp} | ||
* // Build a CAGRA index | ||
* using namespace raft::neighbors; | ||
* // use default index parameters | ||
* cagra::index_params index_params; | ||
* // create and fill the index from a [N, D] dataset | ||
* auto index = cagra::build(res, index_params, dataset); | ||
* | ||
* // Save CAGRA index as base layer only hnswlib index | ||
* cagra::serialize_to_hnswlib(res, "my_index.bin", index); | ||
* | ||
* // Load CAGRA index as base layer only hnswlib index | ||
* cagra_hnswlib::index(D, "my_index.bin", raft::distance::L2Expanded); | ||
* | ||
* // Search K nearest neighbors as an hnswlib index | ||
* // using host threads for concurrency | ||
* cagra_hnswlib::search_params search_params; | ||
* search_params.ef = 50 // ef >= K; | ||
* search_params.num_threads = 10; | ||
* auto neighbors = raft::make_host_matrix<uint32_t>(res, n_queries, k); | ||
* auto distances = raft::make_host_matrix<float>(res, n_queries, k); | ||
* cagra_hnswlib::search(res, search_params, index, queries, neighbors, distances); | ||
* @endcode | ||
*/ | ||
template <typename T> | ||
void search(raft::resources const& res, | ||
const search_params& params, | ||
const index<T>& idx, | ||
raft::host_matrix_view<const T, int64_t, row_major> queries, | ||
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors, | ||
raft::host_matrix_view<float, int64_t, row_major> distances) | ||
{ | ||
RAFT_EXPECTS( | ||
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), | ||
"Number of rows in output neighbors and distances matrices must equal the number of queries."); | ||
|
||
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), | ||
"Number of columns in output neighbors and distances matrices must equal k"); | ||
RAFT_EXPECTS(queries.extent(1) == idx.dim(), | ||
"Number of query dimensions should equal number of dimensions in the index."); | ||
|
||
detail::search(res, params, idx, queries, neighbors, distances); | ||
} | ||
|
||
} // namespace raft::neighbors::cagra_hnswlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "ann_types.hpp" | ||
#include <memory> | ||
#include <raft/distance/distance_types.hpp> | ||
|
||
#include <cstdint> | ||
#include <hnswlib.h> | ||
#include <sys/types.h> | ||
#include <type_traits> | ||
|
||
namespace raft::neighbors::cagra_hnswlib { | ||
|
||
template <typename T> | ||
struct hnsw_dist_t { | ||
using type = void; | ||
}; | ||
|
||
template <> | ||
struct hnsw_dist_t<float> { | ||
using type = float; | ||
}; | ||
|
||
template <> | ||
struct hnsw_dist_t<std::uint8_t> { | ||
using type = int; | ||
}; | ||
|
||
template <> | ||
struct hnsw_dist_t<std::int8_t> { | ||
using type = int; | ||
}; | ||
|
||
struct search_params : ann::search_params { | ||
int ef; // size of the candidate list | ||
int num_threads = 1; // number of host threads to use for concurrent searches | ||
}; | ||
|
||
template <typename T> | ||
struct index : ann::index { | ||
public: | ||
/** | ||
* @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index | ||
* | ||
* @param[in] filepath path to the index | ||
* @param[in] dim dimensions of the training dataset | ||
* @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") | ||
*/ | ||
index(std::string filepath, int dim, raft::distance::DistanceType metric) | ||
: dim_{dim}, metric_{metric} | ||
{ | ||
if constexpr (std::is_same_v<T, float>) { | ||
if (metric == raft::distance::L2Expanded) { | ||
space_ = std::make_unique<hnswlib::L2Space>(dim_); | ||
} else if (metric == raft::distance::InnerProduct) { | ||
space_ = std::make_unique<hnswlib::InnerProductSpace>(dim_); | ||
} | ||
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) { | ||
if (metric == raft::distance::L2Expanded) { | ||
space_ = std::make_unique<hnswlib::L2SpaceI>(dim_); | ||
} | ||
} | ||
|
||
RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used"); | ||
|
||
appr_alg_ = std::make_unique<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>( | ||
space_.get(), filepath); | ||
|
||
appr_alg_->base_layer_only = true; | ||
} | ||
|
||
/** | ||
@brief Get hnswlib index | ||
*/ | ||
auto get_index() const -> hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* | ||
{ | ||
return appr_alg_.get(); | ||
} | ||
|
||
auto dim() const -> int const { return dim_; } | ||
|
||
auto metric() const -> raft::distance::DistanceType { return metric_; } | ||
|
||
private: | ||
int dim_; | ||
raft::distance::DistanceType metric_; | ||
|
||
std::unique_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_; | ||
std::unique_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_; | ||
}; | ||
|
||
} // namespace raft::neighbors::cagra_hnswlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.