Skip to content

Commit

Permalink
Add index class for brute_force knn (#1817)
Browse files Browse the repository at this point in the history
This adds an index class to match the ANN methods. This allows us to precompute norms for the dataset in `brute_force::build` and then use them in `brute_force::search` - meaning we don't have to compute norms for the entire dataset on every query.

Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1817
  • Loading branch information
benfred authored Sep 27, 2023
1 parent 6c7cada commit 25858c5
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 52 deletions.
32 changes: 20 additions & 12 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ include(rapids-find)

option(BUILD_CPU_ONLY "Build CPU only components. Applies to RAFT ANN benchmarks currently" OFF)

# workaround for rapids_cuda_init_architectures not working for arch detection with enable_language(CUDA)
# workaround for rapids_cuda_init_architectures not working for arch detection with
# enable_language(CUDA)
set(lang_list "CXX")

if(NOT BUILD_CPU_ONLY)
Expand Down Expand Up @@ -286,7 +287,8 @@ endif()
set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled)

if(RAFT_COMPILE_LIBRARY)
add_library(raft_objs OBJECT
add_library(
raft_objs OBJECT
src/core/logger.cpp
src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
Expand Down Expand Up @@ -331,6 +333,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu
src/neighbors/brute_force_knn_int_float_int.cu
src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu
src/neighbors/brute_force_knn_index_float.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
Expand Down Expand Up @@ -452,18 +455,21 @@ if(RAFT_COMPILE_LIBRARY)
src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu
src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu
src/util/memory_pool.cpp
)
)
set_target_properties(
raft_objs
PROPERTIES CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)
POSITION_INDEPENDENT_CODE ON
)

target_compile_definitions(raft_objs PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY")
target_compile_options(raft_objs PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${RAFT_CXX_FLAGS}>"
"$<$<COMPILE_LANGUAGE:CUDA>:${RAFT_CUDA_FLAGS}>")
target_compile_options(
raft_objs PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${RAFT_CXX_FLAGS}>"
"$<$<COMPILE_LANGUAGE:CUDA>:${RAFT_CUDA_FLAGS}>"
)

add_library(raft_lib SHARED $<TARGET_OBJECTS:raft_objs>)
add_library(raft_lib_static STATIC $<TARGET_OBJECTS:raft_objs>)
Expand All @@ -477,13 +483,15 @@ if(RAFT_COMPILE_LIBRARY)
)

foreach(target raft_lib raft_lib_static raft_objs)
target_link_libraries(${target} PUBLIC
raft::raft
${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
# will just be cublas
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>)
target_link_libraries(
${target}
PUBLIC raft::raft
${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
# will just be cublas
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
)

#So consumers know when using libraft.so/libraft.a
# So consumers know when using libraft.so/libraft.a
target_compile_definitions(${target} PUBLIC "RAFT_COMPILED")
# ensure CUDA symbols aren't relocated to the middle of the debug build binaries
target_link_options(${target} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld")
Expand Down
39 changes: 38 additions & 1 deletion cpp/include/raft/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/core/resources.hpp> // raft::resources
#include <raft/distance/distance_types.hpp> // raft::distance::DistanceType
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <raft/neighbors/brute_force_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

Expand All @@ -38,6 +39,19 @@ inline void knn_merge_parts(
size_t n_samples,
std::optional<raft::device_vector_view<idx_t, idx_t>> translations = std::nullopt) RAFT_EXPLICIT;

template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances) RAFT_EXPLICIT;

template <typename idx_t,
typename value_t,
typename matrix_idx,
Expand Down Expand Up @@ -93,6 +107,29 @@ instantiate_raft_neighbors_brute_force_knn(

#undef instantiate_raft_neighbors_brute_force_knn

namespace raft::neighbors::brute_force {

extern template void search<float, int>(
raft::resources const& res,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int64_t>(
raft::resources const& res,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
raft::device_matrix_view<const float, int64_t, row_major> dataset,
raft::distance::DistanceType metric,
float metric_arg);
} // namespace raft::neighbors::brute_force

#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \
value_t, idx_t, idx_layout, query_layout) \
extern template void raft::neighbors::brute_force::fused_l2_knn( \
Expand Down
98 changes: 97 additions & 1 deletion cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/brute_force_types.hpp>
#include <raft/neighbors/detail/knn_brute_force.cuh>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>

Expand Down Expand Up @@ -280,6 +281,101 @@ void fused_l2_knn(raft::resources const& handle,
metric);
}

/** @} */ // end group brute_force_knn
/**
* @brief Build the index from the dataset for efficient search.
*
* @tparam T data element type
*
* @param[in] res
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
* @param[in] metric: distance metric to use. Euclidean (L2) is used by default
* @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This
* is ignored if the metric_type is not Minkowski.
*
* @return the constructed brute force index
*/
template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0)
{
// certain distance metrics can benefit by pre-calculating the norms for the index dataset
// which lets us avoid calculating these at query time
std::optional<device_vector<T, int64_t>> norms;
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::CosineExpanded) {
norms = make_device_vector<T, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == raft::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
}
}

return index<T>(res, dataset, std::move(norms), metric, metric_arg);
}

/**
* @brief Brute Force search using the constructed index.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] idx brute force index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
{
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
"Number of columns in queries must match brute force index");

auto k = neighbors.extent(1);
auto d = idx.dataset().extent(1);

std::vector<T*> dataset = {const_cast<T*>(idx.dataset().data_handle())};
std::vector<int64_t> sizes = {idx.dataset().extent(0)};
std::vector<T*> norms;
if (idx.has_norms()) { norms.push_back(const_cast<T*>(idx.norms().data_handle())); }

detail::brute_force_knn_impl<int64_t, IdxT, T>(res,
dataset,
sizes,
d,
const_cast<T*>(queries.data_handle()),
queries.extent(0),
neighbors.data_handle(),
distances.data_handle(),
k,
true,
true,
nullptr,
idx.metric(),
idx.metric_arg(),
raft::identity_op(),
norms.size() ? &norms : nullptr);
}
/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
144 changes: 144 additions & 0 deletions cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "ann_types.hpp"
#include <raft/core/resource/cuda_stream.hpp>

#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>

#include <raft/core/logger.hpp>

namespace raft::neighbors::brute_force {
/**
* @addtogroup brute_force
* @{
*/

/**
* @brief Brute Force index.
*
* The index stores the dataset and norms for the dataset in device memory.
*
* @tparam T data element type
*/
template <typename T>
struct index : ann::index {
public:
/** Distance metric used for retrieval */
[[nodiscard]] constexpr inline raft::distance::DistanceType metric() const noexcept
{
return metric_;
}

/** Total length of the index (number of vectors). */
[[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); }

/** Dimensionality of the data. */
[[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); }

/** Dataset [size, dim] */
[[nodiscard]] inline auto dataset() const noexcept
-> device_matrix_view<const T, int64_t, row_major>
{
return dataset_view_;
}

/** Dataset norms */
[[nodiscard]] inline auto norms() const -> device_vector_view<const T, int64_t, row_major>
{
return make_const_mdspan(norms_.value().view());
}

/** Whether ot not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }

// Don't allow copying the index for performance reasons (try avoiding copying data)
index(const index&) = delete;
index(index&&) = default;
auto operator=(const index&) -> index& = delete;
auto operator=(index&&) -> index& = default;
~index() = default;

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* If the dataset is already in GPU memory, then this class stores a non-owning reference to
* the dataset. If the dataset is in host memory, it will be copied to the device and the
* index will own the device memory.
*/
template <typename data_accessor>
index(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset,
std::optional<raft::device_vector<T, int64_t>>&& norms,
raft::distance::DistanceType metric,
T metric_arg = 0.0)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
update_dataset(res, dataset);
resource::sync_stream(res);
}

private:
/**
* Replace the dataset with a new dataset.
*/
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
dataset_view_ = dataset;
}

/**
* Replace the dataset with a new dataset.
*
* We create a copy of the dataset on the device. The index manages the lifetime of this copy.
*/
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
dataset_ = make_device_matrix<T, int64_t>(dataset.extents(0), dataset.extents(1));
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
dataset.size(),
resource::get_cuda_stream(res));
dataset_view_ = make_const_mdspan(dataset_.view());
}

raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
std::optional<raft::device_vector<T, int64_t>> norms_;
raft::device_matrix_view<const T, int64_t, row_major> dataset_view_;
T metric_arg_;
};

/** @} */

} // namespace raft::neighbors::brute_force
Loading

0 comments on commit 25858c5

Please sign in to comment.