Skip to content

Commit

Permalink
Merge branch 'rapidsai:branch-24.08' into fea-nnd-dist-epilogue
Browse files Browse the repository at this point in the history
  • Loading branch information
jinsolp authored Jun 27, 2024
2 parents a30f85f + 36f77a1 commit 86f91cb
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 23 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RaftCagraHnswlib : public ANN<T>, public AnnGPU {

RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
: ANN<T>(metric, dim),
cagra_build_{metric, dim, param, concurrent_searches},
cagra_build_{metric, dim, param, concurrent_searches, true},
// HnswLib param values don't matter since we don't build with HnswLib
hnswlib_search_{metric, dim, typename HnswLib<T>::BuildParam{50, 100}}
{
Expand Down
12 changes: 9 additions & 3 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,16 @@ class RaftCagra : public ANN<T>, public AnnGPU {
std::optional<raft::neighbors::ivf_pq::search_params> ivf_pq_search_params = std::nullopt;
};

RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
RaftCagra(Metric metric,
int dim,
const BuildParam& param,
int concurrent_searches = 1,
bool shall_include_dataset = false)
: ANN<T>(metric, dim),
index_params_(param),
dimension_(dim),
need_dataset_update_(true),
shall_include_dataset_(shall_include_dataset),
dataset_(std::make_shared<raft::device_matrix<T, int64_t, row_major>>(
std::move(make_device_matrix<T, int64_t>(handle_, 0, 0)))),
graph_(std::make_shared<raft::device_matrix<IdxT, int64_t, row_major>>(
Expand Down Expand Up @@ -135,6 +140,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
float refine_ratio_;
BuildParam index_params_;
bool need_dataset_update_;
bool shall_include_dataset_;
raft::neighbors::cagra::search_params search_params_;
std::shared_ptr<raft::neighbors::cagra::index<T, IdxT>> index_;
int dimension_;
Expand All @@ -161,7 +167,7 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)
auto& params = index_params_.cagra_params;

// Do include the compressed dataset for the CAGRA-Q
bool shall_include_dataset = params.compression.has_value();
bool include_dataset = params.compression.has_value() || shall_include_dataset_;

index_ = std::make_shared<raft::neighbors::cagra::index<T, IdxT>>(
std::move(raft::neighbors::cagra::detail::build(handle_,
Expand All @@ -171,7 +177,7 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)
index_params_.ivf_pq_refine_rate,
index_params_.ivf_pq_build_params,
index_params_.ivf_pq_search_params,
shall_include_dataset)));
include_dataset)));
}

inline std::string allocator_to_string(AllocatorType mem_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <iostream>
#include <memory>
#include <numeric>
Expand Down Expand Up @@ -209,7 +210,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(

#if 0
/* debug */
for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += BLOCK_SIZE) {
for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) {
result_indices_buffer[i] = utils::get_max_value<INDEX_T>();
result_distances_buffer[i] = utils::get_max_value<DISTANCE_T>();
}
Expand Down Expand Up @@ -351,16 +352,19 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
}

#ifdef _CLK_BREAKDOWN
if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && (blockIdx.x == 0) &&
if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && (blockIdx.x == 0) &&
((query_id * 3) % gridDim.y < 3)) {
RAFT_LOG_DEBUG(
printf(
"%s:%d "
"query, %d, thread, %d"
", init, %d"
", init, %lu"
", 1st_distance, %lu"
", topk, %lu"
", pickup_parents, %lu"
", distance, %lu"
"\n",
__FILE__,
__LINE__,
query_id,
threadIdx.x,
clk_init,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <iostream>
#include <memory>
#include <numeric>
Expand Down Expand Up @@ -448,14 +449,6 @@ __device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr,
}
}

template <class T, unsigned BLOCK_SIZE>
__device__ inline void set_value_device(T* const ptr, const T fill, const std::uint32_t count)
{
for (std::uint32_t i = threadIdx.x; i < count; i += BLOCK_SIZE) {
ptr[i] = fill;
}
}

// One query one thread block
template <uint32_t TEAM_SIZE,
uint32_t DATASET_BLOCK_DIM,
Expand Down Expand Up @@ -791,17 +784,20 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
num_executed_iterations[query_id] = iter + 1;
}
#ifdef _CLK_BREAKDOWN
if ((threadIdx.x == 0 || threadIdx.x == BLOCK_SIZE - 1) && ((query_id * 3) % gridDim.y < 3)) {
RAFT_LOG_DEBUG(
if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && ((query_id * 3) % gridDim.y < 3)) {
printf(
"%s:%d "
"query, %d, thread, %d"
", init, %d"
", init, %lu"
", 1st_distance, %lu"
", topk, %lu"
", reset_hash, %lu"
", pickup_parents, %lu"
", restore_hash, %lu"
", distance, %lu"
"\n",
__FILE__,
__LINE__,
query_id,
threadIdx.x,
clk_init,
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void search(raft::resources const& res,
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<float, int64_t, row_major> distances)
{
idx.set_ef(params.ef);
auto const* hnswlib_index =
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/detail/hnsw_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ struct index_impl : index<T> {
*/
auto get_index() const -> void const* override { return appr_alg_.get(); }

/**
@brief Set ef for search
*/
void set_ef(int ef) const override { appr_alg_->ef_ = ef; }

private:
std::unique_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::unique_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace raft::neighbors::hnsw {

/**
* @brief Construct an hnswlib base-layer-only index from a CAGRA index
* NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/cagra_index.bin`
* NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/<random_number>.bin`
* before reading it as an hnswlib index, then deleting the temporary file.
* 2. This function is only offered as a compiled symbol in `libraft.so`
*
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/hnsw_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ struct index : ann::index {

auto metric() const -> raft::distance::DistanceType { return metric_; }

/**
@brief Set ef for search
*/
virtual void set_ef(int ef) const;

private:
int dim_;
raft::distance::DistanceType metric_;
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/raft_runtime/neighbors/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@
#include <raft_runtime/neighbors/hnsw.hpp>

#include <filesystem>
#include <random>
#include <string>

namespace raft::neighbors::hnsw {
#define RAFT_INST_HNSW(T) \
template <> \
std::unique_ptr<raft::neighbors::hnsw::index<T>> from_cagra( \
raft::resources const& res, raft::neighbors::cagra::index<T, uint32_t> cagra_index) \
{ \
std::string filepath = "/tmp/cagra_index.bin"; \
std::random_device dev; \
std::mt19937 rng(dev()); \
std::uniform_int_distribution<std::mt19937::result_type> dist(0); \
auto uuid = std::to_string(dist(rng)); \
std::string filepath = "/tmp/" + uuid + ".bin"; \
raft::runtime::neighbors::cagra::serialize_to_hnswlib(res, filepath, cagra_index); \
auto hnsw_index = raft::runtime::neighbors::hnsw::deserialize_file<T>( \
res, filepath, cagra_index.dim(), cagra_index.metric()); \
Expand Down
6 changes: 4 additions & 2 deletions python/pylibraft/pylibraft/neighbors/hnsw.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ from pylibraft.common.mdspan cimport (
from pylibraft.neighbors.common cimport _get_metric_string

import os
import uuid

import numpy as np

Expand Down Expand Up @@ -292,7 +293,7 @@ def from_cagra(Index index, handle=None):
Returns an hnswlib base-layer-only index from a CAGRA index.
NOTE: This method uses the filesystem to write the CAGRA index in
`/tmp/cagra_index.bin` before reading it as an hnswlib index,
`/tmp/<random_number>.bin` before reading it as an hnswlib index,
then deleting the temporary file.
Saving / loading the index is experimental. The serialization format is
Expand Down Expand Up @@ -320,7 +321,8 @@ def from_cagra(Index index, handle=None):
>>> # Serialize the CAGRA index to hnswlib base layer only index format
>>> hnsw_index = hnsw.from_cagra(index, handle=handle)
"""
filename = "/tmp/cagra_index.bin"
uuid_num = uuid.uuid4()
filename = f"/tmp/{uuid_num}.bin"
save(filename, index, handle=handle)
hnsw_index = load(filename, index.dim, np.dtype(index.active_index_type),
_get_metric_string(index.metric), handle=handle)
Expand Down

0 comments on commit 86f91cb

Please sign in to comment.