Skip to content

Commit

Permalink
add cpp index and python load,search methods with runtime API
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Nov 28, 2023
1 parent 7f04f97 commit 4b2b2c6
Show file tree
Hide file tree
Showing 10 changed files with 858 additions and 6 deletions.
6 changes: 6 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON)
option(BUILD_TESTS "Build raft unit-tests" ON)
option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF)
option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF)
option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON)
option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF)
option(CUDA_ENABLE_LINEINFO
"Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF
Expand Down Expand Up @@ -195,13 +196,18 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH)
rapids_cpm_gbench()
endif()

if (BUILD_CAGRA_HNSWLIB)
include(cmake/thirdparty/get_hnswlib.cmake)
endif()

# ##################################################################################################
# * raft ---------------------------------------------------------------------
add_library(raft INTERFACE)
add_library(raft::raft ALIAS raft)

target_include_directories(
raft INTERFACE "$<BUILD_INTERFACE:${RAFT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:include>"
"$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib>"
)

if(NOT BUILD_CPU_ONLY)
Expand Down
90 changes: 90 additions & 0 deletions cpp/include/raft/neighbors/cagra_hnswlib.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "cagra_hnswlib_types.hpp"
#include "detail/cagra_hnswlib.hpp"

#include <cstddef>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>

namespace raft::neighbors::cagra_hnswlib {

/**
* @brief Search hnswlib base layer only index constructed from a CAGRA index
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*
* Usage example:
* @code{.cpp}
* // Build a CAGRA index
* using namespace raft::neighbors;
* // use default index parameters
* cagra::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = cagra::build(res, index_params, dataset);
*
* // Save CAGRA index as base layer only hnswlib index
* cagra::serialize_to_hnswlib(res, "my_index.bin", index);
*
* // Load CAGRA index as base layer only hnswlib index
* cagra_hnswlib::index(D, "my_index.bin", raft::distance::L2Expanded);
*
* // Search K nearest neighbors as an hnswlib index
* // using host threads for concurrency
* cagra_hnswlib::search_params search_params;
* search_params.ef = 50 // ef >= K;
* search_params.num_threads = 10;
* auto neighbors = raft::make_host_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_host_matrix<float>(res, n_queries, k);
* cagra_hnswlib::search(res, search_params, index, queries, neighbors, distances);
* @endcode
*/
template <typename T>
void search(raft::resources const& res,
const search_params& params,
const index<T>& idx,
raft::host_matrix_view<const T, int64_t, row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<float, int64_t, row_major> distances)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

detail::search(res, params, idx, queries, neighbors, distances);
}

} // namespace raft::neighbors::cagra_hnswlib
108 changes: 108 additions & 0 deletions cpp/include/raft/neighbors/cagra_hnswlib_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "ann_types.hpp"
#include <memory>
#include <raft/distance/distance_types.hpp>

#include <cstdint>
#include <hnswlib.h>
#include <sys/types.h>
#include <type_traits>

namespace raft::neighbors::cagra_hnswlib {

template <typename T>
struct hnsw_dist_t {
using type = void;
};

template <>
struct hnsw_dist_t<float> {
using type = float;
};

template <>
struct hnsw_dist_t<std::uint8_t> {
using type = int;
};

template <>
struct hnsw_dist_t<std::int8_t> {
using type = int;
};

struct search_params : ann::search_params {
int ef; // size of the candidate list
int num_threads = 1; // number of host threads to use for concurrent searches
};

template <typename T>
struct index : ann::index {
public:
/**
* @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index
*
* @param[in] filepath path to the index
* @param[in] dim dimensions of the training dataset
* @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct")
*/
index(std::string filepath, int dim, raft::distance::DistanceType metric)
: dim_{dim}, metric_{metric}
{
if constexpr (std::is_same_v<T, float>) {
if (metric == raft::distance::L2Expanded) {
space_ = std::make_unique<hnswlib::L2Space>(dim_);
} else if (metric == raft::distance::InnerProduct) {
space_ = std::make_unique<hnswlib::InnerProductSpace>(dim_);
}
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
if (metric == raft::distance::L2Expanded) {
space_ = std::make_unique<hnswlib::L2SpaceI>(dim_);
}
}

RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used");

appr_alg_ = std::make_unique<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>>(
space_.get(), filepath);

appr_alg_->base_layer_only = true;
}

/**
@brief Get hnswlib index
*/
auto get_index() const -> hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*
{
return appr_alg_.get();
}

auto dim() const -> int const { return dim_; }

auto metric() const -> raft::distance::DistanceType { return metric_; }

private:
int dim_;
raft::distance::DistanceType metric_;

std::unique_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::unique_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
};

} // namespace raft::neighbors::cagra_hnswlib
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ void serialize_to_hnswlib(raft::resources const& res,
auto zero = 0;
os.write(reinterpret_cast<char*>(&zero), sizeof(int));
}
// delete [] host_graph;
}

template <typename T, typename IdxT>
Expand Down
Loading

0 comments on commit 4b2b2c6

Please sign in to comment.