Skip to content

Commit

Permalink
SNMG ANN
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Nov 14, 2023
1 parent 28eb0b3 commit 3b74685
Show file tree
Hide file tree
Showing 10 changed files with 834 additions and 1 deletion.
3 changes: 2 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;NEIGHBORS_ANN_MG_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -325,6 +325,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_MG_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
Expand Down
3 changes: 3 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/refine_float_float.cu
src/neighbors/refine_int8_t_float.cu
src/neighbors/refine_uint8_t_float.cu
src/neighbors/ann_mg.cu
src/raft_runtime/cluster/cluster_cost.cuh
src/raft_runtime/cluster/cluster_cost_double.cu
src/raft_runtime/cluster/cluster_cost_float.cu
Expand Down Expand Up @@ -489,6 +490,8 @@ if(RAFT_COMPILE_LIBRARY)
${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
# will just be cublas
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
nccl
ucp
)

# So consumers know when using libraft.so/libraft.a
Expand Down
114 changes: 114 additions & 0 deletions cpp/include/raft/neighbors/ann_mg-ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/detail/ann_mg.cuh>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::neighbors::mg {
using namespace raft::neighbors::mg;

template<typename T, typename IdxT>
auto build(const std::vector<int> device_ids,
raft::neighbors::mg::dist_mode mode,
const ivf_flat::index_params& index_params,
raft::host_matrix_view<const T, IdxT, row_major> index_dataset)
-> detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT> RAFT_EXPLICIT;

template<typename T>
auto build(const std::vector<int> device_ids,
raft::neighbors::mg::dist_mode mode,
const ivf_pq::index_params& index_params,
raft::host_matrix_view<const T, uint32_t, row_major> index_dataset)
-> detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t> RAFT_EXPLICIT;

template<typename T, typename IdxT>
void extend(detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
raft::host_matrix_view<const IdxT, IdxT, row_major> new_indices) RAFT_EXPLICIT;

template<typename T>
void extend(detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>& index,
raft::host_matrix_view<const T, uint32_t, row_major> new_vectors,
raft::host_matrix_view<const uint32_t, uint32_t, row_major> new_indices) RAFT_EXPLICIT;

template<typename T, typename IdxT>
void search(detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index,
const ivf_flat::search_params& search_params,
IdxT n_neighbors,
raft::host_matrix_view<const T, IdxT, row_major> query_dataset,
raft::host_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::host_matrix_view<float, IdxT, row_major> distances) RAFT_EXPLICIT;

template<typename T>
void search(detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>& index,
const ivf_pq::search_params& search_params,
uint32_t n_neighbors,
raft::host_matrix_view<const T, uint32_t, row_major> query_dataset,
raft::host_matrix_view<uint32_t, uint32_t, row_major> neighbors,
raft::host_matrix_view<float, uint32_t, row_major> distances) RAFT_EXPLICIT;

}

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_neighbors_ann_mg_build(T, IdxT) \
extern template auto raft::neighbors::mg::build<T, IdxT>( \
const std::vector<int> device_ids, \
raft::neighbors::mg::dist_mode mode, \
const ivf_flat::index_params& index_params, \
raft::host_matrix_view<const T, IdxT, row_major> index_dataset) \
-> detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>; \
\
extern template auto raft::neighbors::mg::build<T>( \
const std::vector<int> device_ids, \
raft::neighbors::mg::dist_mode mode, \
const ivf_pq::index_params& index_params, \
raft::host_matrix_view<const T, uint32_t, row_major> index_dataset) \
-> detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>; \
\
extern template void raft::neighbors::mg::extend<T, IdxT>( \
detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index, \
raft::host_matrix_view<const T, IdxT, row_major> new_vectors, \
raft::host_matrix_view<const IdxT, IdxT, row_major> new_indices); \
\
extern template void raft::neighbors::mg::extend<T>( \
detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>& index, \
raft::host_matrix_view<const T, uint32_t, row_major> new_vectors, \
raft::host_matrix_view<const uint32_t, uint32_t, row_major> new_indices); \
\
extern template void raft::neighbors::mg::search<T, IdxT>( \
detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index, \
const ivf_flat::search_params& search_params, \
IdxT n_neighbors, \
raft::host_matrix_view<const T, IdxT, row_major> query_dataset, \
raft::host_matrix_view<IdxT, IdxT, row_major> neighbors, \
raft::host_matrix_view<float, IdxT, row_major> distances); \
\
extern template void raft::neighbors::mg::search<T>( \
detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>& index, \
const ivf_pq::search_params& search_params, \
uint32_t n_neighbors, \
raft::host_matrix_view<const T, uint32_t, row_major> query_dataset, \
raft::host_matrix_view<uint32_t, uint32_t, row_major> neighbors, \
raft::host_matrix_view<float, uint32_t, row_major> distances); \

instantiate_raft_neighbors_ann_mg_build(float, uint32_t);

#undef instantiate_raft_neighbors_ann_mg_build
80 changes: 80 additions & 0 deletions cpp/include/raft/neighbors/ann_mg-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/detail/ann_mg.cuh>

namespace raft::neighbors::mg {

template<typename T, typename IdxT>
auto build(const std::vector<int> device_ids,
raft::neighbors::mg::dist_mode mode,
const ivf_flat::index_params& index_params,
raft::host_matrix_view<const T, IdxT, row_major> index_dataset)
-> detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>
{
return mg::detail::build<T, IdxT>(device_ids, mode, index_params, index_dataset);
}

template<typename T>
auto build(const std::vector<int> device_ids,
raft::neighbors::mg::dist_mode mode,
const ivf_pq::index_params& index_params,
raft::host_matrix_view<const T, uint32_t, row_major> index_dataset)
-> detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>
{
return mg::detail::build<T>(device_ids, mode, index_params, index_dataset);
}

template<typename T, typename IdxT>
void extend(detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index,
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
raft::host_matrix_view<const IdxT, IdxT, row_major> new_indices)
{
mg::detail::extend<T, IdxT>(index, new_vectors, new_indices);
}

template<typename T>
void extend(detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>& index,
raft::host_matrix_view<const T, uint32_t, row_major> new_vectors,
raft::host_matrix_view<const uint32_t, uint32_t, row_major> new_indices)
{
mg::detail::extend<T>(index, new_vectors, new_indices);
}

template<typename T, typename IdxT>
void search(detail::ann_mg_index<ivf_flat::index<T, IdxT>, T, IdxT>& index,
const ivf_flat::search_params& search_params,
IdxT n_neighbors,
raft::host_matrix_view<const T, IdxT, row_major> query_dataset,
raft::host_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::host_matrix_view<float, IdxT, row_major> distances)
{
mg::detail::search<T, IdxT>(index, search_params, n_neighbors, query_dataset, neighbors, distances);
}

template<typename T>
void search(detail::ann_mg_index<ivf_pq::index<uint32_t>, T, uint32_t>& index,
const ivf_pq::search_params& search_params,
uint32_t n_neighbors,
raft::host_matrix_view<const T, uint32_t, row_major> query_dataset,
raft::host_matrix_view<uint32_t, uint32_t, row_major> neighbors,
raft::host_matrix_view<float, uint32_t, row_major> distances)
{
mg::detail::search<T>(index, search_params, n_neighbors, query_dataset, neighbors, distances);
}
}
24 changes: 24 additions & 0 deletions cpp/include/raft/neighbors/ann_mg.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "ann_mg-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "ann_mg-ext.cuh"
#endif
Loading

0 comments on commit 3b74685

Please sign in to comment.