diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ada7753857..5d2864e2e0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -22,7 +22,8 @@ include(rapids-find) option(BUILD_CPU_ONLY "Build CPU only components. Applies to RAFT ANN benchmarks currently" OFF) -# workaround for rapids_cuda_init_architectures not working for arch detection with enable_language(CUDA) +# workaround for rapids_cuda_init_architectures not working for arch detection with +# enable_language(CUDA) set(lang_list "CXX") if(NOT BUILD_CPU_ONLY) @@ -286,7 +287,8 @@ endif() set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) - add_library(raft_objs OBJECT + add_library( + raft_objs OBJECT src/core/logger.cpp src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -331,6 +333,7 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu src/neighbors/brute_force_knn_int_float_int.cu src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/brute_force_knn_index_float.cu src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu @@ -452,18 +455,21 @@ if(RAFT_COMPILE_LIBRARY) src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu src/util/memory_pool.cpp - ) + ) set_target_properties( raft_objs PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON CUDA_STANDARD 17 CUDA_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON) + POSITION_INDEPENDENT_CODE ON + ) target_compile_definitions(raft_objs PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") - target_compile_options(raft_objs PRIVATE "$<$:${RAFT_CXX_FLAGS}>" - "$<$:${RAFT_CUDA_FLAGS}>") + target_compile_options( + raft_objs PRIVATE "$<$:${RAFT_CXX_FLAGS}>" + "$<$:${RAFT_CUDA_FLAGS}>" + ) add_library(raft_lib SHARED $) add_library(raft_lib_static STATIC $) @@ -477,13 +483,15 @@ if(RAFT_COMPILE_LIBRARY) ) foreach(target raft_lib raft_lib_static raft_objs) - target_link_libraries(${target} PUBLIC - raft::raft - ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this - # will just be cublas - $) + target_link_libraries( + ${target} + PUBLIC raft::raft + ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this + # will just be cublas + $ + ) - #So consumers know when using libraft.so/libraft.a + # So consumers know when using libraft.so/libraft.a target_compile_definitions(${target} PUBLIC "RAFT_COMPILED") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(${target} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index 862db75866..b8c00616da 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -22,7 +22,8 @@ #include // raft::identity_op #include // raft::resources #include // raft::distance::DistanceType -#include // RAFT_EXPLICIT +#include +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -38,6 +39,19 @@ inline void knn_merge_parts( size_t n_samples, std::optional> translations = std::nullopt) RAFT_EXPLICIT; +template +index build(raft::resources const& res, + mdspan, row_major, Accessor> dataset, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + T metric_arg = 0.0) RAFT_EXPLICIT; + +template +void search(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + template ( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +extern template void search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +extern template raft::neighbors::brute_force::index build( + raft::resources const& res, + raft::device_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); +} // namespace raft::neighbors::brute_force + #define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ value_t, idx_t, idx_layout, query_layout) \ extern template void raft::neighbors::brute_force::fused_l2_knn( \ diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index bc9e09e5b0..88439a738b 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -280,6 +281,101 @@ void fused_l2_knn(raft::resources const& handle, metric); } -/** @} */ // end group brute_force_knn +/** + * @brief Build the index from the dataset for efficient search. + * + * @tparam T data element type + * + * @param[in] res + * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] + * @param[in] metric: distance metric to use. Euclidean (L2) is used by default + * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + * + * @return the constructed brute force index + */ +template +index build(raft::resources const& res, + mdspan, row_major, Accessor> dataset, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + T metric_arg = 0.0) +{ + // certain distance metrics can benefit by pre-calculating the norms for the index dataset + // which lets us avoid calculating these at query time + std::optional> norms; + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded) { + norms = make_device_vector(res, dataset.extent(0)); + // cosine needs the l2norm, where as l2 distances needs the squared norm + if (metric == raft::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(res, + dataset, + norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op{}); + } else { + raft::linalg::norm(res, + dataset, + norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + } + } + + return index(res, dataset, std::move(norms), metric, metric_arg); +} +/** + * @brief Brute Force search using the constructed index. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] idx brute force index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + + auto k = neighbors.extent(1); + auto d = idx.dataset().extent(1); + + std::vector dataset = {const_cast(idx.dataset().data_handle())}; + std::vector sizes = {idx.dataset().extent(0)}; + std::vector norms; + if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } + + detail::brute_force_knn_impl(res, + dataset, + sizes, + d, + const_cast(queries.data_handle()), + queries.extent(0), + neighbors.data_handle(), + distances.data_handle(), + k, + true, + true, + nullptr, + idx.metric(), + idx.metric_arg(), + raft::identity_op(), + norms.size() ? &norms : nullptr); +} +/** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp new file mode 100644 index 0000000000..cc934b7a98 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::brute_force { +/** + * @addtogroup brute_force + * @{ + */ + +/** + * @brief Brute Force index. + * + * The index stores the dataset and norms for the dataset in device memory. + * + * @tparam T data element type + */ +template +struct index : ann::index { + public: + /** Distance metric used for retrieval */ + [[nodiscard]] constexpr inline raft::distance::DistanceType metric() const noexcept + { + return metric_; + } + + /** Total length of the index (number of vectors). */ + [[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); } + + /** Dimensionality of the data. */ + [[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); } + + /** Dataset [size, dim] */ + [[nodiscard]] inline auto dataset() const noexcept + -> device_matrix_view + { + return dataset_view_; + } + + /** Dataset norms */ + [[nodiscard]] inline auto norms() const -> device_vector_view + { + return make_const_mdspan(norms_.value().view()); + } + + /** Whether ot not this index has dataset norms */ + [[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); } + + [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } + + // Don't allow copying the index for performance reasons (try avoiding copying data) + index(const index&) = delete; + index(index&&) = default; + auto operator=(const index&) -> index& = delete; + auto operator=(index&&) -> index& = default; + ~index() = default; + + /** Construct a brute force index from dataset + * + * Constructs a brute force index from a dataset. This lets us precompute norms for + * the dataset, providing a speed benefit over doing this at query time. + + * If the dataset is already in GPU memory, then this class stores a non-owning reference to + * the dataset. If the dataset is in host memory, it will be copied to the device and the + * index will own the device memory. + */ + template + index(raft::resources const& res, + mdspan, row_major, data_accessor> dataset, + std::optional>&& norms, + raft::distance::DistanceType metric, + T metric_arg = 0.0) + : ann::index(), + metric_(metric), + dataset_(make_device_matrix(res, 0, 0)), + norms_(std::move(norms)), + metric_arg_(metric_arg) + { + update_dataset(res, dataset); + resource::sync_stream(res); + } + + private: + /** + * Replace the dataset with a new dataset. + */ + void update_dataset(raft::resources const& res, + raft::device_matrix_view dataset) + { + dataset_view_ = dataset; + } + + /** + * Replace the dataset with a new dataset. + * + * We create a copy of the dataset on the device. The index manages the lifetime of this copy. + */ + void update_dataset(raft::resources const& res, + raft::host_matrix_view dataset) + { + dataset_ = make_device_matrix(dataset.extents(0), dataset.extents(1)); + raft::copy(dataset_.data_handle(), + dataset.data_handle(), + dataset.size(), + resource::get_cuda_stream(res)); + dataset_view_ = make_const_mdspan(dataset_.view()); + } + + raft::distance::DistanceType metric_; + raft::device_matrix dataset_; + std::optional> norms_; + raft::device_matrix_view dataset_view_; + T metric_arg_; +}; + +/** @} */ + +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 123a902ef9..be05d5545f 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -64,10 +64,11 @@ void tiled_brute_force_knn(const raft::resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - DistanceEpilogue distance_epilogue = raft::identity_op()) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op(), + const ElementType* precomputed_index_norms = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -97,7 +98,7 @@ void tiled_brute_force_knn(const raft::resources& handle, metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::CosineExpanded) { search_norms.resize(m, stream); - index_norms.resize(n, stream); + if (!precomputed_index_norms) { index_norms.resize(n, stream); } // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == raft::distance::DistanceType::CosineExpanded) { raft::linalg::rowNorm(search_norms.data(), @@ -108,19 +109,24 @@ void tiled_brute_force_knn(const raft::resources& handle, true, stream, raft::sqrt_op{}); - raft::linalg::rowNorm(index_norms.data(), - index, - d, - n, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); + if (!precomputed_index_norms) { + raft::linalg::rowNorm(index_norms.data(), + index, + d, + n, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } } else { raft::linalg::rowNorm( search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - raft::linalg::rowNorm( - index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + + if (!precomputed_index_norms) { + raft::linalg::rowNorm( + index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + } } pairwise_metric = raft::distance::DistanceType::InnerProduct; } @@ -178,7 +184,7 @@ void tiled_brute_force_knn(const raft::resources& handle, if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { auto row_norms = search_norms.data(); - auto col_norms = index_norms.data(); + auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( @@ -200,7 +206,7 @@ void tiled_brute_force_knn(const raft::resources& handle, }); } else if (metric == raft::distance::DistanceType::CosineExpanded) { auto row_norms = search_norms.data(); - auto col_norms = index_norms.data(); + auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( @@ -330,7 +336,8 @@ void brute_force_knn_impl( std::vector* translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metricArg = 0, - DistanceEpilogue distance_epilogue = raft::identity_op()) + DistanceEpilogue distance_epilogue = raft::identity_op(), + std::vector* input_norms = nullptr) { auto userStream = resource::get_cuda_stream(handle); @@ -424,7 +431,8 @@ void brute_force_knn_impl( rowMajorIndex, rowMajorQuery, stream, - metric); + metric, + input_norms ? (*input_norms)[i] : nullptr); // Perform necessary post-processing if (metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -473,7 +481,8 @@ void brute_force_knn_impl( metricArg, 0, 0, - distance_epilogue); + distance_epilogue, + input_norms ? (*input_norms)[i] : nullptr); break; } } diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh index 390436939f..1a48e1adde 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh @@ -36,7 +36,9 @@ void fusedL2Knn(size_t D, bool rowMajorIndex, bool rowMajorQuery, cudaStream_t stream, - raft::distance::DistanceType metric) RAFT_EXPLICIT; + raft::distance::DistanceType metric, + const value_t* index_norms = NULL, + const value_t* query_norms = NULL) RAFT_EXPLICIT; } // namespace raft::spatial::knn::detail @@ -56,7 +58,9 @@ void fusedL2Knn(size_t D, bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms); instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh index 4a571c1447..67abab3d1e 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh @@ -706,6 +706,8 @@ template void fusedL2ExpKnnImpl(const DataT* x, const DataT* y, + const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, @@ -787,19 +789,25 @@ void fusedL2ExpKnnImpl(const DataT* x, } } - DataT* xn = (DataT*)workspace; - DataT* yn = (DataT*)workspace; - - if (x != y) { - yn += m; - raft::linalg::rowNorm( - xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - raft::linalg::rowNorm( - yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } else { + // calculate norms if they haven't been passed in + if (!xn) { + DataT* xn_ = (DataT*)workspace; + workspace = xn_ + m; raft::linalg::rowNorm( - xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + xn_, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + xn = xn_; } + if (!yn) { + if (x == y) { + yn = xn; + } else { + DataT* yn_ = (DataT*)(workspace); + raft::linalg::rowNorm( + yn_, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + yn = yn_; + } + } + fusedL2ExpKnnRowMajor<<>>(x, y, xn, @@ -836,6 +844,8 @@ void fusedL2ExpKnn(IdxT m, IdxT ldd, const DataT* x, const DataT* y, + const DataT* xn, + const DataT* yn, bool sqrt, OutT* out_dists, IdxT* out_inds, @@ -850,6 +860,8 @@ void fusedL2ExpKnn(IdxT m, fusedL2ExpKnnImpl( x, y, + xn, + yn, m, n, k, @@ -867,6 +879,8 @@ void fusedL2ExpKnn(IdxT m, fusedL2ExpKnnImpl( x, y, + xn, + yn, m, n, k, @@ -883,6 +897,8 @@ void fusedL2ExpKnn(IdxT m, } else { fusedL2ExpKnnImpl(x, y, + xn, + yn, m, n, k, @@ -927,7 +943,9 @@ void fusedL2Knn(size_t D, bool rowMajorIndex, bool rowMajorQuery, cudaStream_t stream, - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, + const value_t* index_norms = NULL, + const value_t* query_norms = NULL) { // Validate the input data ASSERT(k > 0, "l2Knn: k must be > 0"); @@ -968,6 +986,8 @@ void fusedL2Knn(size_t D, ldd, query, index, + query_norms, + index_norms, sqrt, out_dists, out_inds, @@ -985,6 +1005,8 @@ void fusedL2Knn(size_t D, ldd, query, index, + query_norms, + index_norms, sqrt, out_dists, out_inds, diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu new file mode 100644 index 0000000000..f2fda93a97 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -0,0 +1,39 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +template void raft::neighbors::brute_force::search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +template void raft::neighbors::brute_force::search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +template raft::neighbors::brute_force::index raft::neighbors::brute_force::build( + raft::resources const& res, + raft::device_matrix_view dataset, + raft::distance::DistanceType metric, + float metric_arg); \ No newline at end of file diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu index 67b08655e6..b73cf31c58 100644 --- a/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu @@ -32,7 +32,9 @@ bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms) instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu index 3c0d13710e..35ef37c984 100644 --- a/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu @@ -32,7 +32,9 @@ bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms) instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu index e799c5181f..ff23d9c41b 100644 --- a/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu @@ -32,7 +32,9 @@ bool rowMajorIndex, \ bool rowMajorQuery, \ cudaStream_t stream, \ - raft::distance::DistanceType metric) + raft::distance::DistanceType metric, \ + const Mvalue_t* index_norms, \ + const Mvalue_t* query_norms) // These are used by brute_force_knn: instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); diff --git a/cpp/template/CMakeLists.txt b/cpp/template/CMakeLists.txt index a1341f3609..538eac07ef 100644 --- a/cpp/template/CMakeLists.txt +++ b/cpp/template/CMakeLists.txt @@ -39,4 +39,3 @@ target_link_libraries(CAGRA_EXAMPLE PRIVATE raft::raft raft::compiled) add_executable(IVF_FLAT_EXAMPLE src/ivf_flat_example.cu) target_link_libraries(IVF_FLAT_EXAMPLE PRIVATE raft::raft raft::compiled) - diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 2ab82b845e..ebde8e6d35 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -180,6 +180,36 @@ class TiledKNNTest : public ::testing::TestWithParam { float(0.001), stream_, true)); + + // Also test out the 'index' api - where we can use precomputed norms + if (params_.row_major) { + auto idx = + raft::neighbors::brute_force::build(handle_, + raft::make_device_matrix_view( + database.data(), params_.num_db_vecs, params_.dim), + metric, + metric_arg); + + raft::neighbors::brute_force::search( + handle_, + idx, + raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim), + raft::make_device_matrix_view( + raft_indices_.data(), params_.num_queries, params_.k), + raft::make_device_matrix_view( + raft_distances_.data(), params_.num_queries, params_.k)); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(ref_indices_.data(), + raft_indices_.data(), + ref_distances_.data(), + raft_distances_.data(), + num_queries, + k_, + float(0.001), + stream_, + true)); + } } void SetUp() override