Skip to content

Commit

Permalink
Fix query id offset
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Sep 8, 2023
1 parent 8bbd285 commit a083244
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
#include <raft/neighbors/sample_filter_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <raft/core/device_mdspan.hpp>
Expand All @@ -32,6 +33,48 @@

namespace raft::neighbors::cagra::detail {

template <class CagraSampleFilterT>
struct CagraSampleFilterWithQueryIdOffset {
const std::size_t offset;
CagraSampleFilterT filter;

CagraSampleFilterWithQueryIdOffset(const std::size_t offset, const CagraSampleFilterT filter)
: offset(offset), filter(filter)
{
}

_RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id)
{
return filter(query_id + offset, sample_id);
}
};

template <class CagraSampleFilterT>
struct CagraSampleFilterT_Selector {
using type = CagraSampleFilterWithQueryIdOffset<CagraSampleFilterT>;
};
template <>
struct CagraSampleFilterT_Selector<raft::neighbors::filtering::none_cagra_sample_filter> {
using type = raft::neighbors::filtering::none_cagra_sample_filter;
};

// A helper function to set a query id offset
template <class CagraSampleFilterT>
inline typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type set_offset(
CagraSampleFilterT filter, const uint32_t offset)
{
typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type new_filter(offset, filter);
return new_filter;
}
template <>
inline
typename CagraSampleFilterT_Selector<raft::neighbors::filtering::none_cagra_sample_filter>::type
set_offset<raft::neighbors::filtering::none_cagra_sample_filter>(
raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t)
{
return filter;
}

/**
* @brief Search ANN using the constructed index.
*
Expand Down Expand Up @@ -76,8 +119,9 @@ void search_main(raft::resources const& res,

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT, CagraSampleFilterT>> plan =
factory<T, internal_IdxT, DistanceT, CagraSampleFilterT>::create(
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>> plan =
factory<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>::create(
res, params, index.dim(), index.graph_degree(), topk);

plan->check(neighbors.extent(1));
Expand Down Expand Up @@ -119,7 +163,7 @@ void search_main(raft::resources const& res,
_seed_ptr,
_num_executed_iterations,
topk,
sample_filter);
set_offset(sample_filter, qid));
}

static_assert(std::is_same_v<DistanceT, float>,
Expand Down

0 comments on commit a083244

Please sign in to comment.