Skip to content

Commit

Permalink
Add option to brute_force index to maintain reference to non-owning n…
Browse files Browse the repository at this point in the history
…orms (#1865)

This makes the faiss integration substantially easier, since we can just use the existing norms that have already been calculated in GpuDistanceParams::vectorNorms - rather than require an owned copy that lives in the brute force index.

Authors:
   - Ben Frederickson (https://github.com/benfred)

Approvers:
   - Corey J. Nolet (https://github.com/cjnolet)
  • Loading branch information
benfred authored Oct 2, 2023
1 parent 1ee423b commit 120cff4
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ struct index : ann::index {
/** Dataset norms */
[[nodiscard]] inline auto norms() const -> device_vector_view<const T, int64_t, row_major>
{
return make_const_mdspan(norms_.value().view());
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); }
[[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }

Expand Down Expand Up @@ -102,10 +102,30 @@ struct index : ann::index {
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); }
update_dataset(res, dataset);
resource::sync_stream(res);
}

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms here.
* Having precomputed norms gives us a performance advantage at query time.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset_view,
std::optional<raft::device_vector_view<const T, int64_t>> norms_view,
raft::distance::DistanceType metric,
T metric_arg = 0.0)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_view_(dataset_view),
norms_view_(norms_view),
metric_arg_(metric_arg)
{
}

private:
/**
* Replace the dataset with a new dataset.
Expand Down Expand Up @@ -135,6 +155,7 @@ struct index : ann::index {
raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
std::optional<raft::device_vector<T, int64_t>> norms_;
std::optional<raft::device_vector_view<const T, int64_t>> norms_view_;
raft::device_matrix_view<const T, int64_t, row_major> dataset_view_;
T metric_arg_;
};
Expand Down

0 comments on commit 120cff4

Please sign in to comment.