From 67bc44120228ead70ebdff54f688ae1616e87b5e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 8 Nov 2024 16:08:08 -0500 Subject: [PATCH] Removing tests --- build.sh | 4 +- .../sparse/distance/detail/bin_distance.cuh | 228 ---- .../raft/sparse/distance/detail/common.hpp | 59 - .../raft/sparse/distance/detail/coo_spmv.cuh | 211 ---- .../distance/detail/coo_spmv_kernel.cuh | 224 ---- .../coo_spmv_strategies/base_strategy.cuh | 149 --- .../coo_mask_row_iterators.cuh | 232 ---- .../dense_smem_strategy.cuh | 119 -- .../coo_spmv_strategies/hash_strategy.cuh | 296 ----- .../sparse/distance/detail/ip_distance.cuh | 90 -- .../sparse/distance/detail/l2_distance.cuh | 499 -------- .../sparse/distance/detail/lp_distance.cuh | 333 ----- .../raft/sparse/distance/detail/utils.cuh | 172 --- cpp/include/raft/sparse/distance/distance.cuh | 224 ---- cpp/include/raft/sparse/hierarchy/common.h | 36 - .../raft/sparse/hierarchy/single_linkage.cuh | 34 - .../raft/sparse/neighbors/brute_force.cuh | 94 -- .../sparse/neighbors/cross_component_nn.cuh | 99 -- .../neighbors/detail/cross_component_nn.cuh | 541 -------- .../raft/sparse/neighbors/detail/knn.cuh | 432 ------- .../sparse/neighbors/detail/knn_graph.cuh | 148 --- cpp/include/raft/sparse/neighbors/knn.cuh | 106 -- .../raft/sparse/neighbors/knn_graph.cuh | 55 - .../raft/sparse/neighbors/specializations.cuh | 24 - cpp/test/CMakeLists.txt | 206 +--- cpp/test/cluster/cluster_solvers.cu | 104 -- .../cluster/cluster_solvers_deprecated.cu | 59 - cpp/test/cluster/kmeans.cu | 363 ------ cpp/test/cluster/kmeans_balanced.cu | 240 ---- cpp/test/cluster/kmeans_find_k.cu | 142 --- cpp/test/cluster/linkage.cu | 674 ---------- cpp/test/cluster/spectral.cu | 109 -- cpp/test/distance/dist_adj.cu | 196 --- cpp/test/distance/dist_adj.cuh | 72 -- .../distance/dist_adj_distance_instance.cu | 65 - cpp/test/distance/dist_adj_threshold.cuh | 36 - cpp/test/distance/dist_canberra.cu | 70 -- cpp/test/distance/dist_correlation.cu | 94 -- cpp/test/distance/dist_cos.cu | 112 -- cpp/test/distance/dist_dice.cu | 112 -- cpp/test/distance/dist_hamming.cu | 71 -- cpp/test/distance/dist_hellinger.cu | 71 -- cpp/test/distance/dist_inner_product.cu | 74 -- cpp/test/distance/dist_jensen_shannon.cu | 71 -- cpp/test/distance/dist_kl_divergence.cu | 71 -- cpp/test/distance/dist_l1.cu | 70 -- cpp/test/distance/dist_l2_exp.cu | 115 -- cpp/test/distance/dist_l2_sqrt_exp.cu | 74 -- cpp/test/distance/dist_l2_unexp.cu | 71 -- cpp/test/distance/dist_l_inf.cu | 70 -- cpp/test/distance/dist_lp_unexp.cu | 71 -- cpp/test/distance/dist_russell_rao.cu | 71 -- cpp/test/distance/distance_base.cuh | 708 ----------- cpp/test/distance/fused_cosine_nn.cu | 420 ------- cpp/test/distance/fused_l2_nn.cu | 437 ------- cpp/test/distance/gram.cu | 174 --- cpp/test/distance/gram_base.cuh | 90 -- cpp/test/distance/masked_nn.cu | 438 ------- .../distance/masked_nn_compress_to_bits.cu | 220 ---- cpp/test/neighbors/ann_brute_force.cuh | 253 ---- .../neighbors/ann_brute_force/test_float.cu | 28 - cpp/test/neighbors/ann_cagra.cuh | 949 -------------- .../ann_cagra/search_kernel_uint64_t.cuh | 155 --- .../neighbors/ann_cagra/test_float_int64_t.cu | 29 - .../ann_cagra/test_float_uint32_t.cu | 40 - .../neighbors/ann_cagra/test_half_int64_t.cu | 29 - .../neighbors/ann_cagra/test_half_uint32_t.cu | 40 - .../ann_cagra/test_int8_t_uint32_t.cu | 38 - .../ann_cagra/test_uint8_t_uint32_t.cu | 40 - cpp/test/neighbors/ann_cagra_vpq.cuh | 336 ----- .../ann_cagra_vpq/test_float_int64_t.cu | 29 - .../ann_cagra_vpq/test_float_uint32_t.cu | 28 - cpp/test/neighbors/ann_ivf_flat.cuh | 675 ---------- .../ann_ivf_flat/test_filter_float_int64_t.cu | 29 - .../ann_ivf_flat/test_float_int64_t.cu | 32 - .../ann_ivf_flat/test_int8_t_int64_t.cu | 28 - .../ann_ivf_flat/test_uint8_t_int64_t.cu | 28 - cpp/test/neighbors/ann_ivf_pq.cuh | 1095 ----------------- .../ann_ivf_pq/ivf_pq_build_float_uint32_t.cu | 37 - .../ann_ivf_pq/ivf_pq_build_test-ext.cuh | 38 - .../ivf_pq_search_float_uint32_t.cu | 67 - .../ann_ivf_pq/test_filter_float_int64_t.cu | 28 - .../ann_ivf_pq/test_filter_int8_t_int64_t.cu | 29 - .../ann_ivf_pq/test_float_int64_t.cu | 27 - .../ann_ivf_pq/test_float_uint32_t.cu | 34 - .../ann_ivf_pq/test_int8_t_int64_t.cu | 28 - .../ann_ivf_pq/test_uint8_t_int64_t.cu | 27 - cpp/test/neighbors/ann_nn_descent.cuh | 332 ----- .../test_batch_float_uint32_t.cu | 30 - .../ann_nn_descent/test_float_uint32_t.cu | 28 - .../ann_nn_descent/test_int8_t_uint32_t.cu | 28 - .../ann_nn_descent/test_uint8_t_uint32_t.cu | 28 - cpp/test/neighbors/ann_utils.cuh | 335 ----- cpp/test/neighbors/fused_l2_knn.cu | 173 --- cpp/test/neighbors/knn.cu | 197 --- cpp/test/neighbors/refine.cu | 129 -- cpp/test/neighbors/tiled_knn.cu | 352 ------ cpp/test/sparse/gram.cu | 332 ----- cpp/test/stats/neighborhood_recall.cu | 177 --- cpp/test/stats/silhouette_score.cu | 230 ---- cpp/test/stats/trustworthiness.cu | 354 ------ 101 files changed, 4 insertions(+), 17367 deletions(-) delete mode 100644 cpp/include/raft/sparse/distance/detail/bin_distance.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/common.hpp delete mode 100644 cpp/include/raft/sparse/distance/detail/coo_spmv.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/ip_distance.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/l2_distance.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/lp_distance.cuh delete mode 100644 cpp/include/raft/sparse/distance/detail/utils.cuh delete mode 100644 cpp/include/raft/sparse/distance/distance.cuh delete mode 100644 cpp/include/raft/sparse/hierarchy/common.h delete mode 100644 cpp/include/raft/sparse/hierarchy/single_linkage.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/brute_force.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/cross_component_nn.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/detail/knn.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/knn.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/knn_graph.cuh delete mode 100644 cpp/include/raft/sparse/neighbors/specializations.cuh delete mode 100644 cpp/test/cluster/cluster_solvers.cu delete mode 100644 cpp/test/cluster/cluster_solvers_deprecated.cu delete mode 100644 cpp/test/cluster/kmeans.cu delete mode 100644 cpp/test/cluster/kmeans_balanced.cu delete mode 100644 cpp/test/cluster/kmeans_find_k.cu delete mode 100644 cpp/test/cluster/linkage.cu delete mode 100644 cpp/test/cluster/spectral.cu delete mode 100644 cpp/test/distance/dist_adj.cu delete mode 100644 cpp/test/distance/dist_adj.cuh delete mode 100644 cpp/test/distance/dist_adj_distance_instance.cu delete mode 100644 cpp/test/distance/dist_adj_threshold.cuh delete mode 100644 cpp/test/distance/dist_canberra.cu delete mode 100644 cpp/test/distance/dist_correlation.cu delete mode 100644 cpp/test/distance/dist_cos.cu delete mode 100644 cpp/test/distance/dist_dice.cu delete mode 100644 cpp/test/distance/dist_hamming.cu delete mode 100644 cpp/test/distance/dist_hellinger.cu delete mode 100644 cpp/test/distance/dist_inner_product.cu delete mode 100644 cpp/test/distance/dist_jensen_shannon.cu delete mode 100644 cpp/test/distance/dist_kl_divergence.cu delete mode 100644 cpp/test/distance/dist_l1.cu delete mode 100644 cpp/test/distance/dist_l2_exp.cu delete mode 100644 cpp/test/distance/dist_l2_sqrt_exp.cu delete mode 100644 cpp/test/distance/dist_l2_unexp.cu delete mode 100644 cpp/test/distance/dist_l_inf.cu delete mode 100644 cpp/test/distance/dist_lp_unexp.cu delete mode 100644 cpp/test/distance/dist_russell_rao.cu delete mode 100644 cpp/test/distance/distance_base.cuh delete mode 100644 cpp/test/distance/fused_cosine_nn.cu delete mode 100644 cpp/test/distance/fused_l2_nn.cu delete mode 100644 cpp/test/distance/gram.cu delete mode 100644 cpp/test/distance/gram_base.cuh delete mode 100644 cpp/test/distance/masked_nn.cu delete mode 100644 cpp/test/distance/masked_nn_compress_to_bits.cu delete mode 100644 cpp/test/neighbors/ann_brute_force.cuh delete mode 100644 cpp/test/neighbors/ann_brute_force/test_float.cu delete mode 100644 cpp/test/neighbors/ann_cagra.cuh delete mode 100644 cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh delete mode 100644 cpp/test/neighbors/ann_cagra/test_float_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra/test_half_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra_vpq.cuh delete mode 100644 cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_flat.cuh delete mode 100644 cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq.cuh delete mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh delete mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu delete mode 100644 cpp/test/neighbors/ann_nn_descent.cuh delete mode 100644 cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu delete mode 100644 cpp/test/neighbors/ann_utils.cuh delete mode 100644 cpp/test/neighbors/fused_l2_knn.cu delete mode 100644 cpp/test/neighbors/knn.cu delete mode 100644 cpp/test/neighbors/refine.cu delete mode 100644 cpp/test/neighbors/tiled_knn.cu delete mode 100644 cpp/test/sparse/gram.cu delete mode 100644 cpp/test/stats/neighborhood_recall.cu delete mode 100644 cpp/test/stats/silhouette_score.cu delete mode 100644 cpp/test/stats/trustworthiness.cu diff --git a/build.sh b/build.sh index 04d330f03d..ae20f324d2 100755 --- a/build.sh +++ b/build.sh @@ -75,8 +75,8 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_BRUTE_FORCE_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" -BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" +TEST_TARGETS="CORE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;STATS_TEST;UTILS_TEST" +BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" NVTX=ON diff --git a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh deleted file mode 100644 index 7a2396c2cd..0000000000 --- a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "common.hpp" - -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { -// @TODO: Move this into sparse prims (coo_norm) -template -RAFT_KERNEL compute_binary_row_norm_kernel(value_t* out, - const value_idx* __restrict__ coo_rows, - const value_t* __restrict__ data, - value_idx nnz) -{ - value_idx i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < nnz) { - // We do conditional here only because it's - // possible there could be some stray zeros in - // the sparse structure and removing them would be - // more expensive. - atomicAdd(&out[coo_rows[i]], data[i] == 1.0); - } -} - -template -RAFT_KERNEL compute_binary_warp_kernel(value_t* __restrict__ C, - const value_t* __restrict__ Q_norms, - const value_t* __restrict__ R_norms, - value_idx n_rows, - value_idx n_cols, - expansion_f expansion_func) -{ - std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; - value_idx i = tid / n_cols; - value_idx j = tid % n_cols; - - if (i >= n_rows || j >= n_cols) return; - - value_t q_norm = Q_norms[i]; - value_t r_norm = R_norms[j]; - value_t dot = C[(size_t)i * n_cols + j]; - C[(size_t)i * n_cols + j] = expansion_func(dot, q_norm, r_norm); -} - -template -void compute_binary(value_t* C, - const value_t* Q_norms, - const value_t* R_norms, - value_idx n_rows, - value_idx n_cols, - expansion_f expansion_func, - cudaStream_t stream) -{ - int blocks = raft::ceildiv((size_t)n_rows * n_cols, tpb); - compute_binary_warp_kernel<<>>( - C, Q_norms, R_norms, n_rows, n_cols, expansion_func); -} - -template -void compute_bin_distance(value_t* out, - const value_idx* Q_coo_rows, - const value_t* Q_data, - value_idx Q_nnz, - const value_idx* R_coo_rows, - const value_t* R_data, - value_idx R_nnz, - value_idx m, - value_idx n, - cudaStream_t stream, - expansion_f expansion_func) -{ - rmm::device_uvector Q_norms(m, stream); - rmm::device_uvector R_norms(n, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Q_norms.data(), 0, Q_norms.size() * sizeof(value_t))); - RAFT_CUDA_TRY(cudaMemsetAsync(R_norms.data(), 0, R_norms.size() * sizeof(value_t))); - - compute_binary_row_norm_kernel<<>>( - Q_norms.data(), Q_coo_rows, Q_data, Q_nnz); - compute_binary_row_norm_kernel<<>>( - R_norms.data(), R_coo_rows, R_data, R_nnz); - - compute_binary(out, Q_norms.data(), R_norms.data(), m, n, expansion_func, stream); -} - -/** - * Jaccard distance using the expanded form: - * 1 - (sum(x_k * y_k) / ((sum(x_k) + sum(y_k)) - sum(x_k * y_k)) - */ -template -class jaccard_expanded_distances_t : public distances_t { - public: - explicit jaccard_expanded_distances_t(const distances_config_t& config) - : config_(&config), workspace(0, resource::get_cuda_stream(config.handle)), ip_dists(config) - { - } - - void compute(value_t* out_dists) - { - ip_dists.compute(out_dists); - - value_idx* b_indices = ip_dists.b_rows_coo(); - value_t* b_data = ip_dists.b_data_coo(); - - rmm::device_uvector search_coo_rows(config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - raft::sparse::convert::csr_to_coo(config_->a_indptr, - config_->a_nrows, - search_coo_rows.data(), - config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - - compute_bin_distance(out_dists, - search_coo_rows.data(), - config_->a_data, - config_->a_nnz, - b_indices, - b_data, - config_->b_nnz, - config_->a_nrows, - config_->b_nrows, - resource::get_cuda_stream(config_->handle), - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { - value_t q_r_union = q_norm + r_norm; - value_t denom = q_r_union - dot; - - value_t jacc = ((denom != 0) * dot) / ((denom == 0) + denom); - - // flip the similarity when both rows are 0 - bool both_empty = q_r_union == 0; - return 1 - ((!both_empty * jacc) + both_empty); - }); - } - - ~jaccard_expanded_distances_t() = default; - - private: - const distances_config_t* config_; - rmm::device_uvector workspace; - ip_distances_t ip_dists; -}; - -/** - * Dice distance using the expanded form: - * 1 - ((2 * sum(x_k * y_k)) / (sum(x_k) + sum(y_k))) - */ -template -class dice_expanded_distances_t : public distances_t { - public: - explicit dice_expanded_distances_t(const distances_config_t& config) - : config_(&config), workspace(0, resource::get_cuda_stream(config.handle)), ip_dists(config) - { - } - - void compute(value_t* out_dists) - { - ip_dists.compute(out_dists); - - value_idx* b_indices = ip_dists.b_rows_coo(); - value_t* b_data = ip_dists.b_data_coo(); - - rmm::device_uvector search_coo_rows(config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - raft::sparse::convert::csr_to_coo(config_->a_indptr, - config_->a_nrows, - search_coo_rows.data(), - config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - - compute_bin_distance(out_dists, - search_coo_rows.data(), - config_->a_data, - config_->a_nnz, - b_indices, - b_data, - config_->b_nnz, - config_->a_nrows, - config_->b_nrows, - resource::get_cuda_stream(config_->handle), - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { - value_t q_r_union = q_norm + r_norm; - value_t dice = (2 * dot) / q_r_union; - bool both_empty = q_r_union == 0; - return 1 - ((!both_empty * dice) + both_empty); - }); - } - - ~dice_expanded_distances_t() = default; - - private: - const distances_config_t* config_; - rmm::device_uvector workspace; - ip_distances_t ip_dists; -}; - -} // END namespace detail -}; // END namespace distance -}; // END namespace sparse -}; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/common.hpp b/cpp/include/raft/sparse/distance/detail/common.hpp deleted file mode 100644 index 0f463dac80..0000000000 --- a/cpp/include/raft/sparse/distance/detail/common.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -struct distances_config_t { - distances_config_t(raft::resources const& handle_) : handle(handle_) {} - - // left side - value_idx a_nrows; - value_idx a_ncols; - value_idx a_nnz; - value_idx* a_indptr; - value_idx* a_indices; - value_t* a_data; - - // right side - value_idx b_nrows; - value_idx b_ncols; - value_idx b_nnz; - value_idx* b_indptr; - value_idx* b_indices; - value_t* b_data; - - raft::resources const& handle; -}; - -template -class distances_t { - public: - virtual void compute(value_t* out) {} - virtual ~distances_t() = default; -}; - -}; // namespace detail -}; // namespace distance -}; // namespace sparse -}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh deleted file mode 100644 index b0469f559a..0000000000 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../../csr.hpp" -#include "../../detail/utils.h" -#include "common.hpp" -#include "coo_spmv_strategies/dense_smem_strategy.cuh" -#include "coo_spmv_strategies/hash_strategy.cuh" - -#include -#include -#include -#include - -#include -#include - -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -inline void balanced_coo_pairwise_generalized_spmv( - value_t* out_dists, - const distances_config_t& config_, - value_idx* coo_rows_b, - product_f product_func, - accum_f accum_func, - write_f write_func, - strategy_t strategy, - int chunk_size = 500000) -{ - uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; - RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, resource::get_cuda_stream(config_.handle))); - - strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); -}; - -/** - * Performs generalized sparse-matrix-sparse-matrix multiplication via a - * sparse-matrix-sparse-vector layout `out=A*B` where generalized product() - * and sum() operations can be used in place of the standard sum and product: - * - * out_ij = sum_k(product(A_ik, B_ik)) The sum goes through values of - * k=0..n_cols-1 where B_kj is nonzero. - * - * The product and sum operations shall form a semiring algebra with the - * following properties: - * 1. {+, 0} is a commutative sum reduction monoid with identity element 0 - * 2. {*, 1} is a product monoid with identity element 1 - * 3. Multiplication by 0 annihilates x. e.g. product(x, 0) = 0 - * - * Each vector of A is loaded into shared memory in dense form and the - * non-zeros of B load balanced across the threads of each block. - * @tparam value_idx index type - * @tparam value_t value type - * @tparam threads_per_block block size - * @tparam product_f semiring product() function - * @tparam accum_f semiring sum() function - * @tparam write_f atomic semiring sum() function - * @param[out] out_dists dense array of out distances of size m * n in row-major - * format. - * @param[in] config_ distance config object - * @param[in] coo_rows_b coo row array for B - * @param[in] product_func semiring product() function - * @param[in] accum_func semiring sum() function - * @param[in] write_func atomic semiring sum() function - * @param[in] chunk_size number of nonzeros of B to process for each row of A - * this value was found through profiling and represents a reasonable - * setting for both large and small densities - */ -template -inline void balanced_coo_pairwise_generalized_spmv( - value_t* out_dists, - const distances_config_t& config_, - value_idx* coo_rows_b, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size = 500000) -{ - uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; - RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, resource::get_cuda_stream(config_.handle))); - - int max_cols = max_cols_per_block(); - - if (max_cols > config_.a_ncols) { - dense_smem_strategy strategy(config_); - strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); - } else { - hash_strategy strategy(config_); - strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); - } -}; - -template -inline void balanced_coo_pairwise_generalized_spmv_rev( - value_t* out_dists, - const distances_config_t& config_, - value_idx* coo_rows_a, - product_f product_func, - accum_f accum_func, - write_f write_func, - strategy_t strategy, - int chunk_size = 500000) -{ - strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, write_func, chunk_size); -}; - -/** - * Used for computing distances where the reduction (e.g. product()) function - * requires an implicit union (product(x, 0) = x) to capture the difference A-B. - * This is necessary in some applications because the standard semiring algebra - * endowed with the default multiplication product monoid will only - * compute the intersection & B-A. - * - * This particular function is meant to accompany the function - * `balanced_coo_pairwise_generalized_spmv` and executes the product operation - * on only those columns that exist in B and not A. - * - * The product and sum operations shall enable the computation of a - * non-annihilating semiring algebra with the following properties: - * 1. {+, 0} is a commutative sum reduction monoid with identity element 0 - * 2. {*, 0} is a product monoid with identity element 0 - * 3. Multiplication by 0 does not annihilate x. e.g. product(x, 0) = x - * - * Manattan distance sum(abs(x_k-y_k)) is a great example of when this type of - * execution pattern is necessary. - * - * @tparam value_idx index type - * @tparam value_t value type - * @tparam threads_per_block block size - * @tparam product_f semiring product() function - * @tparam accum_f semiring sum() function - * @tparam write_f atomic semiring sum() function - * @param[out] out_dists dense array of out distances of size m * n - * @param[in] config_ distance config object - * @param[in] coo_rows_a coo row array for A - * @param[in] product_func semiring product() function - * @param[in] accum_func semiring sum() function - * @param[in] write_func atomic semiring sum() function - * @param[in] chunk_size number of nonzeros of B to process for each row of A - * this value was found through profiling and represents a reasonable - * setting for both large and small densities - */ -template -inline void balanced_coo_pairwise_generalized_spmv_rev( - value_t* out_dists, - const distances_config_t& config_, - value_idx* coo_rows_a, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size = 500000) -{ - // try dense first - int max_cols = max_cols_per_block(); - - if (max_cols > config_.b_ncols) { - dense_smem_strategy strategy(config_); - strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, write_func, chunk_size); - } else { - hash_strategy strategy(config_); - strategy.dispatch_rev(out_dists, coo_rows_a, product_func, accum_func, write_func, chunk_size); - } -}; - -} // namespace detail -} // namespace distance -} // namespace sparse -}; // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh deleted file mode 100644 index d8f5b2eca5..0000000000 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_kernel.cuh +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { -/** - * Load-balanced sparse-matrix-sparse-matrix multiplication (SPMM) kernel with - * sparse-matrix-sparse-vector multiplication layout (SPMV). - * This is intended to be scheduled n_chunks_b times for each row of a. - * The steps are as follows: - * - * 1. Load row from A into dense vector in shared memory. - * This can be further chunked in the future if necessary to support larger - * column sizes. - * 2. Threads of block all step through chunks of B in parallel. - * When a new row is encountered in row_indices_b, a segmented - * reduction is performed across the warps and then across the - * block and the final value written out to host memory. - * - * Reference: https://www.icl.utk.edu/files/publications/2020/icl-utk-1421-2020.pdf - * - * @tparam value_idx index type - * @tparam value_t value type - * @tparam tpb threads per block configured on launch - * @tparam rev if this is true, the reduce/accumulate functions are only - * executed when A[col] == 0.0. when executed before/after !rev - * and A & B are reversed, this allows the full symmetric difference - * and intersection to be computed. - * @tparam kv_t data type stored in shared mem cache - * @tparam product_f reduce function type (semiring product() function). - * accepts two arguments of value_t and returns a value_t - * @tparam accum_f accumulation function type (semiring sum() function). - * accepts two arguments of value_t and returns a value_t - * @tparam write_f function to write value out. this should be mathematically - * equivalent to the accumulate function but implemented as - * an atomic operation on global memory. Accepts two arguments - * of value_t* and value_t and updates the value given by the - * pointer. - * @param[in] indptrA column pointer array for A - * @param[in] indicesA column indices array for A - * @param[in] dataA data array for A - * @param[in] rowsB coo row array for B - * @param[in] indicesB column indices array for B - * @param[in] dataB data array for B - * @param[in] m number of rows in A - * @param[in] n number of rows in B - * @param[in] dim number of features - * @param[in] nnz_b number of nonzeros in B - * @param[out] out array of size m*n - * @param[in] n_blocks_per_row number of blocks of B per row of A - * @param[in] chunk_size number of nnz for B to use for each row of A - * @param[in] buffer_size amount of smem to use for each row of A - * @param[in] product_func semiring product() function - * @param[in] accum_func semiring sum() function - * @param[in] write_func atomic semiring sum() function - */ -template -RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy, - indptr_it indptrA, - value_idx* indicesA, - value_t* dataA, - value_idx nnz_a, - value_idx* rowsB, - value_idx* indicesB, - value_t* dataB, - value_idx m, - value_idx n, - int dim, - value_idx nnz_b, - value_t* out, - int n_blocks_per_row, - int chunk_size, - value_idx b_ncols, - product_f product_func, - accum_f accum_func, - write_f write_func) -{ - typedef cub::WarpReduce warp_reduce; - - value_idx cur_row_a = indptrA.get_row_idx(n_blocks_per_row); - value_idx cur_chunk_offset = blockIdx.x % n_blocks_per_row; - - // chunk starting offset - value_idx ind_offset = cur_chunk_offset * chunk_size * tpb; - // how many total cols will be processed by this block (should be <= chunk_size * n_threads) - value_idx active_chunk_size = min(chunk_size * tpb, nnz_b - ind_offset); - - int tid = threadIdx.x; - int warp_id = tid / raft::warp_size(); - - // compute id relative to current warp - unsigned int lane_id = tid & (raft::warp_size() - 1); - value_idx ind = ind_offset + threadIdx.x; - - extern __shared__ char smem[]; - - typename strategy_t::smem_type A = (typename strategy_t::smem_type)(smem); - typename warp_reduce::TempStorage* temp_storage = (typename warp_reduce::TempStorage*)(A + dim); - - auto inserter = strategy.init_insert(A, dim); - - __syncthreads(); - - value_idx start_offset_a, stop_offset_a; - bool first_a_chunk, last_a_chunk; - indptrA.get_row_offsets( - cur_row_a, start_offset_a, stop_offset_a, n_blocks_per_row, first_a_chunk, last_a_chunk); - - // Convert current row vector in A to dense - for (int i = tid; i <= (stop_offset_a - start_offset_a); i += blockDim.x) { - strategy.insert(inserter, indicesA[start_offset_a + i], dataA[start_offset_a + i]); - } - - __syncthreads(); - - auto finder = strategy.init_find(A, dim); - - if (cur_row_a > m || cur_chunk_offset > n_blocks_per_row) return; - if (ind >= nnz_b) return; - - value_idx start_index_a = 0, stop_index_a = b_ncols - 1; - indptrA.get_indices_boundary(indicesA, - cur_row_a, - start_offset_a, - stop_offset_a, - start_index_a, - stop_index_a, - first_a_chunk, - last_a_chunk); - - value_idx cur_row_b = -1; - value_t c = 0.0; - - auto warp_red = warp_reduce(*(temp_storage + warp_id)); - - if (tid < active_chunk_size) { - cur_row_b = rowsB[ind]; - - auto index_b = indicesB[ind]; - auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); - - if (in_bounds) { - value_t a_col = strategy.find(finder, index_b); - if (!rev || a_col == 0.0) { c = product_func(a_col, dataB[ind]); } - } - } - - // loop through chunks in parallel, reducing when a new row is - // encountered by each thread - for (int i = tid; i < active_chunk_size; i += blockDim.x) { - value_idx ind_next = ind + blockDim.x; - value_idx next_row_b = -1; - - if (i + blockDim.x < active_chunk_size) next_row_b = rowsB[ind_next]; - - bool diff_rows = next_row_b != cur_row_b; - - if (__any_sync(0xffffffff, diff_rows)) { - // grab the threads currently participating in loops. - // because any other threads should have returned already. - unsigned int peer_group = __match_any_sync(0xffffffff, cur_row_b); - bool is_leader = get_lowest_peer(peer_group) == lane_id; - value_t v = warp_red.HeadSegmentedReduce(c, is_leader, accum_func); - - // thread with lowest lane id among peers writes out - if (is_leader && v != 0.0) { - // this conditional should be uniform, since rev is constant - size_t idx = !rev ? (size_t)cur_row_a * n + cur_row_b : (size_t)cur_row_b * m + cur_row_a; - write_func(out + idx, v); - } - - c = 0.0; - } - - if (next_row_b != -1) { - ind = ind_next; - - auto index_b = indicesB[ind]; - auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); - if (in_bounds) { - value_t a_col = strategy.find(finder, index_b); - - if (!rev || a_col == 0.0) { c = accum_func(c, product_func(a_col, dataB[ind])); } - } - - cur_row_b = next_row_b; - } - } -} - -} // namespace detail -} // namespace distance -} // namespace sparse -} // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh deleted file mode 100644 index fc5881e2d4..0000000000 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../common.hpp" -#include "../coo_spmv_kernel.cuh" -#include "../utils.cuh" -#include "coo_mask_row_iterators.cuh" - -#include - -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -class coo_spmv_strategy { - public: - coo_spmv_strategy(const distances_config_t& config_) : config(config_) - { - smem = raft::getSharedMemPerBlock(); - } - - template - void _dispatch_base(strategy_t& strategy, - int smem_dim, - indptr_it& a_indptr, - value_t* out_dists, - value_idx* coo_rows_b, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size, - int n_blocks, - int n_blocks_per_row) - { - RAFT_CUDA_TRY(cudaFuncSetCacheConfig(balanced_coo_generalized_spmv_kernel, - cudaFuncCachePreferShared)); - - balanced_coo_generalized_spmv_kernel - <<>>(strategy, - a_indptr, - config.a_indices, - config.a_data, - config.a_nnz, - coo_rows_b, - config.b_indices, - config.b_data, - config.a_nrows, - config.b_nrows, - smem_dim, - config.b_nnz, - out_dists, - n_blocks_per_row, - chunk_size, - config.b_ncols, - product_func, - accum_func, - write_func); - } - - template - void _dispatch_base_rev(strategy_t& strategy, - int smem_dim, - indptr_it& b_indptr, - value_t* out_dists, - value_idx* coo_rows_a, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size, - int n_blocks, - int n_blocks_per_row) - { - RAFT_CUDA_TRY(cudaFuncSetCacheConfig(balanced_coo_generalized_spmv_kernel, - cudaFuncCachePreferShared)); - - balanced_coo_generalized_spmv_kernel - <<>>(strategy, - b_indptr, - config.b_indices, - config.b_data, - config.b_nnz, - coo_rows_a, - config.a_indices, - config.a_data, - config.b_nrows, - config.a_nrows, - smem_dim, - config.a_nnz, - out_dists, - n_blocks_per_row, - chunk_size, - config.a_ncols, - product_func, - accum_func, - write_func); - } - - protected: - int smem; - const distances_config_t& config; -}; - -} // namespace detail -} // namespace distance -} // namespace sparse -} // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh deleted file mode 100644 index 38aa106d78..0000000000 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../common.hpp" -#include "../utils.cuh" - -#include - -#include -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -class mask_row_it { - public: - mask_row_it(const value_idx* full_indptr_, - const value_idx& n_rows_, - value_idx* mask_row_idx_ = NULL) - : full_indptr(full_indptr_), mask_row_idx(mask_row_idx_), n_rows(n_rows_) - { - } - - __device__ inline value_idx get_row_idx(const int& n_blocks_nnz_b) - { - if (mask_row_idx != NULL) { - return mask_row_idx[blockIdx.x / n_blocks_nnz_b]; - } else { - return blockIdx.x / n_blocks_nnz_b; - } - } - - __device__ inline void get_row_offsets(const value_idx& row_idx, - value_idx& start_offset, - value_idx& stop_offset, - const value_idx& n_blocks_nnz_b, - bool& first_a_chunk, - bool& last_a_chunk) - { - start_offset = full_indptr[row_idx]; - stop_offset = full_indptr[row_idx + 1] - 1; - } - - __device__ constexpr inline void get_indices_boundary(const value_idx* indices, - value_idx& indices_len, - value_idx& start_offset, - value_idx& stop_offset, - value_idx& start_index, - value_idx& stop_index, - bool& first_a_chunk, - bool& last_a_chunk) - { - // do nothing; - } - - __device__ constexpr inline bool check_indices_bounds(value_idx& start_index_a, - value_idx& stop_index_a, - value_idx& index_b) - { - return true; - } - - const value_idx *full_indptr, &n_rows; - value_idx* mask_row_idx; -}; - -template -RAFT_KERNEL fill_chunk_indices_kernel(value_idx* n_chunks_per_row, - value_idx* chunk_indices, - value_idx n_rows) -{ - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < n_rows) { - auto start = n_chunks_per_row[tid]; - auto end = n_chunks_per_row[tid + 1]; - -#pragma unroll - for (int i = start; i < end; i++) { - chunk_indices[i] = tid; - } - } -} - -template -class chunked_mask_row_it : public mask_row_it { - public: - chunked_mask_row_it(const value_idx* full_indptr_, - const value_idx& n_rows_, - value_idx* mask_row_idx_, - int row_chunk_size_, - const value_idx* n_chunks_per_row_, - const value_idx* chunk_indices_, - const cudaStream_t stream_) - : mask_row_it(full_indptr_, n_rows_, mask_row_idx_), - row_chunk_size(row_chunk_size_), - n_chunks_per_row(n_chunks_per_row_), - chunk_indices(chunk_indices_), - stream(stream_) - { - } - - static void init(const value_idx* indptr, - const value_idx* mask_row_idx, - const value_idx& n_rows, - const int row_chunk_size, - rmm::device_uvector& n_chunks_per_row, - rmm::device_uvector& chunk_indices, - cudaStream_t stream) - { - auto policy = rmm::exec_policy(stream); - - constexpr value_idx first_element = 0; - n_chunks_per_row.set_element_async(0, first_element, stream); - n_chunks_per_row_functor chunk_functor(indptr, row_chunk_size); - thrust::transform( - policy, mask_row_idx, mask_row_idx + n_rows, n_chunks_per_row.begin() + 1, chunk_functor); - - thrust::inclusive_scan( - policy, n_chunks_per_row.begin() + 1, n_chunks_per_row.end(), n_chunks_per_row.begin() + 1); - - raft::update_host(&total_row_blocks, n_chunks_per_row.data() + n_rows, 1, stream); - - fill_chunk_indices(n_rows, n_chunks_per_row, chunk_indices, stream); - } - - __device__ inline value_idx get_row_idx(const int& n_blocks_nnz_b) - { - return this->mask_row_idx[chunk_indices[blockIdx.x / n_blocks_nnz_b]]; - } - - __device__ inline void get_row_offsets(const value_idx& row_idx, - value_idx& start_offset, - value_idx& stop_offset, - const int& n_blocks_nnz_b, - bool& first_a_chunk, - bool& last_a_chunk) - { - auto chunk_index = blockIdx.x / n_blocks_nnz_b; - auto chunk_val = chunk_indices[chunk_index]; - auto prev_n_chunks = n_chunks_per_row[chunk_val]; - auto relative_chunk = chunk_index - prev_n_chunks; - first_a_chunk = relative_chunk == 0; - - start_offset = this->full_indptr[row_idx] + relative_chunk * row_chunk_size; - stop_offset = start_offset + row_chunk_size; - - auto final_stop_offset = this->full_indptr[row_idx + 1]; - - last_a_chunk = stop_offset >= final_stop_offset; - stop_offset = last_a_chunk ? final_stop_offset - 1 : stop_offset - 1; - } - - __device__ inline void get_indices_boundary(const value_idx* indices, - value_idx& row_idx, - value_idx& start_offset, - value_idx& stop_offset, - value_idx& start_index, - value_idx& stop_index, - bool& first_a_chunk, - bool& last_a_chunk) - { - start_index = first_a_chunk ? start_index : indices[start_offset - 1] + 1; - stop_index = last_a_chunk ? stop_index : indices[stop_offset]; - } - - __device__ inline bool check_indices_bounds(value_idx& start_index_a, - value_idx& stop_index_a, - value_idx& index_b) - { - return (index_b >= start_index_a && index_b <= stop_index_a); - } - - inline static value_idx total_row_blocks = 0; - const cudaStream_t stream; - const value_idx *n_chunks_per_row, *chunk_indices; - value_idx row_chunk_size; - - struct n_chunks_per_row_functor { - public: - n_chunks_per_row_functor(const value_idx* indptr_, value_idx row_chunk_size_) - : indptr(indptr_), row_chunk_size(row_chunk_size_) - { - } - - __host__ __device__ value_idx operator()(const value_idx& i) - { - auto degree = indptr[i + 1] - indptr[i]; - return raft::ceildiv(degree, (value_idx)row_chunk_size); - } - - const value_idx* indptr; - value_idx row_chunk_size; - }; - - private: - static void fill_chunk_indices(const value_idx& n_rows, - rmm::device_uvector& n_chunks_per_row, - rmm::device_uvector& chunk_indices, - cudaStream_t stream) - { - auto n_threads = std::min(n_rows, 256); - auto n_blocks = raft::ceildiv(n_rows, (value_idx)n_threads); - - chunk_indices.resize(total_row_blocks, stream); - - fill_chunk_indices_kernel - <<>>(n_chunks_per_row.data(), chunk_indices.data(), n_rows); - } -}; - -} // namespace detail -} // namespace distance -} // namespace sparse -} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh deleted file mode 100644 index 5a1c152bd0..0000000000 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "base_strategy.cuh" - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -class dense_smem_strategy : public coo_spmv_strategy { - public: - using smem_type = value_t*; - using insert_type = smem_type; - using find_type = smem_type; - - dense_smem_strategy(const distances_config_t& config_) - : coo_spmv_strategy(config_) - { - } - - inline static int smem_per_block(int n_cols) - { - return (n_cols * sizeof(value_t)) + ((1024 / raft::warp_size()) * sizeof(value_t)); - } - - template - void dispatch(value_t* out_dists, - value_idx* coo_rows_b, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size) - { - auto n_blocks_per_row = raft::ceildiv(this->config.b_nnz, chunk_size * 1024); - auto n_blocks = this->config.a_nrows * n_blocks_per_row; - - mask_row_it a_indptr(this->config.a_indptr, this->config.a_nrows); - - this->_dispatch_base(*this, - this->config.b_ncols, - a_indptr, - out_dists, - coo_rows_b, - product_func, - accum_func, - write_func, - chunk_size, - n_blocks, - n_blocks_per_row); - } - - template - void dispatch_rev(value_t* out_dists, - value_idx* coo_rows_a, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size) - { - auto n_blocks_per_row = raft::ceildiv(this->config.a_nnz, chunk_size * 1024); - auto n_blocks = this->config.b_nrows * n_blocks_per_row; - - mask_row_it b_indptr(this->config.b_indptr, this->config.b_nrows); - - this->_dispatch_base_rev(*this, - this->config.a_ncols, - b_indptr, - out_dists, - coo_rows_a, - product_func, - accum_func, - write_func, - chunk_size, - n_blocks, - n_blocks_per_row); - } - - __device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size) - { - for (int k = threadIdx.x; k < cache_size; k += blockDim.x) { - cache[k] = 0.0; - } - return cache; - } - - __device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value) - { - cache[key] = value; - } - - __device__ inline find_type init_find(smem_type cache, const value_idx& cache_size) - { - return cache; - } - - __device__ inline value_t find(find_type cache, const value_idx& key) { return cache[key]; } -}; - -} // namespace detail -} // namespace distance -} // namespace sparse -} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh deleted file mode 100644 index 8c267c5e63..0000000000 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "base_strategy.cuh" - -#include -#include - -#include -#include -#include - -// this is needed by cuco as key, value must be bitwise comparable. -// compilers don't declare float/double as bitwise comparable -// but that is too strict -// for example, the following is true (or 0): -// float a = 5; -// float b = 5; -// memcmp(&a, &b, sizeof(float)); -CUCO_DECLARE_BITWISE_COMPARABLE(float); -CUCO_DECLARE_BITWISE_COMPARABLE(double); - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -class hash_strategy : public coo_spmv_strategy { - public: - using insert_type = typename cuco::legacy:: - static_map::device_mutable_view; - using smem_type = typename insert_type::slot_type*; - using find_type = - typename cuco::legacy::static_map::device_view; - - hash_strategy(const distances_config_t& config_, - float capacity_threshold_ = 0.5, - int map_size_ = get_map_size()) - : coo_spmv_strategy(config_), - capacity_threshold(capacity_threshold_), - map_size(map_size_) - { - } - - void chunking_needed(const value_idx* indptr, - const value_idx n_rows, - rmm::device_uvector& mask_indptr, - std::tuple& n_rows_divided, - cudaStream_t stream) - { - auto policy = resource::get_thrust_policy(this->config.handle); - - auto less = thrust::copy_if(policy, - thrust::make_counting_iterator(value_idx(0)), - thrust::make_counting_iterator(n_rows), - mask_indptr.data(), - fits_in_hash_table(indptr, 0, capacity_threshold * map_size)); - std::get<0>(n_rows_divided) = less - mask_indptr.data(); - - auto more = thrust::copy_if( - policy, - thrust::make_counting_iterator(value_idx(0)), - thrust::make_counting_iterator(n_rows), - less, - fits_in_hash_table( - indptr, capacity_threshold * map_size, std::numeric_limits::max())); - std::get<1>(n_rows_divided) = more - less; - } - - template - void dispatch(value_t* out_dists, - value_idx* coo_rows_b, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size) - { - auto n_blocks_per_row = raft::ceildiv(this->config.b_nnz, chunk_size * tpb); - rmm::device_uvector mask_indptr(this->config.a_nrows, - resource::get_cuda_stream(this->config.handle)); - std::tuple n_rows_divided; - - chunking_needed(this->config.a_indptr, - this->config.a_nrows, - mask_indptr, - n_rows_divided, - resource::get_cuda_stream(this->config.handle)); - - auto less_rows = std::get<0>(n_rows_divided); - if (less_rows > 0) { - mask_row_it less(this->config.a_indptr, less_rows, mask_indptr.data()); - - auto n_less_blocks = less_rows * n_blocks_per_row; - this->_dispatch_base(*this, - map_size, - less, - out_dists, - coo_rows_b, - product_func, - accum_func, - write_func, - chunk_size, - n_less_blocks, - n_blocks_per_row); - } - - auto more_rows = std::get<1>(n_rows_divided); - if (more_rows > 0) { - rmm::device_uvector n_chunks_per_row( - more_rows + 1, resource::get_cuda_stream(this->config.handle)); - rmm::device_uvector chunk_indices(0, - resource::get_cuda_stream(this->config.handle)); - chunked_mask_row_it::init(this->config.a_indptr, - mask_indptr.data() + less_rows, - more_rows, - capacity_threshold * map_size, - n_chunks_per_row, - chunk_indices, - resource::get_cuda_stream(this->config.handle)); - - chunked_mask_row_it more(this->config.a_indptr, - more_rows, - mask_indptr.data() + less_rows, - capacity_threshold * map_size, - n_chunks_per_row.data(), - chunk_indices.data(), - resource::get_cuda_stream(this->config.handle)); - - auto n_more_blocks = more.total_row_blocks * n_blocks_per_row; - this->_dispatch_base(*this, - map_size, - more, - out_dists, - coo_rows_b, - product_func, - accum_func, - write_func, - chunk_size, - n_more_blocks, - n_blocks_per_row); - } - } - - template - void dispatch_rev(value_t* out_dists, - value_idx* coo_rows_a, - product_f product_func, - accum_f accum_func, - write_f write_func, - int chunk_size) - { - auto n_blocks_per_row = raft::ceildiv(this->config.a_nnz, chunk_size * tpb); - rmm::device_uvector mask_indptr(this->config.b_nrows, - resource::get_cuda_stream(this->config.handle)); - std::tuple n_rows_divided; - - chunking_needed(this->config.b_indptr, - this->config.b_nrows, - mask_indptr, - n_rows_divided, - resource::get_cuda_stream(this->config.handle)); - - auto less_rows = std::get<0>(n_rows_divided); - if (less_rows > 0) { - mask_row_it less(this->config.b_indptr, less_rows, mask_indptr.data()); - - auto n_less_blocks = less_rows * n_blocks_per_row; - this->_dispatch_base_rev(*this, - map_size, - less, - out_dists, - coo_rows_a, - product_func, - accum_func, - write_func, - chunk_size, - n_less_blocks, - n_blocks_per_row); - } - - auto more_rows = std::get<1>(n_rows_divided); - if (more_rows > 0) { - rmm::device_uvector n_chunks_per_row( - more_rows + 1, resource::get_cuda_stream(this->config.handle)); - rmm::device_uvector chunk_indices(0, - resource::get_cuda_stream(this->config.handle)); - chunked_mask_row_it::init(this->config.b_indptr, - mask_indptr.data() + less_rows, - more_rows, - capacity_threshold * map_size, - n_chunks_per_row, - chunk_indices, - resource::get_cuda_stream(this->config.handle)); - - chunked_mask_row_it more(this->config.b_indptr, - more_rows, - mask_indptr.data() + less_rows, - capacity_threshold * map_size, - n_chunks_per_row.data(), - chunk_indices.data(), - resource::get_cuda_stream(this->config.handle)); - - auto n_more_blocks = more.total_row_blocks * n_blocks_per_row; - this->_dispatch_base_rev(*this, - map_size, - more, - out_dists, - coo_rows_a, - product_func, - accum_func, - write_func, - chunk_size, - n_more_blocks, - n_blocks_per_row); - } - } - - __device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size) - { - return insert_type::make_from_uninitialized_slots(cooperative_groups::this_thread_block(), - cache, - cache_size, - cuco::empty_key{value_idx{-1}}, - cuco::empty_value{value_t{0}}); - } - - __device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value) - { - auto success = cache.insert(cuco::pair(key, value)); - } - - __device__ inline find_type init_find(smem_type cache, const value_idx& cache_size) - { - return find_type( - cache, cache_size, cuco::empty_key{value_idx{-1}}, cuco::empty_value{value_t{0}}); - } - - __device__ inline value_t find(find_type cache, const value_idx& key) - { - auto a_pair = cache.find(key); - - value_t a_col = 0.0; - if (a_pair != cache.end()) { a_col = a_pair->second; } - return a_col; - } - - struct fits_in_hash_table { - public: - fits_in_hash_table(const value_idx* indptr_, value_idx degree_l_, value_idx degree_r_) - : indptr(indptr_), degree_l(degree_l_), degree_r(degree_r_) - { - } - - __host__ __device__ bool operator()(const value_idx& i) - { - auto degree = indptr[i + 1] - indptr[i]; - - return degree >= degree_l && degree < degree_r; - } - - private: - const value_idx* indptr; - const value_idx degree_l, degree_r; - }; - - inline static int get_map_size() - { - return (raft::getSharedMemPerBlock() - ((tpb / raft::warp_size()) * sizeof(value_t))) / - sizeof(typename insert_type::slot_type); - } - - private: - float capacity_threshold; - int map_size; -}; - -} // namespace detail -} // namespace distance -} // namespace sparse -} // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh deleted file mode 100644 index 84229120ea..0000000000 --- a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "common.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -class ip_distances_t : public distances_t { - public: - /** - * Computes simple sparse inner product distances as sum(x_y * y_k) - * @param[in] config specifies inputs, outputs, and sizes - */ - ip_distances_t(const distances_config_t& config) - : config_(&config), coo_rows_b(config.b_nnz, resource::get_cuda_stream(config.handle)) - { - raft::sparse::convert::csr_to_coo(config_->b_indptr, - config_->b_nrows, - coo_rows_b.data(), - config_->b_nnz, - resource::get_cuda_stream(config_->handle)); - } - - /** - * Performs pairwise distance computation and computes output distances - * @param out_distances dense output matrix (size a_nrows * b_nrows) - */ - void compute(value_t* out_distances) - { - /** - * Compute pairwise distances and return dense matrix in row-major format - */ - balanced_coo_pairwise_generalized_spmv(out_distances, - *config_, - coo_rows_b.data(), - raft::mul_op(), - raft::add_op(), - raft::atomic_add_op()); - } - - value_idx* b_rows_coo() { return coo_rows_b.data(); } - - value_t* b_data_coo() { return config_->b_data; } - - private: - const distances_config_t* config_; - rmm::device_uvector coo_rows_b; -}; - -}; // END namespace detail -}; // END namespace distance -}; // END namespace sparse -}; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh deleted file mode 100644 index 93d1f048ec..0000000000 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ /dev/null @@ -1,499 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "common.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include - -#include -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -// @TODO: Move this into sparse prims (coo_norm) -template -RAFT_KERNEL compute_row_norm_kernel(value_t* out, - const value_idx* __restrict__ coo_rows, - const value_t* __restrict__ data, - value_idx nnz) -{ - value_idx i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < nnz) { atomicAdd(&out[coo_rows[i]], data[i] * data[i]); } -} - -template -RAFT_KERNEL compute_row_sum_kernel(value_t* out, - const value_idx* __restrict__ coo_rows, - const value_t* __restrict__ data, - value_idx nnz) -{ - value_idx i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < nnz) { atomicAdd(&out[coo_rows[i]], data[i]); } -} - -template -RAFT_KERNEL compute_euclidean_warp_kernel(value_t* __restrict__ C, - const value_t* __restrict__ Q_sq_norms, - const value_t* __restrict__ R_sq_norms, - value_idx n_rows, - value_idx n_cols, - expansion_f expansion_func) -{ - std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; - value_idx i = tid / n_cols; - value_idx j = tid % n_cols; - - if (i >= n_rows || j >= n_cols) return; - - value_t dot = C[(size_t)i * n_cols + j]; - - // e.g. Euclidean expansion func = -2.0 * dot + q_norm + r_norm - value_t val = expansion_func(dot, Q_sq_norms[i], R_sq_norms[j]); - - // correct for small instabilities - C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001); -} - -template -RAFT_KERNEL compute_correlation_warp_kernel(value_t* __restrict__ C, - const value_t* __restrict__ Q_sq_norms, - const value_t* __restrict__ R_sq_norms, - const value_t* __restrict__ Q_norms, - const value_t* __restrict__ R_norms, - value_idx n_rows, - value_idx n_cols, - value_idx n) -{ - std::size_t tid = blockDim.x * blockIdx.x + threadIdx.x; - value_idx i = tid / n_cols; - value_idx j = tid % n_cols; - - if (i >= n_rows || j >= n_cols) return; - - value_t dot = C[(size_t)i * n_cols + j]; - value_t Q_l1 = Q_norms[i]; - value_t R_l1 = R_norms[j]; - - value_t Q_l2 = Q_sq_norms[i]; - value_t R_l2 = R_sq_norms[j]; - - value_t numer = n * dot - (Q_l1 * R_l1); - value_t Q_denom = n * Q_l2 - (Q_l1 * Q_l1); - value_t R_denom = n * R_l2 - (R_l1 * R_l1); - - value_t val = 1 - (numer / raft::sqrt(Q_denom * R_denom)); - - // correct for small instabilities - C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001); -} - -template -void compute_euclidean(value_t* C, - const value_t* Q_sq_norms, - const value_t* R_sq_norms, - value_idx n_rows, - value_idx n_cols, - cudaStream_t stream, - expansion_f expansion_func) -{ - int blocks = raft::ceildiv((size_t)n_rows * n_cols, tpb); - compute_euclidean_warp_kernel<<>>( - C, Q_sq_norms, R_sq_norms, n_rows, n_cols, expansion_func); -} - -template -void compute_l2(value_t* out, - const value_idx* Q_coo_rows, - const value_t* Q_data, - value_idx Q_nnz, - const value_idx* R_coo_rows, - const value_t* R_data, - value_idx R_nnz, - value_idx m, - value_idx n, - cudaStream_t stream, - expansion_f expansion_func) -{ - rmm::device_uvector Q_sq_norms(m, stream); - rmm::device_uvector R_sq_norms(n, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Q_sq_norms.data(), 0, Q_sq_norms.size() * sizeof(value_t))); - RAFT_CUDA_TRY(cudaMemsetAsync(R_sq_norms.data(), 0, R_sq_norms.size() * sizeof(value_t))); - - compute_row_norm_kernel<<>>( - Q_sq_norms.data(), Q_coo_rows, Q_data, Q_nnz); - compute_row_norm_kernel<<>>( - R_sq_norms.data(), R_coo_rows, R_data, R_nnz); - - compute_euclidean(out, Q_sq_norms.data(), R_sq_norms.data(), m, n, stream, expansion_func); -} - -template -void compute_correlation(value_t* C, - const value_t* Q_sq_norms, - const value_t* R_sq_norms, - const value_t* Q_norms, - const value_t* R_norms, - value_idx n_rows, - value_idx n_cols, - value_idx n, - cudaStream_t stream) -{ - int blocks = raft::ceildiv((size_t)n_rows * n_cols, tpb); - compute_correlation_warp_kernel<<>>( - C, Q_sq_norms, R_sq_norms, Q_norms, R_norms, n_rows, n_cols, n); -} - -template -void compute_corr(value_t* out, - const value_idx* Q_coo_rows, - const value_t* Q_data, - value_idx Q_nnz, - const value_idx* R_coo_rows, - const value_t* R_data, - value_idx R_nnz, - value_idx m, - value_idx n, - value_idx n_cols, - cudaStream_t stream) -{ - // sum_sq for std dev - rmm::device_uvector Q_sq_norms(m, stream); - rmm::device_uvector R_sq_norms(n, stream); - - // sum for mean - rmm::device_uvector Q_norms(m, stream); - rmm::device_uvector R_norms(n, stream); - - RAFT_CUDA_TRY(cudaMemsetAsync(Q_sq_norms.data(), 0, Q_sq_norms.size() * sizeof(value_t))); - RAFT_CUDA_TRY(cudaMemsetAsync(R_sq_norms.data(), 0, R_sq_norms.size() * sizeof(value_t))); - - RAFT_CUDA_TRY(cudaMemsetAsync(Q_norms.data(), 0, Q_norms.size() * sizeof(value_t))); - RAFT_CUDA_TRY(cudaMemsetAsync(R_norms.data(), 0, R_norms.size() * sizeof(value_t))); - - compute_row_norm_kernel<<>>( - Q_sq_norms.data(), Q_coo_rows, Q_data, Q_nnz); - compute_row_norm_kernel<<>>( - R_sq_norms.data(), R_coo_rows, R_data, R_nnz); - - compute_row_sum_kernel<<>>( - Q_norms.data(), Q_coo_rows, Q_data, Q_nnz); - compute_row_sum_kernel<<>>( - R_norms.data(), R_coo_rows, R_data, R_nnz); - - compute_correlation(out, - Q_sq_norms.data(), - R_sq_norms.data(), - Q_norms.data(), - R_norms.data(), - m, - n, - n_cols, - stream); -} - -/** - * L2 distance using the expanded form: sum(x_k)^2 + sum(y_k)^2 - 2 * sum(x_k * y_k) - * The expanded form is more efficient for sparse data. - */ -template -class l2_expanded_distances_t : public distances_t { - public: - explicit l2_expanded_distances_t(const distances_config_t& config) - : config_(&config), ip_dists(config) - { - } - - void compute(value_t* out_dists) - { - ip_dists.compute(out_dists); - - value_idx* b_indices = ip_dists.b_rows_coo(); - value_t* b_data = ip_dists.b_data_coo(); - - rmm::device_uvector search_coo_rows(config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - raft::sparse::convert::csr_to_coo(config_->a_indptr, - config_->a_nrows, - search_coo_rows.data(), - config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - - compute_l2(out_dists, - search_coo_rows.data(), - config_->a_data, - config_->a_nnz, - b_indices, - b_data, - config_->b_nnz, - config_->a_nrows, - config_->b_nrows, - resource::get_cuda_stream(config_->handle), - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { - return -2 * dot + q_norm + r_norm; - }); - } - - ~l2_expanded_distances_t() = default; - - protected: - const distances_config_t* config_; - ip_distances_t ip_dists; -}; - -/** - * L2 sqrt distance performing the sqrt operation after the distance computation - * The expanded form is more efficient for sparse data. - */ -template -class l2_sqrt_expanded_distances_t : public l2_expanded_distances_t { - public: - explicit l2_sqrt_expanded_distances_t(const distances_config_t& config) - : l2_expanded_distances_t(config) - { - } - - void compute(value_t* out_dists) override - { - l2_expanded_distances_t::compute(out_dists); - // Sqrt Post-processing - raft::linalg::unaryOp( - out_dists, - out_dists, - this->config_->a_nrows * this->config_->b_nrows, - [] __device__(value_t input) { - int neg = input < 0 ? -1 : 1; - return raft::sqrt(abs(input) * neg); - }, - resource::get_cuda_stream(this->config_->handle)); - } - - ~l2_sqrt_expanded_distances_t() = default; -}; - -template -class correlation_expanded_distances_t : public distances_t { - public: - explicit correlation_expanded_distances_t(const distances_config_t& config) - : config_(&config), ip_dists(config) - { - } - - void compute(value_t* out_dists) - { - ip_dists.compute(out_dists); - - value_idx* b_indices = ip_dists.b_rows_coo(); - value_t* b_data = ip_dists.b_data_coo(); - - rmm::device_uvector search_coo_rows(config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - raft::sparse::convert::csr_to_coo(config_->a_indptr, - config_->a_nrows, - search_coo_rows.data(), - config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - - compute_corr(out_dists, - search_coo_rows.data(), - config_->a_data, - config_->a_nnz, - b_indices, - b_data, - config_->b_nnz, - config_->a_nrows, - config_->b_nrows, - config_->b_ncols, - resource::get_cuda_stream(config_->handle)); - } - - ~correlation_expanded_distances_t() = default; - - protected: - const distances_config_t* config_; - ip_distances_t ip_dists; -}; - -/** - * Cosine distance using the expanded form: 1 - ( sum(x_k * y_k) / (sqrt(sum(x_k)^2) * - * sqrt(sum(y_k)^2))) The expanded form is more efficient for sparse data. - */ -template -class cosine_expanded_distances_t : public distances_t { - public: - explicit cosine_expanded_distances_t(const distances_config_t& config) - : config_(&config), workspace(0, resource::get_cuda_stream(config.handle)), ip_dists(config) - { - } - - void compute(value_t* out_dists) - { - ip_dists.compute(out_dists); - - value_idx* b_indices = ip_dists.b_rows_coo(); - value_t* b_data = ip_dists.b_data_coo(); - - rmm::device_uvector search_coo_rows(config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - raft::sparse::convert::csr_to_coo(config_->a_indptr, - config_->a_nrows, - search_coo_rows.data(), - config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - - compute_l2(out_dists, - search_coo_rows.data(), - config_->a_data, - config_->a_nnz, - b_indices, - b_data, - config_->b_nnz, - config_->a_nrows, - config_->b_nrows, - resource::get_cuda_stream(config_->handle), - [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { - value_t norms = raft::sqrt(q_norm) * raft::sqrt(r_norm); - // deal with potential for 0 in denominator by forcing 0/1 instead - value_t cos = ((norms != 0) * dot) / ((norms == 0) + norms); - - // flip the similarity when both rows are 0 - bool both_empty = (q_norm == 0) && (r_norm == 0); - return 1 - ((!both_empty * cos) + both_empty); - }); - } - - ~cosine_expanded_distances_t() = default; - - private: - const distances_config_t* config_; - rmm::device_uvector workspace; - ip_distances_t ip_dists; -}; - -/** - * Hellinger distance using the expanded form: sqrt(1 - sum(sqrt(x_k) * sqrt(y_k))) - * The expanded form is more efficient for sparse data. - * - * This distance computation modifies A and B by computing a sqrt - * and then performing a `pow(x, 2)` to convert it back. Because of this, - * it is possible that the values in A and B might differ slightly - * after this is invoked. - */ -template -class hellinger_expanded_distances_t : public distances_t { - public: - explicit hellinger_expanded_distances_t(const distances_config_t& config) - : config_(&config), workspace(0, resource::get_cuda_stream(config.handle)) - { - } - - void compute(value_t* out_dists) - { - rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), - resource::get_cuda_stream(config_->handle)); - - raft::sparse::convert::csr_to_coo(config_->b_indptr, - config_->b_nrows, - coo_rows.data(), - config_->b_nnz, - resource::get_cuda_stream(config_->handle)); - - balanced_coo_pairwise_generalized_spmv( - out_dists, - *config_, - coo_rows.data(), - [] __device__(value_t a, value_t b) { return raft::sqrt(a) * raft::sqrt(b); }, - raft::add_op(), - raft::atomic_add_op()); - - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { - // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - bool rectifier = (1 - input) > 0; - return raft::sqrt(rectifier * (1 - input)); - }, - resource::get_cuda_stream(config_->handle)); - } - - ~hellinger_expanded_distances_t() = default; - - private: - const distances_config_t* config_; - rmm::device_uvector workspace; -}; - -template -class russelrao_expanded_distances_t : public distances_t { - public: - explicit russelrao_expanded_distances_t(const distances_config_t& config) - : config_(&config), workspace(0, resource::get_cuda_stream(config.handle)), ip_dists(config) - { - } - - void compute(value_t* out_dists) - { - ip_dists.compute(out_dists); - - value_t n_cols = config_->a_ncols; - value_t n_cols_inv = 1.0 / n_cols; - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return (n_cols - input) * n_cols_inv; }, - resource::get_cuda_stream(config_->handle)); - - auto exec_policy = rmm::exec_policy(resource::get_cuda_stream(config_->handle)); - auto diags = thrust::counting_iterator(0); - value_idx b_nrows = config_->b_nrows; - thrust::for_each(exec_policy, diags, diags + config_->a_nrows, [=] __device__(value_idx input) { - out_dists[input * b_nrows + input] = 0.0; - }); - } - - ~russelrao_expanded_distances_t() = default; - - private: - const distances_config_t* config_; - rmm::device_uvector workspace; - ip_distances_t ip_dists; -}; - -}; // END namespace detail -}; // END namespace distance -}; // END namespace sparse -}; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh deleted file mode 100644 index b178e02c34..0000000000 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ /dev/null @@ -1,333 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "common.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -template -void unexpanded_lp_distances(value_t* out_dists, - const distances_config_t* config_, - product_f product_func, - accum_f accum_func, - write_f write_func) -{ - rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), - resource::get_cuda_stream(config_->handle)); - - raft::sparse::convert::csr_to_coo(config_->b_indptr, - config_->b_nrows, - coo_rows.data(), - config_->b_nnz, - resource::get_cuda_stream(config_->handle)); - - balanced_coo_pairwise_generalized_spmv( - out_dists, *config_, coo_rows.data(), product_func, accum_func, write_func); - - raft::sparse::convert::csr_to_coo(config_->a_indptr, - config_->a_nrows, - coo_rows.data(), - config_->a_nnz, - resource::get_cuda_stream(config_->handle)); - - balanced_coo_pairwise_generalized_spmv_rev( - out_dists, *config_, coo_rows.data(), product_func, accum_func, write_func); -} - -/** - * Computes L1 distances for sparse input. This does not have - * an equivalent expanded form, so it is only executed in - * an unexpanded form. - * @tparam value_idx - * @tparam value_t - */ -template -class l1_unexpanded_distances_t : public distances_t { - public: - l1_unexpanded_distances_t(const distances_config_t& config) : config_(&config) - { - } - - void compute(value_t* out_dists) - { - unexpanded_lp_distances( - out_dists, config_, raft::absdiff_op(), raft::add_op(), raft::atomic_add_op()); - } - - private: - const distances_config_t* config_; -}; - -template -class l2_unexpanded_distances_t : public distances_t { - public: - l2_unexpanded_distances_t(const distances_config_t& config) : config_(&config) - { - } - - void compute(value_t* out_dists) - { - unexpanded_lp_distances( - out_dists, config_, raft::sqdiff_op(), raft::add_op(), raft::atomic_add_op()); - } - - protected: - const distances_config_t* config_; -}; - -template -class l2_sqrt_unexpanded_distances_t : public l2_unexpanded_distances_t { - public: - l2_sqrt_unexpanded_distances_t(const distances_config_t& config) - : l2_unexpanded_distances_t(config) - { - } - - void compute(value_t* out_dists) - { - l2_unexpanded_distances_t::compute(out_dists); - - uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; - // Sqrt Post-processing - raft::linalg::unaryOp( - out_dists, - out_dists, - n, - [] __device__(value_t input) { - int neg = input < 0 ? -1 : 1; - return raft::sqrt(abs(input) * neg); - }, - resource::get_cuda_stream(this->config_->handle)); - } -}; - -template -class linf_unexpanded_distances_t : public distances_t { - public: - explicit linf_unexpanded_distances_t(const distances_config_t& config) - : config_(&config) - { - } - - void compute(value_t* out_dists) - { - unexpanded_lp_distances( - out_dists, config_, raft::absdiff_op(), raft::max_op(), raft::atomic_max_op()); - } - - private: - const distances_config_t* config_; -}; - -template -class canberra_unexpanded_distances_t : public distances_t { - public: - explicit canberra_unexpanded_distances_t(const distances_config_t& config) - : config_(&config) - { - } - - void compute(value_t* out_dists) - { - unexpanded_lp_distances( - out_dists, - config_, - [] __device__(value_t a, value_t b) { - value_t d = fabs(a) + fabs(b); - - // deal with potential for 0 in denominator by - // forcing 1/0 instead - return ((d != 0) * fabs(a - b)) / (d + (d == 0)); - }, - raft::add_op(), - raft::atomic_add_op()); - } - - private: - const distances_config_t* config_; -}; - -template -class lp_unexpanded_distances_t : public distances_t { - public: - explicit lp_unexpanded_distances_t(const distances_config_t& config, - value_t p_) - : config_(&config), p(p_) - { - } - - void compute(value_t* out_dists) - { - unexpanded_lp_distances( - out_dists, - config_, - raft::compose_op(raft::pow_const_op(p), raft::sub_op()), - raft::add_op(), - raft::atomic_add_op()); - - uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; - value_t one_over_p = value_t{1} / p; - raft::linalg::unaryOp(out_dists, - out_dists, - n, - raft::pow_const_op(one_over_p), - resource::get_cuda_stream(config_->handle)); - } - - private: - const distances_config_t* config_; - value_t p; -}; - -template -class hamming_unexpanded_distances_t : public distances_t { - public: - explicit hamming_unexpanded_distances_t(const distances_config_t& config) - : config_(&config) - { - } - - void compute(value_t* out_dists) - { - unexpanded_lp_distances( - out_dists, config_, raft::notequal_op(), raft::add_op(), raft::atomic_add_op()); - - uint64_t n = (uint64_t)config_->a_nrows * (uint64_t)config_->b_nrows; - value_t n_cols = 1.0 / config_->a_ncols; - raft::linalg::unaryOp(out_dists, - out_dists, - n, - raft::mul_const_op(n_cols), - resource::get_cuda_stream(config_->handle)); - } - - private: - const distances_config_t* config_; -}; - -template -class jensen_shannon_unexpanded_distances_t : public distances_t { - public: - explicit jensen_shannon_unexpanded_distances_t( - const distances_config_t& config) - : config_(&config) - { - } - - void compute(value_t* out_dists) - { - unexpanded_lp_distances( - out_dists, - config_, - [] __device__(value_t a, value_t b) { - value_t m = 0.5f * (a + b); - bool a_zero = a == 0; - bool b_zero = b == 0; - - value_t x = (!a_zero * m) / (a_zero + a); - value_t y = (!b_zero * m) / (b_zero + b); - - bool x_zero = x == 0; - bool y_zero = y == 0; - - return (-a * (!x_zero * log(x + x_zero))) + (-b * (!y_zero * log(y + y_zero))); - }, - raft::add_op(), - raft::atomic_add_op()); - - uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; - raft::linalg::unaryOp( - out_dists, - out_dists, - n, - [=] __device__(value_t input) { return raft::sqrt(0.5 * input); }, - resource::get_cuda_stream(config_->handle)); - } - - private: - const distances_config_t* config_; -}; - -template -class kl_divergence_unexpanded_distances_t : public distances_t { - public: - explicit kl_divergence_unexpanded_distances_t( - const distances_config_t& config) - : config_(&config) - { - } - - void compute(value_t* out_dists) - { - rmm::device_uvector coo_rows(std::max(config_->b_nnz, config_->a_nnz), - resource::get_cuda_stream(config_->handle)); - - raft::sparse::convert::csr_to_coo(config_->b_indptr, - config_->b_nrows, - coo_rows.data(), - config_->b_nnz, - resource::get_cuda_stream(config_->handle)); - - balanced_coo_pairwise_generalized_spmv( - out_dists, - *config_, - coo_rows.data(), - [] __device__(value_t a, value_t b) { return a * log(a / b); }, - raft::add_op(), - raft::atomic_add_op()); - - uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; - raft::linalg::unaryOp(out_dists, - out_dists, - n, - raft::mul_const_op(0.5), - resource::get_cuda_stream(config_->handle)); - } - - private: - const distances_config_t* config_; -}; - -}; // END namespace detail -}; // END namespace distance -}; // END namespace sparse -}; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh deleted file mode 100644 index 864d61ba2f..0000000000 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include - -namespace raft { -namespace sparse { -namespace distance { -namespace detail { - -/** - * Computes the maximum number of columns that can be stored - * in shared memory in dense form with the given block size - * and precision. - * @return the maximum number of columns that can be stored in smem - */ -template -inline int max_cols_per_block() -{ - // max cols = (total smem available - cub reduction smem) - return (raft::getSharedMemPerBlock() - ((tpb / raft::warp_size()) * sizeof(value_t))) / - sizeof(value_t); -} - -template -RAFT_KERNEL faster_dot_on_csr_kernel(dot_t* __restrict__ dot, - const value_idx* __restrict__ indptr, - const value_idx* __restrict__ cols, - const value_t* __restrict__ A, - const value_t* __restrict__ B, - const value_idx nnz, - const value_idx n_rows, - const value_idx dim) -{ - auto vec_id = threadIdx.x; - auto lane_id = threadIdx.x & 0x1f; - - extern __shared__ char smem[]; - value_t* s_A = (value_t*)smem; - value_idx cur_row = -1; - - for (int row = blockIdx.x; row < n_rows; row += gridDim.x) { - for (int dot_id = blockIdx.y + indptr[row]; dot_id < indptr[row + 1]; dot_id += gridDim.y) { - if (dot_id >= nnz) { return; } - const value_idx col = cols[dot_id] * dim; - const value_t* __restrict__ B_col = B + col; - - if (threadIdx.x == 0) { dot[dot_id] = 0.0; } - __syncthreads(); - - if (cur_row != row) { - for (value_idx k = vec_id; k < dim; k += blockDim.x) { - s_A[k] = A[row * dim + k]; - } - cur_row = row; - } - - dot_t l_dot_ = 0.0; - for (value_idx k = vec_id; k < dim; k += blockDim.x) { - asm("prefetch.global.L2 [%0];" ::"l"(B_col + k + blockDim.x)); - if constexpr ((std::is_same_v && std::is_same_v)) { - l_dot_ += __half2float(s_A[k]) * __half2float(__ldcg(B_col + k)); - } else { - l_dot_ += s_A[k] * __ldcg(B_col + k); - } - } - - typedef cub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage; - dot_t warp_sum = WarpReduce(temp_storage).Sum(l_dot_); - - if (lane_id == 0) { atomicAdd_block(dot + dot_id, warp_sum); } - } - } -} - -template -void faster_dot_on_csr(raft::resources const& handle, - dot_t* dot, - const value_idx nnz, - const value_idx* indptr, - const value_idx* cols, - const value_t* A, - const value_t* B, - const value_idx n_rows, - const value_idx dim) -{ - if (nnz == 0 || n_rows == 0) return; - - auto stream = resource::get_cuda_stream(handle); - - constexpr value_idx MAX_ROW_PER_ITER = 500; - int dev_id, sm_count, blocks_per_sm; - - const int smem_size = dim * sizeof(value_t); - cudaGetDevice(&dev_id); - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - - if (dim < 128) { - constexpr int tpb = 64; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); - auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); - auto block_y = - (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; - dim3 blocks(block_x, block_y, 1); - - faster_dot_on_csr_kernel - <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); - - } else if (dim < 256) { - constexpr int tpb = 128; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); - auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); - auto block_y = - (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; - dim3 blocks(block_x, block_y, 1); - - faster_dot_on_csr_kernel - <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); - } else if (dim < 512) { - constexpr int tpb = 256; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); - auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); - auto block_y = - (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; - dim3 blocks(block_x, block_y, 1); - - faster_dot_on_csr_kernel - <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); - } else { - constexpr int tpb = 512; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); - auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); - auto block_y = - (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; - dim3 blocks(block_x, block_y, 1); - - faster_dot_on_csr_kernel - <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); - } - - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -} // namespace detail -} // namespace distance -} // namespace sparse -} // namespace raft diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh deleted file mode 100644 index ead44f0c51..0000000000 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __SPARSE_DIST_H -#define __SPARSE_DIST_H - -#pragma once - -#include "detail/common.hpp" - -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace sparse { -namespace distance { - -static const std::unordered_set supportedDistance{ - raft::distance::DistanceType::L2Expanded, - raft::distance::DistanceType::L2Unexpanded, - raft::distance::DistanceType::L2SqrtExpanded, - raft::distance::DistanceType::L2SqrtUnexpanded, - raft::distance::DistanceType::InnerProduct, - raft::distance::DistanceType::L1, - raft::distance::DistanceType::Canberra, - raft::distance::DistanceType::Linf, - raft::distance::DistanceType::LpUnexpanded, - raft::distance::DistanceType::JaccardExpanded, - raft::distance::DistanceType::CosineExpanded, - raft::distance::DistanceType::HellingerExpanded, - raft::distance::DistanceType::DiceExpanded, - raft::distance::DistanceType::CorrelationExpanded, - raft::distance::DistanceType::RusselRaoExpanded, - raft::distance::DistanceType::HammingUnexpanded, - raft::distance::DistanceType::JensenShannon, - raft::distance::DistanceType::KLDivergence}; - -/** - * Compute pairwise distances between A and B, using the provided - * input configuration and distance function. - * - * @tparam value_idx index type - * @tparam value_t value type - * @param[out] out dense output array (size A.nrows * B.nrows) - * @param[in] input_config input argument configuration - * @param[in] metric distance metric to use - * @param[in] metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwiseDistance(value_t* out, - detail::distances_config_t input_config, - raft::distance::DistanceType metric, - float metric_arg) -{ - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - detail::l2_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::L2SqrtExpanded: - detail::l2_sqrt_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::InnerProduct: - detail::ip_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::L2Unexpanded: - detail::l2_unexpanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - detail::l2_sqrt_unexpanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::L1: - detail::l1_unexpanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::LpUnexpanded: - detail::lp_unexpanded_distances_t(input_config, metric_arg).compute(out); - break; - case raft::distance::DistanceType::Linf: - detail::linf_unexpanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::Canberra: - detail::canberra_unexpanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::JaccardExpanded: - detail::jaccard_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::CosineExpanded: - detail::cosine_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::HellingerExpanded: - detail::hellinger_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::DiceExpanded: - detail::dice_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::CorrelationExpanded: - detail::correlation_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::RusselRaoExpanded: - detail::russelrao_expanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::HammingUnexpanded: - detail::hamming_unexpanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::JensenShannon: - detail::jensen_shannon_unexpanded_distances_t(input_config).compute(out); - break; - case raft::distance::DistanceType::KLDivergence: - detail::kl_divergence_unexpanded_distances_t(input_config).compute(out); - break; - - default: THROW("Unsupported distance: %d", metric); - } -} - -/** - * @defgroup sparse_distance Sparse Pairwise Distance - * @{ - */ - -/** - * @brief Compute pairwise distances between x and y, using the provided - * input configuration and distance function. - * - * @code{.cpp} - * #include - * #include - * #include - * - * int x_n_rows = 100000; - * int y_n_rows = 50000; - * int n_cols = 10000; - * - * raft::device_resources handle; - * auto x = raft::make_device_csr_matrix(handle, x_n_rows, n_cols); - * auto y = raft::make_device_csr_matrix(handle, y_n_rows, n_cols); - * - * ... - * // populate data - * ... - * - * auto out = raft::make_device_matrix(handle, x_nrows, y_nrows); - * auto metric = raft::distance::DistanceType::L2Expanded; - * raft::sparse::distance::pairwise_distance(handle, x.view(), y.view(), out, metric); - * @endcode - * - * @tparam DeviceCSRMatrix raft::device_csr_matrix or raft::device_csr_matrix_view - * @tparam ElementType data-type of inputs and output - * @tparam IndexType data-type for indexing - * - * @param[in] handle raft::resources - * @param[in] x raft::device_csr_matrix_view - * @param[in] y raft::device_csr_matrix_view - * @param[out] dist raft::device_matrix_view dense matrix - * @param[in] metric distance metric to use - * @param[in] metric_arg metric argument (used for Minkowski distance) - */ -template >> -void pairwise_distance(raft::resources const& handle, - DeviceCSRMatrix x, - DeviceCSRMatrix y, - raft::device_matrix_view dist, - raft::distance::DistanceType metric, - float metric_arg = 2.0f) -{ - auto x_structure = x.structure_view(); - auto y_structure = y.structure_view(); - - RAFT_EXPECTS(x_structure.get_n_cols() == y_structure.get_n_cols(), - "Number of columns must be equal"); - - RAFT_EXPECTS(dist.extent(0) == x_structure.get_n_rows(), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y_structure.get_n_rows(), - "Number of columns in output must be equal to " - "number of rows in Y"); - - detail::distances_config_t input_config(handle); - input_config.a_nrows = x_structure.get_n_rows(); - input_config.a_ncols = x_structure.get_n_cols(); - input_config.a_nnz = x_structure.get_nnz(); - input_config.a_indptr = const_cast(x_structure.get_indptr().data()); - input_config.a_indices = const_cast(x_structure.get_indices().data()); - input_config.a_data = const_cast(x.get_elements().data()); - - input_config.b_nrows = y_structure.get_n_rows(); - input_config.b_ncols = y_structure.get_n_cols(); - input_config.b_nnz = y_structure.get_nnz(); - input_config.b_indptr = const_cast(y_structure.get_indptr().data()); - input_config.b_indices = const_cast(y_structure.get_indices().data()); - input_config.b_data = const_cast(y.get_elements().data()); - - pairwiseDistance(dist.data_handle(), input_config, metric, metric_arg); -} - -/** @} */ // end of sparse_distance - -}; // namespace distance -}; // namespace sparse -}; // namespace raft - -#endif \ No newline at end of file diff --git a/cpp/include/raft/sparse/hierarchy/common.h b/cpp/include/raft/sparse/hierarchy/common.h deleted file mode 100644 index 6ac0fc3b4b..0000000000 --- a/cpp/include/raft/sparse/hierarchy/common.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use raft/cluster/single_linkage_types.hpp instead.") -#endif - -#include - -namespace raft::hierarchy { -using raft::cluster::linkage_output; -using raft::cluster::linkage_output_int; -using raft::cluster::linkage_output_int64; -using raft::cluster::LinkageDistance; -} // namespace raft::hierarchy diff --git a/cpp/include/raft/sparse/hierarchy/single_linkage.cuh b/cpp/include/raft/sparse/hierarchy/single_linkage.cuh deleted file mode 100644 index d21b2a87a6..0000000000 --- a/cpp/include/raft/sparse/hierarchy/single_linkage.cuh +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the raft/cluster version instead.") -#endif - -#include -#include - -namespace raft::hierarchy { -using raft::cluster::single_linkage; -} diff --git a/cpp/include/raft/sparse/neighbors/brute_force.cuh b/cpp/include/raft/sparse/neighbors/brute_force.cuh deleted file mode 100644 index 47e00a012f..0000000000 --- a/cpp/include/raft/sparse/neighbors/brute_force.cuh +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -namespace raft::sparse::neighbors::brute_force { - -/** - * Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors - * using some distance implementation - * @param[in] idxIndptr csr indptr of the index matrix (size n_idx_rows + 1) - * @param[in] idxIndices csr column indices array of the index matrix (size n_idx_nnz) - * @param[in] idxData csr data array of the index matrix (size idxNNZ) - * @param[in] idxNNZ number of non-zeros for sparse index matrix - * @param[in] n_idx_rows number of data samples in index matrix - * @param[in] n_idx_cols - * @param[in] queryIndptr csr indptr of the query matrix (size n_query_rows + 1) - * @param[in] queryIndices csr indices array of the query matrix (size queryNNZ) - * @param[in] queryData csr data array of the query matrix (size queryNNZ) - * @param[in] queryNNZ number of non-zeros for sparse query matrix - * @param[in] n_query_rows number of data samples in query matrix - * @param[in] n_query_cols number of features in query matrix - * @param[out] output_indices dense matrix for output indices (size n_query_rows * k) - * @param[out] output_dists dense matrix for output distances (size n_query_rows * k) - * @param[in] k the number of neighbors to query - * @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to - * @param[in] batch_size_index maximum number of rows to use from index matrix per batch - * @param[in] batch_size_query maximum number of rows to use from query matrix per batch - * @param[in] metric distance metric/measure to use - * @param[in] metricArg potential argument for metric (currently unused) - */ -template -void knn(const value_idx* idxIndptr, - const value_idx* idxIndices, - const value_t* idxData, - size_t idxNNZ, - int n_idx_rows, - int n_idx_cols, - const value_idx* queryIndptr, - const value_idx* queryIndices, - const value_t* queryData, - size_t queryNNZ, - int n_query_rows, - int n_query_cols, - value_idx* output_indices, - value_t* output_dists, - int k, - raft::resources const& handle, - size_t batch_size_index = 2 << 14, // approx 1M - size_t batch_size_query = 2 << 14, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0) -{ - detail::sparse_knn_t(idxIndptr, - idxIndices, - idxData, - idxNNZ, - n_idx_rows, - n_idx_cols, - queryIndptr, - queryIndices, - queryData, - queryNNZ, - n_query_rows, - n_query_cols, - output_indices, - output_dists, - k, - handle, - batch_size_index, - batch_size_query, - metric, - metricArg) - .run(); -} - -}; // namespace raft::sparse::neighbors::brute_force diff --git a/cpp/include/raft/sparse/neighbors/cross_component_nn.cuh b/cpp/include/raft/sparse/neighbors/cross_component_nn.cuh deleted file mode 100644 index c94c6254c3..0000000000 --- a/cpp/include/raft/sparse/neighbors/cross_component_nn.cuh +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace raft::sparse::neighbors { - -template -using FixConnectivitiesRedOp = detail::FixConnectivitiesRedOp; - -/** - * Gets the number of unique components from array of - * colors or labels. This does not assume the components are - * drawn from a monotonically increasing set. - * @tparam value_idx - * @param[in] colors array of components - * @param[in] n_rows size of components array - * @param[in] stream cuda stream for which to order cuda operations - * @return total number of components - */ -template -value_idx get_n_components(value_idx* colors, size_t n_rows, cudaStream_t stream) -{ - return detail::get_n_components(colors, n_rows, stream); -} - -/** - * Connects the components of an otherwise unconnected knn graph - * by computing a 1-nn to neighboring components of each data point - * (e.g. component(nn) != component(self)) and reducing the results to - * include the set of smallest destination components for each source - * component. The result will not necessarily contain - * n_components^2 - n_components number of elements because many components - * will likely not be contained in the neighborhoods of 1-nns. - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[out] out output edge list containing nearest cross-component - * edges. - * @param[in] X original (row-major) dense matrix for which knn graph should be constructed. - * @param[in] orig_colors array containing component number for each row of X - * @param[in] n_rows number of rows in X - * @param[in] n_cols number of cols in X - * @param[in] reduction_op reduction operation for computing nearest neighbors. The reduction - * operation must have `gather` and `scatter` functions defined - * @param[in] row_batch_size the batch size for computing nearest neighbors. This parameter controls - * the number of samples for which the nearest neighbors are computed at once. Therefore, it affects - * the memory consumption mainly by reducing the size of the adjacency matrix for masked nearest - * neighbors computation - * @param[in] col_batch_size the input data is sorted and 'unsorted' based on color. An additional - * scratch space buffer of shape (n_rows, col_batch_size) is created for this. Usually, this - * parameter affects the memory consumption more drastically than the row_batch_size with a marginal - * increase in compute time as the col_batch_size is reduced - * @param[in] metric distance metric - */ -template -void cross_component_nn( - raft::resources const& handle, - raft::sparse::COO& out, - const value_t* X, - const value_idx* orig_colors, - size_t n_rows, - size_t n_cols, - red_op reduction_op, - size_t row_batch_size = 0, - size_t col_batch_size = 0, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) -{ - detail::cross_component_nn(handle, - out, - X, - orig_colors, - n_rows, - n_cols, - reduction_op, - row_batch_size, - col_batch_size, - metric); -} - -}; // end namespace raft::sparse::neighbors \ No newline at end of file diff --git a/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh b/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh deleted file mode 100644 index a47d5a6f34..0000000000 --- a/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh +++ /dev/null @@ -1,541 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace raft::sparse::neighbors::detail { - -/** - * Base functor with reduction ops for performing masked 1-nn - * computation. - * @tparam value_idx - * @tparam value_t - */ -template -struct FixConnectivitiesRedOp { - value_idx m; - - // default constructor for cutlass - DI FixConnectivitiesRedOp() : m(0) {} - - FixConnectivitiesRedOp(value_idx m_) : m(m_){}; - - typedef typename raft::KeyValuePair KVP; - DI void operator()(value_idx rit, KVP* out, const KVP& other) const - { - if (rit < m && other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - - DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const - { - if (rit < m && a.value < b.value) { - return a; - } else - return b; - } - - DI void init(value_t* out, value_t maxVal) const { *out = maxVal; } - DI void init(KVP* out, value_t maxVal) const - { - out->key = -1; - out->value = maxVal; - } - - DI void init_key(value_t& out, value_idx idx) const { return; } - DI void init_key(KVP& out, value_idx idx) const { out.key = idx; } - - DI value_t get_value(KVP& out) const { return out.value; } - - DI value_t get_value(value_t& out) const { return out; } - - /** The gather and scatter ensure that operator() is still consistent after rearranging the data. - * TODO (tarang-jain): refactor cross_component_nn API to separate out the gather and scatter - * functions from the reduction op. Reference: https://github.com/rapidsai/raft/issues/1614 */ - void gather(const raft::resources& handle, value_idx* map) {} - - void scatter(const raft::resources& handle, value_idx* map) {} -}; - -/** - * Assumes 3-iterator tuple containing COO rows, cols, and - * a cub keyvalue pair object. Sorts the 3 arrays in - * ascending order: row->col->keyvaluepair - */ -struct TupleComp { - template - __host__ __device__ bool operator()(const one& t1, const two& t2) - { - // sort first by each sample's color, - if (thrust::get<0>(t1) < thrust::get<0>(t2)) return true; - if (thrust::get<0>(t1) > thrust::get<0>(t2)) return false; - - // then by the color of each sample's closest neighbor, - if (thrust::get<1>(t1) < thrust::get<1>(t2)) return true; - if (thrust::get<1>(t1) > thrust::get<1>(t2)) return false; - - // then sort by value in descending order - return thrust::get<2>(t1).value < thrust::get<2>(t2).value; - } -}; - -template -struct CubKVPMinReduce { - typedef raft::KeyValuePair KVP; - - DI KVP - - operator()(LabelT rit, const KVP& a, const KVP& b) - { - return b.value < a.value ? b : a; - } - - DI KVP - - operator()(const KVP& a, const KVP& b) - { - return b.value < a.value ? b : a; - } - -}; // KVPMinReduce - -/** - * Gets the number of unique components from array of - * colors or labels. This does not assume the components are - * drawn from a monotonically increasing set. - * @tparam value_idx - * @param[in] colors array of components - * @param[in] n_rows size of components array - * @param[in] stream cuda stream for which to order cuda operations - * @return total number of components - */ -template -value_idx get_n_components(value_idx* colors, size_t n_rows, cudaStream_t stream) -{ - rmm::device_uvector map_ids(0, stream); - int num_clusters = raft::label::getUniquelabels(map_ids, colors, n_rows, stream); - return num_clusters; -} - -/** - * Functor to look up a component for a vertex - * @tparam value_idx - * @tparam value_t - */ -template -struct LookupColorOp { - value_idx* colors; - - LookupColorOp(value_idx* colors_) : colors(colors_) {} - - DI value_idx - - operator()(const raft::KeyValuePair& kvp) - { - return colors[kvp.key]; - } -}; - -/** - * Compute the cross-component 1-nearest neighbors for each row in X using - * the given array of components - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[out] kvp mapping of closest neighbor vertex and distance for each vertex in the given - * array of components - * @param[out] nn_colors components of nearest neighbors for each vertex - * @param[in] colors components of each vertex - * @param[in] X original dense data - * @param[in] n_rows number of rows in original dense data - * @param[in] n_cols number of columns in original dense data - * @param[in] row_batch_size row batch size for computing nearest neighbors - * @param[in] col_batch_size column batch size for sorting and 'unsorting' - * @param[in] reduction_op reduction operation for computing nearest neighbors - */ -template -void perform_1nn(raft::resources const& handle, - raft::KeyValuePair* kvp, - value_idx* nn_colors, - value_idx* colors, - const value_t* X, - size_t n_rows, - size_t n_cols, - size_t row_batch_size, - size_t col_batch_size, - red_op reduction_op) -{ - auto stream = resource::get_cuda_stream(handle); - auto exec_policy = resource::get_thrust_policy(handle); - - auto sort_plan = raft::make_device_vector(handle, (value_idx)n_rows); - raft::linalg::map_offset(handle, sort_plan.view(), [] __device__(value_idx idx) { return idx; }); - - thrust::sort_by_key( - resource::get_thrust_policy(handle), colors, colors + n_rows, sort_plan.data_handle()); - - // Modify the reduction operation based on the sort plan. - reduction_op.gather(handle, sort_plan.data_handle()); - - auto X_mutable_view = - raft::make_device_matrix_view(const_cast(X), n_rows, n_cols); - auto sort_plan_const_view = - raft::make_device_vector_view(sort_plan.data_handle(), n_rows); - raft::matrix::gather(handle, X_mutable_view, sort_plan_const_view, (value_idx)col_batch_size); - - // Get the number of unique components from the array of colors - value_idx n_components = get_n_components(colors, n_rows, stream); - - // colors_group_idxs is an array containing the *end* indices of each color - // component in colors. That is, the value of colors_group_idxs[j] indicates - // the start of color j + 1, i.e., it is the inclusive scan of the sizes of - // the color components. - auto colors_group_idxs = raft::make_device_vector(handle, n_components + 1); - raft::sparse::convert::sorted_coo_to_csr( - colors, n_rows, colors_group_idxs.data_handle(), n_components + 1, stream); - - auto group_idxs_view = raft::make_device_vector_view( - colors_group_idxs.data_handle() + 1, n_components); - - auto x_norm = raft::make_device_vector(handle, (value_idx)n_rows); - raft::linalg::rowNorm( - x_norm.data_handle(), X, n_cols, n_rows, raft::linalg::L2Norm, true, stream); - - auto adj = raft::make_device_matrix(handle, row_batch_size, n_components); - using OutT = raft::KeyValuePair; - using ParamT = raft::distance::masked_l2_nn_params; - - bool apply_sqrt = true; - bool init_out_buffer = true; - ParamT params{reduction_op, reduction_op, apply_sqrt, init_out_buffer}; - - auto X_full_view = raft::make_device_matrix_view(X, n_rows, n_cols); - - size_t n_batches = raft::ceildiv(n_rows, row_batch_size); - - for (size_t bid = 0; bid < n_batches; bid++) { - size_t batch_offset = bid * row_batch_size; - size_t rows_per_batch = min(row_batch_size, n_rows - batch_offset); - - auto X_batch_view = raft::make_device_matrix_view( - X + batch_offset * n_cols, rows_per_batch, n_cols); - - auto x_norm_batch_view = raft::make_device_vector_view( - x_norm.data_handle() + batch_offset, rows_per_batch); - - auto mask_op = [colors, - n_components = raft::util::FastIntDiv(n_components), - batch_offset] __device__(value_idx idx) { - value_idx row = idx / n_components; - value_idx col = idx % n_components; - return colors[batch_offset + row] != col; - }; - - auto adj_vector_view = raft::make_device_vector_view( - adj.data_handle(), rows_per_batch * n_components); - - raft::linalg::map_offset(handle, adj_vector_view, mask_op); - - auto adj_view = raft::make_device_matrix_view( - adj.data_handle(), rows_per_batch, n_components); - - auto kvp_view = - raft::make_device_vector_view, value_idx>( - kvp + batch_offset, rows_per_batch); - - raft::distance::masked_l2_nn(handle, - params, - X_batch_view, - X_full_view, - x_norm_batch_view, - x_norm.view(), - adj_view, - group_idxs_view, - kvp_view); - } - - // Transform the keys so that they correctly point to the unpermuted indices. - thrust::transform(exec_policy, - kvp, - kvp + n_rows, - kvp, - [sort_plan = sort_plan.data_handle()] __device__(OutT KVP) { - OutT res; - res.value = KVP.value; - res.key = sort_plan[KVP.key]; - return res; - }); - - // Undo permutation of the rows of X by scattering in place. - raft::matrix::scatter(handle, X_mutable_view, sort_plan_const_view, (value_idx)col_batch_size); - - // Undo permutation of the key-value pair and color vectors. This is not done - // inplace, so using two temporary vectors. - auto tmp_colors = raft::make_device_vector(handle, n_rows); - auto tmp_kvp = raft::make_device_vector(handle, n_rows); - - thrust::scatter(exec_policy, kvp, kvp + n_rows, sort_plan.data_handle(), tmp_kvp.data_handle()); - thrust::scatter( - exec_policy, colors, colors + n_rows, sort_plan.data_handle(), tmp_colors.data_handle()); - reduction_op.scatter(handle, sort_plan.data_handle()); - - raft::copy_async(colors, tmp_colors.data_handle(), n_rows, stream); - raft::copy_async(kvp, tmp_kvp.data_handle(), n_rows, stream); - - LookupColorOp extract_colors_op(colors); - thrust::transform(exec_policy, kvp, kvp + n_rows, nn_colors, extract_colors_op); -} - -/** - * Sort nearest neighboring components wrt component of source vertices - * @tparam value_idx - * @tparam value_t - * @param[inout] colors components array of source vertices - * @param[inout] nn_colors nearest neighbors components array - * @param[inout] kvp nearest neighbor source vertex / distance array - * @param[inout] src_indices array of source vertex indices which will become arg_sort - * indices - * @param n_rows number of components in `colors` - * @param stream stream for which to order CUDA operations - */ -template -void sort_by_color(raft::resources const& handle, - value_idx* colors, - value_idx* nn_colors, - raft::KeyValuePair* kvp, - value_idx* src_indices, - size_t n_rows) -{ - auto exec_policy = resource::get_thrust_policy(handle); - thrust::counting_iterator arg_sort_iter(0); - thrust::copy(exec_policy, arg_sort_iter, arg_sort_iter + n_rows, src_indices); - - auto keys = thrust::make_zip_iterator( - thrust::make_tuple(colors, nn_colors, (KeyValuePair*)kvp)); - auto vals = thrust::make_zip_iterator(thrust::make_tuple(src_indices)); - // get all the colors in contiguous locations so we can map them to warps. - thrust::sort_by_key(exec_policy, keys, keys + n_rows, vals, TupleComp()); -} - -template -RAFT_KERNEL min_components_by_color_kernel(value_idx* out_rows, - value_idx* out_cols, - value_t* out_vals, - const value_idx* out_index, - const value_idx* indices, - const raft::KeyValuePair* kvp, - size_t nnz) -{ - size_t tid = blockDim.x * blockIdx.x + threadIdx.x; - - if (tid >= nnz) return; - - int idx = out_index[tid]; - - if ((tid == 0 || (out_index[tid - 1] != idx))) { - out_rows[idx] = indices[tid]; - out_cols[idx] = kvp[tid].key; - out_vals[idx] = kvp[tid].value; - } -} - -/** - * Computes the min set of unique components that neighbor the - * components of each source vertex. - * @tparam value_idx - * @tparam value_t - * @param[out] coo output edge list - * @param[in] out_index output indptr for ordering edge list - * @param[in] indices indices of source vertices for each component - * @param[in] kvp indices and distances of each destination vertex for each component - * @param[in] n_colors number of components - * @param[in] stream cuda stream for which to order cuda operations - */ -template -void min_components_by_color(raft::sparse::COO& coo, - const value_idx* out_index, - const value_idx* indices, - const raft::KeyValuePair* kvp, - size_t nnz, - cudaStream_t stream) -{ - /** - * Arrays should be ordered by: colors_indptr->colors_n->kvp.value - * so the last element of each column in the input CSR should be - * the min. - */ - min_components_by_color_kernel<<>>( - coo.rows(), coo.cols(), coo.vals(), out_index, indices, kvp, nnz); -} - -/** - * Connects the components of an otherwise unconnected knn graph - * by computing a 1-nn to neighboring components of each data point - * (e.g. component(nn) != component(self)) and reducing the results to - * include the set of smallest destination components for each source - * component. The result will not necessarily contain - * n_components^2 - n_components number of elements because many components - * will likely not be contained in the neighborhoods of 1-nns. - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[out] out output edge list containing nearest cross-component - * edges. - * @param[in] X original (row-major) dense matrix for which knn graph should be constructed. - * @param[in] orig_colors array containing component number for each row of X - * @param[in] n_rows number of rows in X - * @param[in] n_cols number of cols in X - * @param[in] reduction_op reduction operation for computing nearest neighbors. The reduction - * operation must have `gather` and `scatter` functions defined - * @param[in] row_batch_size the batch size for computing nearest neighbors. This parameter controls - * the number of samples for which the nearest neighbors are computed at once. Therefore, it affects - * the memory consumption mainly by reducing the size of the adjacency matrix for masked nearest - * neighbors computation. default 0 indicates that no batching is done - * @param[in] col_batch_size the input data is sorted and 'unsorted' based on color. An additional - * scratch space buffer of shape (n_rows, col_batch_size) is created for this. Usually, this - * parameter affects the memory consumption more drastically than the col_batch_size with a marginal - * increase in compute time as the col_batch_size is reduced. default 0 indicates that no batching - * is done - * @param[in] metric distance metric - */ -template -void cross_component_nn( - raft::resources const& handle, - raft::sparse::COO& out, - const value_t* X, - const value_idx* orig_colors, - size_t n_rows, - size_t n_cols, - red_op reduction_op, - size_t row_batch_size, - size_t col_batch_size, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) -{ - auto stream = resource::get_cuda_stream(handle); - - RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded, - "Fixing connectivities for an unconnected k-NN graph only " - "supports L2SqrtExpanded currently."); - - if (row_batch_size == 0 || row_batch_size > n_rows) { row_batch_size = n_rows; } - - if (col_batch_size == 0 || col_batch_size > n_cols) { col_batch_size = n_cols; } - - rmm::device_uvector colors(n_rows, stream); - - // Normalize colors so they are drawn from a monotonically increasing set - constexpr bool zero_based = true; - raft::label::make_monotonic( - colors.data(), const_cast(orig_colors), n_rows, stream, zero_based); - - /** - * First compute 1-nn for all colors where the color of each data point - * is guaranteed to be != color of its nearest neighbor. - */ - rmm::device_uvector nn_colors(n_rows, stream); - rmm::device_uvector> temp_inds_dists(n_rows, stream); - rmm::device_uvector src_indices(n_rows, stream); - - perform_1nn(handle, - temp_inds_dists.data(), - nn_colors.data(), - colors.data(), - X, - n_rows, - n_cols, - row_batch_size, - col_batch_size, - reduction_op); - - /** - * Sort data points by color (neighbors are not sorted) - */ - // max_color + 1 = number of connected components - // sort nn_colors by key w/ original colors - sort_by_color( - handle, colors.data(), nn_colors.data(), temp_inds_dists.data(), src_indices.data(), n_rows); - - /** - * Take the min for any duplicate colors - */ - // Compute mask of duplicates - rmm::device_uvector out_index(n_rows + 1, stream); - raft::sparse::op::compute_duplicates_mask( - out_index.data(), colors.data(), nn_colors.data(), n_rows, stream); - - thrust::exclusive_scan(resource::get_thrust_policy(handle), - out_index.data(), - out_index.data() + out_index.size(), - out_index.data()); - - // compute final size - value_idx size = 0; - raft::update_host(&size, out_index.data() + (out_index.size() - 1), 1, stream); - resource::sync_stream(handle, stream); - - size++; - - raft::sparse::COO min_edges(stream); - min_edges.allocate(size, n_rows, n_rows, true, stream); - - min_components_by_color( - min_edges, out_index.data(), src_indices.data(), temp_inds_dists.data(), n_rows, stream); - - /** - * Symmetrize resulting edge list - */ - raft::sparse::linalg::symmetrize( - handle, min_edges.rows(), min_edges.cols(), min_edges.vals(), n_rows, n_rows, size, out); -} - -}; // end namespace raft::sparse::neighbors::detail diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh deleted file mode 100644 index 68bba31360..0000000000 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ /dev/null @@ -1,432 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace raft::sparse::neighbors::detail { - -template -struct csr_batcher_t { - csr_batcher_t(value_idx batch_size, - value_idx n_rows, - const value_idx* csr_indptr, - const value_idx* csr_indices, - const value_t* csr_data) - : batch_start_(0), - batch_stop_(0), - batch_rows_(0), - total_rows_(n_rows), - batch_size_(batch_size), - csr_indptr_(csr_indptr), - csr_indices_(csr_indices), - csr_data_(csr_data), - batch_csr_start_offset_(0), - batch_csr_stop_offset_(0) - { - } - - void set_batch(int batch_num) - { - batch_start_ = batch_num * batch_size_; - batch_stop_ = batch_start_ + batch_size_ - 1; // zero-based indexing - - if (batch_stop_ >= total_rows_) batch_stop_ = total_rows_ - 1; // zero-based indexing - - batch_rows_ = (batch_stop_ - batch_start_) + 1; - } - - value_idx get_batch_csr_indptr_nnz(value_idx* batch_indptr, cudaStream_t stream) - { - raft::sparse::op::csr_row_slice_indptr(batch_start_, - batch_stop_, - csr_indptr_, - batch_indptr, - &batch_csr_start_offset_, - &batch_csr_stop_offset_, - stream); - - return batch_csr_stop_offset_ - batch_csr_start_offset_; - } - - void get_batch_csr_indices_data(value_idx* csr_indices, value_t* csr_data, cudaStream_t stream) - { - raft::sparse::op::csr_row_slice_populate(batch_csr_start_offset_, - batch_csr_stop_offset_, - csr_indices_, - csr_data_, - csr_indices, - csr_data, - stream); - } - - value_idx batch_rows() const { return batch_rows_; } - - value_idx batch_start() const { return batch_start_; } - - value_idx batch_stop() const { return batch_stop_; } - - private: - value_idx batch_size_; - value_idx batch_start_; - value_idx batch_stop_; - value_idx batch_rows_; - - value_idx total_rows_; - - const value_idx* csr_indptr_; - const value_idx* csr_indices_; - const value_t* csr_data_; - - value_idx batch_csr_start_offset_; - value_idx batch_csr_stop_offset_; -}; - -template -class sparse_knn_t { - public: - sparse_knn_t(const value_idx* idxIndptr_, - const value_idx* idxIndices_, - const value_t* idxData_, - size_t idxNNZ_, - int n_idx_rows_, - int n_idx_cols_, - const value_idx* queryIndptr_, - const value_idx* queryIndices_, - const value_t* queryData_, - size_t queryNNZ_, - int n_query_rows_, - int n_query_cols_, - value_idx* output_indices_, - value_t* output_dists_, - int k_, - raft::resources const& handle_, - size_t batch_size_index_ = 2 << 14, // approx 1M - size_t batch_size_query_ = 2 << 14, - raft::distance::DistanceType metric_ = raft::distance::DistanceType::L2Expanded, - float metricArg_ = 0) - : idxIndptr(idxIndptr_), - idxIndices(idxIndices_), - idxData(idxData_), - idxNNZ(idxNNZ_), - n_idx_rows(n_idx_rows_), - n_idx_cols(n_idx_cols_), - queryIndptr(queryIndptr_), - queryIndices(queryIndices_), - queryData(queryData_), - queryNNZ(queryNNZ_), - n_query_rows(n_query_rows_), - n_query_cols(n_query_cols_), - output_indices(output_indices_), - output_dists(output_dists_), - k(k_), - handle(handle_), - batch_size_index(batch_size_index_), - batch_size_query(batch_size_query_), - metric(metric_), - metricArg(metricArg_) - { - } - - void run() - { - using namespace raft::sparse; - - int n_batches_query = raft::ceildiv((size_t)n_query_rows, batch_size_query); - csr_batcher_t query_batcher( - batch_size_query, n_query_rows, queryIndptr, queryIndices, queryData); - - size_t rows_processed = 0; - - for (int i = 0; i < n_batches_query; i++) { - /** - * Compute index batch info - */ - query_batcher.set_batch(i); - - /** - * Slice CSR to rows in batch - */ - - rmm::device_uvector query_batch_indptr(query_batcher.batch_rows() + 1, - resource::get_cuda_stream(handle)); - - value_idx n_query_batch_nnz = query_batcher.get_batch_csr_indptr_nnz( - query_batch_indptr.data(), resource::get_cuda_stream(handle)); - - rmm::device_uvector query_batch_indices(n_query_batch_nnz, - resource::get_cuda_stream(handle)); - rmm::device_uvector query_batch_data(n_query_batch_nnz, - resource::get_cuda_stream(handle)); - - query_batcher.get_batch_csr_indices_data( - query_batch_indices.data(), query_batch_data.data(), resource::get_cuda_stream(handle)); - - // A 3-partition temporary merge space to scale the batching. 2 parts for subsequent - // batches and 1 space for the results of the merge, which get copied back to the top - rmm::device_uvector merge_buffer_indices(0, resource::get_cuda_stream(handle)); - rmm::device_uvector merge_buffer_dists(0, resource::get_cuda_stream(handle)); - - value_t* dists_merge_buffer_ptr; - value_idx* indices_merge_buffer_ptr; - - int n_batches_idx = raft::ceildiv((size_t)n_idx_rows, batch_size_index); - csr_batcher_t idx_batcher( - batch_size_index, n_idx_rows, idxIndptr, idxIndices, idxData); - - for (int j = 0; j < n_batches_idx; j++) { - idx_batcher.set_batch(j); - - merge_buffer_indices.resize(query_batcher.batch_rows() * k * 3, - resource::get_cuda_stream(handle)); - merge_buffer_dists.resize(query_batcher.batch_rows() * k * 3, - resource::get_cuda_stream(handle)); - - /** - * Slice CSR to rows in batch - */ - rmm::device_uvector idx_batch_indptr(idx_batcher.batch_rows() + 1, - resource::get_cuda_stream(handle)); - rmm::device_uvector idx_batch_indices(0, resource::get_cuda_stream(handle)); - rmm::device_uvector idx_batch_data(0, resource::get_cuda_stream(handle)); - - value_idx idx_batch_nnz = idx_batcher.get_batch_csr_indptr_nnz( - idx_batch_indptr.data(), resource::get_cuda_stream(handle)); - - idx_batch_indices.resize(idx_batch_nnz, resource::get_cuda_stream(handle)); - idx_batch_data.resize(idx_batch_nnz, resource::get_cuda_stream(handle)); - - idx_batcher.get_batch_csr_indices_data( - idx_batch_indices.data(), idx_batch_data.data(), resource::get_cuda_stream(handle)); - - /** - * Compute distances - */ - uint64_t dense_size = - (uint64_t)idx_batcher.batch_rows() * (uint64_t)query_batcher.batch_rows(); - rmm::device_uvector batch_dists(dense_size, resource::get_cuda_stream(handle)); - - RAFT_CUDA_TRY(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t))); - - compute_distances(idx_batcher, - query_batcher, - idx_batch_nnz, - n_query_batch_nnz, - idx_batch_indptr.data(), - idx_batch_indices.data(), - idx_batch_data.data(), - query_batch_indptr.data(), - query_batch_indices.data(), - query_batch_data.data(), - batch_dists.data()); - - // Build batch indices array - rmm::device_uvector batch_indices(batch_dists.size(), - resource::get_cuda_stream(handle)); - - // populate batch indices array - value_idx batch_rows = query_batcher.batch_rows(), batch_cols = idx_batcher.batch_rows(); - - iota_fill(batch_indices.data(), batch_rows, batch_cols, resource::get_cuda_stream(handle)); - - /** - * Perform k-selection on batch & merge with other k-selections - */ - size_t merge_buffer_offset = batch_rows * k; - dists_merge_buffer_ptr = merge_buffer_dists.data() + merge_buffer_offset; - indices_merge_buffer_ptr = merge_buffer_indices.data() + merge_buffer_offset; - - perform_k_selection(idx_batcher, - query_batcher, - batch_dists.data(), - batch_indices.data(), - dists_merge_buffer_ptr, - indices_merge_buffer_ptr); - - value_t* dists_merge_buffer_tmp_ptr = dists_merge_buffer_ptr; - value_idx* indices_merge_buffer_tmp_ptr = indices_merge_buffer_ptr; - - // Merge results of difference batches if necessary - if (idx_batcher.batch_start() > 0) { - size_t merge_buffer_tmp_out = batch_rows * k * 2; - dists_merge_buffer_tmp_ptr = merge_buffer_dists.data() + merge_buffer_tmp_out; - indices_merge_buffer_tmp_ptr = merge_buffer_indices.data() + merge_buffer_tmp_out; - - merge_batches(idx_batcher, - query_batcher, - merge_buffer_dists.data(), - merge_buffer_indices.data(), - dists_merge_buffer_tmp_ptr, - indices_merge_buffer_tmp_ptr); - } - - // copy merged output back into merge buffer partition for next iteration - raft::copy_async(merge_buffer_indices.data(), - indices_merge_buffer_tmp_ptr, - batch_rows * k, - resource::get_cuda_stream(handle)); - raft::copy_async(merge_buffer_dists.data(), - dists_merge_buffer_tmp_ptr, - batch_rows * k, - resource::get_cuda_stream(handle)); - } - - // Copy final merged batch to output array - raft::copy_async(output_indices + (rows_processed * k), - merge_buffer_indices.data(), - query_batcher.batch_rows() * k, - resource::get_cuda_stream(handle)); - raft::copy_async(output_dists + (rows_processed * k), - merge_buffer_dists.data(), - query_batcher.batch_rows() * k, - resource::get_cuda_stream(handle)); - - rows_processed += query_batcher.batch_rows(); - } - } - - private: - void merge_batches(csr_batcher_t& idx_batcher, - csr_batcher_t& query_batcher, - value_t* merge_buffer_dists, - value_idx* merge_buffer_indices, - value_t* out_dists, - value_idx* out_indices) - { - // build translation buffer to shift resulting indices by the batch - std::vector id_ranges; - id_ranges.push_back(0); - id_ranges.push_back(idx_batcher.batch_start()); - - rmm::device_uvector trans(id_ranges.size(), resource::get_cuda_stream(handle)); - raft::update_device( - trans.data(), id_ranges.data(), id_ranges.size(), resource::get_cuda_stream(handle)); - - // combine merge buffers only if there's more than 1 partition to combine - raft::spatial::knn::knn_merge_parts(merge_buffer_dists, - merge_buffer_indices, - out_dists, - out_indices, - query_batcher.batch_rows(), - 2, - k, - resource::get_cuda_stream(handle), - trans.data()); - } - - void perform_k_selection(csr_batcher_t idx_batcher, - csr_batcher_t query_batcher, - value_t* batch_dists, - value_idx* batch_indices, - value_t* out_dists, - value_idx* out_indices) - { - // populate batch indices array - value_idx batch_rows = query_batcher.batch_rows(), batch_cols = idx_batcher.batch_rows(); - - // build translation buffer to shift resulting indices by the batch - std::vector id_ranges; - id_ranges.push_back(0); - id_ranges.push_back(idx_batcher.batch_start()); - - // in the case where the number of idx rows in the batch is < k, we - // want to adjust k. - value_idx n_neighbors = std::min(static_cast(k), batch_cols); - - bool ascending = raft::distance::is_min_close(metric); - - // kernel to slice first (min) k cols and copy into batched merge buffer - raft::matrix::select_k( - handle, - make_device_matrix_view(batch_dists, batch_rows, batch_cols), - make_device_matrix_view(batch_indices, batch_rows, batch_cols), - make_device_matrix_view(out_dists, batch_rows, n_neighbors), - make_device_matrix_view(out_indices, batch_rows, n_neighbors), - ascending, - true); - } - - void compute_distances(csr_batcher_t& idx_batcher, - csr_batcher_t& query_batcher, - size_t idx_batch_nnz, - size_t query_batch_nnz, - value_idx* idx_batch_indptr, - value_idx* idx_batch_indices, - value_t* idx_batch_data, - value_idx* query_batch_indptr, - value_idx* query_batch_indices, - value_t* query_batch_data, - value_t* batch_dists) - { - /** - * Compute distances - */ - raft::sparse::distance::detail::distances_config_t dist_config(handle); - dist_config.b_nrows = idx_batcher.batch_rows(); - dist_config.b_ncols = n_idx_cols; - dist_config.b_nnz = idx_batch_nnz; - - dist_config.b_indptr = idx_batch_indptr; - dist_config.b_indices = idx_batch_indices; - dist_config.b_data = idx_batch_data; - - dist_config.a_nrows = query_batcher.batch_rows(); - dist_config.a_ncols = n_query_cols; - dist_config.a_nnz = query_batch_nnz; - - dist_config.a_indptr = query_batch_indptr; - dist_config.a_indices = query_batch_indices; - dist_config.a_data = query_batch_data; - - if (raft::sparse::distance::supportedDistance.find(metric) == - raft::sparse::distance::supportedDistance.end()) - THROW("DistanceType not supported: %d", metric); - - raft::sparse::distance::pairwiseDistance(batch_dists, dist_config, metric, metricArg); - } - - const value_idx *idxIndptr, *idxIndices, *queryIndptr, *queryIndices; - value_idx* output_indices; - const value_t *idxData, *queryData; - value_t* output_dists; - - size_t idxNNZ, queryNNZ, batch_size_index, batch_size_query; - - raft::distance::DistanceType metric; - - float metricArg; - - int n_idx_rows, n_idx_cols, n_query_rows, n_query_cols, k; - - raft::resources const& handle; -}; - -}; // namespace raft::sparse::neighbors::detail diff --git a/cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh b/cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh deleted file mode 100644 index 4e46904c83..0000000000 --- a/cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include - -#include -#include - -namespace raft::sparse::neighbors::detail { - -/** - * Fills indices array of pairwise distance array - * @tparam value_idx - * @param indices - * @param m - */ -template -RAFT_KERNEL fill_indices(value_idx* indices, size_t m, size_t nnz) -{ - value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; - if (tid >= nnz) return; - value_idx v = tid / m; - indices[tid] = v; -} - -template -value_idx build_k(value_idx n_samples, int c) -{ - // from "kNN-MST-Agglomerative: A fast & scalable graph-based data clustering - // approach on GPU" - return std::min(n_samples, std::max((value_idx)2, (value_idx)floor(log2(n_samples)) + c)); -} - -template -RAFT_KERNEL conv_indices_kernel(in_t* inds, out_t* out, size_t nnz) -{ - size_t tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= nnz) return; - out_t v = inds[tid]; - out[tid] = v; -} - -template -void conv_indices(in_t* inds, out_t* out, size_t size, cudaStream_t stream) -{ - size_t blocks = ceildiv(size, (size_t)tpb); - conv_indices_kernel<<>>(inds, out, size); -} - -/** - * Constructs a (symmetrized) knn graph edge list from - * dense input vectors. - * - * Note: The resulting KNN graph is not guaranteed to be connected. - * - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[in] X dense matrix of input data samples and observations - * @param[in] m number of data samples (rows) in X - * @param[in] n number of observations (columns) in X - * @param[in] metric distance metric to use when constructing neighborhoods - * @param[out] out output edge list - * @param[out] out output edge list - * @param c - */ -template -void knn_graph(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - raft::sparse::COO& out, - int c = 15) -{ - size_t k = build_k(m, c); - - auto stream = resource::get_cuda_stream(handle); - - size_t nnz = m * k; - - rmm::device_uvector rows(nnz, stream); - rmm::device_uvector indices(nnz, stream); - rmm::device_uvector data(nnz, stream); - - size_t blocks = ceildiv(nnz, (size_t)256); - fill_indices<<>>(rows.data(), k, nnz); - - std::vector inputs; - inputs.push_back(const_cast(X)); - - std::vector sizes; - sizes.push_back(m); - - // This is temporary. Once faiss is updated, we should be able to - // pass value_idx through to knn. - rmm::device_uvector int64_indices(nnz, stream); - - raft::spatial::knn::brute_force_knn(handle, - inputs, - sizes, - n, - const_cast(X), - m, - int64_indices.data(), - data.data(), - k, - true, - true, - nullptr, - metric); - - // convert from current knn's 64-bit to 32-bit. - conv_indices(int64_indices.data(), indices.data(), nnz, stream); - - raft::sparse::linalg::symmetrize( - handle, rows.data(), indices.data(), data.data(), m, k, nnz, out); -} - -}; // namespace raft::sparse::neighbors::detail diff --git a/cpp/include/raft/sparse/neighbors/knn.cuh b/cpp/include/raft/sparse/neighbors/knn.cuh deleted file mode 100644 index 2cf68818aa..0000000000 --- a/cpp/include/raft/sparse/neighbors/knn.cuh +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -/** - * DISCLAIMER: this file is deprecated: use knn.cuh instead - */ - -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the sparse/spatial version instead.") -#endif - -#include -#include - -namespace raft::sparse::neighbors { - -/** - * Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors - * using some distance implementation - * @param[in] idxIndptr csr indptr of the index matrix (size n_idx_rows + 1) - * @param[in] idxIndices csr column indices array of the index matrix (size n_idx_nnz) - * @param[in] idxData csr data array of the index matrix (size idxNNZ) - * @param[in] idxNNZ number of non-zeros for sparse index matrix - * @param[in] n_idx_rows number of data samples in index matrix - * @param[in] n_idx_cols - * @param[in] queryIndptr csr indptr of the query matrix (size n_query_rows + 1) - * @param[in] queryIndices csr indices array of the query matrix (size queryNNZ) - * @param[in] queryData csr data array of the query matrix (size queryNNZ) - * @param[in] queryNNZ number of non-zeros for sparse query matrix - * @param[in] n_query_rows number of data samples in query matrix - * @param[in] n_query_cols number of features in query matrix - * @param[out] output_indices dense matrix for output indices (size n_query_rows * k) - * @param[out] output_dists dense matrix for output distances (size n_query_rows * k) - * @param[in] k the number of neighbors to query - * @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to - * @param[in] batch_size_index maximum number of rows to use from index matrix per batch - * @param[in] batch_size_query maximum number of rows to use from query matrix per batch - * @param[in] metric distance metric/measure to use - * @param[in] metricArg potential argument for metric (currently unused) - */ -template -void brute_force_knn(const value_idx* idxIndptr, - const value_idx* idxIndices, - const value_t* idxData, - size_t idxNNZ, - int n_idx_rows, - int n_idx_cols, - const value_idx* queryIndptr, - const value_idx* queryIndices, - const value_t* queryData, - size_t queryNNZ, - int n_query_rows, - int n_query_cols, - value_idx* output_indices, - value_t* output_dists, - int k, - raft::resources const& handle, - size_t batch_size_index = 2 << 14, // approx 1M - size_t batch_size_query = 2 << 14, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0) -{ - brute_force::knn(idxIndptr, - idxIndices, - idxData, - idxNNZ, - n_idx_rows, - n_idx_cols, - queryIndptr, - queryIndices, - queryData, - queryNNZ, - n_query_rows, - n_query_cols, - output_indices, - output_dists, - k, - handle, - batch_size_index, - batch_size_query, - metric, - metricArg); -} - -}; // namespace raft::sparse::neighbors diff --git a/cpp/include/raft/sparse/neighbors/knn_graph.cuh b/cpp/include/raft/sparse/neighbors/knn_graph.cuh deleted file mode 100644 index 8257afc16f..0000000000 --- a/cpp/include/raft/sparse/neighbors/knn_graph.cuh +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include - -namespace raft::sparse::neighbors { - -/** - * Constructs a (symmetrized) knn graph edge list from - * dense input vectors. - * - * Note: The resulting KNN graph is not guaranteed to be connected. - * - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[in] X dense matrix of input data samples and observations - * @param[in] m number of data samples (rows) in X - * @param[in] n number of observations (columns) in X - * @param[in] metric distance metric to use when constructing neighborhoods - * @param[out] out output edge list - * @param c - */ -template -void knn_graph(raft::resources const& handle, - const value_t* X, - std::size_t m, - std::size_t n, - raft::distance::DistanceType metric, - raft::sparse::COO& out, - int c = 15) -{ - detail::knn_graph(handle, X, m, n, metric, out, c); -} - -}; // namespace raft::sparse::neighbors diff --git a/cpp/include/raft/sparse/neighbors/specializations.cuh b/cpp/include/raft/sparse/neighbors/specializations.cuh deleted file mode 100644 index e85b05575f..0000000000 --- a/cpp/include/raft/sparse/neighbors/specializations.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message( \ - __FILE__ \ - " is deprecated and will be removed." \ - " Including specializations is not necessary any more." \ - " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") -#endif diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 5d504d2100..caf0372d8a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -95,10 +95,6 @@ endfunction() # * distance tests ------------------------------------------------------------------------- if(BUILD_TESTS) - ConfigureTest( - NAME CLUSTER_TEST PATH cluster/kmeans.cu cluster/kmeans_balanced.cu cluster/kmeans_find_k.cu - cluster/cluster_solvers.cu cluster/linkage.cu cluster/spectral.cu LIB EXPLICIT_INSTANTIATE_ONLY - ) ConfigureTest( NAME @@ -139,71 +135,6 @@ if(BUILD_TESTS) NOCUDA ) - ConfigureTest( - NAME - DISTANCE_TEST - PATH - distance/dist_adj.cu - distance/dist_adj_distance_instance.cu - distance/dist_canberra.cu - distance/dist_correlation.cu - distance/dist_cos.cu - distance/dist_dice.cu - distance/dist_hamming.cu - distance/dist_hellinger.cu - distance/dist_inner_product.cu - distance/dist_jensen_shannon.cu - distance/dist_kl_divergence.cu - distance/dist_l1.cu - distance/dist_l2_exp.cu - distance/dist_l2_unexp.cu - distance/dist_l2_sqrt_exp.cu - distance/dist_l_inf.cu - distance/dist_lp_unexp.cu - distance/dist_russell_rao.cu - distance/masked_nn.cu - distance/masked_nn_compress_to_bits.cu - distance/fused_l2_nn.cu - distance/fused_cosine_nn.cu - distance/gram.cu - LIB - EXPLICIT_INSTANTIATE_ONLY - ) - - list( - APPEND - EXT_HEADER_TEST_SOURCES - ext_headers/raft_neighbors_brute_force.cu - ext_headers/raft_distance_distance.cu - ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu - ext_headers/raft_matrix_detail_select_k.cu - ext_headers/raft_neighbors_ball_cover.cu - ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu - ext_headers/raft_distance_fused_l2_nn.cu - ext_headers/raft_neighbors_ivf_pq.cu - ext_headers/raft_neighbors_ivf_flat.cu - ext_headers/raft_core_logger.cpp - ext_headers/raft_neighbors_refine.cu - ext_headers/raft_neighbors_detail_ivf_flat_search.cu - ext_headers/raft_linalg_detail_coalesced_reduction.cu - ext_headers/raft_sparse_matrix_detail_select_k.cu - ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu - ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu - ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu - ) - - # Test that the split headers compile in isolation with: - # - # * EXT_HEADERS_TEST_COMPILED_EXPLICIT: RAFT_COMPILED, RAFT_EXPLICIT_INSTANTIATE_ONLY defined - # * EXT_HEADERS_TEST_COMPILED_IMPLICIT: RAFT_COMPILED defined - # * EXT_HEADERS_TEST_IMPLICIT: no macros defined. - ConfigureTest( - NAME EXT_HEADERS_TEST_COMPILED_EXPLICIT PATH ${EXT_HEADER_TEST_SOURCES} LIB - EXPLICIT_INSTANTIATE_ONLY - ) - ConfigureTest(NAME EXT_HEADERS_TEST_COMPILED_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES} LIB) - ConfigureTest(NAME EXT_HEADERS_TEST_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES}) - ConfigureTest(NAME LABEL_TEST PATH label/label.cu label/merge_labels.cu) ConfigureTest( @@ -292,8 +223,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SOLVERS_TEST PATH cluster/cluster_solvers_deprecated.cu linalg/eigen_solvers.cu lap/lap.cu - sparse/mst.cu LIB EXPLICIT_INSTANTIATE_ONLY + NAME SOLVERS_TEST PATH linalg/eigen_solvers.cu lap/lap.cu sparse/mst.cu LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -321,136 +252,6 @@ if(BUILD_TESTS) sparse/symmetrize.cu ) - ConfigureTest( - NAME SPARSE_DIST_TEST PATH sparse/dist_coo_spmv.cu sparse/distance.cu sparse/gram.cu LIB - EXPLICIT_INSTANTIATE_ONLY - ) - - ConfigureTest( - NAME SPARSE_NEIGHBORS_TEST PATH sparse/neighbors/cross_component_nn.cu - sparse/neighbors/brute_force.cu sparse/neighbors/knn_graph.cu LIB EXPLICIT_INSTANTIATE_ONLY - ) - - ConfigureTest( - NAME - NEIGHBORS_TEST - PATH - neighbors/knn.cu - neighbors/fused_l2_knn.cu - neighbors/tiled_knn.cu - neighbors/haversine.cu - neighbors/ball_cover.cu - neighbors/epsilon_neighborhood.cu - neighbors/refine.cu - LIB - EXPLICIT_INSTANTIATE_ONLY - ) - - ConfigureTest( - NAME NEIGHBORS_ANN_BRUTE_FORCE_TEST PATH neighbors/ann_brute_force/test_float.cu LIB - EXPLICIT_INSTANTIATE_ONLY GPUS 1 PERCENT 100 - ) - - ConfigureTest( - NAME - NEIGHBORS_ANN_CAGRA_TEST - PATH - neighbors/ann_cagra/test_float_uint32_t.cu - neighbors/ann_cagra/test_half_uint32_t.cu - neighbors/ann_cagra/test_int8_t_uint32_t.cu - neighbors/ann_cagra/test_uint8_t_uint32_t.cu - neighbors/ann_cagra/test_float_int64_t.cu - neighbors/ann_cagra/test_half_int64_t.cu - neighbors/ann_cagra_vpq/test_float_int64_t.cu - neighbors/ann_cagra_vpq/test_float_uint32_t.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim128_t8.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim256_t16.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim512_t32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim1024_t32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim128_t8.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim256_t16.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim512_t32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim1024_t32.cu - LIB - EXPLICIT_INSTANTIATE_ONLY - GPUS - 1 - PERCENT - 100 - ) - - ConfigureTest( - NAME - NEIGHBORS_ANN_IVF_TEST - PATH - neighbors/ann_ivf_flat/test_filter_float_int64_t.cu - neighbors/ann_ivf_flat/test_float_int64_t.cu - neighbors/ann_ivf_flat/test_int8_t_int64_t.cu - neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu - neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu - neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu - neighbors/ann_ivf_pq/test_float_uint32_t.cu - neighbors/ann_ivf_pq/test_float_int64_t.cu - neighbors/ann_ivf_pq/test_int8_t_int64_t.cu - neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu - neighbors/ann_ivf_pq/test_filter_float_int64_t.cu - neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu - LIB - EXPLICIT_INSTANTIATE_ONLY - GPUS - 1 - PERCENT - 100 - ) - - ConfigureTest( - NAME - NEIGHBORS_ANN_NN_DESCENT_TEST - PATH - neighbors/ann_nn_descent/test_float_uint32_t.cu - neighbors/ann_nn_descent/test_int8_t_uint32_t.cu - neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu - # TODO: Investigate why this test is failing Reference issue - # https://github.com/rapidsai/raft/issues/2450 - # neighbors/ann_nn_descent/test_batch_float_uint32_t.cu - LIB - EXPLICIT_INSTANTIATE_ONLY - GPUS - 1 - PERCENT - 100 - ) - ConfigureTest( NAME STATS_TEST @@ -471,14 +272,11 @@ if(BUILD_TESTS) stats/mean_center.cu stats/minmax.cu stats/mutual_info_score.cu - stats/neighborhood_recall.cu stats/r2_score.cu stats/rand_index.cu stats/regression_metrics.cu - stats/silhouette_score.cu stats/stddev.cu stats/sum.cu - stats/trustworthiness.cu stats/weighted_mean.cu stats/v_measure.cu LIB diff --git a/cpp/test/cluster/cluster_solvers.cu b/cpp/test/cluster/cluster_solvers.cu deleted file mode 100644 index cc0a381bbf..0000000000 --- a/cpp/test/cluster/cluster_solvers.cu +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include - -#include -#include - -namespace raft { -namespace spectral { - -TEST(Raft, ClusterSolvers) -{ - using namespace matrix; - using index_type = int; - using value_type = double; - - raft::resources h; - - index_type maxiter{100}; - value_type tol{1.0e-10}; - unsigned long long seed{100110021003}; - - auto stream = resource::get_cuda_stream(h); - - index_type n{100}; - index_type d{10}; - index_type k{5}; - - // nullptr expected to trigger exceptions: - // - value_type* eigvecs{nullptr}; - index_type* codes{nullptr}; - - cluster_solver_config_t cfg{k, maxiter, tol, seed}; - - kmeans_solver_t cluster_solver{cfg}; - - EXPECT_ANY_THROW(cluster_solver.solve(h, n, d, eigvecs, codes)); -} - -TEST(Raft, ModularitySolvers) -{ - using namespace matrix; - using index_type = int; - using value_type = double; - - raft::resources h; - ASSERT_EQ(0, resource::get_device_id(h)); - - index_type neigvs{10}; - index_type maxiter{100}; - index_type restart_iter{10}; - value_type tol{1.0e-10}; - bool reorthog{true}; - - // nullptr expected to trigger exceptions: - // - index_type* clusters{nullptr}; - value_type* eigvals{nullptr}; - value_type* eigvecs{nullptr}; - - unsigned long long seed{100110021003}; - - eigen_solver_config_t eig_cfg{ - neigvs, maxiter, restart_iter, tol, reorthog, seed}; - lanczos_solver_t eig_solver{eig_cfg}; - - index_type k{5}; - - cluster_solver_config_t clust_cfg{k, maxiter, tol, seed}; - kmeans_solver_t cluster_solver{clust_cfg}; - - auto stream = resource::get_cuda_stream(h); - sparse_matrix_t sm{h, nullptr, nullptr, nullptr, 0, 0}; - - EXPECT_ANY_THROW(spectral::modularity_maximization( - h, sm, eig_solver, cluster_solver, clusters, eigvals, eigvecs)); - - value_type modularity{0}; - EXPECT_ANY_THROW(spectral::analyzeModularity(h, sm, k, clusters, modularity)); -} - -} // namespace spectral -} // namespace raft diff --git a/cpp/test/cluster/cluster_solvers_deprecated.cu b/cpp/test/cluster/cluster_solvers_deprecated.cu deleted file mode 100644 index 954e3e5bb6..0000000000 --- a/cpp/test/cluster/cluster_solvers_deprecated.cu +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -#include -#include - -namespace raft { -namespace spectral { - -TEST(Raft, ClusterSolvers) -{ - using namespace matrix; - using index_type = int; - using value_type = double; - - raft::resources h; - - index_type maxiter{100}; - value_type tol{1.0e-10}; - unsigned long long seed{100110021003}; - - auto stream = resource::get_cuda_stream(h); - - index_type n{100}; - index_type d{10}; - index_type k{5}; - - // nullptr expected to trigger exceptions: - // - value_type* eigvecs{nullptr}; - index_type* codes{nullptr}; - - cluster_solver_config_deprecated_t cfg{k, maxiter, tol, seed}; - kmeans_solver_deprecated_t cluster_solver{cfg}; - - EXPECT_ANY_THROW(cluster_solver.solve(h, n, d, eigvecs, codes)); -} - -} // namespace spectral -} // namespace raft diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu deleted file mode 100644 index 33c17c527d..0000000000 --- a/cpp/test/cluster/kmeans.cu +++ /dev/null @@ -1,363 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -#include -#include - -namespace raft { - -template -struct KmeansInputs { - int n_row; - int n_col; - int n_clusters; - T tol; - bool weighted; -}; - -template -void run_cluster_cost(const raft::resources& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost) -{ - raft::cluster::kmeans::cluster_cost( - handle, minClusterDistance, workspace, clusterCost, raft::add_op{}); -} - -template -class KmeansTest : public ::testing::TestWithParam> { - protected: - KmeansTest() - : d_labels(0, resource::get_cuda_stream(handle)), - d_labels_ref(0, resource::get_cuda_stream(handle)), - d_centroids(0, resource::get_cuda_stream(handle)), - d_sample_weight(0, resource::get_cuda_stream(handle)) - { - } - - void apiTest() - { - testparams = ::testing::TestWithParam>::GetParam(); - - auto stream = resource::get_cuda_stream(handle); - int n_samples = testparams.n_row; - int n_features = testparams.n_col; - params.n_clusters = testparams.n_clusters; - params.tol = testparams.tol; - params.n_init = 1; - params.rng_state.seed = 1; - params.oversampling_factor = 0; - - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - - auto X = raft::make_device_matrix(handle, n_samples, n_features); - auto labels = raft::make_device_vector(handle, n_samples); - - raft::random::make_blobs(X.data_handle(), - labels.data_handle(), - n_samples, - n_features, - params.n_clusters, - stream, - true, - nullptr, - nullptr, - T(1.0), - false, - (T)-10.0f, - (T)10.0f, - (uint64_t)1234); - d_labels.resize(n_samples, stream); - d_labels_ref.resize(n_samples, stream); - d_centroids.resize(params.n_clusters * n_features, stream); - raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); - rmm::device_uvector d_sample_weight(n_samples, stream); - thrust::fill( - thrust::cuda::par.on(stream), d_sample_weight.data(), d_sample_weight.data() + n_samples, 1); - auto weight_view = - raft::make_device_vector_view(d_sample_weight.data(), n_samples); - - T inertia = 0; - int n_iter = 0; - rmm::device_uvector workspace(0, stream); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - rmm::device_uvector inRankCp(0, stream); - auto X_view = raft::make_const_mdspan(X.view()); - auto centroids_view = - raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); - auto miniX = raft::make_device_matrix(handle, n_samples / 4, n_features); - - // Initialize kmeans on a portion of X - raft::cluster::kmeans::shuffle_and_gather( - handle, - X_view, - raft::make_device_matrix_view(miniX.data_handle(), miniX.extent(0), miniX.extent(1)), - miniX.extent(0), - params.rng_state.seed); - - raft::cluster::kmeans::init_plus_plus( - handle, params, raft::make_const_mdspan(miniX.view()), centroids_view, workspace); - - auto minClusterDistance = raft::make_device_vector(handle, n_samples); - auto minClusterAndDistance = - raft::make_device_vector, int>(handle, n_samples); - auto L2NormX = raft::make_device_vector(handle, n_samples); - auto clusterCostBefore = raft::make_device_scalar(handle, 0); - auto clusterCostAfter = raft::make_device_scalar(handle, 0); - - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - - raft::cluster::kmeans::min_cluster_distance(handle, - X_view, - centroids_view, - minClusterDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - run_cluster_cost(handle, minClusterDistance.view(), workspace, clusterCostBefore.view()); - - // Run a fit of kmeans - raft::cluster::kmeans::fit_main(handle, - params, - X_view, - weight_view, - centroids_view, - raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter), - workspace); - - // Check that the cluster cost decreased - raft::cluster::kmeans::min_cluster_distance(handle, - X_view, - centroids_view, - minClusterDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - run_cluster_cost(handle, minClusterDistance.view(), workspace, clusterCostAfter.view()); - T h_clusterCostBefore = T(0); - T h_clusterCostAfter = T(0); - raft::update_host(&h_clusterCostBefore, clusterCostBefore.data_handle(), 1, stream); - raft::update_host(&h_clusterCostAfter, clusterCostAfter.data_handle(), 1, stream); - ASSERT_TRUE(h_clusterCostAfter < h_clusterCostBefore); - - // Count samples in clusters using 2 methods and compare them - // Fill minClusterAndDistance - raft::cluster::kmeans::min_cluster_and_distance( - handle, - X_view, - raft::make_device_matrix_view( - d_centroids.data(), params.n_clusters, n_features), - minClusterAndDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - raft::cluster::kmeans::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - auto sampleCountInCluster = raft::make_device_vector(handle, params.n_clusters); - auto weigthInCluster = raft::make_device_vector(handle, params.n_clusters); - auto newCentroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - raft::cluster::kmeans::update_centroids(handle, - X_view, - weight_view, - raft::make_device_matrix_view( - d_centroids.data(), params.n_clusters, n_features), - itr, - weigthInCluster.view(), - newCentroids.view()); - raft::cluster::kmeans::count_samples_in_cluster(handle, - params, - X_view, - L2NormX.view(), - newCentroids.view(), - workspace, - sampleCountInCluster.view()); - - ASSERT_TRUE(devArrMatch(sampleCountInCluster.data_handle(), - weigthInCluster.data_handle(), - params.n_clusters, - CompareApprox(params.tol))); - } - - void basicTest() - { - testparams = ::testing::TestWithParam>::GetParam(); - - int n_samples = testparams.n_row; - int n_features = testparams.n_col; - params.n_clusters = testparams.n_clusters; - params.tol = testparams.tol; - params.n_init = 5; - params.rng_state.seed = 1; - params.oversampling_factor = 0; - - auto X = raft::make_device_matrix(handle, n_samples, n_features); - auto labels = raft::make_device_vector(handle, n_samples); - auto stream = resource::get_cuda_stream(handle); - - raft::random::make_blobs(X.data_handle(), - labels.data_handle(), - n_samples, - n_features, - params.n_clusters, - stream, - true, - nullptr, - nullptr, - T(1.0), - false, - (T)-10.0f, - (T)10.0f, - (uint64_t)1234); - - d_labels.resize(n_samples, stream); - d_labels_ref.resize(n_samples, stream); - d_centroids.resize(params.n_clusters * n_features, stream); - - std::optional> d_sw = std::nullopt; - auto d_centroids_view = - raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); - if (testparams.weighted) { - d_sample_weight.resize(n_samples, stream); - d_sw = std::make_optional( - raft::make_device_vector_view(d_sample_weight.data(), n_samples)); - thrust::fill(thrust::cuda::par.on(stream), - d_sample_weight.data(), - d_sample_weight.data() + n_samples, - 1); - } - - raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); - - T inertia = 0; - int n_iter = 0; - auto X_view = raft::make_const_mdspan(X.view()); - - raft::cluster::kmeans_fit_predict( - handle, - params, - X_view, - d_sw, - d_centroids_view, - raft::make_device_vector_view(d_labels.data(), n_samples), - raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); - - resource::sync_stream(handle, stream); - - score = raft::stats::adjusted_rand_index( - d_labels_ref.data(), d_labels.data(), n_samples, resource::get_cuda_stream(handle)); - - if (score < 1.0) { - std::stringstream ss; - ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); - std::cout << (ss.str().c_str()) << '\n'; - ss.str(std::string()); - ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); - std::cout << (ss.str().c_str()) << '\n'; - std::cout << "Score = " << score << '\n'; - } - } - - void SetUp() override - { - basicTest(); - apiTest(); - } - - protected: - raft::resources handle; - KmeansInputs testparams; - rmm::device_uvector d_labels; - rmm::device_uvector d_labels_ref; - rmm::device_uvector d_centroids; - rmm::device_uvector d_sample_weight; - double score; - raft::cluster::KMeansParams params; -}; - -const std::vector> inputsf2 = {{1000, 32, 5, 0.0001f, true}, - {1000, 32, 5, 0.0001f, false}, - {1000, 100, 20, 0.0001f, true}, - {1000, 100, 20, 0.0001f, false}, - {10000, 32, 10, 0.0001f, true}, - {10000, 32, 10, 0.0001f, false}, - {10000, 100, 50, 0.0001f, true}, - {10000, 100, 50, 0.0001f, false}, - {10000, 500, 100, 0.0001f, true}, - {10000, 500, 100, 0.0001f, false}}; - -const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, - {1000, 32, 5, 0.0001, false}, - {1000, 100, 20, 0.0001, true}, - {1000, 100, 20, 0.0001, false}, - {10000, 32, 10, 0.0001, true}, - {10000, 32, 10, 0.0001, false}, - {10000, 100, 50, 0.0001, true}, - {10000, 100, 50, 0.0001, false}, - {10000, 500, 100, 0.0001, true}, - {10000, 500, 100, 0.0001, false}}; - -typedef KmeansTest KmeansTestF; -TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); } - -typedef KmeansTest KmeansTestD; -TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score == 1.0); } - -INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); - -INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); - -} // namespace raft diff --git a/cpp/test/cluster/kmeans_balanced.cu b/cpp/test/cluster/kmeans_balanced.cu deleted file mode 100644 index 5009eaf122..0000000000 --- a/cpp/test/cluster/kmeans_balanced.cu +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -#include -#include - -/* This test takes advantage of the fact that make_blobs generates balanced clusters. - * It doesn't currently test whether the algorithm can make balanced clusters with an imbalanced - * dataset. - */ - -namespace raft { - -template -struct KmeansBalancedInputs { - IdxT n_rows; - IdxT n_cols; - IdxT n_clusters; - raft::cluster::kmeans_balanced_params kb_params; - MathT tol; -}; - -template -::std::ostream& operator<<(::std::ostream& os, const KmeansBalancedInputs& p) -{ - os << "{ " << p.n_rows << ", " << p.n_cols << ", " << p.n_clusters << ", " << p.kb_params.n_iters - << static_cast(p.kb_params.metric) << '}' << std::endl; - return os; -} - -template -class KmeansBalancedTest : public ::testing::TestWithParam> { - protected: - KmeansBalancedTest() - : stream(resource::get_cuda_stream(handle)), - d_labels(0, stream), - d_labels_ref(0, stream), - d_centroids(0, stream) - { - } - - void basicTest() - { - MappingOpT op{}; - - auto p = ::testing::TestWithParam>::GetParam(); - - auto X = raft::make_device_matrix(handle, p.n_rows, p.n_cols); - auto blob_labels = raft::make_device_vector(handle, p.n_rows); - - MathT* blobs_ptr; - rmm::device_uvector blobs(0, stream); - if constexpr (!std::is_same_v) { - blobs.resize(p.n_rows * p.n_cols, stream); - blobs_ptr = blobs.data(); - } else { - blobs_ptr = X.data_handle(); - } - - raft::random::make_blobs(blobs_ptr, - blob_labels.data_handle(), - p.n_rows, - p.n_cols, - p.n_clusters, - stream, - true, - nullptr, - nullptr, - MathT{0.1}, - true, - MathT{-1}, - MathT{1}, - (uint64_t)1234); - - // Convert blobs dataset to DataT if necessary - if constexpr (!std::is_same_v) { - raft::linalg::unaryOp( - X.data_handle(), blobs.data(), p.n_rows * p.n_cols, op.reverse_op, stream); - } - - d_labels.resize(p.n_rows, stream); - d_labels_ref.resize(p.n_rows, stream); - d_centroids.resize(p.n_clusters * p.n_cols, stream); - - raft::linalg::unaryOp( - d_labels_ref.data(), blob_labels.data_handle(), p.n_rows, raft::cast_op(), stream); - - auto X_view = - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); - auto d_centroids_view = - raft::make_device_matrix_view(d_centroids.data(), p.n_clusters, p.n_cols); - auto d_labels_view = raft::make_device_vector_view(d_labels.data(), p.n_rows); - - raft::cluster::kmeans_balanced::fit_predict( - handle, p.kb_params, X_view, d_centroids_view, d_labels_view, op); - - resource::sync_stream(handle, stream); - - score = raft::stats::adjusted_rand_index( - d_labels_ref.data(), d_labels.data(), p.n_rows, resource::get_cuda_stream(handle)); - - if (score < 1.0) { - std::stringstream ss; - ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); - std::cout << (ss.str().c_str()) << '\n'; - ss.str(std::string()); - ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); - std::cout << (ss.str().c_str()) << '\n'; - std::cout << "Score = " << score << '\n'; - } - } - - void SetUp() override { basicTest(); } - - protected: - raft::handle_t handle; - cudaStream_t stream; - rmm::device_uvector d_labels; - rmm::device_uvector d_labels_ref; - rmm::device_uvector d_centroids; - double score; -}; - -template -std::vector> get_kmeans_balanced_inputs() -{ - std::vector> out; - KmeansBalancedInputs p; - p.kb_params.n_iters = 20; - p.kb_params.metric = raft::distance::DistanceType::L2Expanded; - p.tol = MathT{0.0001}; - std::vector> row_cols_k = {{1000, 32, 5}, - {1000, 100, 20}, - {10000, 32, 10}, - {10000, 100, 50}, - {10000, 500, 100}, - {1000000, 128, 10}}; - for (auto& rck : row_cols_k) { - p.n_rows = static_cast(std::get<0>(rck)); - p.n_cols = static_cast(std::get<1>(rck)); - p.n_clusters = static_cast(std::get<2>(rck)); - out.push_back(p); - } - return out; -} - -const auto inputsf_i32 = get_kmeans_balanced_inputs(); -const auto inputsd_i32 = get_kmeans_balanced_inputs(); -const auto inputsf_i64 = get_kmeans_balanced_inputs(); -const auto inputsd_i64 = get_kmeans_balanced_inputs(); - -#define KB_TEST(test_type, test_name, test_inputs) \ - typedef RAFT_DEPAREN(test_type) test_name; \ - TEST_P(test_name, Result) { ASSERT_TRUE(score == 1.0); } \ - INSTANTIATE_TEST_CASE_P(KmeansBalancedTests, test_name, ::testing::ValuesIn(test_inputs)) - -/* - * First set of tests: no conversion - */ - -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestFFU32I32, - inputsf_i32); -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestDDU32I32, - inputsd_i32); -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestFFU32I64, - inputsf_i64); -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestDDU32I64, - inputsd_i64); -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestFFI32I32, - inputsf_i32); -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestFFI32I64, - inputsf_i64); -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestFFI64I32, - inputsf_i32); -KB_TEST((KmeansBalancedTest), - KmeansBalancedTestFFI64I64, - inputsf_i64); - -/* - * Second set of tests: integer dataset with conversion - */ - -template -struct i2f_scaler { - // Note: with a scaling factor of 42, and generating blobs with centers between -1 and 1 with a - // standard deviation of 0.1, it's statistically very unlikely that we'd overflow - const raft::compose_op, raft::cast_op> op{ - raft::div_const_op{42}, raft::cast_op{}}; - const raft::compose_op, raft::mul_const_op> reverse_op{ - raft::cast_op{}, raft::mul_const_op{42}}; - - RAFT_INLINE_FUNCTION auto operator()(const DataT& x) const { return op(x); }; -}; - -KB_TEST((KmeansBalancedTest>), - KmeansBalancedTestFI8U32I32, - inputsf_i32); -KB_TEST((KmeansBalancedTest>), - KmeansBalancedTestDI8U32I32, - inputsd_i32); - -} // namespace raft diff --git a/cpp/test/cluster/kmeans_find_k.cu b/cpp/test/cluster/kmeans_find_k.cu deleted file mode 100644 index 8e05ad3695..0000000000 --- a/cpp/test/cluster/kmeans_find_k.cu +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" - -#include -#include -#include -#include -#include -#include - -#include - -#include -#include - -namespace raft { - -template -struct KmeansFindKInputs { - int n_row; - int n_col; - int n_clusters; - T tol; - bool weighted; -}; - -template -class KmeansFindKTest : public ::testing::TestWithParam> { - protected: - KmeansFindKTest() - : stream(resource::get_cuda_stream(handle)), best_k(raft::make_host_scalar(0)) - { - } - - void basicTest() - { - testparams = ::testing::TestWithParam>::GetParam(); - - int n_samples = testparams.n_row; - int n_features = testparams.n_col; - int n_clusters = testparams.n_clusters; - - auto X = raft::make_device_matrix(handle, n_samples, n_features); - auto labels = raft::make_device_vector(handle, n_samples); - - raft::random::make_blobs(X.data_handle(), - labels.data_handle(), - n_samples, - n_features, - n_clusters, - stream, - true, - nullptr, - nullptr, - T(.001), - false, - (T)-10.0f, - (T)10.0f, - (uint64_t)1234); - - auto inertia = raft::make_host_scalar(0); - auto n_iter = raft::make_host_scalar(0); - - auto X_view = - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); - - raft::cluster::kmeans::find_k( - handle, X_view, best_k.view(), inertia.view(), n_iter.view(), n_clusters); - - resource::sync_stream(handle, stream); - } - - void SetUp() override { basicTest(); } - - protected: - raft::resources handle; - cudaStream_t stream; - KmeansFindKInputs testparams; - raft::host_scalar best_k; -}; - -const std::vector> inputsf2 = {{1000, 32, 8, 0.001f, true}, - {1000, 32, 8, 0.001f, false}, - {1000, 100, 20, 0.001f, true}, - {1000, 100, 20, 0.001f, false}, - {10000, 32, 10, 0.001f, true}, - {10000, 32, 10, 0.001f, false}, - {10000, 100, 50, 0.001f, true}, - {10000, 100, 50, 0.001f, false}, - {10000, 500, 100, 0.001f, true}, - {10000, 500, 100, 0.001f, false}}; - -const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, - {1000, 32, 5, 0.0001, false}, - {1000, 100, 20, 0.0001, true}, - {1000, 100, 20, 0.0001, false}, - {10000, 32, 10, 0.0001, true}, - {10000, 32, 10, 0.0001, false}, - {10000, 100, 50, 0.0001, true}, - {10000, 100, 50, 0.0001, false}, - {10000, 500, 100, 0.0001, true}, - {10000, 500, 100, 0.0001, false}}; - -typedef KmeansFindKTest KmeansFindKTestF; -TEST_P(KmeansFindKTestF, Result) -{ - if (best_k.view()[0] != testparams.n_clusters) { - std::cout << best_k.view()[0] << " " << testparams.n_clusters << std::endl; - } - ASSERT_TRUE(best_k.view()[0] == testparams.n_clusters); -} - -typedef KmeansFindKTest KmeansFindKTestD; -TEST_P(KmeansFindKTestD, Result) -{ - if (best_k.view()[0] != testparams.n_clusters) { - std::cout << best_k.view()[0] << " " << testparams.n_clusters << std::endl; - } - - ASSERT_TRUE(best_k.view()[0] == testparams.n_clusters); -} - -INSTANTIATE_TEST_CASE_P(KmeansFindKTests, KmeansFindKTestF, ::testing::ValuesIn(inputsf2)); - -INSTANTIATE_TEST_CASE_P(KmeansFindKTests, KmeansFindKTestD, ::testing::ValuesIn(inputsd2)); - -} // namespace raft diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu deleted file mode 100644 index ba7ed4254e..0000000000 --- a/cpp/test/cluster/linkage.cu +++ /dev/null @@ -1,674 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// XXX: We allow the instantiation of masked_l2_nn here: -// raft::linkage::FixConnectivitiesRedOp red_op(params.n_row); -// raft::linkage::cross_component_nn( -// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); -// -// TODO: consider adding this to libraft.so or creating an instance in a -// separate translation unit for this test. -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY - -#include "../test_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -namespace raft { - -using namespace std; - -template -struct LinkageInputs { - IdxT n_row; - IdxT n_col; - - std::vector data; - - std::vector expected_labels; - - int n_clusters; - - bool use_knn; - - int c; -}; - -/** - * @brief kernel to calculate the values of a and b - * @param firstClusterArray: the array of classes of type T - * @param secondClusterArray: the array of classes of type T - * @param size: the size of the data points - * @param a: number of pairs of points that both the clusters have classified the same - * @param b: number of pairs of points that both the clusters have classified differently - */ -template -RAFT_KERNEL computeTheNumerator( - const T* firstClusterArray, const T* secondClusterArray, uint64_t size, uint64_t* a, uint64_t* b) -{ - // calculating the indices of pairs of datapoints compared by the current thread - uint64_t j = threadIdx.x + blockIdx.x * blockDim.x; - uint64_t i = threadIdx.y + blockIdx.y * blockDim.y; - - // thread-local variables to count a and b - uint64_t myA = 0, myB = 0; - - if (i < size && j < size && j < i) { - // checking if the pair have been classified the same by both the clusters - if (firstClusterArray[i] == firstClusterArray[j] && - secondClusterArray[i] == secondClusterArray[j]) { - ++myA; - } - - // checking if the pair have been classified differently by both the clusters - else if (firstClusterArray[i] != firstClusterArray[j] && - secondClusterArray[i] != secondClusterArray[j]) { - ++myB; - } - } - - // specialize blockReduce for a 2D block of 1024 threads of type uint64_t - typedef cub::BlockReduce - BlockReduce; - - // Allocate shared memory for blockReduce - __shared__ typename BlockReduce::TempStorage temp_storage; - - // summing up thread-local counts specific to a block - myA = BlockReduce(temp_storage).Sum(myA); - __syncthreads(); - myB = BlockReduce(temp_storage).Sum(myB); - __syncthreads(); - - // executed once per block - if (threadIdx.x == 0 && threadIdx.y == 0) { - raft::myAtomicAdd((unsigned long long int*)a, myA); - raft::myAtomicAdd((unsigned long long int*)b, myB); - } -} - -/** - * @brief Function to calculate RandIndex - * more info on rand index - * @param firstClusterArray: the array of classes of type T - * @param secondClusterArray: the array of classes of type T - * @param size: the size of the data points of type uint64_t - * @param stream: the cudaStream object - */ -template -double compute_rand_index(T* firstClusterArray, - T* secondClusterArray, - uint64_t size, - cudaStream_t stream) -{ - // rand index for size less than 2 is not defined - ASSERT(size >= 2, "Rand Index for size less than 2 not defined!"); - - // allocating and initializing memory for a and b in the GPU - rmm::device_uvector arr_buf(2, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(arr_buf.data(), 0, 2 * sizeof(uint64_t), stream)); - - // kernel configuration - static const int BLOCK_DIM_Y = 16, BLOCK_DIM_X = 16; - dim3 numThreadsPerBlock(BLOCK_DIM_X, BLOCK_DIM_Y); - dim3 numBlocks(raft::ceildiv(size, numThreadsPerBlock.x), - raft::ceildiv(size, numThreadsPerBlock.y)); - - // calling the kernel - computeTheNumerator<<>>( - firstClusterArray, secondClusterArray, size, arr_buf.data(), arr_buf.data() + 1); - - // synchronizing and updating the calculated values of a and b from device to host - uint64_t ab_host[2] = {0}; - raft::update_host(ab_host, arr_buf.data(), 2, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - - // error handling - RAFT_CUDA_TRY(cudaGetLastError()); - - // denominator - uint64_t nChooseTwo = size * (size - 1) / 2; - - // calculating the rand_index - return (double)(((double)(ab_host[0] + ab_host[1])) / (double)nChooseTwo); -} - -template -::std::ostream& operator<<(::std::ostream& os, const LinkageInputs& dims) -{ - return os; -} - -template -class LinkageTest : public ::testing::TestWithParam> { - public: - LinkageTest() - : params(::testing::TestWithParam>::GetParam()), - labels(0, resource::get_cuda_stream(handle)), - labels_ref(0, resource::get_cuda_stream(handle)) - { - } - - protected: - void basicTest() - { - auto stream = resource::get_cuda_stream(handle); - - labels.resize(params.n_row, stream); - labels_ref.resize(params.n_row, stream); - rmm::device_uvector data(params.n_row * params.n_col, stream); - - raft::copy(data.data(), params.data.data(), data.size(), stream); - raft::copy(labels_ref.data(), params.expected_labels.data(), params.n_row, stream); - - rmm::device_uvector out_children(params.n_row * 2, stream); - - auto data_view = raft::make_device_matrix_view( - data.data(), params.n_row, params.n_col); - auto dendrogram_view = - raft::make_device_matrix_view(out_children.data(), params.n_row, 2); - auto labels_view = raft::make_device_vector_view(labels.data(), params.n_row); - - if (params.use_knn) { - raft::cluster::hierarchy:: - single_linkage( - handle, - data_view, - dendrogram_view, - labels_view, - raft::distance::DistanceType::L2SqrtExpanded, - params.n_clusters, - std::make_optional(params.c)); - - } else { - raft::cluster::hierarchy:: - single_linkage( - handle, - data_view, - dendrogram_view, - labels_view, - raft::distance::DistanceType::L2SqrtExpanded, - params.n_clusters, - std::make_optional(params.c)); - } - - resource::sync_stream(handle, stream); - - score = compute_rand_index(labels.data(), labels_ref.data(), params.n_row, stream); - } - - void SetUp() override { basicTest(); } - - protected: - raft::resources handle; - - LinkageInputs params; - rmm::device_uvector labels, labels_ref; - double score; -}; - -const std::vector> linkage_inputsf2 = { - // Test n_clusters == n_points - {10, - 5, - {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, 0.77782677, 0.43772379, - 0.4035871, 0.3282796, 0.47544681, 0.59862974, 0.12319357, 0.06239463, 0.28200272, 0.1345717, - 0.50498218, 0.5113505, 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, - 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, 0.84854131, 0.28890216, - 0.85267903, 0.74703138, 0.83842071, 0.34942792, 0.27864171, 0.70911132, 0.21338564, 0.32035554, - 0.73788331, 0.46926692, 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, - 0.76166195, 0.66613745}, - {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - 10, - true, - -1}, - // // Test outlier points - {9, - 2, - {-1, -50, 3, 4, 5000, 10000, 1, 3, 4, 5, 0.000005, 0.00002, 2000000, 500000, 10, 50, 30, 5}, - {6, 0, 5, 0, 0, 4, 3, 2, 1}, - 7, - true, - -1}, - - // Test n_clusters == (n_points / 2) - {10, - 5, - {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, 0.77782677, 0.43772379, - 0.4035871, 0.3282796, 0.47544681, 0.59862974, 0.12319357, 0.06239463, 0.28200272, 0.1345717, - 0.50498218, 0.5113505, 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, - 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, 0.84854131, 0.28890216, - 0.85267903, 0.74703138, 0.83842071, 0.34942792, 0.27864171, 0.70911132, 0.21338564, 0.32035554, - 0.73788331, 0.46926692, 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, - 0.76166195, 0.66613745}, - {1, 0, 4, 0, 0, 3, 2, 0, 2, 1}, - 5, - true, - -1}, - - // Test n_points == 100 - {100, - 10, - {6.26168372e-01, 9.30437651e-01, 6.02450208e-01, 2.73025296e-01, 9.53050619e-01, 3.32164396e-01, - 6.88942598e-01, 5.79163537e-01, 6.70341547e-01, 2.70140602e-02, 9.30429671e-01, 7.17721157e-01, - 9.89948537e-01, 7.75253347e-01, 1.34491522e-02, 2.48522428e-02, 3.51413378e-01, 7.64405834e-01, - 7.86373507e-01, 7.18748577e-01, 8.66998621e-01, 6.80316582e-01, 2.51288712e-01, 4.91078420e-01, - 3.76246281e-01, 4.86828710e-01, 5.67464772e-01, 5.30734742e-01, 8.99478296e-01, 7.66699088e-01, - 9.49339111e-01, 3.55248484e-01, 9.06046929e-01, 4.48407772e-01, 6.96395305e-01, 2.44277335e-01, - 7.74840000e-01, 5.21046603e-01, 4.66423971e-02, 5.12019638e-02, 8.95019614e-01, 5.28956953e-01, - 4.31536306e-01, 5.83857744e-01, 4.41787364e-01, 4.68656523e-01, 5.73971433e-01, 6.79989654e-01, - 3.19650588e-01, 6.12579596e-01, 6.49126442e-02, 8.39131142e-01, 2.85252117e-01, 5.84848929e-01, - 9.46507115e-01, 8.58440748e-01, 3.61528940e-01, 2.44215959e-01, 3.80101125e-01, 4.57128957e-02, - 8.82216988e-01, 8.31498633e-01, 7.23474381e-01, 7.75788607e-01, 1.40864146e-01, 6.62092382e-01, - 5.13985168e-01, 3.00686418e-01, 8.70109949e-01, 2.43187753e-01, 2.89391938e-01, 2.84214238e-01, - 8.70985521e-01, 8.77491176e-01, 6.72537226e-01, 3.30929686e-01, 1.85934324e-01, 9.16222614e-01, - 6.18239142e-01, 2.64768597e-01, 5.76145451e-01, 8.62961369e-01, 6.84757925e-01, 7.60549082e-01, - 1.27645356e-01, 4.51004673e-01, 3.92292980e-01, 4.63170803e-01, 4.35449330e-02, 2.17583404e-01, - 5.71832605e-02, 2.06763039e-01, 3.70116249e-01, 2.09750028e-01, 6.17283019e-01, 8.62549231e-01, - 9.84156240e-02, 2.66249156e-01, 3.87635103e-01, 2.85591012e-02, 4.24826068e-01, 4.45795088e-01, - 6.86227676e-01, 1.08848960e-01, 5.96731841e-02, 3.71770228e-01, 1.91548833e-01, 6.95136078e-01, - 9.00700636e-01, 8.76363105e-01, 2.67334632e-01, 1.80619709e-01, 7.94060419e-01, 1.42854171e-02, - 1.09372387e-01, 8.74028108e-01, 6.46403232e-01, 4.86588834e-01, 5.93446175e-02, 6.11886291e-01, - 8.83865057e-01, 3.15879821e-01, 2.27043992e-01, 9.76764951e-01, 6.15620336e-01, 9.76199360e-01, - 2.40548962e-01, 3.21795663e-01, 8.75087904e-02, 8.11234663e-01, 6.96070480e-01, 8.12062321e-01, - 1.21958818e-01, 3.44348628e-02, 8.72630414e-01, 3.06162776e-01, 1.76043529e-02, 9.45894971e-01, - 5.33896401e-01, 6.21642973e-01, 4.93062535e-01, 4.48984262e-01, 2.24560379e-01, 4.24052195e-02, - 4.43447610e-01, 8.95646149e-01, 6.05220676e-01, 1.81840491e-01, 9.70831206e-01, 2.12563586e-02, - 6.92582693e-01, 7.55946922e-01, 7.95086143e-01, 6.05328941e-01, 3.99350764e-01, 4.32846636e-01, - 9.81114529e-01, 4.98266428e-01, 6.37127930e-03, 1.59085889e-01, 6.34682067e-05, 5.59429440e-01, - 7.38827633e-01, 8.93214770e-01, 2.16494306e-01, 9.35430573e-02, 4.75665868e-02, 7.80503518e-01, - 7.86240041e-01, 7.06854594e-01, 2.13725879e-02, 7.68246091e-01, 4.50234808e-01, 5.21231104e-01, - 5.01989826e-03, 4.22081572e-02, 1.65337732e-01, 8.54134740e-01, 4.99430262e-01, 8.94525601e-01, - 1.14028379e-01, 3.69739861e-01, 1.32955599e-01, 2.65563824e-01, 2.52811151e-01, 1.44792843e-01, - 6.88449594e-01, 4.44921417e-01, 8.23296587e-01, 1.93266317e-01, 1.19033309e-01, 1.36368966e-01, - 3.42600285e-01, 5.64505195e-01, 5.57594559e-01, 7.44257892e-01, 8.38231569e-02, 4.11548847e-01, - 3.21010077e-01, 8.55081359e-01, 4.30105779e-01, 1.16229135e-01, 9.87731964e-02, 3.14712335e-01, - 4.50880592e-01, 2.72289598e-01, 6.31615256e-01, 8.97432958e-01, 4.44764250e-01, 8.03776440e-01, - 2.68767748e-02, 2.43374608e-01, 4.02141103e-01, 4.98881209e-01, 5.33173003e-01, 8.82890436e-01, - 7.16149148e-01, 4.19664401e-01, 2.29335357e-01, 2.88637806e-01, 3.44696803e-01, 6.78171906e-01, - 5.69849716e-01, 5.86454477e-01, 3.54474989e-01, 9.03876540e-01, 6.45980000e-01, 6.34887593e-01, - 7.88039746e-02, 2.04814126e-01, 7.82251754e-01, 2.43147074e-01, 7.50951808e-01, 1.72799092e-02, - 2.95349590e-01, 6.57991826e-01, 8.81214312e-01, 5.73970708e-01, 2.77610881e-01, 1.82155097e-01, - 7.69797417e-02, 6.44792402e-01, 9.46950998e-01, 7.73064845e-01, 6.04733624e-01, 5.80094567e-01, - 1.67498426e-01, 2.66514296e-01, 6.50140368e-01, 1.91170299e-01, 2.08752199e-01, 3.01664091e-01, - 9.85033484e-01, 2.92909152e-01, 8.65816607e-01, 1.85222119e-01, 2.28814559e-01, 1.34286382e-02, - 2.89234322e-01, 8.18668708e-01, 4.71706924e-01, 9.23199803e-01, 2.80879188e-01, 1.47319284e-01, - 4.13915748e-01, 9.31274932e-02, 6.66322195e-01, 9.66953974e-01, 3.19405786e-01, 6.69486551e-01, - 5.03096313e-02, 6.95225201e-01, 5.78469859e-01, 6.29481655e-01, 1.39252534e-01, 1.22564968e-01, - 6.80663678e-01, 6.34607157e-01, 6.42765834e-01, 1.57127410e-02, 2.92132086e-01, 5.24423878e-01, - 4.68676824e-01, 2.86003928e-01, 7.18608322e-01, 8.95617933e-01, 5.48844309e-01, 1.74517278e-01, - 5.24379196e-01, 2.13526524e-01, 5.88375435e-01, 9.88560185e-01, 4.17435771e-01, 6.14438688e-01, - 9.53760881e-01, 5.27151288e-01, 7.03017278e-01, 3.44448559e-01, 4.47059676e-01, 2.83414901e-01, - 1.98979011e-01, 4.24917361e-01, 5.73172761e-01, 2.32398853e-02, 1.65887230e-01, 4.05552785e-01, - 9.29665524e-01, 2.26135696e-01, 9.20563384e-01, 7.65259963e-01, 4.54820075e-01, 8.97710267e-01, - 3.78559302e-03, 9.15219382e-01, 3.55705698e-01, 6.94905124e-01, 8.58540202e-01, 3.89790666e-01, - 2.49478206e-01, 7.93679304e-01, 4.75830027e-01, 4.40425353e-01, 3.70579459e-01, 1.40578049e-01, - 1.70386675e-01, 7.04056121e-01, 4.85963102e-01, 9.68450060e-01, 6.77178001e-01, 2.65934654e-01, - 2.58915007e-01, 6.70052890e-01, 2.61945109e-01, 8.46207759e-01, 1.01928951e-01, 2.85611334e-01, - 2.45776933e-01, 2.66658783e-01, 3.71724077e-01, 4.34319025e-01, 4.24407347e-01, 7.15417683e-01, - 8.07997684e-01, 1.64296275e-01, 6.01638065e-01, 8.60606804e-02, 2.68719187e-01, 5.11764101e-01, - 9.75844338e-01, 7.81226782e-01, 2.20925515e-01, 7.18135040e-01, 9.82395577e-01, 8.39160243e-01, - 9.08058083e-01, 6.88010677e-01, 8.14271847e-01, 5.12460821e-01, 1.17311345e-01, 5.96075228e-01, - 9.17455497e-01, 2.12052706e-01, 7.04074603e-01, 8.72872565e-02, 8.76047818e-01, 6.96235046e-01, - 8.54801557e-01, 2.49729159e-01, 9.76594604e-01, 2.87386363e-01, 2.36461559e-02, 9.94075254e-01, - 4.25193986e-01, 7.61869994e-01, 5.13334255e-01, 6.44711165e-02, 8.92156689e-01, 3.55235167e-01, - 1.08154647e-01, 8.78446825e-01, 2.43833016e-01, 9.23071293e-01, 2.72724115e-01, 9.46631338e-01, - 3.74510294e-01, 4.08451278e-02, 9.78392777e-01, 3.65079221e-01, 6.37199516e-01, 5.51144906e-01, - 5.25978080e-01, 1.42803678e-01, 4.05451674e-01, 7.79788219e-01, 6.26009784e-01, 3.35249497e-01, - 1.43159543e-02, 1.80363779e-01, 5.05096904e-01, 2.82619947e-01, 5.83561392e-01, 3.10951324e-01, - 8.73223968e-01, 4.38545619e-01, 4.81348800e-01, 6.68497085e-01, 3.79345401e-01, 9.58832501e-01, - 1.89869550e-01, 2.34083070e-01, 2.94066207e-01, 5.74892667e-02, 6.92106828e-02, 9.61127686e-02, - 6.72650672e-02, 8.47345378e-01, 2.80916761e-01, 7.32177357e-03, 9.80785961e-01, 5.73192225e-02, - 8.48781331e-01, 8.83225408e-01, 7.34398275e-01, 7.70381941e-01, 6.20778343e-01, 8.96822048e-01, - 5.40732486e-01, 3.69704071e-01, 5.77305837e-01, 2.08221827e-01, 7.34275341e-01, 1.06110900e-01, - 3.49496706e-01, 8.34948910e-01, 1.56403291e-02, 6.78576376e-01, 8.96141268e-01, 5.94835119e-01, - 1.43943153e-01, 3.49618530e-01, 2.10440392e-01, 3.46585620e-01, 1.05153093e-01, 3.45446174e-01, - 2.72177079e-01, 7.07946300e-01, 4.33717726e-02, 3.31232203e-01, 3.91874320e-01, 4.76338141e-01, - 6.22777789e-01, 2.95989228e-02, 4.32855769e-01, 7.61049310e-01, 3.63279149e-01, 9.47210350e-01, - 6.43721247e-01, 6.58025802e-01, 1.05247633e-02, 5.29974442e-01, 7.30675767e-01, 4.30041079e-01, - 6.62634841e-01, 8.25936616e-01, 9.91253704e-01, 6.79399281e-01, 5.44177006e-01, 7.52876048e-01, - 3.32139049e-01, 7.98732398e-01, 7.38865223e-01, 9.16055132e-01, 6.11736493e-01, 9.63672879e-01, - 1.83778839e-01, 7.27558919e-02, 5.91602822e-01, 3.25235484e-01, 2.34741217e-01, 9.52346277e-01, - 9.18556407e-01, 9.35373324e-01, 6.89209070e-01, 2.56049054e-01, 6.17975395e-01, 7.82285691e-01, - 9.84983432e-01, 6.62322741e-01, 2.04144457e-01, 3.98446577e-01, 1.38918297e-01, 3.05919921e-01, - 3.14043787e-01, 5.91072666e-01, 7.44703771e-01, 8.92272567e-01, 9.78017873e-01, 9.01203161e-01, - 1.41526372e-01, 4.14878484e-01, 6.80683651e-01, 5.01733152e-02, 8.14635389e-01, 2.27926375e-01, - 9.03269815e-01, 8.68443745e-01, 9.86939190e-01, 7.40779486e-01, 2.61005311e-01, 3.19276232e-01, - 9.69509248e-01, 1.11908818e-01, 4.49198556e-01, 1.27056715e-01, 3.84064823e-01, 5.14591811e-01, - 2.10747488e-01, 9.53884090e-01, 8.43167950e-01, 4.51187972e-01, 3.75331782e-01, 6.23566461e-01, - 3.55290379e-01, 2.95705968e-01, 1.69622690e-01, 1.42981830e-01, 2.72180991e-01, 9.46468040e-01, - 3.70932500e-01, 9.94292830e-01, 4.62587505e-01, 7.14817405e-01, 2.45370540e-02, 3.00906377e-01, - 5.75768304e-01, 9.71448393e-01, 6.95574827e-02, 3.93693854e-01, 5.29306116e-01, 5.04694554e-01, - 6.73797120e-02, 6.76596969e-01, 5.50948898e-01, 3.24909641e-01, 7.70337719e-01, 6.51842631e-03, - 3.03264879e-01, 7.61037886e-03, 2.72289601e-01, 1.50502041e-01, 6.71103888e-02, 7.41503703e-01, - 1.92088941e-01, 2.19043977e-01, 9.09320161e-01, 2.37993569e-01, 6.18107973e-02, 8.31447852e-01, - 2.23355609e-01, 1.84789435e-01, 4.16104518e-01, 4.21573859e-01, 8.72446305e-02, 2.97294197e-01, - 4.50328256e-01, 8.72199917e-01, 2.51279916e-01, 4.86219272e-01, 7.57071329e-01, 4.85655942e-01, - 1.06187277e-01, 4.92341327e-01, 1.46017513e-01, 5.25421017e-01, 4.22637906e-01, 2.24685018e-01, - 8.72648431e-01, 5.54051490e-01, 1.80745062e-01, 2.12756336e-01, 5.20883169e-01, 7.60363654e-01, - 8.30254678e-01, 5.00003328e-01, 4.69017439e-01, 6.38105527e-01, 3.50638261e-02, 5.22217353e-02, - 9.06516882e-02, 8.52975842e-01, 1.19985883e-01, 3.74926753e-01, 6.50302066e-01, 1.98875727e-01, - 6.28362507e-02, 4.32693501e-01, 3.10500685e-01, 6.20732833e-01, 4.58503272e-01, 3.20790034e-01, - 7.91284868e-01, 7.93054570e-01, 2.93406765e-01, 8.95399023e-01, 1.06441034e-01, 7.53085241e-02, - 8.67523104e-01, 1.47963482e-01, 1.25584706e-01, 3.81545040e-02, 6.34338619e-01, 1.76368938e-02, - 5.75553531e-02, 5.31607516e-01, 2.63869588e-01, 9.41945823e-01, 9.24028838e-02, 5.21496463e-01, - 7.74866558e-01, 5.65210610e-01, 7.28015327e-02, 6.51963790e-01, 8.94727453e-01, 4.49571590e-01, - 1.29932405e-01, 8.64026259e-01, 9.92599934e-01, 7.43721560e-01, 8.87300215e-01, 1.06369925e-01, - 8.11335531e-01, 7.87734900e-01, 9.87344678e-01, 5.32502820e-01, 4.42612382e-01, 9.64041183e-01, - 1.66085871e-01, 1.12937664e-01, 5.24423470e-01, 6.54689333e-01, 4.59119726e-01, 5.22774091e-01, - 3.08722276e-02, 6.26979315e-01, 4.49754105e-01, 8.07495757e-01, 2.34199499e-01, 1.67765675e-01, - 9.22168418e-01, 3.73210378e-01, 8.04432575e-01, 5.61890354e-01, 4.47025593e-01, 6.43155678e-01, - 2.40407640e-01, 5.91631279e-01, 1.59369206e-01, 7.75799090e-01, 8.32067212e-01, 5.59791576e-02, - 6.39105224e-01, 4.85274738e-01, 2.12630838e-01, 2.81431312e-02, 7.16205363e-01, 6.83885011e-01, - 5.23869697e-01, 9.99418314e-01, 8.35331599e-01, 4.69877463e-02, 6.74712562e-01, 7.99273684e-01, - 2.77001890e-02, 5.75809742e-01, 2.78513031e-01, 8.36209905e-01, 7.25472379e-01, 4.87173943e-01, - 7.88311357e-01, 9.64676177e-01, 1.75752651e-01, 4.98112580e-01, 8.08850418e-02, 6.40981131e-01, - 4.06647450e-01, 8.46539387e-01, 2.12620694e-01, 9.11012851e-01, 8.25041445e-01, 8.90065575e-01, - 9.63626055e-01, 5.96689242e-01, 1.63372670e-01, 4.51640148e-01, 3.43026542e-01, 5.80658851e-01, - 2.82327625e-01, 4.75535418e-01, 6.27760926e-01, 8.46314115e-01, 9.61961932e-01, 3.19806094e-01, - 5.05508062e-01, 5.28102944e-01, 6.13045057e-01, 7.44714938e-01, 1.50586073e-01, 7.91878033e-01, - 4.89839179e-01, 3.10496849e-01, 8.82309038e-01, 2.86922314e-01, 4.84687559e-01, 5.20838630e-01, - 4.62955493e-01, 2.38185305e-01, 5.47259907e-02, 7.10916137e-01, 7.31887202e-01, 6.25602317e-01, - 8.77741168e-01, 4.19881322e-01, 4.81222328e-01, 1.28224501e-01, 2.46034010e-01, 3.34971854e-01, - 7.37216484e-01, 5.62134821e-02, 7.14089724e-01, 9.85549393e-01, 4.66295827e-01, 3.08722434e-03, - 4.70237690e-01, 2.66524167e-01, 7.93875484e-01, 4.54795911e-02, 8.09702944e-01, 1.47709735e-02, - 1.70082405e-01, 6.35905179e-01, 3.75379109e-01, 4.30315011e-01, 3.15788760e-01, 5.58065230e-01, - 2.24643800e-01, 2.42142981e-01, 6.57283636e-01, 3.34921891e-01, 1.26588975e-01, 7.68064155e-01, - 9.43856291e-01, 4.47518596e-01, 5.44453573e-01, 9.95764932e-01, 7.16444391e-01, 8.51019765e-01, - 1.01179183e-01, 4.45473958e-01, 4.60327322e-01, 4.96895844e-02, 4.72907738e-01, 5.58987444e-01, - 3.41027487e-01, 1.56175026e-01, 7.58283148e-01, 6.83600909e-01, 2.14623396e-01, 3.27348880e-01, - 3.92517893e-01, 6.70418431e-01, 5.16440832e-01, 8.63140348e-01, 5.73277464e-01, 3.46608058e-01, - 7.39396341e-01, 7.20852434e-01, 2.35653246e-02, 3.89935659e-01, 7.53783745e-01, 6.34563528e-01, - 8.79339335e-01, 7.41599159e-02, 5.62433904e-01, 6.15553852e-01, 4.56956324e-01, 5.20047447e-01, - 5.26845015e-02, 5.58471266e-01, 1.63632233e-01, 5.38936665e-02, 6.49593683e-01, 2.56838748e-01, - 8.99035326e-01, 7.20847756e-01, 5.68954684e-01, 7.43684755e-01, 5.70924238e-01, 3.82318724e-01, - 4.89328290e-01, 5.62208561e-01, 4.97540804e-02, 4.18011085e-01, 6.88041565e-01, 2.16234653e-01, - 7.89548214e-01, 8.46136387e-01, 8.46816189e-01, 1.73842353e-01, 6.11627842e-02, 8.44440559e-01, - 4.50646654e-01, 3.74785037e-01, 4.87196697e-01, 4.56276448e-01, 9.13284391e-01, 4.15715464e-01, - 7.13597697e-01, 1.23641270e-02, 5.10031271e-01, 4.74601930e-02, 2.55731159e-01, 3.22090006e-01, - 1.91165703e-01, 4.51170940e-01, 7.50843157e-01, 4.42420576e-01, 4.25380660e-01, 4.50667257e-01, - 6.55689206e-01, 9.68257670e-02, 1.96528793e-01, 8.97343028e-01, 4.99940904e-01, 6.65504083e-01, - 9.41828079e-01, 4.54397338e-01, 5.61893331e-01, 5.09839880e-01, 4.53117514e-01, 8.96804127e-02, - 1.74888861e-01, 6.65641378e-01, 2.81668336e-01, 1.89532742e-01, 5.61668382e-01, 8.68330157e-02, - 8.25092797e-01, 5.18106324e-01, 1.71904024e-01, 3.68385523e-01, 1.62005436e-01, 7.48507399e-01, - 9.30274827e-01, 2.38198517e-01, 9.52222901e-01, 5.23587800e-01, 6.94384557e-01, 1.09338652e-01, - 4.83356794e-01, 2.73050402e-01, 3.68027050e-01, 5.92366466e-01, 1.83192289e-01, 8.60376029e-01, - 7.13926203e-01, 8.16750052e-01, 1.57890291e-01, 6.25691951e-01, 5.24831646e-01, 1.73873797e-01, - 1.02429784e-01, 9.17488471e-01, 4.03584434e-01, 9.31170884e-01, 2.79386137e-01, 8.77745206e-01, - 2.45200576e-01, 1.28896951e-01, 3.15713052e-01, 5.27874291e-01, 2.16444335e-01, 7.03883817e-01, - 7.74738919e-02, 8.42422142e-01, 3.75598924e-01, 3.51002411e-01, 6.22752776e-01, 4.82407943e-01, - 7.43107867e-01, 9.46182666e-01, 9.44344819e-01, 3.28124763e-01, 1.06147431e-01, 1.65102684e-01, - 3.84060507e-01, 2.91057722e-01, 7.68173662e-02, 1.03543651e-01, 6.76698940e-01, 1.43141994e-01, - 7.21342202e-01, 6.69471294e-03, 9.07298311e-01, 5.57080171e-01, 8.10954489e-01, 4.11120526e-01, - 2.06407453e-01, 2.59590556e-01, 7.58512718e-01, 5.79873897e-01, 2.92875650e-01, 2.83686529e-01, - 2.42829343e-01, 9.19323719e-01, 3.46832864e-01, 3.58238858e-01, 7.42827585e-01, 2.05760059e-01, - 9.58438860e-01, 5.66326411e-01, 6.60292846e-01, 5.61095078e-02, 6.79465531e-01, 7.05118513e-01, - 4.44713264e-01, 2.09732933e-01, 5.22732436e-01, 1.74396512e-01, 5.29356748e-01, 4.38475687e-01, - 4.94036404e-01, 4.09785794e-01, 6.40025507e-01, 5.79371821e-01, 1.57726118e-01, 6.04572263e-01, - 5.41072639e-01, 5.18847173e-01, 1.97093284e-01, 8.91767002e-01, 4.29050835e-01, 8.25490570e-01, - 3.87699807e-01, 4.50705808e-01, 2.49371643e-01, 3.36074898e-01, 9.29925118e-01, 6.65393649e-01, - 9.07275994e-01, 3.73075859e-01, 4.14044139e-03, 2.37463702e-01, 2.25893784e-01, 2.46900245e-01, - 4.50350196e-01, 3.48618117e-01, 5.07193932e-01, 5.23435142e-01, 8.13611417e-01, 8.92715622e-01, - 1.02623450e-01, 3.06088345e-01, 7.80461650e-01, 2.21453645e-01, 2.01419652e-01, 2.84254457e-01, - 3.68286735e-01, 7.39358243e-01, 8.97879394e-01, 9.81599566e-01, 7.56526442e-01, 7.37645545e-01, - 4.23976657e-02, 8.25922012e-01, 2.60956996e-01, 2.90702065e-01, 8.98388344e-01, 3.03733299e-01, - 8.49071471e-01, 3.45835425e-01, 7.65458276e-01, 5.68094872e-01, 8.93770930e-01, 9.93161641e-01, - 5.63368667e-02, 4.26548945e-01, 5.46745780e-01, 5.75674571e-01, 7.94599487e-01, 7.18935553e-02, - 4.46492976e-01, 6.40240123e-01, 2.73246969e-01, 2.00465968e-01, 1.30718835e-01, 1.92492005e-01, - 1.96617189e-01, 6.61271644e-01, 8.12687657e-01, 8.66342445e-01 - - }, - {0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 4, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - 10, - true, - -4}, - {10, - 5, - {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, 0.77782677, 0.43772379, - 0.4035871, 0.3282796, 0.47544681, 0.59862974, 0.12319357, 0.06239463, 0.28200272, 0.1345717, - 0.50498218, 0.5113505, 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, - 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, 0.84854131, 0.28890216, - 0.85267903, 0.74703138, 0.83842071, 0.34942792, 0.27864171, 0.70911132, 0.21338564, 0.32035554, - 0.73788331, 0.46926692, 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, - 0.76166195, 0.66613745}, - {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, - 10, - false, - 5}, - // Test outlier points - {9, - 2, - {-1, -50, 3, 4, 5000, 10000, 1, 3, 4, 5, 0.000005, 0.00002, 2000000, 500000, 10, 50, 30, 5}, - {6, 0, 5, 0, 0, 4, 3, 2, 1}, - 7, - false, - 5}, - - // Test n_clusters == (n_points / 2) - {10, - 5, - {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, 0.77782677, 0.43772379, - 0.4035871, 0.3282796, 0.47544681, 0.59862974, 0.12319357, 0.06239463, 0.28200272, 0.1345717, - 0.50498218, 0.5113505, 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, - 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, 0.84854131, 0.28890216, - 0.85267903, 0.74703138, 0.83842071, 0.34942792, 0.27864171, 0.70911132, 0.21338564, 0.32035554, - 0.73788331, 0.46926692, 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, - 0.76166195, 0.66613745}, - {1, 0, 4, 0, 0, 3, 2, 0, 2, 1}, - 5, - false, - 10}, - - // Test n_points == 100 - {100, - 10, - {6.26168372e-01, 9.30437651e-01, 6.02450208e-01, 2.73025296e-01, 9.53050619e-01, 3.32164396e-01, - 6.88942598e-01, 5.79163537e-01, 6.70341547e-01, 2.70140602e-02, 9.30429671e-01, 7.17721157e-01, - 9.89948537e-01, 7.75253347e-01, 1.34491522e-02, 2.48522428e-02, 3.51413378e-01, 7.64405834e-01, - 7.86373507e-01, 7.18748577e-01, 8.66998621e-01, 6.80316582e-01, 2.51288712e-01, 4.91078420e-01, - 3.76246281e-01, 4.86828710e-01, 5.67464772e-01, 5.30734742e-01, 8.99478296e-01, 7.66699088e-01, - 9.49339111e-01, 3.55248484e-01, 9.06046929e-01, 4.48407772e-01, 6.96395305e-01, 2.44277335e-01, - 7.74840000e-01, 5.21046603e-01, 4.66423971e-02, 5.12019638e-02, 8.95019614e-01, 5.28956953e-01, - 4.31536306e-01, 5.83857744e-01, 4.41787364e-01, 4.68656523e-01, 5.73971433e-01, 6.79989654e-01, - 3.19650588e-01, 6.12579596e-01, 6.49126442e-02, 8.39131142e-01, 2.85252117e-01, 5.84848929e-01, - 9.46507115e-01, 8.58440748e-01, 3.61528940e-01, 2.44215959e-01, 3.80101125e-01, 4.57128957e-02, - 8.82216988e-01, 8.31498633e-01, 7.23474381e-01, 7.75788607e-01, 1.40864146e-01, 6.62092382e-01, - 5.13985168e-01, 3.00686418e-01, 8.70109949e-01, 2.43187753e-01, 2.89391938e-01, 2.84214238e-01, - 8.70985521e-01, 8.77491176e-01, 6.72537226e-01, 3.30929686e-01, 1.85934324e-01, 9.16222614e-01, - 6.18239142e-01, 2.64768597e-01, 5.76145451e-01, 8.62961369e-01, 6.84757925e-01, 7.60549082e-01, - 1.27645356e-01, 4.51004673e-01, 3.92292980e-01, 4.63170803e-01, 4.35449330e-02, 2.17583404e-01, - 5.71832605e-02, 2.06763039e-01, 3.70116249e-01, 2.09750028e-01, 6.17283019e-01, 8.62549231e-01, - 9.84156240e-02, 2.66249156e-01, 3.87635103e-01, 2.85591012e-02, 4.24826068e-01, 4.45795088e-01, - 6.86227676e-01, 1.08848960e-01, 5.96731841e-02, 3.71770228e-01, 1.91548833e-01, 6.95136078e-01, - 9.00700636e-01, 8.76363105e-01, 2.67334632e-01, 1.80619709e-01, 7.94060419e-01, 1.42854171e-02, - 1.09372387e-01, 8.74028108e-01, 6.46403232e-01, 4.86588834e-01, 5.93446175e-02, 6.11886291e-01, - 8.83865057e-01, 3.15879821e-01, 2.27043992e-01, 9.76764951e-01, 6.15620336e-01, 9.76199360e-01, - 2.40548962e-01, 3.21795663e-01, 8.75087904e-02, 8.11234663e-01, 6.96070480e-01, 8.12062321e-01, - 1.21958818e-01, 3.44348628e-02, 8.72630414e-01, 3.06162776e-01, 1.76043529e-02, 9.45894971e-01, - 5.33896401e-01, 6.21642973e-01, 4.93062535e-01, 4.48984262e-01, 2.24560379e-01, 4.24052195e-02, - 4.43447610e-01, 8.95646149e-01, 6.05220676e-01, 1.81840491e-01, 9.70831206e-01, 2.12563586e-02, - 6.92582693e-01, 7.55946922e-01, 7.95086143e-01, 6.05328941e-01, 3.99350764e-01, 4.32846636e-01, - 9.81114529e-01, 4.98266428e-01, 6.37127930e-03, 1.59085889e-01, 6.34682067e-05, 5.59429440e-01, - 7.38827633e-01, 8.93214770e-01, 2.16494306e-01, 9.35430573e-02, 4.75665868e-02, 7.80503518e-01, - 7.86240041e-01, 7.06854594e-01, 2.13725879e-02, 7.68246091e-01, 4.50234808e-01, 5.21231104e-01, - 5.01989826e-03, 4.22081572e-02, 1.65337732e-01, 8.54134740e-01, 4.99430262e-01, 8.94525601e-01, - 1.14028379e-01, 3.69739861e-01, 1.32955599e-01, 2.65563824e-01, 2.52811151e-01, 1.44792843e-01, - 6.88449594e-01, 4.44921417e-01, 8.23296587e-01, 1.93266317e-01, 1.19033309e-01, 1.36368966e-01, - 3.42600285e-01, 5.64505195e-01, 5.57594559e-01, 7.44257892e-01, 8.38231569e-02, 4.11548847e-01, - 3.21010077e-01, 8.55081359e-01, 4.30105779e-01, 1.16229135e-01, 9.87731964e-02, 3.14712335e-01, - 4.50880592e-01, 2.72289598e-01, 6.31615256e-01, 8.97432958e-01, 4.44764250e-01, 8.03776440e-01, - 2.68767748e-02, 2.43374608e-01, 4.02141103e-01, 4.98881209e-01, 5.33173003e-01, 8.82890436e-01, - 7.16149148e-01, 4.19664401e-01, 2.29335357e-01, 2.88637806e-01, 3.44696803e-01, 6.78171906e-01, - 5.69849716e-01, 5.86454477e-01, 3.54474989e-01, 9.03876540e-01, 6.45980000e-01, 6.34887593e-01, - 7.88039746e-02, 2.04814126e-01, 7.82251754e-01, 2.43147074e-01, 7.50951808e-01, 1.72799092e-02, - 2.95349590e-01, 6.57991826e-01, 8.81214312e-01, 5.73970708e-01, 2.77610881e-01, 1.82155097e-01, - 7.69797417e-02, 6.44792402e-01, 9.46950998e-01, 7.73064845e-01, 6.04733624e-01, 5.80094567e-01, - 1.67498426e-01, 2.66514296e-01, 6.50140368e-01, 1.91170299e-01, 2.08752199e-01, 3.01664091e-01, - 9.85033484e-01, 2.92909152e-01, 8.65816607e-01, 1.85222119e-01, 2.28814559e-01, 1.34286382e-02, - 2.89234322e-01, 8.18668708e-01, 4.71706924e-01, 9.23199803e-01, 2.80879188e-01, 1.47319284e-01, - 4.13915748e-01, 9.31274932e-02, 6.66322195e-01, 9.66953974e-01, 3.19405786e-01, 6.69486551e-01, - 5.03096313e-02, 6.95225201e-01, 5.78469859e-01, 6.29481655e-01, 1.39252534e-01, 1.22564968e-01, - 6.80663678e-01, 6.34607157e-01, 6.42765834e-01, 1.57127410e-02, 2.92132086e-01, 5.24423878e-01, - 4.68676824e-01, 2.86003928e-01, 7.18608322e-01, 8.95617933e-01, 5.48844309e-01, 1.74517278e-01, - 5.24379196e-01, 2.13526524e-01, 5.88375435e-01, 9.88560185e-01, 4.17435771e-01, 6.14438688e-01, - 9.53760881e-01, 5.27151288e-01, 7.03017278e-01, 3.44448559e-01, 4.47059676e-01, 2.83414901e-01, - 1.98979011e-01, 4.24917361e-01, 5.73172761e-01, 2.32398853e-02, 1.65887230e-01, 4.05552785e-01, - 9.29665524e-01, 2.26135696e-01, 9.20563384e-01, 7.65259963e-01, 4.54820075e-01, 8.97710267e-01, - 3.78559302e-03, 9.15219382e-01, 3.55705698e-01, 6.94905124e-01, 8.58540202e-01, 3.89790666e-01, - 2.49478206e-01, 7.93679304e-01, 4.75830027e-01, 4.40425353e-01, 3.70579459e-01, 1.40578049e-01, - 1.70386675e-01, 7.04056121e-01, 4.85963102e-01, 9.68450060e-01, 6.77178001e-01, 2.65934654e-01, - 2.58915007e-01, 6.70052890e-01, 2.61945109e-01, 8.46207759e-01, 1.01928951e-01, 2.85611334e-01, - 2.45776933e-01, 2.66658783e-01, 3.71724077e-01, 4.34319025e-01, 4.24407347e-01, 7.15417683e-01, - 8.07997684e-01, 1.64296275e-01, 6.01638065e-01, 8.60606804e-02, 2.68719187e-01, 5.11764101e-01, - 9.75844338e-01, 7.81226782e-01, 2.20925515e-01, 7.18135040e-01, 9.82395577e-01, 8.39160243e-01, - 9.08058083e-01, 6.88010677e-01, 8.14271847e-01, 5.12460821e-01, 1.17311345e-01, 5.96075228e-01, - 9.17455497e-01, 2.12052706e-01, 7.04074603e-01, 8.72872565e-02, 8.76047818e-01, 6.96235046e-01, - 8.54801557e-01, 2.49729159e-01, 9.76594604e-01, 2.87386363e-01, 2.36461559e-02, 9.94075254e-01, - 4.25193986e-01, 7.61869994e-01, 5.13334255e-01, 6.44711165e-02, 8.92156689e-01, 3.55235167e-01, - 1.08154647e-01, 8.78446825e-01, 2.43833016e-01, 9.23071293e-01, 2.72724115e-01, 9.46631338e-01, - 3.74510294e-01, 4.08451278e-02, 9.78392777e-01, 3.65079221e-01, 6.37199516e-01, 5.51144906e-01, - 5.25978080e-01, 1.42803678e-01, 4.05451674e-01, 7.79788219e-01, 6.26009784e-01, 3.35249497e-01, - 1.43159543e-02, 1.80363779e-01, 5.05096904e-01, 2.82619947e-01, 5.83561392e-01, 3.10951324e-01, - 8.73223968e-01, 4.38545619e-01, 4.81348800e-01, 6.68497085e-01, 3.79345401e-01, 9.58832501e-01, - 1.89869550e-01, 2.34083070e-01, 2.94066207e-01, 5.74892667e-02, 6.92106828e-02, 9.61127686e-02, - 6.72650672e-02, 8.47345378e-01, 2.80916761e-01, 7.32177357e-03, 9.80785961e-01, 5.73192225e-02, - 8.48781331e-01, 8.83225408e-01, 7.34398275e-01, 7.70381941e-01, 6.20778343e-01, 8.96822048e-01, - 5.40732486e-01, 3.69704071e-01, 5.77305837e-01, 2.08221827e-01, 7.34275341e-01, 1.06110900e-01, - 3.49496706e-01, 8.34948910e-01, 1.56403291e-02, 6.78576376e-01, 8.96141268e-01, 5.94835119e-01, - 1.43943153e-01, 3.49618530e-01, 2.10440392e-01, 3.46585620e-01, 1.05153093e-01, 3.45446174e-01, - 2.72177079e-01, 7.07946300e-01, 4.33717726e-02, 3.31232203e-01, 3.91874320e-01, 4.76338141e-01, - 6.22777789e-01, 2.95989228e-02, 4.32855769e-01, 7.61049310e-01, 3.63279149e-01, 9.47210350e-01, - 6.43721247e-01, 6.58025802e-01, 1.05247633e-02, 5.29974442e-01, 7.30675767e-01, 4.30041079e-01, - 6.62634841e-01, 8.25936616e-01, 9.91253704e-01, 6.79399281e-01, 5.44177006e-01, 7.52876048e-01, - 3.32139049e-01, 7.98732398e-01, 7.38865223e-01, 9.16055132e-01, 6.11736493e-01, 9.63672879e-01, - 1.83778839e-01, 7.27558919e-02, 5.91602822e-01, 3.25235484e-01, 2.34741217e-01, 9.52346277e-01, - 9.18556407e-01, 9.35373324e-01, 6.89209070e-01, 2.56049054e-01, 6.17975395e-01, 7.82285691e-01, - 9.84983432e-01, 6.62322741e-01, 2.04144457e-01, 3.98446577e-01, 1.38918297e-01, 3.05919921e-01, - 3.14043787e-01, 5.91072666e-01, 7.44703771e-01, 8.92272567e-01, 9.78017873e-01, 9.01203161e-01, - 1.41526372e-01, 4.14878484e-01, 6.80683651e-01, 5.01733152e-02, 8.14635389e-01, 2.27926375e-01, - 9.03269815e-01, 8.68443745e-01, 9.86939190e-01, 7.40779486e-01, 2.61005311e-01, 3.19276232e-01, - 9.69509248e-01, 1.11908818e-01, 4.49198556e-01, 1.27056715e-01, 3.84064823e-01, 5.14591811e-01, - 2.10747488e-01, 9.53884090e-01, 8.43167950e-01, 4.51187972e-01, 3.75331782e-01, 6.23566461e-01, - 3.55290379e-01, 2.95705968e-01, 1.69622690e-01, 1.42981830e-01, 2.72180991e-01, 9.46468040e-01, - 3.70932500e-01, 9.94292830e-01, 4.62587505e-01, 7.14817405e-01, 2.45370540e-02, 3.00906377e-01, - 5.75768304e-01, 9.71448393e-01, 6.95574827e-02, 3.93693854e-01, 5.29306116e-01, 5.04694554e-01, - 6.73797120e-02, 6.76596969e-01, 5.50948898e-01, 3.24909641e-01, 7.70337719e-01, 6.51842631e-03, - 3.03264879e-01, 7.61037886e-03, 2.72289601e-01, 1.50502041e-01, 6.71103888e-02, 7.41503703e-01, - 1.92088941e-01, 2.19043977e-01, 9.09320161e-01, 2.37993569e-01, 6.18107973e-02, 8.31447852e-01, - 2.23355609e-01, 1.84789435e-01, 4.16104518e-01, 4.21573859e-01, 8.72446305e-02, 2.97294197e-01, - 4.50328256e-01, 8.72199917e-01, 2.51279916e-01, 4.86219272e-01, 7.57071329e-01, 4.85655942e-01, - 1.06187277e-01, 4.92341327e-01, 1.46017513e-01, 5.25421017e-01, 4.22637906e-01, 2.24685018e-01, - 8.72648431e-01, 5.54051490e-01, 1.80745062e-01, 2.12756336e-01, 5.20883169e-01, 7.60363654e-01, - 8.30254678e-01, 5.00003328e-01, 4.69017439e-01, 6.38105527e-01, 3.50638261e-02, 5.22217353e-02, - 9.06516882e-02, 8.52975842e-01, 1.19985883e-01, 3.74926753e-01, 6.50302066e-01, 1.98875727e-01, - 6.28362507e-02, 4.32693501e-01, 3.10500685e-01, 6.20732833e-01, 4.58503272e-01, 3.20790034e-01, - 7.91284868e-01, 7.93054570e-01, 2.93406765e-01, 8.95399023e-01, 1.06441034e-01, 7.53085241e-02, - 8.67523104e-01, 1.47963482e-01, 1.25584706e-01, 3.81545040e-02, 6.34338619e-01, 1.76368938e-02, - 5.75553531e-02, 5.31607516e-01, 2.63869588e-01, 9.41945823e-01, 9.24028838e-02, 5.21496463e-01, - 7.74866558e-01, 5.65210610e-01, 7.28015327e-02, 6.51963790e-01, 8.94727453e-01, 4.49571590e-01, - 1.29932405e-01, 8.64026259e-01, 9.92599934e-01, 7.43721560e-01, 8.87300215e-01, 1.06369925e-01, - 8.11335531e-01, 7.87734900e-01, 9.87344678e-01, 5.32502820e-01, 4.42612382e-01, 9.64041183e-01, - 1.66085871e-01, 1.12937664e-01, 5.24423470e-01, 6.54689333e-01, 4.59119726e-01, 5.22774091e-01, - 3.08722276e-02, 6.26979315e-01, 4.49754105e-01, 8.07495757e-01, 2.34199499e-01, 1.67765675e-01, - 9.22168418e-01, 3.73210378e-01, 8.04432575e-01, 5.61890354e-01, 4.47025593e-01, 6.43155678e-01, - 2.40407640e-01, 5.91631279e-01, 1.59369206e-01, 7.75799090e-01, 8.32067212e-01, 5.59791576e-02, - 6.39105224e-01, 4.85274738e-01, 2.12630838e-01, 2.81431312e-02, 7.16205363e-01, 6.83885011e-01, - 5.23869697e-01, 9.99418314e-01, 8.35331599e-01, 4.69877463e-02, 6.74712562e-01, 7.99273684e-01, - 2.77001890e-02, 5.75809742e-01, 2.78513031e-01, 8.36209905e-01, 7.25472379e-01, 4.87173943e-01, - 7.88311357e-01, 9.64676177e-01, 1.75752651e-01, 4.98112580e-01, 8.08850418e-02, 6.40981131e-01, - 4.06647450e-01, 8.46539387e-01, 2.12620694e-01, 9.11012851e-01, 8.25041445e-01, 8.90065575e-01, - 9.63626055e-01, 5.96689242e-01, 1.63372670e-01, 4.51640148e-01, 3.43026542e-01, 5.80658851e-01, - 2.82327625e-01, 4.75535418e-01, 6.27760926e-01, 8.46314115e-01, 9.61961932e-01, 3.19806094e-01, - 5.05508062e-01, 5.28102944e-01, 6.13045057e-01, 7.44714938e-01, 1.50586073e-01, 7.91878033e-01, - 4.89839179e-01, 3.10496849e-01, 8.82309038e-01, 2.86922314e-01, 4.84687559e-01, 5.20838630e-01, - 4.62955493e-01, 2.38185305e-01, 5.47259907e-02, 7.10916137e-01, 7.31887202e-01, 6.25602317e-01, - 8.77741168e-01, 4.19881322e-01, 4.81222328e-01, 1.28224501e-01, 2.46034010e-01, 3.34971854e-01, - 7.37216484e-01, 5.62134821e-02, 7.14089724e-01, 9.85549393e-01, 4.66295827e-01, 3.08722434e-03, - 4.70237690e-01, 2.66524167e-01, 7.93875484e-01, 4.54795911e-02, 8.09702944e-01, 1.47709735e-02, - 1.70082405e-01, 6.35905179e-01, 3.75379109e-01, 4.30315011e-01, 3.15788760e-01, 5.58065230e-01, - 2.24643800e-01, 2.42142981e-01, 6.57283636e-01, 3.34921891e-01, 1.26588975e-01, 7.68064155e-01, - 9.43856291e-01, 4.47518596e-01, 5.44453573e-01, 9.95764932e-01, 7.16444391e-01, 8.51019765e-01, - 1.01179183e-01, 4.45473958e-01, 4.60327322e-01, 4.96895844e-02, 4.72907738e-01, 5.58987444e-01, - 3.41027487e-01, 1.56175026e-01, 7.58283148e-01, 6.83600909e-01, 2.14623396e-01, 3.27348880e-01, - 3.92517893e-01, 6.70418431e-01, 5.16440832e-01, 8.63140348e-01, 5.73277464e-01, 3.46608058e-01, - 7.39396341e-01, 7.20852434e-01, 2.35653246e-02, 3.89935659e-01, 7.53783745e-01, 6.34563528e-01, - 8.79339335e-01, 7.41599159e-02, 5.62433904e-01, 6.15553852e-01, 4.56956324e-01, 5.20047447e-01, - 5.26845015e-02, 5.58471266e-01, 1.63632233e-01, 5.38936665e-02, 6.49593683e-01, 2.56838748e-01, - 8.99035326e-01, 7.20847756e-01, 5.68954684e-01, 7.43684755e-01, 5.70924238e-01, 3.82318724e-01, - 4.89328290e-01, 5.62208561e-01, 4.97540804e-02, 4.18011085e-01, 6.88041565e-01, 2.16234653e-01, - 7.89548214e-01, 8.46136387e-01, 8.46816189e-01, 1.73842353e-01, 6.11627842e-02, 8.44440559e-01, - 4.50646654e-01, 3.74785037e-01, 4.87196697e-01, 4.56276448e-01, 9.13284391e-01, 4.15715464e-01, - 7.13597697e-01, 1.23641270e-02, 5.10031271e-01, 4.74601930e-02, 2.55731159e-01, 3.22090006e-01, - 1.91165703e-01, 4.51170940e-01, 7.50843157e-01, 4.42420576e-01, 4.25380660e-01, 4.50667257e-01, - 6.55689206e-01, 9.68257670e-02, 1.96528793e-01, 8.97343028e-01, 4.99940904e-01, 6.65504083e-01, - 9.41828079e-01, 4.54397338e-01, 5.61893331e-01, 5.09839880e-01, 4.53117514e-01, 8.96804127e-02, - 1.74888861e-01, 6.65641378e-01, 2.81668336e-01, 1.89532742e-01, 5.61668382e-01, 8.68330157e-02, - 8.25092797e-01, 5.18106324e-01, 1.71904024e-01, 3.68385523e-01, 1.62005436e-01, 7.48507399e-01, - 9.30274827e-01, 2.38198517e-01, 9.52222901e-01, 5.23587800e-01, 6.94384557e-01, 1.09338652e-01, - 4.83356794e-01, 2.73050402e-01, 3.68027050e-01, 5.92366466e-01, 1.83192289e-01, 8.60376029e-01, - 7.13926203e-01, 8.16750052e-01, 1.57890291e-01, 6.25691951e-01, 5.24831646e-01, 1.73873797e-01, - 1.02429784e-01, 9.17488471e-01, 4.03584434e-01, 9.31170884e-01, 2.79386137e-01, 8.77745206e-01, - 2.45200576e-01, 1.28896951e-01, 3.15713052e-01, 5.27874291e-01, 2.16444335e-01, 7.03883817e-01, - 7.74738919e-02, 8.42422142e-01, 3.75598924e-01, 3.51002411e-01, 6.22752776e-01, 4.82407943e-01, - 7.43107867e-01, 9.46182666e-01, 9.44344819e-01, 3.28124763e-01, 1.06147431e-01, 1.65102684e-01, - 3.84060507e-01, 2.91057722e-01, 7.68173662e-02, 1.03543651e-01, 6.76698940e-01, 1.43141994e-01, - 7.21342202e-01, 6.69471294e-03, 9.07298311e-01, 5.57080171e-01, 8.10954489e-01, 4.11120526e-01, - 2.06407453e-01, 2.59590556e-01, 7.58512718e-01, 5.79873897e-01, 2.92875650e-01, 2.83686529e-01, - 2.42829343e-01, 9.19323719e-01, 3.46832864e-01, 3.58238858e-01, 7.42827585e-01, 2.05760059e-01, - 9.58438860e-01, 5.66326411e-01, 6.60292846e-01, 5.61095078e-02, 6.79465531e-01, 7.05118513e-01, - 4.44713264e-01, 2.09732933e-01, 5.22732436e-01, 1.74396512e-01, 5.29356748e-01, 4.38475687e-01, - 4.94036404e-01, 4.09785794e-01, 6.40025507e-01, 5.79371821e-01, 1.57726118e-01, 6.04572263e-01, - 5.41072639e-01, 5.18847173e-01, 1.97093284e-01, 8.91767002e-01, 4.29050835e-01, 8.25490570e-01, - 3.87699807e-01, 4.50705808e-01, 2.49371643e-01, 3.36074898e-01, 9.29925118e-01, 6.65393649e-01, - 9.07275994e-01, 3.73075859e-01, 4.14044139e-03, 2.37463702e-01, 2.25893784e-01, 2.46900245e-01, - 4.50350196e-01, 3.48618117e-01, 5.07193932e-01, 5.23435142e-01, 8.13611417e-01, 8.92715622e-01, - 1.02623450e-01, 3.06088345e-01, 7.80461650e-01, 2.21453645e-01, 2.01419652e-01, 2.84254457e-01, - 3.68286735e-01, 7.39358243e-01, 8.97879394e-01, 9.81599566e-01, 7.56526442e-01, 7.37645545e-01, - 4.23976657e-02, 8.25922012e-01, 2.60956996e-01, 2.90702065e-01, 8.98388344e-01, 3.03733299e-01, - 8.49071471e-01, 3.45835425e-01, 7.65458276e-01, 5.68094872e-01, 8.93770930e-01, 9.93161641e-01, - 5.63368667e-02, 4.26548945e-01, 5.46745780e-01, 5.75674571e-01, 7.94599487e-01, 7.18935553e-02, - 4.46492976e-01, 6.40240123e-01, 2.73246969e-01, 2.00465968e-01, 1.30718835e-01, 1.92492005e-01, - 1.96617189e-01, 6.61271644e-01, 8.12687657e-01, 8.66342445e-01 - - }, - {0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 4, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - 10, - false, - 5}}; - -typedef LinkageTest LinkageTestF_Int; -TEST_P(LinkageTestF_Int, Result) { EXPECT_TRUE(score == 1.0); } - -INSTANTIATE_TEST_CASE_P(LinkageTest, LinkageTestF_Int, ::testing::ValuesIn(linkage_inputsf2)); -} // end namespace raft diff --git a/cpp/test/cluster/spectral.cu b/cpp/test/cluster/spectral.cu deleted file mode 100644 index b8ee611f29..0000000000 --- a/cpp/test/cluster/spectral.cu +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" - -#include -#include -#include - -#include - -#include -#include - -namespace raft { -namespace cluster { - -/** - * Warning: There appears to be a CUDA 12.2 bug in cusparse that causes an - * alignment issue. We've fixed the bug in our code through a workaround - * (see raft/sparse/linalg/spmm.hpp for fix). This test is meant to fail - * in the case where the fix is accidentally reverted, so that it doesn't - * break any downstream libraries that depend on RAFT - */ -TEST(Raft, Spectral) -{ - raft::handle_t handle; - - std::vector h_offsets({0, 2, 4, 7, 10, 12, 14}); - std::vector h_indices({1, 2, 0, 2, 0, 1, 3, 2, 4, 5, 3, 5, 3, 4}); - std::vector h_values( - {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); - std::vector expected_clustering({1, 1, 1, 0, 0, 0}); - - int32_t n_clusters{2}; - int32_t n_eigenvectors{2}; - int32_t evs_max_it{100}; - int32_t kmean_max_it{100}; - int32_t restartIter_lanczos = 15 + n_eigenvectors; - float evs_tol{0.001}; - float kmean_tol{0.001}; - unsigned long long seed1{1234567}; - unsigned long long seed2{12345678}; - bool reorthog{false}; - - rmm::device_uvector offsets(h_offsets.size(), handle.get_stream()); - rmm::device_uvector indices(h_indices.size(), handle.get_stream()); - rmm::device_uvector values(h_indices.size(), handle.get_stream()); - rmm::device_uvector clustering(expected_clustering.size(), handle.get_stream()); - rmm::device_uvector eigenvalues(n_eigenvectors, handle.get_stream()); - rmm::device_uvector eigenvectors(n_eigenvectors * expected_clustering.size(), - handle.get_stream()); - - rmm::device_uvector exp_dev(expected_clustering.size(), handle.get_stream()); - - raft::update_device( - exp_dev.data(), expected_clustering.data(), expected_clustering.size(), handle.get_stream()); - - raft::update_device(offsets.data(), h_offsets.data(), h_offsets.size(), handle.get_stream()); - raft::update_device(indices.data(), h_indices.data(), h_indices.size(), handle.get_stream()); - raft::update_device(values.data(), h_values.data(), h_values.size(), handle.get_stream()); - - raft::spectral::matrix::sparse_matrix_t const matrix{ - handle, - offsets.data(), - indices.data(), - values.data(), - static_cast(offsets.size() - 1), - static_cast(indices.size())}; - - raft::spectral::eigen_solver_config_t eig_cfg{ - n_eigenvectors, evs_max_it, restartIter_lanczos, evs_tol, reorthog, seed1}; - raft::spectral::lanczos_solver_t eig_solver{eig_cfg}; - - raft::spectral::cluster_solver_config_t clust_cfg{ - n_clusters, kmean_max_it, kmean_tol, seed2}; - raft::spectral::kmeans_solver_t cluster_solver{clust_cfg}; - - raft::spectral::partition(handle, - matrix, - eig_solver, - cluster_solver, - clustering.data(), - eigenvalues.data(), - eigenvectors.data()); - - ASSERT_TRUE(devArrMatch(expected_clustering.data(), - exp_dev.data(), - exp_dev.size(), - 1, - raft::Compare(), - handle.get_stream())); -} - -} // namespace cluster -} // namespace raft \ No newline at end of file diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu deleted file mode 100644 index a22fb7b1f9..0000000000 --- a/cpp/test/distance/dist_adj.cu +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "dist_adj.cuh" - -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace raft { -namespace distance { - -template -RAFT_KERNEL naiveDistanceAdjKernel(uint8_t* dist, - const DataType* x, - const DataType* y, - int m, - int n, - int k, - DataType eps, - bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto diff = x[xidx] - y[yidx]; - acc += diff * diff; - } - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc <= eps; -} - -template -void naiveDistanceAdj(uint8_t* dist, - const DataType* x, - const DataType* y, - int m, - int n, - int k, - DataType eps, - bool isRowMajor, - cudaStream_t stream) -{ - static const dim3 TPB(16, 32, 1); - dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); - naiveDistanceAdjKernel<<>>(dist, x, y, m, n, k, eps, isRowMajor); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -struct DistanceAdjInputs { - DataType eps; - int m, n, k; - bool isRowMajor; - unsigned long long int seed; -}; - -template -::std::ostream& operator<<(::std::ostream& os, const DistanceAdjInputs& dims) -{ - return os; -} - -template -class DistanceAdjTest : public ::testing::TestWithParam> { - public: - DistanceAdjTest() - : params(::testing::TestWithParam>::GetParam()), - stream(resource::get_cuda_stream(handle)), - dist(params.m * params.n, stream), - dist_ref(params.m * params.n, stream) - { - } - - void SetUp() override - { - raft::random::RngState r(params.seed); - int m = params.m; - int n = params.n; - int k = params.k; - bool isRowMajor = params.isRowMajor; - - rmm::device_uvector x(m * k, stream); - rmm::device_uvector y(n * k, stream); - - uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); - uniform(handle, r, y.data(), n * k, DataType(-1.0), DataType(1.0)); - - DataType threshold = params.eps; - - naiveDistanceAdj(dist_ref.data(), x.data(), y.data(), m, n, k, threshold, isRowMajor, stream); - size_t worksize = raft::distance:: - getWorkspaceSize( - x.data(), y.data(), m, n, k); - rmm::device_uvector workspace(worksize, stream); - - using threshold_final_op_ = threshold_final_op; - threshold_final_op_ threshold_op(threshold); - - raft::distance::distance(handle, - x.data(), - y.data(), - dist.data(), - m, - n, - k, - workspace.data(), - worksize, - threshold_op, - isRowMajor); - resource::sync_stream(handle, stream); - } - - void TearDown() override {} - - protected: - DistanceAdjInputs params; - // We use uint8_t even if the output in this test is a bool because - // cutlass doesn't support bool as output buffer yet. In cuda - // sizeof(bool) is 1 byte hence it doesn't increase - // memory consumption if we use uint8_t instead of bool. - rmm::device_uvector dist_ref; - rmm::device_uvector dist; - raft::resources handle; - cudaStream_t stream; -}; - -const std::vector> inputsf = { - {0.01f, 1024, 1024, 32, true, 1234ULL}, - {0.1f, 1024, 1024, 32, true, 1234ULL}, - {1.0f, 1024, 1024, 32, true, 1234ULL}, - {10.0f, 1024, 1024, 32, true, 1234ULL}, - {0.01f, 1024, 1024, 32, false, 1234ULL}, - {0.1f, 1024, 1024, 32, false, 1234ULL}, - {1.0f, 1024, 1024, 32, false, 1234ULL}, - {10.0f, 1024, 1024, 32, false, 1234ULL}, -}; -typedef DistanceAdjTest DistanceAdjTestF; -TEST_P(DistanceAdjTestF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.01, 1024, 1024, 32, true, 1234ULL}, - {0.1, 1024, 1024, 32, true, 1234ULL}, - {1.0, 1024, 1024, 32, true, 1234ULL}, - {10.0, 1024, 1024, 32, true, 1234ULL}, - {0.01, 1024, 1024, 32, false, 1234ULL}, - {0.1, 1024, 1024, 32, false, 1234ULL}, - {1.0, 1024, 1024, 32, false, 1234ULL}, - {10.0, 1024, 1024, 32, false, 1234ULL}, -}; -typedef DistanceAdjTest DistanceAdjTestD; -TEST_P(DistanceAdjTestD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare(), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestD, ::testing::ValuesIn(inputsd)); - -} // namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_adj.cuh b/cpp/test/distance/dist_adj.cuh deleted file mode 100644 index 2861cb33de..0000000000 --- a/cpp/test/distance/dist_adj.cuh +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dist_adj_threshold.cuh" - -#include - -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ - extern template void raft::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - FinalLambda fin_op, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, - float, - float, - uint8_t, - raft::distance::threshold_float, - int); - -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, - double, - double, - uint8_t, - raft::distance::threshold_double, - int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ - extern template size_t raft::distance::getWorkspaceSize( \ - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) - -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ - extern template size_t raft::distance::getWorkspaceSize( \ - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) - -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); - -#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_distance_instance.cu b/cpp/test/distance/dist_adj_distance_instance.cu deleted file mode 100644 index 158a5986c2..0000000000 --- a/cpp/test/distance/dist_adj_distance_instance.cu +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY - -#include "dist_adj_threshold.cuh" - -#include - -#include - -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ - template void raft::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - FinalLambda fin_op, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, - float, - float, - uint8_t, - raft::distance::threshold_float, - int); - -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, - double, - double, - uint8_t, - raft::distance::threshold_double, - int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ - template size_t raft::distance::getWorkspaceSize( \ - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) - -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); - -#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_threshold.cuh b/cpp/test/distance/dist_adj_threshold.cuh deleted file mode 100644 index 78663b3cd1..0000000000 --- a/cpp/test/distance/dist_adj_threshold.cuh +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // uint8_t - -namespace raft::distance { - -template -struct threshold_final_op { - DataT threshold_val; - - __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} - __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} - __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept - { - return d_val <= threshold_val; - } -}; - -using threshold_float = threshold_final_op; -using threshold_double = threshold_final_op; - -} // namespace raft::distance diff --git a/cpp/test/distance/dist_canberra.cu b/cpp/test/distance/dist_canberra.cu deleted file mode 100644 index 9b8b6c016b..0000000000 --- a/cpp/test/distance/dist_canberra.cu +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceCanberra : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceCanberra DistanceCanberraF; -TEST_P(DistanceCanberraF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceCanberra DistanceCanberraD; -TEST_P(DistanceCanberraD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraD, ::testing::ValuesIn(inputsd)); - -class BigMatrixCanberra : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixCanberra, Result) {} - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu deleted file mode 100644 index aa2866483a..0000000000 --- a/cpp/test/distance/dist_correlation.cu +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceCorrelation - : public DistanceTest {}; - -template -class DistanceCorrelationXequalY - : public DistanceTestSameBuffer {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceCorrelation DistanceCorrelationF; -TEST_P(DistanceCorrelationF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf)); - -typedef DistanceCorrelationXequalY DistanceCorrelationXequalYF; -TEST_P(DistanceCorrelationXequalYF, Result) -{ - int m = params.m; - ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), - dist[0].data(), - m, - m, - raft::CompareApprox(params.tolerance), - stream)); - ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), - dist[1].data(), - m / 2, - m, - raft::CompareApprox(params.tolerance), - stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationXequalYF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceCorrelation DistanceCorrelationD; -TEST_P(DistanceCorrelationD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationD, ::testing::ValuesIn(inputsd)); - -class BigMatrixCorrelation - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixCorrelation, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu deleted file mode 100644 index b792ec4039..0000000000 --- a/cpp/test/distance/dist_cos.cu +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceExpCos : public DistanceTest { -}; - -template -class DistanceExpCosXequalY - : public DistanceTestSameBuffer {}; - -const std::vector> inputsf = { - {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; - -const std::vector> inputsXeqYf = { - {0.01f, 1024, 1024, 32, true, 1234ULL}, - {0.01f, 1024, 32, 1024, true, 1234ULL}, - {0.01f, 32, 1024, 1024, true, 1234ULL}, - {0.03f, 1024, 1024, 1024, true, 1234ULL}, - {0.01f, 1024, 1024, 32, false, 1234ULL}, - {0.01f, 1024, 32, 1024, false, 1234ULL}, - {0.01f, 32, 1024, 1024, false, 1234ULL}, - {0.03f, 1024, 1024, 1024, false, 1234ULL}, -}; - -typedef DistanceExpCos DistanceExpCosF; -TEST_P(DistanceExpCosF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf)); - -typedef DistanceExpCosXequalY DistanceExpCosXequalYF; -TEST_P(DistanceExpCosXequalYF, Result) -{ - int m = params.m; - int n = params.m; - ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), - dist[0].data(), - m, - n, - raft::CompareApprox(params.tolerance), - stream)); - n = params.isRowMajor ? m : m / 2; - m = params.isRowMajor ? m / 2 : m; - - ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), - dist[1].data(), - m, - n, - raft::CompareApprox(params.tolerance), - stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYF, ::testing::ValuesIn(inputsXeqYf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceExpCos DistanceExpCosD; -TEST_P(DistanceExpCosD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosD, ::testing::ValuesIn(inputsd)); - -class BigMatrixCos : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixCos, Result) {} - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_dice.cu b/cpp/test/distance/dist_dice.cu deleted file mode 100644 index e127659dc6..0000000000 --- a/cpp/test/distance/dist_dice.cu +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceExpDice : public DistanceTest { -}; - -template -class DistanceExpDiceXequalY - : public DistanceTestSameBuffer {}; - -const std::vector> inputsf = { - {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; - -const std::vector> inputsXeqYf = { - {0.01f, 1024, 1024, 32, true, 1234ULL}, - {0.01f, 1024, 32, 1024, true, 1234ULL}, - {0.01f, 32, 1024, 1024, true, 1234ULL}, - {0.03f, 1024, 1024, 1024, true, 1234ULL}, - {0.01f, 1024, 1024, 32, false, 1234ULL}, - {0.01f, 1024, 32, 1024, false, 1234ULL}, - {0.01f, 32, 1024, 1024, false, 1234ULL}, - {0.03f, 1024, 1024, 1024, false, 1234ULL}, -}; - -typedef DistanceExpDice DistanceExpDiceF; -TEST_P(DistanceExpDiceF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApproxNaN(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceF, ::testing::ValuesIn(inputsf)); - -typedef DistanceExpDiceXequalY DistanceExpDiceXequalYF; -TEST_P(DistanceExpDiceXequalYF, Result) -{ - int m = params.m; - int n = params.m; - ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), - dist[0].data(), - m, - n, - raft::CompareApproxNaN(params.tolerance), - stream)); - n = params.isRowMajor ? m : m / 2; - m = params.isRowMajor ? m / 2 : m; - - ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), - dist[1].data(), - m, - n, - raft::CompareApproxNaN(params.tolerance), - stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceXequalYF, ::testing::ValuesIn(inputsXeqYf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceExpDice DistanceExpDiceD; -TEST_P(DistanceExpDiceD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApproxNaN(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceD, ::testing::ValuesIn(inputsd)); - -class BigMatrixDice : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixDice, Result) {} - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_hamming.cu b/cpp/test/distance/dist_hamming.cu deleted file mode 100644 index 9529ec2eaa..0000000000 --- a/cpp/test/distance/dist_hamming.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceHamming - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceHamming DistanceHammingF; -TEST_P(DistanceHammingF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceHamming DistanceHammingD; -TEST_P(DistanceHammingD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingD, ::testing::ValuesIn(inputsd)); - -class BigMatrixHamming - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixHamming, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_hellinger.cu b/cpp/test/distance/dist_hellinger.cu deleted file mode 100644 index 93d6101a18..0000000000 --- a/cpp/test/distance/dist_hellinger.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceHellingerExp - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceHellingerExp DistanceHellingerExpF; -TEST_P(DistanceHellingerExpF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceHellingerExp DistanceHellingerExpD; -TEST_P(DistanceHellingerExpD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpD, ::testing::ValuesIn(inputsd)); - -class BigMatrixHellingerExp - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixHellingerExp, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_inner_product.cu b/cpp/test/distance/dist_inner_product.cu deleted file mode 100644 index 8dd7ef0874..0000000000 --- a/cpp/test/distance/dist_inner_product.cu +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceInnerProduct - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 10, 5, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceInnerProduct DistanceInnerProductF; -TEST_P(DistanceInnerProductF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceInnerProduct DistanceInnerProductD; -TEST_P(DistanceInnerProductD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductD, ::testing::ValuesIn(inputsd)); - -class BigMatrixInnerProduct - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixInnerProduct, Result) {} - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu deleted file mode 100644 index e0e256c925..0000000000 --- a/cpp/test/distance/dist_jensen_shannon.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceJensenShannon - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceJensenShannon DistanceJensenShannonF; -TEST_P(DistanceJensenShannonF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceJensenShannon DistanceJensenShannonD; -TEST_P(DistanceJensenShannonD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonD, ::testing::ValuesIn(inputsd)); - -class BigMatrixJensenShannon - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixJensenShannon, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu deleted file mode 100644 index 1f79ebcad4..0000000000 --- a/cpp/test/distance/dist_kl_divergence.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceKLDivergence - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceKLDivergence DistanceKLDivergenceF; -TEST_P(DistanceKLDivergenceF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceKLDivergence DistanceKLDivergenceD; -TEST_P(DistanceKLDivergenceD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceD, ::testing::ValuesIn(inputsd)); - -class BigMatrixKLDivergence - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixKLDivergence, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_l1.cu b/cpp/test/distance/dist_l1.cu deleted file mode 100644 index ce62a4aeec..0000000000 --- a/cpp/test/distance/dist_l1.cu +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceUnexpL1 : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceUnexpL1 DistanceUnexpL1F; -TEST_P(DistanceUnexpL1F, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1F, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceUnexpL1 DistanceUnexpL1D; -TEST_P(DistanceUnexpL1D, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1D, ::testing::ValuesIn(inputsd)); - -class BigMatrixUnexpL1 : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixUnexpL1, Result) {} - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu deleted file mode 100644 index 0203d9ed9d..0000000000 --- a/cpp/test/distance/dist_l2_exp.cu +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceEucExpTest : public DistanceTest { -}; - -template -class DistanceEucExpTestXequalY - : public DistanceTestSameBuffer {}; - -const std::vector> inputsf = { - {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, - {0.001f, 2048, 4096, 128, true, 1234ULL}, - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.003f, 1021, 1021, 1021, true, 1234ULL}, - {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, - {0.003f, 1021, 1021, 1021, false, 1234ULL}, -}; - -const std::vector> inputsXeqYf = { - {0.01f, 2048, 4096, 128, true, 1234ULL}, - {0.01f, 1024, 1024, 32, true, 1234ULL}, - {0.01f, 1024, 32, 1024, true, 1234ULL}, - {0.01f, 32, 1024, 1024, true, 1234ULL}, - {0.03f, 1024, 1024, 1024, true, 1234ULL}, - {0.03f, 1021, 1021, 1021, true, 1234ULL}, - {0.01f, 1024, 1024, 32, false, 1234ULL}, - {0.01f, 1024, 32, 1024, false, 1234ULL}, - {0.01f, 32, 1024, 1024, false, 1234ULL}, - {0.03f, 1024, 1024, 1024, false, 1234ULL}, - {0.03f, 1021, 1021, 1021, false, 1234ULL}, -}; - -typedef DistanceEucExpTest DistanceEucExpTestF; -TEST_P(DistanceEucExpTestF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf)); - -typedef DistanceEucExpTestXequalY DistanceEucExpTestXequalYF; -TEST_P(DistanceEucExpTestXequalYF, Result) -{ - int m = params.m; - ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), - dist[0].data(), - m, - m, - raft::CompareApprox(params.tolerance), - stream)); - ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), - dist[1].data(), - m / 2, - m, - raft::CompareApprox(params.tolerance), - stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, - DistanceEucExpTestXequalYF, - ::testing::ValuesIn(inputsXeqYf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceEucExpTest DistanceEucExpTestD; -TEST_P(DistanceEucExpTestD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestD, ::testing::ValuesIn(inputsd)); - -class BigMatrixEucExp : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixEucExp, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_l2_sqrt_exp.cu b/cpp/test/distance/dist_l2_sqrt_exp.cu deleted file mode 100644 index 5bccabcc3f..0000000000 --- a/cpp/test/distance/dist_l2_sqrt_exp.cu +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceEucSqrtExpTest - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 2048, 4096, 128, true, 1234ULL}, - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.003f, 1021, 1021, 1021, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, - {0.003f, 1021, 1021, 1021, false, 1234ULL}, -}; -typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestF; -TEST_P(DistanceEucSqrtExpTestF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestD; -TEST_P(DistanceEucSqrtExpTestD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestD, ::testing::ValuesIn(inputsd)); - -class BigMatrixEucSqrtExp - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixEucSqrtExp, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_l2_unexp.cu b/cpp/test/distance/dist_l2_unexp.cu deleted file mode 100644 index 19b0ff6dbf..0000000000 --- a/cpp/test/distance/dist_l2_unexp.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceEucUnexpTest - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceEucUnexpTest DistanceEucUnexpTestF; -TEST_P(DistanceEucUnexpTestF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceEucUnexpTest DistanceEucUnexpTestD; -TEST_P(DistanceEucUnexpTestD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestD, ::testing::ValuesIn(inputsd)); - -class BigMatrixEucUnexp : public BigMatrixDistanceTest { -}; -TEST_F(BigMatrixEucUnexp, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_l_inf.cu b/cpp/test/distance/dist_l_inf.cu deleted file mode 100644 index 223d186a8d..0000000000 --- a/cpp/test/distance/dist_l_inf.cu +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceLinf : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceLinf DistanceLinfF; -TEST_P(DistanceLinfF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceLinf DistanceLinfD; -TEST_P(DistanceLinfD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfD, ::testing::ValuesIn(inputsd)); - -class BigMatrixLinf : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixLinf, Result) {} - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_lp_unexp.cu b/cpp/test/distance/dist_lp_unexp.cu deleted file mode 100644 index 9d6f5921a7..0000000000 --- a/cpp/test/distance/dist_lp_unexp.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceLpUnexp : public DistanceTest { -}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL, 4.0f}, - {0.001f, 1024, 32, 1024, true, 1234ULL, 3.0f}, - {0.001f, 32, 1024, 1024, true, 1234ULL, 4.0f}, - {0.003f, 1024, 1024, 1024, true, 1234ULL, 3.0f}, - {0.001f, 1024, 1024, 32, false, 1234ULL, 4.0f}, - {0.001f, 1024, 32, 1024, false, 1234ULL, 3.0f}, - {0.001f, 32, 1024, 1024, false, 1234ULL, 4.0f}, - {0.003f, 1024, 1024, 1024, false, 1234ULL, 3.0f}, -}; -typedef DistanceLpUnexp DistanceLpUnexpF; -TEST_P(DistanceLpUnexpF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL, 4.0}, - {0.001, 1024, 32, 1024, true, 1234ULL, 3.0}, - {0.001, 32, 1024, 1024, true, 1234ULL, 4.0}, - {0.003, 1024, 1024, 1024, true, 1234ULL, 3.0}, - {0.001, 1024, 1024, 32, false, 1234ULL, 4.0}, - {0.001, 1024, 32, 1024, false, 1234ULL, 3.0}, - {0.001, 32, 1024, 1024, false, 1234ULL, 4.0}, - {0.003, 1024, 1024, 1024, false, 1234ULL, 3.0}, -}; -typedef DistanceLpUnexp DistanceLpUnexpD; -TEST_P(DistanceLpUnexpD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpD, ::testing::ValuesIn(inputsd)); - -class BigMatrixLpUnexp : public BigMatrixDistanceTest { -}; -TEST_F(BigMatrixLpUnexp, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/dist_russell_rao.cu b/cpp/test/distance/dist_russell_rao.cu deleted file mode 100644 index 73cf4b33a4..0000000000 --- a/cpp/test/distance/dist_russell_rao.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "distance_base.cuh" - -namespace raft { -namespace distance { - -template -class DistanceRussellRao - : public DistanceTest {}; - -const std::vector> inputsf = { - {0.001f, 1024, 1024, 32, true, 1234ULL}, - {0.001f, 1024, 32, 1024, true, 1234ULL}, - {0.001f, 32, 1024, 1024, true, 1234ULL}, - {0.003f, 1024, 1024, 1024, true, 1234ULL}, - {0.001f, 1024, 1024, 32, false, 1234ULL}, - {0.001f, 1024, 32, 1024, false, 1234ULL}, - {0.001f, 32, 1024, 1024, false, 1234ULL}, - {0.003f, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceRussellRao DistanceRussellRaoF; -TEST_P(DistanceRussellRaoF, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.001, 1024, 1024, 32, true, 1234ULL}, - {0.001, 1024, 32, 1024, true, 1234ULL}, - {0.001, 32, 1024, 1024, true, 1234ULL}, - {0.003, 1024, 1024, 1024, true, 1234ULL}, - {0.001, 1024, 1024, 32, false, 1234ULL}, - {0.001, 1024, 32, 1024, false, 1234ULL}, - {0.001, 32, 1024, 1024, false, 1234ULL}, - {0.003, 1024, 1024, 1024, false, 1234ULL}, -}; -typedef DistanceRussellRao DistanceRussellRaoD; -TEST_P(DistanceRussellRaoD, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_TRUE(raft::devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoD, ::testing::ValuesIn(inputsd)); - -class BigMatrixRussellRao - : public BigMatrixDistanceTest {}; -TEST_F(BigMatrixRussellRao, Result) {} -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh deleted file mode 100644 index f44fb18519..0000000000 --- a/cpp/test/distance/distance_base.cuh +++ /dev/null @@ -1,708 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" - -#include // common::nvtx::range -#include // make_device_matrix_view -#include // raft::sqrt -#include -#include // raft::resources -#include -#include // raft::distance::DistanceType -#include - -#include // rmm::device_uvector - -#include - -namespace raft { -namespace distance { - -template -RAFT_KERNEL naiveDistanceKernel(DataType* dist, - const DataType* x, - const DataType* y, - int m, - int n, - int k, - raft::distance::DistanceType type, - bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto diff = x[xidx] - y[yidx]; - acc += diff * diff; - } - if (type == raft::distance::DistanceType::L2SqrtExpanded || - type == raft::distance::DistanceType::L2SqrtUnexpanded) - acc = raft::sqrt(acc); - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(DataType* dist, - const DataType* x, - const DataType* y, - int m, - int n, - int k, - raft::distance::DistanceType type, - bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) { return; } - - DataType acc = DataType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - auto diff = (a > b) ? (a - b) : (b - a); - if (type == raft::distance::DistanceType::Linf) { - acc = raft::max(acc, diff); - } else if (type == raft::distance::DistanceType::Canberra) { - const auto add = raft::abs(a) + raft::abs(b); - // deal with potential for 0 in denominator by - // forcing 1/0 instead - acc += ((add != 0) * diff / (add + (add == 0))); - } else { - acc += diff; - } - } - - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -RAFT_KERNEL naiveDiceDistanceKernel( - DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) { return; } - - DataType acc_a = DataType(0); - DataType acc_b = DataType(0); - DataType acc_ab = DataType(0); - - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - acc_a += a; - acc_b += b; - acc_ab += a * b; - } - - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - - // Use 1.0 - (dice dissimilarity) to calc the distance - dist[outidx] = (DataType)1.0 - (2 * acc_ab / ((acc_a) + (acc_b))); -} - -template -RAFT_KERNEL naiveCosineDistanceKernel( - DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) { return; } - - DataType acc_a = DataType(0); - DataType acc_b = DataType(0); - DataType acc_ab = DataType(0); - - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - acc_a += a * a; - acc_b += b * b; - acc_ab += a * b; - } - - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - - // Use 1.0 - (cosine similarity) to calc the distance - dist[outidx] = (DataType)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); -} - -template -RAFT_KERNEL naiveInnerProductKernel( - DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) { return; } - - DataType acc_ab = DataType(0); - - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - acc_ab += a * b; - } - - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc_ab; -} - -template -RAFT_KERNEL naiveHellingerDistanceKernel( - DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) { return; } - - DataType acc_ab = DataType(0); - - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - acc_ab += raft::sqrt(a) * raft::sqrt(b); - } - - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - - // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - acc_ab = 1 - acc_ab; - auto rectifier = (!signbit(acc_ab)); - dist[outidx] = raft::sqrt(rectifier * acc_ab); -} - -template -RAFT_KERNEL naiveLpUnexpDistanceKernel(DataType* dist, - const DataType* x, - const DataType* y, - int m, - int n, - int k, - bool isRowMajor, - DataType p) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - auto diff = raft::abs(a - b); - acc += raft::pow(diff, p); - } - auto one_over_p = 1 / p; - acc = raft::pow(acc, one_over_p); - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -RAFT_KERNEL naiveHammingDistanceKernel( - DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - acc += (a != b); - } - acc = acc / k; - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -RAFT_KERNEL naiveJensenShannonDistanceKernel( - DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - DataType acc = DataType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - - DataType m = 0.5f * (a + b); - bool a_zero = a == 0; - bool b_zero = b == 0; - - DataType p = (!a_zero * m) / (a_zero + a); - DataType q = (!b_zero * m) / (b_zero + b); - - bool p_zero = p == 0; - bool q_zero = q == 0; - - acc += (-a * (!p_zero * log(p + p_zero))) + (-b * (!q_zero * log(q + q_zero))); - } - acc = raft::sqrt(0.5f * acc); - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -RAFT_KERNEL naiveRussellRaoDistanceKernel( - OutType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - OutType acc = OutType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - acc += (a * b); - } - acc = (k - acc) / k; - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -RAFT_KERNEL naiveKLDivergenceDistanceKernel( - OutType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - OutType acc = OutType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - bool b_zero = (b == 0); - bool a_zero = (a == 0); - acc += a * (log(a + a_zero) - log(b + b_zero)); - } - acc = 0.5f * acc; - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -RAFT_KERNEL naiveCorrelationDistanceKernel( - OutType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) -{ - int midx = threadIdx.x + blockIdx.x * blockDim.x; - int nidx = threadIdx.y + blockIdx.y * blockDim.y; - if (midx >= m || nidx >= n) return; - OutType acc = OutType(0); - auto a_norm = DataType(0); - auto b_norm = DataType(0); - auto a_sq_norm = DataType(0); - auto b_sq_norm = DataType(0); - for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - a_norm += a; - b_norm += b; - a_sq_norm += (a * a); - b_sq_norm += (b * b); - acc += (a * b); - } - - auto numer = k * acc - (a_norm * b_norm); - auto Q_denom = k * a_sq_norm - (a_norm * a_norm); - auto R_denom = k * b_sq_norm - (b_norm * b_norm); - - acc = 1 - (numer / raft::sqrt(Q_denom * R_denom)); - - int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - dist[outidx] = acc; -} - -template -void naiveDistance(DataType* dist, - const DataType* x, - const DataType* y, - int m, - int n, - int k, - raft::distance::DistanceType type, - bool isRowMajor, - DataType metric_arg = 2.0f, - cudaStream_t stream = 0) -{ - static const dim3 TPB(4, 256, 1); - dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); - - switch (type) { - case raft::distance::DistanceType::Canberra: - case raft::distance::DistanceType::Linf: - case raft::distance::DistanceType::L1: - naiveL1_Linf_CanberraDistanceKernel - <<>>(dist, x, y, m, n, k, type, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Expanded: - naiveDistanceKernel - <<>>(dist, x, y, m, n, k, type, isRowMajor); - break; - case raft::distance::DistanceType::CosineExpanded: - naiveCosineDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::HellingerExpanded: - naiveHellingerDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::LpUnexpanded: - naiveLpUnexpDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::HammingUnexpanded: - naiveHammingDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::InnerProduct: - naiveInnerProductKernel<<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::JensenShannon: - naiveJensenShannonDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::RusselRaoExpanded: - naiveRussellRaoDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::KLDivergence: - naiveKLDivergenceDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::CorrelationExpanded: - naiveCorrelationDistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); - break; - case raft::distance::DistanceType::DiceExpanded: - naiveDiceDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); - break; - default: FAIL() << "should be here\n"; - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -struct DistanceInputs { - DataType tolerance; - int m, n, k; - bool isRowMajor; - unsigned long long int seed; - DataType metric_arg = 2.0f; -}; - -template -::std::ostream& operator<<(::std::ostream& os, const DistanceInputs& dims) -{ - return os; -} - -// TODO: Remove when mdspan-based raft::runtime::distance::pairwise_distance is -// implemented. -// -// Context: -// https://github.com/rapidsai/raft/issues/1338 -template -constexpr bool layout_to_row_major(); - -template <> -constexpr bool layout_to_row_major() -{ - return true; -} -template <> -constexpr bool layout_to_row_major() -{ - return false; -} - -template -void distanceLauncher(raft::resources const& handle, - DataType* x, - DataType* y, - DataType* dist, - DataType* dist2, - int m, - int n, - int k, - DistanceInputs& params, - DataType threshold, - DataType metric_arg = 2.0f) -{ - auto x_v = make_device_matrix_view(x, m, k); - auto y_v = make_device_matrix_view(y, n, k); - auto dist_v = make_device_matrix_view(dist, m, n); - - raft::distance::distance( - handle, x_v, y_v, dist_v, metric_arg); -} - -template -class DistanceTest : public ::testing::TestWithParam> { - public: - DistanceTest() - : params(::testing::TestWithParam>::GetParam()), - stream(resource::get_cuda_stream(handle)), - x(params.m * params.k, stream), - y(params.n * params.k, stream), - dist_ref(params.m * params.n, stream), - dist(params.m * params.n, stream), - dist2(params.m * params.n, stream) - { - } - - void SetUp() override - { - auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); - common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); - - raft::random::RngState r(params.seed); - int m = params.m; - int n = params.n; - int k = params.k; - DataType metric_arg = params.metric_arg; - bool isRowMajor = params.isRowMajor; - if (distanceType == raft::distance::DistanceType::HellingerExpanded || - distanceType == raft::distance::DistanceType::JensenShannon || - distanceType == raft::distance::DistanceType::KLDivergence) { - // Hellinger works only on positive numbers - uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); - uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); - } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded || - distanceType == raft::distance::DistanceType::DiceExpanded) { - uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); - uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); - // Russel rao works on boolean values. - bernoulli(handle, r, x.data(), m * k, 0.5f); - bernoulli(handle, r, y.data(), n * k, 0.5f); - } else { - uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); - uniform(handle, r, y.data(), n * k, DataType(-1.0), DataType(1.0)); - } - naiveDistance( - dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); - - DataType threshold = -10000.f; - - if (isRowMajor) { - distanceLauncher(handle, - x.data(), - y.data(), - dist.data(), - dist2.data(), - m, - n, - k, - params, - threshold, - metric_arg); - - } else { - distanceLauncher(handle, - x.data(), - y.data(), - dist.data(), - dist2.data(), - m, - n, - k, - params, - threshold, - metric_arg); - } - resource::sync_stream(handle, stream); - } - - protected: - raft::resources handle; - cudaStream_t stream; - - DistanceInputs params; - rmm::device_uvector x, y, dist_ref, dist, dist2; -}; - -/* - * This test suite verifies the path when X and Y are same buffer, - * distance metrics which requires norms like L2 expanded/cosine/correlation - * takes a more optimal path in such case to skip norm calculation for Y buffer. - * It may happen that though both X and Y are same buffer but user passes - * different dimensions for them like in case of tiled_brute_force_knn. - */ -template -class DistanceTestSameBuffer : public ::testing::TestWithParam> { - public: - using dev_vector = rmm::device_uvector; - DistanceTestSameBuffer() - : params(::testing::TestWithParam>::GetParam()), - stream(resource::get_cuda_stream(handle)), - x(params.m * params.k, stream), - dist_ref({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), - dist({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), - dist2({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}) - { - } - - void SetUp() override - { - auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); - common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); - - raft::random::RngState r(params.seed); - int m = params.m; - int n = params.m; - int k = params.k; - DataType metric_arg = params.metric_arg; - bool isRowMajor = params.isRowMajor; - if (distanceType == raft::distance::DistanceType::HellingerExpanded || - distanceType == raft::distance::DistanceType::JensenShannon || - distanceType == raft::distance::DistanceType::KLDivergence) { - // Hellinger works only on positive numbers - uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); - } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded || - distanceType == raft::distance::DistanceType::DiceExpanded) { - uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); - // Russel rao works on boolean values. - bernoulli(handle, r, x.data(), m * k, 0.5f); - } else { - uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); - } - - for (int i = 0; i < 2; i++) { - // both X and Y are same buffer but when i = 1 - // different dimensions for x & y is passed. - m = m / (i + 1); - naiveDistance(dist_ref[i].data(), - x.data(), - x.data(), - m, - n, - k, - distanceType, - isRowMajor, - metric_arg, - stream); - - DataType threshold = -10000.f; - - if (isRowMajor) { - distanceLauncher(handle, - x.data(), - x.data(), - dist[i].data(), - dist2[i].data(), - m, - n, - k, - params, - threshold, - metric_arg); - - } else { - distanceLauncher(handle, - x.data(), - x.data(), - dist[i].data(), - dist2[i].data(), - m, - n, - k, - params, - threshold, - metric_arg); - } - } - resource::sync_stream(handle, stream); - } - - protected: - raft::resources handle; - cudaStream_t stream; - - DistanceInputs params; - dev_vector x; - static const int N = 2; - std::array dist_ref, dist, dist2; -}; - -template -class BigMatrixDistanceTest : public ::testing::Test { - public: - BigMatrixDistanceTest() - : x(m * k, resource::get_cuda_stream(handle)), - dist(std::size_t(m) * m, resource::get_cuda_stream(handle)){}; - void SetUp() override - { - auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); - common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); - - void pairwise_distance(raft::resources const& handle, - float* x, - float* y, - float* dists, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); - constexpr bool row_major = true; - constexpr float metric_arg = 0.0f; - raft::distance::distance( - handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); - RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); - } - - protected: - raft::resources handle; - int m = 48000; - int n = 48000; - int k = 1; - rmm::device_uvector x, dist; -}; -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu deleted file mode 100644 index d4d632e1dc..0000000000 --- a/cpp/test/distance/fused_cosine_nn.cu +++ /dev/null @@ -1,420 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation - -#include "../test_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace distance { - -template -struct RaftKVPMinReduce { - typedef raft::KeyValuePair KVP; - - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - -}; // KVPMinReduce - -template -__global__ void naiveCosKernel(raft::KeyValuePair* min, - DataT* x, - DataT* y, - int m, - int n, - int k, - int* workspace, - DataT maxVal) -{ - int midx = threadIdx.y + blockIdx.y * blockDim.y; - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - DataT acc_a = DataT(0); - DataT acc_b = DataT(0); - DataT acc_ab = DataT(0); - // if (midx >= m || nidx >= n) { return; } - - for (int i = 0; i < k; ++i) { - int xidx = i + midx * k; - int yidx = i + nidx * k; - auto a = x[xidx]; - auto b = y[yidx]; - acc_a += a * a; - acc_b += b * b; - acc_ab += a * b; - } - - // Use 1.0 - (cosine similarity) to calc the distance - DataT acc = maxVal; - if (midx < m || nidx < n) { acc = (DataT)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } - - ReduceOpT redOp; - typedef cub::WarpReduce> WarpReduce; - __shared__ typename WarpReduce::TempStorage temp[NWARPS]; - int warpId = threadIdx.x / raft::WarpSize; - raft::KeyValuePair tmp; - tmp.key = nidx; - tmp.value = midx >= m || nidx >= n ? maxVal : acc; - tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); - if (threadIdx.x % raft::WarpSize == 0 && midx < m) { - while (atomicCAS(workspace + midx, 0, 1) == 1) - ; - __threadfence(); - redOp(midx, min + midx, tmp); - __threadfence(); - atomicCAS(workspace + midx, 1, 0); - } -} - -template -void naive(raft::KeyValuePair* min, - DataT* x, - DataT* y, - int m, - int n, - int k, - int* workspace, - cudaStream_t stream) -{ - static const dim3 TPB(32, 16, 1); - dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - auto blks = raft::ceildiv(m, 256); - MinAndDistanceReduceOp op; - detail::initKernel, int> - <<>>(min, m, std::numeric_limits::max(), op); - RAFT_CUDA_TRY(cudaGetLastError()); - naiveCosKernel, 16> - <<>>(min, x, y, m, n, k, workspace, std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -struct Inputs { - DataT tolerance; - int m, n, k; - unsigned long long int seed; - - friend std::ostream& operator<<(std::ostream& os, const Inputs& p) - { - return os << "m: " << p.m - << ", " - "n: " - << p.n - << ", " - "k: " - << p.k - << ", " - "seed: " - << p.seed - << ", " - "tol: " - << p.tolerance; - } -}; - -template -class FusedCosineNNTest : public ::testing::TestWithParam> { - public: - FusedCosineNNTest() - : params(::testing::TestWithParam>::GetParam()), - stream(resource::get_cuda_stream(handle)), - x(params.m * params.k, stream), - y(params.n * params.k, stream), - xn(params.m, stream), - yn(params.n, stream), - min(params.m, stream), - min_ref(params.m, stream), - workspace(params.m * sizeof(int), stream) - { - } - - protected: - void SetUp() override - { - raft::random::RngState r(params.seed); - int m = params.m; - int n = params.n; - int k = params.k; - uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); - uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); - generateGoldenResult(); - raft::linalg::rowNorm( - xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); - raft::linalg::rowNorm( - yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - protected: - raft::resources handle; - cudaStream_t stream; - Inputs params; - rmm::device_uvector x; - rmm::device_uvector y; - rmm::device_uvector xn; - rmm::device_uvector yn; - rmm::device_uvector> min; - rmm::device_uvector> min_ref; - rmm::device_uvector workspace; - - virtual void generateGoldenResult() - { - int m = params.m; - int n = params.n; - int k = params.k; - naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); - } - - void runTest(raft::KeyValuePair* out) - { - int m = params.m; - int n = params.n; - int k = params.k; - raft::distance::DistanceType metric = raft::distance::DistanceType::CosineExpanded; - constexpr bool init_out_buffer = true; - fusedDistanceNNMinReduce, int>(out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - false, - init_out_buffer, - true, - metric, - 0.0f, - stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - } -}; - -template -struct CompareApproxAbsKVP { - typedef typename raft::KeyValuePair KVP; - CompareApproxAbsKVP(T eps_) : eps(eps_) {} - bool operator()(const KVP& a, const KVP& b) const - { - T diff = std::abs(std::abs(a.value) - std::abs(b.value)); - T m = std::max(std::abs(a.value), std::abs(b.value)); - T ratio = m >= eps ? diff / m : diff; - return (ratio <= eps); - } - - private: - T eps; -}; - -template -struct CompareExactKVP { - typedef typename raft::KeyValuePair KVP; - bool operator()(const KVP& a, const KVP& b) const - { - if (a.value != b.value) return false; - return true; - } -}; - -template -::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, - const raft::KeyValuePair* actual, - size_t size, - L eq_compare, - cudaStream_t stream = 0) -{ - typedef typename raft::KeyValuePair KVP; - std::shared_ptr exp_h(new KVP[size]); - std::shared_ptr act_h(new KVP[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto exp = exp_h.get()[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - return ::testing::AssertionFailure() - << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," - << exp.value << " @" << i; - } - } - return ::testing::AssertionSuccess(); -} - -const std::vector> inputsf = { - {0.001f, 32, 32, 32, 1234ULL}, - {0.001f, 32, 64, 32, 1234ULL}, - {0.001f, 64, 32, 32, 1234ULL}, - {0.001f, 64, 64, 32, 1234ULL}, - {0.001f, 128, 32, 32, 1234ULL}, - {0.001f, 128, 64, 32, 1234ULL}, - {0.001f, 128, 128, 64, 1234ULL}, - {0.001f, 64, 128, 128, 1234ULL}, - - {0.001f, 32, 32, 34, 1234ULL}, - {0.001f, 32, 64, 34, 1234ULL}, - {0.001f, 64, 32, 34, 1234ULL}, - {0.001f, 64, 64, 34, 1234ULL}, - {0.001f, 128, 32, 34, 1234ULL}, - {0.001f, 128, 64, 34, 1234ULL}, - {0.001f, 128, 128, 66, 1234ULL}, - {0.001f, 64, 128, 130, 1234ULL}, - - {0.001f, 32, 32, 33, 1234ULL}, - {0.001f, 32, 64, 33, 1234ULL}, - {0.001f, 64, 32, 33, 1234ULL}, - {0.001f, 64, 64, 33, 1234ULL}, - {0.001f, 128, 32, 33, 1234ULL}, - {0.001f, 128, 64, 33, 1234ULL}, - {0.001f, 128, 128, 65, 1234ULL}, - {0.001f, 64, 128, 129, 1234ULL}, - {0.006f, 1805, 134, 2, 1234ULL}, - {0.006f, 8192, 1024, 64, 1234ULL}, - {0.006f, 8192, 1025, 64, 1234ULL}, - - // Repeat with smaller values of k - {0.006f, 32, 32, 1, 1234ULL}, - {0.001f, 32, 64, 2, 1234ULL}, - {0.001f, 64, 32, 3, 1234ULL}, - {0.001f, 64, 64, 4, 1234ULL}, - {0.001f, 128, 32, 5, 1234ULL}, - {0.001f, 128, 64, 6, 1234ULL}, - {0.001f, 128, 128, 7, 1234ULL}, - {0.001f, 64, 128, 8, 1234ULL}, - - {0.001f, 32, 32, 9, 1234ULL}, - {0.001f, 32, 64, 10, 1234ULL}, - {0.001f, 64, 32, 11, 1234ULL}, - {0.001f, 64, 64, 12, 1234ULL}, - {0.001f, 128, 32, 13, 1234ULL}, - {0.001f, 128, 64, 14, 1234ULL}, - {0.001f, 128, 128, 15, 1234ULL}, - {0.001f, 64, 128, 16, 1234ULL}, - - {0.001f, 32, 32, 17, 1234ULL}, - {0.001f, 32, 64, 18, 1234ULL}, - {0.001f, 64, 32, 19, 1234ULL}, - {0.001f, 64, 64, 20, 1234ULL}, - {0.001f, 128, 32, 21, 1234ULL}, - {0.001f, 128, 64, 22, 1234ULL}, - {0.001f, 128, 128, 23, 1234ULL}, - {0.00001, 64, 128, 24, 1234ULL}, - {0.001f, 1805, 134, 25, 1234ULL}, - {0.006f, 8192, 1024, 25, 1234ULL}, - {0.006f, 8192, 1024, 66, 1234ULL}, -}; -typedef FusedCosineNNTest FusedCosineNNTestF; -TEST_P(FusedCosineNNTestF, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestF, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, - {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, - {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, - {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, - - {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, - {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, - {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, - {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, - - {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, - {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, - {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, - {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, - - {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, -}; -typedef FusedCosineNNTest FusedCosineNNTestD; -TEST_P(FusedCosineNNTestD, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestD, ::testing::ValuesIn(inputsd)); - -/// This is to test output determinism of the prim -template -class FusedCosineNNDetTest : public FusedCosineNNTest { - public: - FusedCosineNNDetTest() : stream(resource::get_cuda_stream(handle)), min1(0, stream) {} - - void SetUp() override - { - FusedCosineNNTest::SetUp(); - int m = this->params.m; - min1.resize(m, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - void TearDown() override { FusedCosineNNTest::TearDown(); } - - protected: - raft::resources handle; - cudaStream_t stream; - - rmm::device_uvector> min1; - - static const int NumRepeats = 3; - - void generateGoldenResult() override {} -}; - -typedef FusedCosineNNDetTest FusedCosineNNDetTestF; -TEST_P(FusedCosineNNDetTestF, Result) -{ - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - cudaMemsetAsync(min1.data(), 0, sizeof(*min.data()) * params.m, stream); - } -} -INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestF, ::testing::ValuesIn(inputsf)); - -typedef FusedCosineNNDetTest FusedCosineNNDetTestD; -TEST_P(FusedCosineNNDetTestD, Result) -{ - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } -} -INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestD, ::testing::ValuesIn(inputsd)); - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu deleted file mode 100644 index 6fd8f15808..0000000000 --- a/cpp/test/distance/fused_l2_nn.cu +++ /dev/null @@ -1,437 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace distance { - -template -struct RaftKVPMinReduce { - typedef raft::KeyValuePair KVP; - - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - -}; // KVPMinReduce - -template -RAFT_KERNEL naiveKernel(raft::KeyValuePair* min, - DataT* x, - DataT* y, - int m, - int n, - int k, - int* workspace, - DataT maxVal) -{ - int midx = threadIdx.y + blockIdx.y * blockDim.y; - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - DataT acc = DataT(0); - for (int i = 0; i < k; ++i) { - int xidx = i + midx * k; - int yidx = i + nidx * k; - auto diff = midx >= m || nidx >= n ? DataT(0) : x[xidx] - y[yidx]; - acc += diff * diff; - } - - if (Sqrt) { acc = raft::sqrt(acc); } - ReduceOpT redOp; - typedef cub::WarpReduce> WarpReduce; - __shared__ typename WarpReduce::TempStorage temp[NWARPS]; - int warpId = threadIdx.x / raft::WarpSize; - raft::KeyValuePair tmp; - tmp.key = nidx; - tmp.value = midx >= m || nidx >= n ? maxVal : acc; - tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); - if (threadIdx.x % raft::WarpSize == 0 && midx < m) { - while (atomicCAS(workspace + midx, 0, 1) == 1) - ; - __threadfence(); - redOp(midx, min + midx, tmp); - __threadfence(); - atomicCAS(workspace + midx, 1, 0); - } -} - -template -void naive(raft::KeyValuePair* min, - DataT* x, - DataT* y, - int m, - int n, - int k, - int* workspace, - cudaStream_t stream) -{ - static const dim3 TPB(32, 16, 1); - dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - auto blks = raft::ceildiv(m, 256); - MinAndDistanceReduceOp op; - detail::initKernel, int> - <<>>(min, m, std::numeric_limits::max(), op); - RAFT_CUDA_TRY(cudaGetLastError()); - naiveKernel, 16> - <<>>(min, x, y, m, n, k, workspace, std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -struct Inputs { - DataT tolerance; - int m, n, k; - unsigned long long int seed; - - friend std::ostream& operator<<(std::ostream& os, const Inputs& p) - { - return os << "m: " << p.m - << ", " - "n: " - << p.n - << ", " - "k: " - << p.k - << ", " - "seed: " - << p.seed - << ", " - "tol: " - << p.tolerance; - } -}; - -template -class FusedL2NNTest : public ::testing::TestWithParam> { - public: - FusedL2NNTest() - : params(::testing::TestWithParam>::GetParam()), - stream(resource::get_cuda_stream(handle)), - x(params.m * params.k, stream), - y(params.n * params.k, stream), - xn(params.m, stream), - yn(params.n, stream), - min(params.m, stream), - min_ref(params.m, stream), - workspace(params.m * sizeof(int), stream) - { - } - - protected: - void SetUp() override - { - raft::random::RngState r(params.seed); - int m = params.m; - int n = params.n; - int k = params.k; - uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); - uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); - generateGoldenResult(); - raft::linalg::rowNorm(xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm(yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - protected: - raft::resources handle; - cudaStream_t stream; - Inputs params; - rmm::device_uvector x; - rmm::device_uvector y; - rmm::device_uvector xn; - rmm::device_uvector yn; - rmm::device_uvector> min; - rmm::device_uvector> min_ref; - rmm::device_uvector workspace; - - virtual void generateGoldenResult() - { - int m = params.m; - int n = params.n; - int k = params.k; - naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); - } - - void runTest(raft::KeyValuePair* out) - { - int m = params.m; - int n = params.n; - int k = params.k; - - const bool init_out_buffer = true; - fusedL2NNMinReduce, int>(out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - Sqrt, - init_out_buffer, - stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - } -}; - -template -struct CompareApproxAbsKVP { - typedef typename raft::KeyValuePair KVP; - CompareApproxAbsKVP(T eps_) : eps(eps_) {} - bool operator()(const KVP& a, const KVP& b) const - { - T diff = std::abs(std::abs(a.value) - std::abs(b.value)); - T m = std::max(std::abs(a.value), std::abs(b.value)); - T ratio = m >= eps ? diff / m : diff; - return (ratio <= eps); - } - - private: - T eps; -}; - -template -struct CompareExactKVP { - typedef typename raft::KeyValuePair KVP; - bool operator()(const KVP& a, const KVP& b) const - { - if (a.value != b.value) return false; - return true; - } -}; - -template -::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, - const raft::KeyValuePair* actual, - size_t size, - L eq_compare, - cudaStream_t stream = 0) -{ - typedef typename raft::KeyValuePair KVP; - std::shared_ptr exp_h(new KVP[size]); - std::shared_ptr act_h(new KVP[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto exp = exp_h.get()[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - return ::testing::AssertionFailure() - << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," - << exp.value << " @" << i; - } - } - return ::testing::AssertionSuccess(); -} - -const std::vector> inputsf = { - {0.001f, 32, 32, 32, 1234ULL}, - {0.001f, 32, 64, 32, 1234ULL}, - {0.001f, 64, 32, 32, 1234ULL}, - {0.001f, 64, 64, 32, 1234ULL}, - {0.001f, 128, 32, 32, 1234ULL}, - {0.001f, 128, 64, 32, 1234ULL}, - {0.001f, 128, 128, 64, 1234ULL}, - {0.001f, 64, 128, 128, 1234ULL}, - - {0.001f, 32, 32, 34, 1234ULL}, - {0.001f, 32, 64, 34, 1234ULL}, - {0.001f, 64, 32, 34, 1234ULL}, - {0.001f, 64, 64, 34, 1234ULL}, - {0.001f, 128, 32, 34, 1234ULL}, - {0.001f, 128, 64, 34, 1234ULL}, - {0.001f, 128, 128, 66, 1234ULL}, - {0.001f, 64, 128, 130, 1234ULL}, - - {0.001f, 32, 32, 33, 1234ULL}, - {0.001f, 32, 64, 33, 1234ULL}, - {0.001f, 64, 32, 33, 1234ULL}, - {0.001f, 64, 64, 33, 1234ULL}, - {0.001f, 128, 32, 33, 1234ULL}, - {0.001f, 128, 64, 33, 1234ULL}, - {0.001f, 128, 128, 65, 1234ULL}, - {0.001f, 64, 128, 129, 1234ULL}, - {0.006f, 1805, 134, 2, 1234ULL}, - {0.006f, 8192, 1024, 64, 1234ULL}, - {0.006f, 8192, 1025, 64, 1234ULL}, - - // Repeat with smaller values of k - {0.006f, 32, 32, 1, 1234ULL}, - {0.001f, 32, 64, 2, 1234ULL}, - {0.001f, 64, 32, 3, 1234ULL}, - {0.001f, 64, 64, 4, 1234ULL}, - {0.001f, 128, 32, 5, 1234ULL}, - {0.001f, 128, 64, 6, 1234ULL}, - {0.001f, 128, 128, 7, 1234ULL}, - {0.001f, 64, 128, 8, 1234ULL}, - - {0.001f, 32, 32, 9, 1234ULL}, - {0.001f, 32, 64, 10, 1234ULL}, - {0.001f, 64, 32, 11, 1234ULL}, - {0.001f, 64, 64, 12, 1234ULL}, - {0.001f, 128, 32, 13, 1234ULL}, - {0.001f, 128, 64, 14, 1234ULL}, - {0.001f, 128, 128, 15, 1234ULL}, - {0.001f, 64, 128, 16, 1234ULL}, - - {0.001f, 32, 32, 17, 1234ULL}, - {0.001f, 32, 64, 18, 1234ULL}, - {0.001f, 64, 32, 19, 1234ULL}, - {0.001f, 64, 64, 20, 1234ULL}, - {0.001f, 128, 32, 21, 1234ULL}, - {0.001f, 128, 64, 22, 1234ULL}, - {0.001f, 128, 128, 23, 1234ULL}, - {0.00001, 64, 128, 24, 1234ULL}, - {0.001f, 1805, 134, 25, 1234ULL}, - {0.006f, 8192, 1024, 25, 1234ULL}, - {0.006f, 8192, 1024, 66, 1234ULL}, -}; -typedef FusedL2NNTest FusedL2NNTestF_Sq; -TEST_P(FusedL2NNTestF_Sq, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestF_Sq, ::testing::ValuesIn(inputsf)); -typedef FusedL2NNTest FusedL2NNTestF_Sqrt; -TEST_P(FusedL2NNTestF_Sqrt, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestF_Sqrt, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, - {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, - {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, - {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, - - {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, - {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, - {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, - {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, - - {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, - {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, - {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, - {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, - - {0.00001, 1805, 134, 2, 1234ULL}, //{0.00001, 8192, 1024, 25, 1234ULL}, -}; -typedef FusedL2NNTest FusedL2NNTestD_Sq; -TEST_P(FusedL2NNTestD_Sq, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestD_Sq, ::testing::ValuesIn(inputsd)); -typedef FusedL2NNTest FusedL2NNTestD_Sqrt; -TEST_P(FusedL2NNTestD_Sqrt, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestD_Sqrt, ::testing::ValuesIn(inputsd)); - -/// This is to test output determinism of the prim -template -class FusedL2NNDetTest : public FusedL2NNTest { - public: - FusedL2NNDetTest() : stream(resource::get_cuda_stream(handle)), min1(0, stream) {} - - void SetUp() override - { - FusedL2NNTest::SetUp(); - int m = this->params.m; - min1.resize(m, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - void TearDown() override { FusedL2NNTest::TearDown(); } - - protected: - raft::resources handle; - cudaStream_t stream; - - rmm::device_uvector> min1; - - static const int NumRepeats = 3; - - void generateGoldenResult() override {} -}; - -typedef FusedL2NNDetTest FusedL2NNDetTestF_Sq; -TEST_P(FusedL2NNDetTestF_Sq, Result) -{ - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } -} -INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestF_Sq, ::testing::ValuesIn(inputsf)); -typedef FusedL2NNDetTest FusedL2NNDetTestF_Sqrt; -TEST_P(FusedL2NNDetTestF_Sqrt, Result) -{ - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } -} -INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestF_Sqrt, ::testing::ValuesIn(inputsf)); - -typedef FusedL2NNDetTest FusedL2NNDetTestD_Sq; -TEST_P(FusedL2NNDetTestD_Sq, Result) -{ - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } -} -INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestD_Sq, ::testing::ValuesIn(inputsd)); -typedef FusedL2NNDetTest FusedL2NNDetTestD_Sqrt; -TEST_P(FusedL2NNDetTestD_Sqrt, Result) -{ - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } -} -INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); - -} // end namespace distance -} // end namespace raft diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu deleted file mode 100644 index e911a25ff1..0000000000 --- a/cpp/test/distance/gram.cu +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "gram_base.cuh" - -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -namespace raft::distance::kernels { - -struct GramMatrixInputs { - int n1; // feature vectors in matrix 1 - int n2; // featuer vectors in matrix 2 - int n_cols; // number of elements in a feature vector - bool is_row_major; - KernelParams kernel; - int ld1; - int ld2; - int ld_out; - // We will generate random input using the dimensions given here. - // The reference output is calculated by a custom kernel. -}; - -std::ostream& operator<<(std::ostream& os, const GramMatrixInputs& p) -{ - std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; - os << "/" << p.n1 << "x" << p.n2 << "x" << p.n_cols << "/" - << (p.is_row_major ? "RowMajor/" : "ColMajor/") << kernel_names[p.kernel.kernel] << "/ld_" - << p.ld1 << "x" << p.ld2 << "x" << p.ld_out; - return os; -} - -const std::vector inputs = { - {42, 137, 2, false, {KernelType::LINEAR}}, - {42, 137, 2, true, {KernelType::LINEAR}}, - {42, 137, 2, false, {KernelType::LINEAR}, 64, 179, 181}, - {42, 137, 2, true, {KernelType::LINEAR}, 64, 179, 181}, - {137, 42, 2, false, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}}, - {137, 42, 2, true, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}}, - {137, 42, 2, false, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}, 159, 73, 144}, - {137, 42, 2, true, {KernelType::POLYNOMIAL, 2, 0.5, 2.4}, 159, 73, 144}, - {42, 137, 2, false, {KernelType::TANH, 0, 0.5, 2.4}}, - {42, 137, 2, true, {KernelType::TANH, 0, 0.5, 2.4}}, - {42, 137, 2, false, {KernelType::TANH, 0, 0.5, 2.4}, 64, 155, 49}, - {42, 137, 2, true, {KernelType::TANH, 0, 0.5, 2.4}, 64, 155, 143}, - {3, 4, 2, false, {KernelType::RBF, 0, 0.5}}, - {42, 137, 2, false, {KernelType::RBF, 0, 0.5}}, - {42, 137, 2, true, {KernelType::RBF, 0, 0.5}}, - // Distance kernel does not support LD parameter yet. - //{42, 137, 2, false, {KernelType::RBF, 0, 0.5}, 64, 155, 49}, - // {42, 137, 2, true, {KernelType::RBF, 0, 0.5}, 64, 155, 143}, -}; - -template -class GramMatrixTest : public ::testing::TestWithParam { - protected: - GramMatrixTest() - : params(GetParam()), - handle(), - x1(0, resource::get_cuda_stream(handle)), - x2(0, resource::get_cuda_stream(handle)), - gram(0, resource::get_cuda_stream(handle)), - gram_host(0) - { - auto stream = resource::get_cuda_stream(handle); - - if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; } - if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; } - if (params.ld_out == 0) { params.ld_out = params.is_row_major ? params.n2 : params.n1; } - // Derive the size of the output from the offset of the last element. - size_t size = get_offset(params.n1 - 1, params.n_cols - 1, params.ld1, params.is_row_major) + 1; - x1.resize(size, stream); - size = get_offset(params.n2 - 1, params.n_cols - 1, params.ld2, params.is_row_major) + 1; - x2.resize(size, stream); - size = get_offset(params.n1 - 1, params.n2 - 1, params.ld_out, params.is_row_major) + 1; - - gram.resize(size, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(gram.data(), 0, gram.size() * sizeof(math_t), stream)); - gram_host.resize(gram.size()); - std::fill(gram_host.begin(), gram_host.end(), 0); - - raft::random::RngState rng(42137ULL); - raft::random::uniform(handle, rng, x1.data(), x1.size(), math_t(0), math_t(1)); - raft::random::uniform(handle, rng, x2.data(), x2.size(), math_t(0), math_t(1)); - } - - ~GramMatrixTest() override {} - - void runTest() - { - std::unique_ptr> kernel = - std::unique_ptr>(KernelFactory::create(params.kernel)); - - auto x1_span = - params.is_row_major - ? raft::make_device_strided_matrix_view( - x1.data(), params.n1, params.n_cols, params.ld1) - : raft::make_device_strided_matrix_view( - x1.data(), params.n1, params.n_cols, params.ld1); - auto x2_span = - params.is_row_major - ? raft::make_device_strided_matrix_view( - x2.data(), params.n2, params.n_cols, params.ld2) - : raft::make_device_strided_matrix_view( - x2.data(), params.n2, params.n_cols, params.ld2); - auto out_span = - params.is_row_major - ? raft::make_device_strided_matrix_view( - gram.data(), params.n1, params.n2, params.ld_out) - : raft::make_device_strided_matrix_view( - gram.data(), params.n1, params.n2, params.ld_out); - - (*kernel)(handle, x1_span, x2_span, out_span); - - auto stream = resource::get_cuda_stream(handle); - naiveGramMatrixKernel(params.n1, - params.n2, - params.n_cols, - x1, - x2, - gram_host.data(), - params.ld1, - params.ld2, - params.ld_out, - params.is_row_major, - params.kernel, - stream, - handle); - - ASSERT_TRUE(raft::devArrMatchHost( - gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f), stream)); - } - - GramMatrixInputs params; - raft::resources handle; - - rmm::device_uvector x1; - rmm::device_uvector x2; - rmm::device_uvector gram; - - std::vector gram_host; -}; - -typedef GramMatrixTest GramMatrixTestFloat; -typedef GramMatrixTest GramMatrixTestDouble; - -TEST_P(GramMatrixTestFloat, Gram) { runTest(); } - -INSTANTIATE_TEST_SUITE_P(GramMatrixTests, GramMatrixTestFloat, ::testing::ValuesIn(inputs)); -}; // end namespace raft::distance::kernels diff --git a/cpp/test/distance/gram_base.cuh b/cpp/test/distance/gram_base.cuh deleted file mode 100644 index 170bcb76c1..0000000000 --- a/cpp/test/distance/gram_base.cuh +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include - -#include -#include - -namespace raft { -namespace distance { -namespace kernels { - -// Get the offset of element [i,k]. -HDI int get_offset(int i, int k, int ld, bool is_row_major) -{ - return is_row_major ? i * ld + k : i + k * ld; -} - -// Calculate the Gram matrix on the host. -template -void naiveGramMatrixKernel(int n1, - int n2, - int n_cols, - const rmm::device_uvector& x1, - const rmm::device_uvector& x2, - math_t* gram_host, - int ld1, - int ld2, - int ld_out, - bool is_row_major, - KernelParams kernel, - cudaStream_t stream, - const raft::resources& handle) -{ - std::vector x1_host(x1.size()); - raft::update_host(x1_host.data(), x1.data(), x1.size(), stream); - std::vector x2_host(x2.size()); - raft::update_host(x2_host.data(), x2.data(), x2.size(), stream); - resource::sync_stream(handle, stream); - - for (int i = 0; i < n1; i++) { - for (int j = 0; j < n2; j++) { - float d = 0; - for (int k = 0; k < n_cols; k++) { - if (kernel.kernel == KernelType::RBF) { - math_t diff = x1_host[get_offset(i, k, ld1, is_row_major)] - - x2_host[get_offset(j, k, ld2, is_row_major)]; - d += diff * diff; - } else { - d += x1_host[get_offset(i, k, ld1, is_row_major)] * - x2_host[get_offset(j, k, ld2, is_row_major)]; - } - } - int idx = get_offset(i, j, ld_out, is_row_major); - math_t v = 0; - switch (kernel.kernel) { - case (KernelType::LINEAR): gram_host[idx] = d; break; - case (KernelType::POLYNOMIAL): - v = kernel.gamma * d + kernel.coef0; - gram_host[idx] = std::pow(v, kernel.degree); - break; - case (KernelType::TANH): gram_host[idx] = std::tanh(kernel.gamma * d + kernel.coef0); break; - case (KernelType::RBF): gram_host[idx] = exp(-kernel.gamma * d); break; - } - } - } -} - -} // namespace kernels -} // namespace distance -} // namespace raft diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu deleted file mode 100644 index 34fa07c45b..0000000000 --- a/cpp/test/distance/masked_nn.cu +++ /dev/null @@ -1,438 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace raft::distance::masked_nn { - -// The adjacency pattern determines what distances get computed. -enum AdjacencyPattern { - checkerboard = 0, // adjacency matrix looks like a checkerboard (half the distances are computed) - checkerboard_4 = 1, // checkerboard with tiles of size 4x4 - checkerboard_64 = 2, // checkerboard with tiles of size 64x64 - all_true = 3, // no distance computations can be skipped - all_false = 4 // all distance computations can be skipped -}; - -// Kernels: -// - init_adj: to initialize the adjacency kernel with a specific adjacency pattern -// - referenceKernel: to produce the ground-truth output - -RAFT_KERNEL init_adj(AdjacencyPattern pattern, - int n, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs) -{ - int m = adj.extent(0); - int num_groups = adj.extent(1); - - for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; - idx_m += blockDim.y * gridDim.y) { - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; - idx_g += blockDim.x * gridDim.x) { - switch (pattern) { - case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; - case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; - case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; - case all_true: adj(idx_m, idx_g) = true; break; - case all_false: adj(idx_m, idx_g) = false; break; - default: assert(false && "unknown pattern"); - } - } - } - // Each group is of size n / num_groups. - // - // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive - // scan of the group lengths) - // - // - The first group always starts at index zero, so we do not store it. - // - // - The group_idxs[num_groups - 1] should always equal n. - - if (blockIdx.y == 0 && threadIdx.y == 0) { - const int g_stride = blockDim.x * gridDim.x; - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { - group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); - } - group_idxs(num_groups - 1) = n; - } -} - -template -__launch_bounds__(32 * NWARPS, 2) RAFT_KERNEL referenceKernel(raft::KeyValuePair* min, - DataT* x, - DataT* y, - bool* adj, - int* group_idxs, - int m, - int n, - int k, - int num_groups, - bool sqrt, - int* workspace, - DataT maxVal) -{ - const int m_stride = blockDim.y * gridDim.y; - const int m_offset = threadIdx.y + blockIdx.y * blockDim.y; - const int n_stride = blockDim.x * gridDim.x; - const int n_offset = threadIdx.x + blockIdx.x * blockDim.x; - - for (int m_grid = 0; m_grid < m; m_grid += m_stride) { - for (int n_grid = 0; n_grid < n; n_grid += n_stride) { - int midx = m_grid + m_offset; - int nidx = n_grid + n_offset; - - // Do a reverse linear search to determine the group index. - int group_idx = 0; - for (int i = num_groups; 0 <= i; --i) { - if (nidx < group_idxs[i]) { group_idx = i; } - } - const bool include_dist = adj[midx * num_groups + group_idx] && midx < m && nidx < n; - - // Compute L2 metric. - DataT acc = DataT(0); - for (int i = 0; i < k; ++i) { - int xidx = i + midx * k; - int yidx = i + nidx * k; - auto diff = x[xidx] - y[yidx]; - acc += diff * diff; - } - if (sqrt) { acc = raft::sqrt(acc); } - ReduceOpT redOp; - typedef cub::WarpReduce> WarpReduce; - __shared__ typename WarpReduce::TempStorage temp[NWARPS]; - int warpId = threadIdx.x / raft::WarpSize; - raft::KeyValuePair tmp; - tmp.key = include_dist ? nidx : -1; - tmp.value = include_dist ? acc : maxVal; - tmp = WarpReduce(temp[warpId]).Reduce(tmp, raft::distance::KVPMinReduce{}); - if (threadIdx.x % raft::WarpSize == 0 && midx < m) { - while (atomicCAS(workspace + midx, 0, 1) == 1) - ; - __threadfence(); - redOp(midx, min + midx, tmp); - __threadfence(); - atomicCAS(workspace + midx, 1, 0); - } - __syncthreads(); - } - } -} - -// Structs -// - Params: holds parameters for test case -// - Inputs: holds the inputs to the functions under test (x, y, adj, group_idxs). Is generated from -// the inputs. -struct Params { - double tolerance; - int m, n, k, num_groups; - bool sqrt; - unsigned long long int seed; - AdjacencyPattern pattern; -}; - -inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& -{ - os << "m: " << p.m << ", n: " << p.n << ", k: " << p.k << ", num_groups: " << p.num_groups - << ", sqrt: " << p.sqrt << ", seed: " << p.seed << ", tol: " << p.tolerance; - return os; -} - -template -struct Inputs { - using IdxT = int; - - raft::device_matrix x, y; - raft::device_matrix adj; - raft::device_vector group_idxs; - - Inputs(const raft::handle_t& handle, const Params& p) - : x{raft::make_device_matrix(handle, p.m, p.k)}, - y{raft::make_device_matrix(handle, p.n, p.k)}, - adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, - group_idxs{raft::make_device_vector(handle, p.num_groups)} - { - // Initialize x, y - raft::random::RngState r(p.seed); - uniform(handle, r, x.data_handle(), p.m * p.k, DataT(-1.0), DataT(1.0)); - uniform(handle, r, y.data_handle(), p.n * p.k, DataT(-1.0), DataT(1.0)); - - // Initialize adj, group_idxs. - dim3 block(32, 32); - dim3 grid(10, 10); - init_adj<<>>( - p.pattern, p.n, adj.view(), group_idxs.view()); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -template > -auto reference(const raft::handle_t& handle, Inputs inp, const Params& p) - -> raft::device_vector -{ - int m = inp.x.extent(0); - int n = inp.y.extent(0); - int k = inp.x.extent(1); - int num_groups = inp.group_idxs.extent(0); - - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { - return raft::make_device_vector(handle, 0); - } - - // Initialize workspace - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector workspace(p.m * sizeof(int), stream); - RAFT_CUDA_TRY(cudaMemsetAsync(workspace.data(), 0, sizeof(int) * m, stream)); - - // Initialize output - auto out = raft::make_device_vector(handle, m); - auto blks = raft::ceildiv(m, 256); - MinAndDistanceReduceOp op; - raft::distance::detail::initKernel, int> - <<>>(out.data_handle(), m, std::numeric_limits::max(), op); - RAFT_CUDA_TRY(cudaGetLastError()); - - // Launch reference kernel - const int nwarps = 16; - static const dim3 TPB(32, nwarps, 1); - dim3 nblks(1, 200, 1); - referenceKernel - <<>>(out.data_handle(), - inp.x.data_handle(), - inp.y.data_handle(), - inp.adj.data_handle(), - inp.group_idxs.data_handle(), - m, - n, - k, - num_groups, - p.sqrt, - (int*)workspace.data(), - std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaGetLastError()); - - return out; -} - -template > -auto run_masked_nn(const raft::handle_t& handle, Inputs inp, const Params& p) - -> raft::device_vector -{ - // Compute norms: - auto x_norm = raft::make_device_vector(handle, p.m); - auto y_norm = raft::make_device_vector(handle, p.n); - - raft::linalg::norm(handle, - std::as_const(inp.x).view(), - x_norm.view(), - raft::linalg::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - raft::linalg::norm(handle, - std::as_const(inp.y).view(), - y_norm.view(), - raft::linalg::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - - // Create parameters for masked_l2_nn - using IdxT = int; - using RedOpT = MinAndDistanceReduceOp; - using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = raft::distance::masked_l2_nn_params; - - bool init_out = true; - ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, p.sqrt, init_out}; - - // Create output - auto out = raft::make_device_vector(handle, p.m); - - // Launch kernel - raft::distance::masked_l2_nn(handle, - masked_l2_params, - inp.x.view(), - inp.y.view(), - x_norm.view(), - y_norm.view(), - inp.adj.view(), - inp.group_idxs.view(), - out.view()); - - resource::sync_stream(handle); - - return out; -} - -template -struct CompareApproxAbsKVP { - typedef typename raft::KeyValuePair KVP; - CompareApproxAbsKVP(T eps_) : eps(eps_) {} - bool operator()(const KVP& a, const KVP& b) const - { - T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); - T m = std::max(raft::abs(a.value), raft::abs(b.value)); - T ratio = m >= eps ? diff / m : diff; - return (ratio <= eps); - } - - private: - T eps; -}; - -template -::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, - const raft::KeyValuePair* actual, - size_t size, - L eq_compare, - cudaStream_t stream = 0) -{ - typedef typename raft::KeyValuePair KVP; - std::shared_ptr exp_h(new KVP[size]); - std::shared_ptr act_h(new KVP[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (size_t i(0); i < size; ++i) { - auto exp = exp_h.get()[i]; - auto act = act_h.get()[i]; - if (!eq_compare(exp, act)) { - return ::testing::AssertionFailure() - << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," - << exp.value << " @" << i; - } - } - return ::testing::AssertionSuccess(); -} - -inline auto gen_params() -> std::vector -{ - // Regular powers of two - auto regular = raft::util::itertools::product({0.001f}, // tolerance - {32, 64, 512}, // m - {32, 64, 512}, // n - {8, 32}, // k - {2, 32}, // num_groups - {true, false}, // sqrt - {1234ULL}, // seed - {AdjacencyPattern::all_true, - AdjacencyPattern::checkerboard, - AdjacencyPattern::checkerboard_64, - AdjacencyPattern::all_false}); - - // Irregular sizes to check tiling and bounds checking - auto irregular = raft::util::itertools::product({0.001f}, // tolerance - {511, 512, 513}, // m - {127, 128, 129}, // n - {5}, // k - {3, 9}, // num_groups - {true, false}, // sqrt - {1234ULL}, // seed - {AdjacencyPattern::all_true, - AdjacencyPattern::checkerboard, - AdjacencyPattern::checkerboard_64}); - - regular.insert(regular.end(), irregular.begin(), irregular.end()); - - return regular; -} - -class MaskedL2NNTest : public ::testing::TestWithParam { - // Empty. -}; - -// -TEST_P(MaskedL2NNTest, ReferenceCheckFloat) -{ - using DataT = float; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out_reference = reference(handle, inputs, p); - auto out_fast = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out_reference.data_handle(), - out_fast.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - resource::get_cuda_stream(handle))); -} - -// This test checks whether running the masked_l2_nn twice returns the same -// output. -TEST_P(MaskedL2NNTest, DeterminismCheck) -{ - using DataT = float; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out1 = run_masked_nn(handle, inputs, p); - auto out2 = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out1.data_handle(), - out2.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - resource::get_cuda_stream(handle))); -} - -TEST_P(MaskedL2NNTest, ReferenceCheckDouble) -{ - using DataT = double; - - // Get parameters; create handle and input data. - Params p = GetParam(); - raft::handle_t handle{}; - Inputs inputs{handle, p}; - - // Calculate reference and test output - auto out_reference = reference(handle, inputs, p); - auto out_fast = run_masked_nn(handle, inputs, p); - - // Check for differences. - ASSERT_TRUE(devArrMatch(out_reference.data_handle(), - out_fast.data_handle(), - p.m, - CompareApproxAbsKVP(p.tolerance), - resource::get_cuda_stream(handle))); -} - -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTest, ::testing::ValuesIn(gen_params())); - -} // end namespace raft::distance::masked_nn diff --git a/cpp/test/distance/masked_nn_compress_to_bits.cu b/cpp/test/distance/masked_nn_compress_to_bits.cu deleted file mode 100644 index 2512af5e4f..0000000000 --- a/cpp/test/distance/masked_nn_compress_to_bits.cu +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "../test_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include - -namespace raft::distance::masked_nn::compress_to_bits { - -/** - * @brief Transpose and decompress 2D bitfield to boolean matrix - * - * Inverse operation of compress_to_bits - * - * @tparam T - * - * @parameter[in] in An `m x n` bitfield matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `(m * bits_per_elem) x n` boolean matrix. - */ -template ::value>> -RAFT_KERNEL decompress_bits_kernel(const T* in, int in_rows, int in_cols, bool* out) -{ - constexpr int bits_per_element = 8 * sizeof(T); - - const size_t i = threadIdx.y + blockIdx.y * blockDim.y; - const size_t j = threadIdx.x + blockIdx.x * blockDim.x; - - if (in_rows <= i || in_cols <= j) { return; } - - const size_t out_rows = in_rows * bits_per_element; - const size_t out_cols = in_cols; - const size_t out_i = i * bits_per_element; - const size_t out_j = j; - - if (out_rows <= out_i && out_cols <= out_j) { return; } - - T bitfield = in[i * in_cols + j]; - for (int bitpos = 0; bitpos < bits_per_element; ++bitpos) { - bool bit = ((T(1) << bitpos) & bitfield) != 0; - out[(out_i + bitpos) * out_cols + out_j] = bit; - } -} - -/** - * @brief Transpose and decompress 2D bitfield to boolean matrix - * - * Inverse operation of compress_to_bits - * - * @tparam T - * - * @parameter[in] in An `m x n` bitfield matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. - */ -template ::value>> -void decompress_bits(const raft::handle_t& handle, const T* in, int in_rows, int in_cols, bool* out) -{ - auto stream = resource::get_cuda_stream(handle); - dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); - dim3 block(32, 32); - decompress_bits_kernel<<>>(in, in_rows, in_cols, out); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -// Params holds parameters for test case -struct Params { - int m, n; -}; - -inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& -{ - return os << "m: " << p.m << ", n: " << p.n; -} - -// Check that the following holds -// -// decompress(compress(x)) == x -// -// for 2D boolean matrices x. -template -void check_invertible(const Params& p) -{ - using raft::distance::detail::compress_to_bits; - constexpr int bits_per_elem = sizeof(T) * 8; - - // Make m and n that are safe to ceildiv. - int m = raft::round_up_safe(p.m, bits_per_elem); - int n = p.n; - - // Generate random input - raft::handle_t handle{}; - raft::random::RngState r(1ULL); - auto in = raft::make_device_matrix(handle, m, n); - raft::random::bernoulli(handle, r, in.data_handle(), m * n, 0.5f); - - int tmp_m = raft::ceildiv(m, bits_per_elem); - int out_m = tmp_m * bits_per_elem; - - auto tmp = raft::make_device_matrix(handle, tmp_m, n); - auto out = raft::make_device_matrix(handle, out_m, n); - - resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaGetLastError()); - - ASSERT_EQ(in.extent(0), out.extent(0)) << "M does not match"; - ASSERT_EQ(in.extent(1), out.extent(1)) << "N does not match"; - - compress_to_bits(handle, in.view(), tmp.view()); - resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaGetLastError()); - - decompress_bits(handle, tmp.data_handle(), tmp.extent(0), tmp.extent(1), out.data_handle()); - resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaGetLastError()); - - // Check for differences. - ASSERT_TRUE(raft::devArrMatch(in.data_handle(), - out.data_handle(), - in.extent(0) * in.extent(1), - raft::Compare(), - resource::get_cuda_stream(handle))); - resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -void check_all_true(const Params& p) -{ - using raft::distance::detail::compress_to_bits; - using T = uint64_t; - constexpr int bits_per_elem = sizeof(T) * 8; - - // Make m and n that are safe to ceildiv. - int m = raft::round_up_safe(p.m, bits_per_elem); - int n = p.n; - - raft::handle_t handle{}; - raft::random::RngState r(1ULL); - auto in = raft::make_device_matrix(handle, m, n); - raft::matrix::fill(handle, in.view(), true); - - int tmp_m = raft::ceildiv(m, bits_per_elem); - auto tmp = raft::make_device_matrix(handle, tmp_m, n); - resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaGetLastError()); - - compress_to_bits(handle, in.view(), tmp.view()); - resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaGetLastError()); - - auto expected = raft::make_device_matrix(handle, tmp_m, n); - raft::matrix::fill(handle, expected.view(), ~T(0)); - - // Check for differences. - ASSERT_TRUE(raft::devArrMatch(expected.data_handle(), - tmp.data_handle(), - tmp.extent(0) * tmp.extent(1), - raft::Compare(), - resource::get_cuda_stream(handle))); - resource::sync_stream(handle); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -class CompressToBitsTest : public ::testing::TestWithParam { - // Empty. -}; - -TEST_P(CompressToBitsTest, CheckTrue64) { check_all_true(GetParam()); } - -TEST_P(CompressToBitsTest, CheckInvertible64) -{ - using T = uint64_t; - check_invertible(GetParam()); -} - -TEST_P(CompressToBitsTest, CheckInvertible32) -{ - using T = uint32_t; - check_invertible(GetParam()); -} - -std::vector params = raft::util::itertools::product( - {1, 3, 32, 33, 63, 64, 65, 128, 10013}, {1, 3, 32, 33, 63, 64, 65, 13001}); - -INSTANTIATE_TEST_CASE_P(CompressToBits, CompressToBitsTest, ::testing::ValuesIn(params)); - -} // namespace raft::distance::masked_nn::compress_to_bits \ No newline at end of file diff --git a/cpp/test/neighbors/ann_brute_force.cuh b/cpp/test/neighbors/ann_brute_force.cuh deleted file mode 100644 index 6370c5ee83..0000000000 --- a/cpp/test/neighbors/ann_brute_force.cuh +++ /dev/null @@ -1,253 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../test_utils.cuh" -#include "ann_utils.cuh" -#include "knn_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::brute_force { - -template -struct AnnBruteForceInputs { - IdxT num_queries; - IdxT num_db_vecs; - IdxT dim; - IdxT k; - raft::distance::DistanceType metric; - bool host_dataset; -}; - -template -::std::ostream& operator<<(::std::ostream& os, const AnnBruteForceInputs& p) -{ - os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " - << static_cast(p.metric) << ", " << p.host_dataset << '}' << std::endl; - return os; -} - -template -class AnnBruteForceTest : public ::testing::TestWithParam> { - public: - AnnBruteForceTest() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam>::GetParam()), - database(0, stream_), - search_queries(0, stream_) - { - } - - void testBruteForce() - { - size_t queries_size = ps.num_queries * ps.k; - - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.metric); - resource::sync_stream(handle_); - - { - // Require exact result for brute force - rmm::device_uvector distances_bruteforce_dev(queries_size, stream_); - rmm::device_uvector indices_bruteforce_dev(queries_size, stream_); - brute_force::index_params index_params{}; - brute_force::search_params search_params{}; - index_params.metric = ps.metric; - index_params.metric_arg = 0; - - auto device_dataset = std::optional>{}; - auto idx = [this, &index_params]() { - if (ps.host_dataset) { - auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); - raft::copy( - host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); - return brute_force::build( - handle_, index_params, raft::make_const_mdspan(host_database.view())); - } else { - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - return brute_force::build(handle_, index_params, database_view); - } - }(); - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.num_queries, ps.dim); - auto indices_out_view = raft::make_device_matrix_view( - indices_bruteforce_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_bruteforce_dev.data(), ps.num_queries, ps.k); - brute_force::serialize(handle_, std::string{"brute_force_index"}, idx); - - auto index_loaded = - brute_force::deserialize(handle_, std::string{"brute_force_index"}); - ASSERT_EQ(idx.size(), index_loaded.size()); - - brute_force::search(handle_, - search_params, - index_loaded, - search_queries_view, - indices_out_view, - dists_out_view); - - resource::sync_stream(handle_); - - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices_naive_dev.data(), - indices_bruteforce_dev.data(), - distances_naive_dev.data(), - distances_bruteforce_dev.data(), - ps.num_queries, - ps.k, - 0.001f, - stream_, - true)); - brute_force::serialize(handle_, std::string{"brute_force_index"}, idx, false); - index_loaded = brute_force::deserialize(handle_, std::string{"brute_force_index"}); - index_loaded.update_dataset(handle_, idx.dataset()); - ASSERT_EQ(idx.size(), index_loaded.size()); - - brute_force::search(handle_, - search_params, - index_loaded, - search_queries_view, - indices_out_view, - dists_out_view); - - resource::sync_stream(handle_); - - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices_naive_dev.data(), - indices_bruteforce_dev.data(), - distances_naive_dev.data(), - distances_bruteforce_dev.data(), - ps.num_queries, - ps.k, - 0.001f, - stream_, - true)); - } - } - - void SetUp() override - { - database.resize(ps.num_db_vecs * ps.dim, stream_); - search_queries.resize(ps.num_queries * ps.dim, stream_); - - raft::random::RngState r(1234ULL); - if constexpr (std::is_same{}) { - raft::random::uniform( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); - raft::random::uniform( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); - raft::random::uniformInt( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); - } - resource::sync_stream(handle_); - } - - void TearDown() override - { - resource::sync_stream(handle_); - database.resize(0, stream_); - search_queries.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - AnnBruteForceInputs ps; - rmm::device_uvector database; - rmm::device_uvector search_queries; -}; - -const std::vector> inputs = { - // test various dims (aligned and not aligned to vector sizes) - {1000, 10000, 1, 16, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 2, 16, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 3, 16, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 4, 16, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 5, 16, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 8, 16, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 5, 16, raft::distance::DistanceType::L2SqrtExpanded, true}, - {1000, 10000, 8, 16, raft::distance::DistanceType::L2SqrtExpanded, true}, - - // test dims that do not fit into kernel shared memory limits - {1000, 10000, 2048, 16, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 2049, 16, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 2050, 16, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 2051, 16, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 2052, 16, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 2053, 16, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 2056, 16, raft::distance::DistanceType::L2Expanded, true}, - - // host input data - {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, - {100, 10000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, - {20, 100000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, - {1000, 100000, 16, 10, raft::distance::DistanceType::L2Expanded, false}, - {10000, 131072, 8, 10, raft::distance::DistanceType::L2Expanded, false}, - - {1000, 10000, 16, 10, raft::distance::DistanceType::InnerProduct, false}}; -} // namespace raft::neighbors::brute_force diff --git a/cpp/test/neighbors/ann_brute_force/test_float.cu b/cpp/test/neighbors/ann_brute_force/test_float.cu deleted file mode 100644 index f157b5f65c..0000000000 --- a/cpp/test/neighbors/ann_brute_force/test_float.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_brute_force.cuh" - -#include - -namespace raft::neighbors::brute_force { - -using AnnBruteForceTest_float = AnnBruteForceTest; -TEST_P(AnnBruteForceTest_float, AnnBruteForce) { this->testBruteForce(); } - -INSTANTIATE_TEST_CASE_P(AnnBruteForceTest, AnnBruteForceTest_float, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::brute_force diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh deleted file mode 100644 index cc787d3e57..0000000000 --- a/cpp/test/neighbors/ann_cagra.cuh +++ /dev/null @@ -1,949 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation - -#include "../test_utils.cuh" -#include "ann_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -#include - -#include -#include -#include -#include - -namespace raft::neighbors::cagra { -namespace { - -/* A filter that excludes all indices below `offset`. */ -struct test_cagra_sample_filter { - static constexpr unsigned offset = 300; - inline _RAFT_HOST_DEVICE auto operator()( - // query index - const uint32_t query_ix, - // the index of the current sample inside the current inverted list - const uint32_t sample_ix) const - { - return sample_ix >= offset; - } -}; - -// For sort_knn_graph test -template -void RandomSuffle(raft::host_matrix_view index) -{ - for (IdxT i = 0; i < index.extent(0); i++) { - uint64_t rand = i; - IdxT* const row_ptr = index.data_handle() + i * index.extent(1); - for (unsigned j = 0; j < index.extent(1); j++) { - // Swap two indices at random - rand = raft::neighbors::cagra::detail::device::xorshift64(rand); - const auto i0 = rand % index.extent(1); - rand = raft::neighbors::cagra::detail::device::xorshift64(rand); - const auto i1 = rand % index.extent(1); - - const auto tmp = row_ptr[i0]; - row_ptr[i0] = row_ptr[i1]; - row_ptr[i1] = tmp; - } - } -} - -template -testing::AssertionResult CheckOrder(raft::host_matrix_view index_test, - raft::host_matrix_view dataset, - raft::distance::DistanceType metric) -{ - for (IdxT i = 0; i < index_test.extent(0); i++) { - const DatatT* const base_vec = dataset.data_handle() + i * dataset.extent(1); - const IdxT* const index_row = index_test.data_handle() + i * index_test.extent(1); - DistanceT prev_distance = metric == raft::distance::DistanceType::L2Expanded - ? 0 - : std::numeric_limits::max(); - for (unsigned j = 0; j < index_test.extent(1) - 1; j++) { - const DatatT* const target_vec = dataset.data_handle() + index_row[j] * dataset.extent(1); - DistanceT distance = 0; - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - for (unsigned l = 0; l < dataset.extent(1); l++) { - const auto diff = - static_cast(target_vec[l]) - static_cast(base_vec[l]); - distance += diff * diff; - } - if (prev_distance > distance) { - return testing::AssertionFailure() - << "Wrong index order (row = " << i << ", neighbor_id = " << j - << "). (distance[neighbor_id-1] = " << prev_distance - << "should be lesser than distance[neighbor_id] = " << distance << ")"; - } - break; - case raft::distance::DistanceType::InnerProduct: - for (unsigned l = 0; l < dataset.extent(1); l++) { - const auto prod = - static_cast(target_vec[l]) * static_cast(base_vec[l]); - distance += prod; - } - if (prev_distance < distance) { - return testing::AssertionFailure() - << "Wrong index order (row = " << i << ", neighbor_id = " << j - << "). (distance[neighbor_id-1] = " << prev_distance - << "should be greater than distance[neighbor_id] = " << distance << ")"; - } - break; - default: - return testing::AssertionFailure() - << "Distance metric " << metric - << " not supported. Only L2Expanded and InnerProduct are supported"; - } - prev_distance = distance; - } - } - return testing::AssertionSuccess(); -} - -template -struct fpi_mapper {}; - -template <> -struct fpi_mapper { - using type = int64_t; - static constexpr int kBitshiftBase = 53; -}; - -template <> -struct fpi_mapper { - using type = int32_t; - static constexpr int kBitshiftBase = 24; -}; - -template <> -struct fpi_mapper { - using type = int16_t; - static constexpr int kBitshiftBase = 11; -}; - -// Generate dataset to ensure no rounding error occurs in the norm computation of any two vectors. -// When testing the CAGRA index sorting function, rounding errors can affect the norm and alter the -// order of the index. To ensure the accuracy of the test, we utilize the dataset. The generation -// method is based on the error-free transformation (EFT) method. -template -RAFT_KERNEL GenerateRoundingErrorFreeDataset_kernel(T* const ptr, - const uint32_t size, - const typename fpi_mapper::type resolution) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= size) { return; } - - const float u32 = *reinterpret_cast::type*>(ptr + tid); - ptr[tid] = u32 / resolution; -} - -template -void GenerateRoundingErrorFreeDataset( - const raft::resources& handle, - T* const ptr, - const uint32_t n_row, - const uint32_t dim, - raft::random::RngState& rng, - const bool diff_flag // true if compute the norm between two vectors -) -{ - using mapper_type = fpi_mapper; - using int_type = typename mapper_type::type; - auto cuda_stream = resource::get_cuda_stream(handle); - const uint32_t size = n_row * dim; - const uint32_t block_size = 256; - const uint32_t grid_size = (size + block_size - 1) / block_size; - - const auto bitshift = (mapper_type::kBitshiftBase - std::log2(dim) - (diff_flag ? 1 : 0)) / 2; - // Skip the test when `dim` is too big for type `T` to allow rounding error-free test. - if (bitshift <= 1) { GTEST_SKIP(); } - const int_type resolution = int_type{1} << static_cast(std::floor(bitshift)); - raft::random::uniformInt( - handle, rng, reinterpret_cast(ptr), size, -resolution, resolution - 1); - - GenerateRoundingErrorFreeDataset_kernel - <<>>(ptr, size, resolution); -} - -template -void InitDataset(const raft::resources& handle, - DataT* const datatset_ptr, - std::uint32_t size, - std::uint32_t dim, - raft::distance::DistanceType metric, - raft::random::RngState& r) -{ - if constexpr (std::is_same_v || std::is_same_v) { - GenerateRoundingErrorFreeDataset(handle, datatset_ptr, size, dim, r, true); - - if (metric == raft::distance::InnerProduct) { - auto dataset_view = raft::make_device_matrix_view(datatset_ptr, size, dim); - raft::linalg::row_normalize( - handle, raft::make_const_mdspan(dataset_view), dataset_view, raft::linalg::L2Norm); - } - } else if constexpr (std::is_same_v || std::is_same_v) { - if constexpr (std::is_same_v) { - raft::random::uniformInt(handle, r, datatset_ptr, size * dim, DataT(-10), DataT(10)); - } else { - raft::random::uniformInt(handle, r, datatset_ptr, size * dim, DataT(1), DataT(20)); - } - - if (metric == raft::distance::InnerProduct) { - // TODO (enp1s0): Change this once row_normalize supports (u)int8 matrices. - // https://github.com/rapidsai/raft/issues/2291 - - using ComputeT = float; - auto dataset_view = raft::make_device_matrix_view(datatset_ptr, size, dim); - auto dev_row_norm = raft::make_device_vector(handle, size); - const auto normalized_norm = - (std::is_same_v ? 40 : 20) * std::sqrt(static_cast(dim)); - - raft::linalg::reduce(dev_row_norm.data_handle(), - datatset_ptr, - dim, - size, - 0.f, - true, - true, - resource::get_cuda_stream(handle), - false, - raft::sq_op(), - raft::add_op(), - raft::sqrt_op()); - raft::linalg::matrix_vector_op( - handle, - raft::make_const_mdspan(dataset_view), - raft::make_const_mdspan(dev_row_norm.view()), - dataset_view, - raft::linalg::Apply::ALONG_COLUMNS, - [normalized_norm] __device__(DataT elm, ComputeT norm) { - const ComputeT v = elm / norm * normalized_norm; - const ComputeT max_v_range = std::numeric_limits::max(); - const ComputeT min_v_range = std::numeric_limits::min(); - return static_cast(std::min(max_v_range, std::max(min_v_range, v))); - }); - } - } -} -} // namespace - -struct AnnCagraInputs { - int n_queries; - int n_rows; - int dim; - int k; - graph_build_algo build_algo; - search_algo algo; - int max_queries; - int team_size; - int itopk_size; - int search_width; - raft::distance::DistanceType metric; - bool host_dataset; - bool include_serialized_dataset; - // std::optional - double min_recall; // = std::nullopt; -}; - -inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) -{ - std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; - std::vector build_algo = {"IVF_PQ", "NN_DESCENT"}; - os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim - << ", k=" << p.k << ", " << algo.at((int)p.algo) << ", max_queries=" << p.max_queries - << ", itopk_size=" << p.itopk_size << ", search_width=" << p.search_width - << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") - << ", build_algo=" << build_algo.at((int)p.build_algo) << '}' << std::endl; - return os; -} - -template -class AnnCagraTest : public ::testing::TestWithParam { - public: - AnnCagraTest() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_), - search_queries(0, stream_) - { - } - - protected: - void testCagra() - { - // TODO (tarang-jain): remove when NN Descent index building support InnerProduct. Reference - // issue: https://github.com/rapidsai/raft/issues/2276 - if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) - GTEST_SKIP(); - - size_t queries_size = ps.n_queries * ps.k; - std::vector indices_Cagra(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_Cagra(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.n_queries, - ps.n_rows, - ps.dim, - ps.k, - ps.metric); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - { - rmm::device_uvector distances_dev(queries_size, stream_); - rmm::device_uvector indices_dev(queries_size, stream_); - - { - cagra::index_params index_params; - index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is - // not used for knn_graph building. - index_params.build_algo = ps.build_algo; - cagra::search_params search_params; - search_params.algo = ps.algo; - search_params.max_queries = ps.max_queries; - search_params.team_size = ps.team_size; - search_params.itopk_size = ps.itopk_size; - - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.n_rows, ps.dim); - - { - cagra::index index(handle_); - if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); - auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - index = cagra::build(handle_, index_params, database_host_view); - } else { - index = cagra::build(handle_, index_params, database_view); - }; - cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset); - } - - auto index = cagra::deserialize(handle_, "cagra_index"); - if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.n_queries, ps.dim); - auto indices_out_view = - raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_dev.data(), ps.n_queries, ps.k); - - cagra::search( - handle_, search_params, index, search_queries_view, indices_out_view, dists_out_view); - update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); - update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - // for (int i = 0; i < min(ps.n_queries, 10); i++) { - // // std::cout << "query " << i << std::end; - // print_vector("T", indices_naive.data() + i * ps.k, ps.k, std::cout); - // print_vector("C", indices_Cagra.data() + i * ps.k, ps.k, std::cout); - // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); - // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); - // } - - double min_recall = ps.min_recall; - EXPECT_TRUE(eval_neighbours(indices_naive, - indices_Cagra, - distances_naive, - distances_Cagra, - ps.n_queries, - ps.k, - 0.003, - min_recall)); - EXPECT_TRUE(eval_distances(handle_, - database.data(), - search_queries.data(), - indices_dev.data(), - distances_dev.data(), - ps.n_rows, - ps.dim, - ps.n_queries, - ps.k, - ps.metric, - 1.0e-4)); - } - } - - void SetUp() override - { - database.resize(((size_t)ps.n_rows) * ps.dim, stream_); - search_queries.resize(ps.n_queries * ps.dim, stream_); - raft::random::RngState r(1234ULL); - InitDataset(handle_, database.data(), ps.n_rows, ps.dim, ps.metric, r); - InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r); - resource::sync_stream(handle_); - } - - void TearDown() override - { - resource::sync_stream(handle_); - database.resize(0, stream_); - search_queries.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - AnnCagraInputs ps; - rmm::device_uvector database; - rmm::device_uvector search_queries; -}; - -template -class AnnCagraSortTest : public ::testing::TestWithParam { - public: - AnnCagraSortTest() - : ps(::testing::TestWithParam::GetParam()), database(0, handle_.get_stream()) - { - } - - protected: - void testCagraSort() - { - if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) - GTEST_SKIP(); - - { - // Step 1: Build a sorted KNN graph by CAGRA knn build - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.n_rows, ps.dim); - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy( - database_host.data_handle(), database.data(), database.size(), handle_.get_stream()); - auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - - cagra::index_params index_params; - auto knn_graph = - raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); - - if (ps.build_algo == graph_build_algo::IVF_PQ) { - auto build_params = ivf_pq::index_params::from_dataset(database_view, ps.metric); - if (ps.host_dataset) { - cagra::build_knn_graph( - handle_, database_host_view, knn_graph.view(), 2, build_params); - } else { - cagra::build_knn_graph( - handle_, database_view, knn_graph.view(), 2, build_params); - } - } else { - auto nn_descent_idx_params = experimental::nn_descent::index_params{}; - nn_descent_idx_params.graph_degree = index_params.intermediate_graph_degree; - nn_descent_idx_params.intermediate_graph_degree = index_params.intermediate_graph_degree; - - if (ps.host_dataset) { - cagra::build_knn_graph( - handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); - } else { - cagra::build_knn_graph( - handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); - } - } - - handle_.sync_stream(); - ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view(), ps.metric)); - - if (ps.metric != raft::distance::DistanceType::InnerProduct) { - RandomSuffle(knn_graph.view()); - - cagra::sort_knn_graph(handle_, database_view, knn_graph.view()); - handle_.sync_stream(); - - ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view(), ps.metric)); - } - } - } - - void SetUp() override - { - database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream()); - raft::random::RngState r(1234ULL); - if constexpr (std::is_same_v || std::is_same_v) { - GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, false); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); - } - handle_.sync_stream(); - } - - void TearDown() override - { - handle_.sync_stream(); - database.resize(0, handle_.get_stream()); - } - - private: - raft::device_resources handle_; - AnnCagraInputs ps; - rmm::device_uvector database; -}; - -template -class AnnCagraFilterTest : public ::testing::TestWithParam { - public: - AnnCagraFilterTest() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_), - search_queries(0, stream_) - { - } - - protected: - void testCagraFilter() - { - if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) - GTEST_SKIP(); - - size_t queries_size = ps.n_queries * ps.k; - std::vector indices_Cagra(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_Cagra(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database_filtered_ptr, - ps.n_queries, - ps.n_rows - test_cagra_sample_filter::offset, - ps.dim, - ps.k, - ps.metric); - raft::linalg::addScalar(indices_naive_dev.data(), - indices_naive_dev.data(), - IdxT(test_cagra_sample_filter::offset), - queries_size, - stream_); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - { - rmm::device_uvector distances_dev(queries_size, stream_); - rmm::device_uvector indices_dev(queries_size, stream_); - - { - cagra::index_params index_params; - index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is - // not used for knn_graph building. - index_params.nn_descent_niter = 50; - cagra::search_params search_params; - search_params.algo = ps.algo; - search_params.max_queries = ps.max_queries; - search_params.team_size = ps.team_size; - search_params.hashmap_mode = cagra::hash_mode::HASH; - - // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for - // k>1024 skip these tests until fixed - if (ps.k >= 1024) { GTEST_SKIP(); } - // search_params.itopk_size = ps.itopk_size; - - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.n_rows, ps.dim); - - cagra::index index(handle_); - if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); - auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - index = cagra::build(handle_, index_params, database_host_view); - } else { - index = cagra::build(handle_, index_params, database_view); - } - - if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.n_queries, ps.dim); - auto indices_out_view = - raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_dev.data(), ps.n_queries, ps.k); - - cagra::search_with_filtering(handle_, - search_params, - index, - search_queries_view, - indices_out_view, - dists_out_view, - test_cagra_sample_filter()); - update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); - update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - // Test filter - bool unacceptable_node = false; - for (int q = 0; q < ps.n_queries; q++) { - for (int i = 0; i < ps.k; i++) { - const auto n = indices_Cagra[q * ps.k + i]; - unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n); - } - } - EXPECT_FALSE(unacceptable_node); - - double min_recall = ps.min_recall; - // TODO(mfoerster): re-enable uniquenes test - EXPECT_TRUE(eval_neighbours(indices_naive, - indices_Cagra, - distances_naive, - distances_Cagra, - ps.n_queries, - ps.k, - 0.003, - min_recall, - false)); - EXPECT_TRUE(eval_distances(handle_, - database.data(), - search_queries.data(), - indices_dev.data(), - distances_dev.data(), - ps.n_rows, - ps.dim, - ps.n_queries, - ps.k, - ps.metric, - 1.0e-4)); - } - } - - void testCagraRemoved() - { - if (ps.metric == distance::InnerProduct && ps.build_algo == graph_build_algo::NN_DESCENT) - GTEST_SKIP(); - - size_t queries_size = ps.n_queries * ps.k; - std::vector indices_Cagra(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_Cagra(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database_filtered_ptr, - ps.n_queries, - ps.n_rows - test_cagra_sample_filter::offset, - ps.dim, - ps.k, - ps.metric); - raft::linalg::addScalar(indices_naive_dev.data(), - indices_naive_dev.data(), - IdxT(test_cagra_sample_filter::offset), - queries_size, - stream_); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - { - rmm::device_uvector distances_dev(queries_size, stream_); - rmm::device_uvector indices_dev(queries_size, stream_); - - { - cagra::index_params index_params; - index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is - // not used for knn_graph building. - index_params.nn_descent_niter = 50; - cagra::search_params search_params; - search_params.algo = ps.algo; - search_params.max_queries = ps.max_queries; - search_params.team_size = ps.team_size; - search_params.hashmap_mode = cagra::hash_mode::HASH; - - // TODO: setting search_params.itopk_size here breaks the filter tests, but is required for - // k>1024 skip these tests until fixed - if (ps.k >= 1024) { GTEST_SKIP(); } - // search_params.itopk_size = ps.itopk_size; - - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.n_rows, ps.dim); - - cagra::index index(handle_); - if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); - auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - index = cagra::build(handle_, index_params, database_host_view); - } else { - index = cagra::build(handle_, index_params, database_view); - } - - if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.n_queries, ps.dim); - auto indices_out_view = - raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_dev.data(), ps.n_queries, ps.k); - auto removed_indices = - raft::make_device_vector(handle_, test_cagra_sample_filter::offset); - thrust::sequence( - resource::get_thrust_policy(handle_), - thrust::device_pointer_cast(removed_indices.data_handle()), - thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); - resource::sync_stream(handle_); - raft::core::bitset removed_indices_bitset( - handle_, removed_indices.view(), ps.n_rows); - cagra::search_with_filtering( - handle_, - search_params, - index, - search_queries_view, - indices_out_view, - dists_out_view, - raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); - update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); - update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - double min_recall = ps.min_recall; - // TODO(mfoerster): re-enable uniquenes test - EXPECT_TRUE(eval_neighbours(indices_naive, - indices_Cagra, - distances_naive, - distances_Cagra, - ps.n_queries, - ps.k, - 0.003, - min_recall, - false)); - EXPECT_TRUE(eval_distances(handle_, - database.data(), - search_queries.data(), - indices_dev.data(), - distances_dev.data(), - ps.n_rows, - ps.dim, - ps.n_queries, - ps.k, - ps.metric, - 1.0e-4)); - } - } - - void SetUp() override - { - database.resize(((size_t)ps.n_rows) * ps.dim, stream_); - search_queries.resize(ps.n_queries * ps.dim, stream_); - raft::random::RngState r(1234ULL); - InitDataset(handle_, database.data(), ps.n_rows, ps.dim, ps.metric, r); - InitDataset(handle_, search_queries.data(), ps.n_queries, ps.dim, ps.metric, r); - resource::sync_stream(handle_); - } - - void TearDown() override - { - resource::sync_stream(handle_); - database.resize(0, stream_); - search_queries.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - AnnCagraInputs ps; - rmm::device_uvector database; - rmm::device_uvector search_queries; -}; - -inline std::vector generate_inputs() -{ - // TODO(tfeher): test MULTI_CTA kernel with search_width > 1 to allow multiple CTA per queries - std::vector inputs = raft::util::itertools::product( - {100}, - {1000}, - {1, 8, 17, 1599}, - {16}, // k - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, - {0, 1, 10, 100}, // query size - {0}, - {256}, - {1}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false}, - {true}, - {0.995}); - - auto inputs2 = raft::util::itertools::product( - {100}, - {1000}, - {1, 8, 17, 1599}, - {1}, // k - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, - {0, 1, 10, 100}, // query size - {0}, - {256}, - {1}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false}, - {true}, - {99. / 100} - // smaller threshould than the other test cases because it is too strict for Top-1 search - ); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = raft::util::itertools::product( - {100}, - {1000}, - {1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim - {16}, // k - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false}, - {true}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = raft::util::itertools::product( - {100}, - {1000}, - {64}, - {16}, - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0, 4, 8, 16, 32}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false}, - {false}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = raft::util::itertools::product( - {100}, - {1000}, - {64}, - {16}, - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {32, 64, 128, 256, 512, 768}, - {1}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false}, - {true}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = raft::util::itertools::product( - {100}, - {10000, 20000}, - {32}, - {10}, - {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false, true}, - {false}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = raft::util::itertools::product( - {100}, - {20000}, - {32}, - {2048}, // k - {graph_build_algo::NN_DESCENT}, - {search_algo::AUTO}, - {10}, - {0}, - {4096}, // itopk_size - {1}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false}, - {false}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - return inputs; -} - -const std::vector inputs = generate_inputs(); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh deleted file mode 100644 index 412e71bff1..0000000000 --- a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // none_cagra_sample_filter -#include // RAFT_EXPLICIT - -namespace raft::neighbors::cagra::detail { - -namespace multi_cta_search { -#define instantiate_kernel_selection( \ - DATASET_DESCRIPTOR, TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::DATASET_DESCRIPTOR dataset_desc, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_kernel_selection(standard_dataset_descriptor_t, - 32, - 1024, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection(standard_dataset_descriptor_t, - 8, - 128, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection(standard_dataset_descriptor_t, - 16, - 256, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection(standard_dataset_descriptor_t, - 32, - 512, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection -} // namespace multi_cta_search - -namespace single_cta_search { - -#define instantiate_single_cta_select_and_run( \ - DATASET_DESCRIPTOR, TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::DATASET_DESCRIPTOR dataset_desc, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, - 32, - 1024, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, - 8, - 128, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, - 16, - 256, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, - 32, - 512, - float, - uint64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu deleted file mode 100644 index ff7e839abf..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_float_int64_t.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_cagra.cuh" -#include "search_kernel_uint64_t.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraTest AnnCagraTestF_I64; -TEST_P(AnnCagraTestF_I64, AnnCagra) { this->testCagra(); } - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_I64, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu deleted file mode 100644 index 7d29ce4f99..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_cagra.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraTest AnnCagraTestF_U32; -TEST_P(AnnCagraTestF_U32, AnnCagra) { this->testCagra(); } - -typedef AnnCagraSortTest AnnCagraSortTestF_U32; -TEST_P(AnnCagraSortTestF_U32, AnnCagraSort) { this->testCagraSort(); } - -typedef AnnCagraFilterTest AnnCagraFilterTestF_U32; -TEST_P(AnnCagraFilterTestF_U32, AnnCagraFilter) -{ - this->testCagraFilter(); - this->testCagraRemoved(); -} - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestF_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_half_int64_t.cu b/cpp/test/neighbors/ann_cagra/test_half_int64_t.cu deleted file mode 100644 index bcdd95bece..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_half_int64_t.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_cagra.cuh" -#include "search_kernel_uint64_t.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraTest AnnCagraTestH_I64; -TEST_P(AnnCagraTestH_I64, AnnCagra) { this->testCagra(); } - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestH_I64, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu deleted file mode 100644 index ec7144f8d0..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_cagra.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraTest AnnCagraTestH_U32; -TEST_P(AnnCagraTestH_U32, AnnCagra) { this->testCagra(); } - -typedef AnnCagraSortTest AnnCagraSortTestH_U32; -TEST_P(AnnCagraSortTestH_U32, AnnCagraSort) { this->testCagraSort(); } - -typedef AnnCagraFilterTest AnnCagraFilterTestH_U32; -TEST_P(AnnCagraFilterTestH_U32, AnnCagraFilter) -{ - this->testCagraFilter(); - this->testCagraRemoved(); -} - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestH_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestH_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestH_U32, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu deleted file mode 100644 index b2242d89b1..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_cagra.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraTest AnnCagraTestI8_U32; -TEST_P(AnnCagraTestI8_U32, AnnCagra) { this->testCagra(); } -typedef AnnCagraSortTest AnnCagraSortTestI8_U32; -TEST_P(AnnCagraSortTestI8_U32, AnnCagraSort) { this->testCagraSort(); } -typedef AnnCagraFilterTest AnnCagraFilterTestI8_U32; -TEST_P(AnnCagraFilterTestI8_U32, AnnCagraFilter) -{ - this->testCagraFilter(); - this->testCagraRemoved(); -} - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestI8_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestI8_U32, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu deleted file mode 100644 index 302b2bec18..0000000000 --- a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_cagra.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraTest AnnCagraTestU8_U32; -TEST_P(AnnCagraTestU8_U32, AnnCagra) { this->testCagra(); } - -typedef AnnCagraSortTest AnnCagraSortTestU8_U32; -TEST_P(AnnCagraSortTestU8_U32, AnnCagraSort) { this->testCagraSort(); } - -typedef AnnCagraFilterTest AnnCagraFilterTestU8_U32; -TEST_P(AnnCagraFilterTestU8_U32, AnnCagraSort) -{ - this->testCagraFilter(); - this->testCagraRemoved(); -} - -INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraSortTest, AnnCagraSortTestU8_U32, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestU8_U32, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra_vpq.cuh b/cpp/test/neighbors/ann_cagra_vpq.cuh deleted file mode 100644 index 6b24bca921..0000000000 --- a/cpp/test/neighbors/ann_cagra_vpq.cuh +++ /dev/null @@ -1,336 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../test_utils.cuh" -#include "ann_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -#include - -#include -#include -#include -#include -#include - -namespace { -template -void GenerateDataset(T* const dataset_ptr, - T* const query_ptr, - const std::size_t dataset_size, - const std::size_t query_size, - const std::size_t dim, - const std::size_t num_centers, - cudaStream_t cuda_stream) -{ - auto center_list = raft::make_host_matrix(num_centers, dim); - auto host_dataset = raft::make_host_matrix(std::max(dataset_size, query_size), dim); - - std::normal_distribution dist(0, 1); - std::mt19937 mt(0); - for (std::size_t i = 0; i < center_list.size(); i++) { - center_list.data_handle()[i] = dist(mt); - } - - std::uniform_int_distribution i_dist(0, num_centers - 1); - for (std::size_t i = 0; i < dataset_size; i++) { - const auto center_index = i_dist(mt); - for (std::size_t j = 0; j < dim; j++) { - host_dataset.data_handle()[i * dim + j] = - center_list.data_handle()[center_index + j] + dist(mt) * 1e-1; - } - } - raft::copy(dataset_ptr, host_dataset.data_handle(), dataset_size * dim, cuda_stream); - - for (std::size_t i = 0; i < query_size; i++) { - const auto center_index = i_dist(mt); - for (std::size_t j = 0; j < dim; j++) { - host_dataset.data_handle()[i * dim + j] = - center_list.data_handle()[center_index + j] + dist(mt) * 1e-1; - } - } - raft::copy(query_ptr, host_dataset.data_handle(), query_size * dim, cuda_stream); -} -} // namespace - -namespace raft::neighbors::cagra { -struct AnnCagraVpqInputs { - int n_queries; - int n_rows; - int dim; - int k; - int pq_len; - int pq_bits; - graph_build_algo build_algo; - search_algo algo; - int max_queries; - int team_size; - int itopk_size; - int search_width; - raft::distance::DistanceType metric; - bool host_dataset; - bool include_serialized_dataset; - // std::optional - double min_recall; // = std::nullopt; -}; - -inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraVpqInputs& p) -{ - std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; - std::vector build_algo = {"IVF_PQ", "NN_DESCENT"}; - os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim - << ", k=" << p.k << ", pq_bits=" << p.pq_bits << ", pq_len=" << p.pq_len << ", " - << algo.at((int)p.algo) << ", max_queries=" << p.max_queries << ", itopk_size=" << p.itopk_size - << ", search_width=" << p.search_width << ", metric=" << static_cast(p.metric) - << (p.host_dataset ? ", host" : ", device") - << ", build_algo=" << build_algo.at((int)p.build_algo) << '}' << std::endl; - return os; -} - -template -class AnnCagraVpqTest : public ::testing::TestWithParam { - public: - AnnCagraVpqTest() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_), - search_queries(0, stream_) - { - } - - protected: - void testCagra() - { - size_t queries_size = ps.n_queries * ps.k; - std::vector indices_Cagra(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_Cagra(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.n_queries, - ps.n_rows, - ps.dim, - ps.k, - ps.metric); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - const auto vpq_k = ps.k * 4; - { - rmm::device_uvector distances_dev(vpq_k * ps.n_queries, stream_); - rmm::device_uvector indices_dev(vpq_k * ps.n_queries, stream_); - - { - if ((ps.dim % ps.pq_len) != 0) { - // TODO: remove this requirement in the algorithm. - GTEST_SKIP() << "(TODO) At the moment dim, (" << ps.dim - << ") must be a multiple of pq_len (" << ps.pq_len << ")"; - } - cagra::index_params index_params; - index_params.compression = vpq_params{.pq_bits = static_cast(ps.pq_bits), - .pq_dim = static_cast(ps.dim / ps.pq_len)}; - index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is - // not used for knn_graph building. - index_params.build_algo = ps.build_algo; - cagra::search_params search_params; - search_params.algo = ps.algo; - search_params.max_queries = ps.max_queries; - search_params.team_size = ps.team_size; - search_params.itopk_size = ps.itopk_size; - - auto database_view = - raft::make_device_matrix_view(database.data(), ps.n_rows, ps.dim); - - { - cagra::index index(handle_); - if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); - auto database_host_view = raft::make_host_matrix_view( - database_host.data_handle(), ps.n_rows, ps.dim); - index = cagra::build(handle_, index_params, database_host_view); - } else { - index = cagra::build(handle_, index_params, database_view); - }; - cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset); - } - - auto index = cagra::deserialize(handle_, "cagra_index"); - if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } - - // CAGRA-Q sanity check: we've built the right index type - auto* vpq_dataset = - dynamic_cast*>(&index.data()); - EXPECT_NE(vpq_dataset, nullptr) - << "Expected VPQ dataset, because we're testing CAGRA-Q here."; - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.n_queries, ps.dim); - auto indices_out_view = - raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, vpq_k); - auto dists_out_view = raft::make_device_matrix_view( - distances_dev.data(), ps.n_queries, vpq_k); - - cagra::search( - handle_, search_params, index, search_queries_view, indices_out_view, dists_out_view); - - { - auto host_dataset = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy( - host_dataset.data_handle(), (const DataT*)database.data(), ps.n_rows * ps.dim, stream_); - - auto host_queries = raft::make_host_matrix(ps.n_queries, ps.dim); - raft::copy(host_queries.data_handle(), - (const DataT*)search_queries_view.data_handle(), - ps.n_queries * ps.dim, - stream_); - - auto host_index_candidate = raft::make_host_matrix(ps.n_queries, vpq_k); - raft::copy(host_index_candidate.data_handle(), - indices_out_view.data_handle(), - ps.n_queries * vpq_k, - stream_); - - auto host_indices_Cagra_view = - raft::make_host_matrix_view(indices_Cagra.data(), ps.n_queries, ps.k); - - auto host_dists_Cagra_view = - raft::make_host_matrix_view(distances_Cagra.data(), ps.n_queries, ps.k); - - resource::sync_stream(handle_); - - raft::neighbors::refine(handle_, - raft::make_const_mdspan(host_dataset.view()), - raft::make_const_mdspan(host_queries.view()), - raft::make_const_mdspan(host_index_candidate.view()), - host_indices_Cagra_view, - host_dists_Cagra_view, - ps.metric); - - raft::copy(indices_dev.data(), - host_indices_Cagra_view.data_handle(), - ps.k * ps.n_queries, - stream_); - raft::copy(distances_dev.data(), - host_dists_Cagra_view.data_handle(), - ps.k * ps.n_queries, - stream_); - resource::sync_stream(handle_); - } - } - - double min_recall = ps.min_recall; - EXPECT_TRUE(eval_neighbours(indices_naive, - indices_Cagra, - distances_naive, - distances_Cagra, - ps.n_queries, - ps.k, - 0.003, - min_recall)); - EXPECT_TRUE(eval_distances(handle_, - database.data(), - search_queries.data(), - indices_dev.data(), - distances_dev.data(), - ps.n_rows, - ps.dim, - ps.n_queries, - ps.k, - ps.metric, - 1.0e-4)); - } - } - - void SetUp() override - { - database.resize(((size_t)ps.n_rows) * ps.dim, stream_); - search_queries.resize(ps.n_queries * ps.dim, stream_); - GenerateDataset(database.data(), - search_queries.data(), - ps.n_rows, - ps.n_queries, - ps.dim, - static_cast(std::sqrt(ps.n_rows)), - stream_); - resource::sync_stream(handle_); - } - - void TearDown() override - { - resource::sync_stream(handle_); - database.resize(0, stream_); - search_queries.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - AnnCagraVpqInputs ps; - rmm::device_uvector database; - rmm::device_uvector search_queries; -}; - -const std::vector vpq_inputs = raft::util::itertools::product( - {100}, // n_queries - {1000, 10000}, // n_rows - {128, 132, 192, 256, 512, 768}, // dim - {8, 12}, // k - {2, 4}, // pq_len - {8}, // pq_bits - {graph_build_algo::NN_DESCENT}, // build_algo - {search_algo::SINGLE_CTA, search_algo::MULTI_CTA}, // algo - {0}, // max_queries - {0}, // team_size - {512}, // itopk_size - {1}, // search_width - {raft::distance::DistanceType::L2Expanded}, // metric - {false}, // host_dataset - {true}, // include_serialized_dataset - {0.8} // min_recall -); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu deleted file mode 100644 index f60edb5ed6..0000000000 --- a/cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "../ann_cagra_vpq.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraVpqTest AnnCagraVpqTestF_I64; -TEST_P(AnnCagraVpqTestF_I64, AnnCagraVpq) { this->testCagra(); } - -INSTANTIATE_TEST_CASE_P(AnnCagraVpqTest, AnnCagraVpqTestF_I64, ::testing::ValuesIn(vpq_inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu deleted file mode 100644 index 19d3f32250..0000000000 --- a/cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_cagra_vpq.cuh" - -#include - -namespace raft::neighbors::cagra { - -typedef AnnCagraVpqTest AnnCagraVpqTestF_U32; -TEST_P(AnnCagraVpqTestF_U32, AnnCagraVpq) { this->testCagra(); } - -INSTANTIATE_TEST_CASE_P(AnnCagraVpqTest, AnnCagraVpqTestF_U32, ::testing::ValuesIn(vpq_inputs)); - -} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh deleted file mode 100644 index de6af589fa..0000000000 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ /dev/null @@ -1,675 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../test_utils.cuh" -#include "ann_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::ivf_flat { - -struct test_ivf_sample_filter { - static constexpr unsigned offset = 300; -}; - -template -struct AnnIvfFlatInputs { - IdxT num_queries; - IdxT num_db_vecs; - IdxT dim; - IdxT k; - IdxT nprobe; - IdxT nlist; - raft::distance::DistanceType metric; - bool adaptive_centers; - bool host_dataset; -}; - -template -::std::ostream& operator<<(::std::ostream& os, const AnnIvfFlatInputs& p) -{ - os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " - << p.nprobe << ", " << p.nlist << ", " << static_cast(p.metric) << ", " - << p.adaptive_centers << ", " << p.host_dataset << '}' << std::endl; - return os; -} - -template -class AnnIVFFlatTest : public ::testing::TestWithParam> { - public: - AnnIVFFlatTest() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam>::GetParam()), - database(0, stream_), - search_queries(0, stream_) - { - } - - void testIVFFlat() - { - size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivfflat(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_ivfflat(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.metric); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - { - // unless something is really wrong with clustering, this could serve as a lower bound on - // recall - double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); - - rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); - rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); - - { - // legacy interface - raft::spatial::knn::IVFFlatParam ivfParams; - ivfParams.nprobe = ps.nprobe; - ivfParams.nlist = ps.nlist; - raft::spatial::knn::knnIndex index; - - approx_knn_build_index(handle_, - &index, - dynamic_cast(&ivfParams), - ps.metric, - (IdxT)0, - database.data(), - ps.num_db_vecs, - ps.dim); - - resource::sync_stream(handle_); - approx_knn_search(handle_, - distances_ivfflat_dev.data(), - indices_ivfflat_dev.data(), - &index, - ps.k, - search_queries.data(), - ps.num_queries); - - update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); - update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - ASSERT_TRUE(eval_neighbours(indices_naive, - indices_ivfflat, - distances_naive, - distances_ivfflat, - ps.num_queries, - ps.k, - 0.001, - min_recall)); - { - ivf_flat::index_params index_params; - ivf_flat::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; - index_params.adaptive_centers = ps.adaptive_centers; - search_params.n_probes = ps.nprobe; - - index_params.add_data_on_build = false; - index_params.kmeans_trainset_fraction = 0.5; - index_params.metric_arg = 0; - - ivf_flat::index idx(handle_, index_params, ps.dim); - ivf_flat::index index_2(handle_, index_params, ps.dim); - - if (!ps.host_dataset) { - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - idx = ivf_flat::build(handle_, index_params, database_view); - rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); - thrust::sequence(resource::get_thrust_policy(handle_), - thrust::device_pointer_cast(vector_indices.data()), - thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); - resource::sync_stream(handle_); - - IdxT half_of_data = ps.num_db_vecs / 2; - - auto half_of_data_view = raft::make_device_matrix_view( - (const DataT*)database.data(), half_of_data, ps.dim); - - const std::optional> no_opt = std::nullopt; - index_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); - - auto new_half_of_data_view = raft::make_device_matrix_view( - database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); - - auto new_half_of_data_indices_view = raft::make_device_vector_view( - vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); - - ivf_flat::extend(handle_, - new_half_of_data_view, - std::make_optional>( - new_half_of_data_indices_view), - &index_2); - - } else { - auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); - raft::copy( - host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); - idx = - ivf_flat::build(handle_, index_params, raft::make_const_mdspan(host_database.view())); - - auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); - std::iota(vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, 0); - - IdxT half_of_data = ps.num_db_vecs / 2; - - auto half_of_data_view = raft::make_host_matrix_view( - (const DataT*)host_database.data_handle(), half_of_data, ps.dim); - - const std::optional> no_opt = std::nullopt; - index_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); - - auto new_half_of_data_view = raft::make_host_matrix_view( - host_database.data_handle() + half_of_data * ps.dim, - IdxT(ps.num_db_vecs) - half_of_data, - ps.dim); - auto new_half_of_data_indices_view = raft::make_host_vector_view( - vector_indices.data_handle() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); - ivf_flat::extend(handle_, - new_half_of_data_view, - std::make_optional>( - new_half_of_data_indices_view), - &index_2); - } - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.num_queries, ps.dim); - auto indices_out_view = raft::make_device_matrix_view( - indices_ivfflat_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_ivfflat_dev.data(), ps.num_queries, ps.k); - ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2); - - auto index_loaded = ivf_flat::detail::deserialize(handle_, "ivf_flat_index"); - ASSERT_EQ(index_2.size(), index_loaded.size()); - - ivf_flat::search(handle_, - search_params, - index_loaded, - search_queries_view, - indices_out_view, - dists_out_view); - - update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); - update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - - // Test the centroid invariants - if (index_2.adaptive_centers()) { - // The centers must be up-to-date with the corresponding data - std::vector list_sizes(index_2.n_lists()); - std::vector list_indices(index_2.n_lists()); - rmm::device_uvector centroid(ps.dim, stream_); - raft::copy( - list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); - raft::copy( - list_indices.data(), index_2.inds_ptrs().data_handle(), index_2.n_lists(), stream_); - resource::sync_stream(handle_); - for (uint32_t l = 0; l < index_2.n_lists(); l++) { - if (list_sizes[l] == 0) continue; - rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); - raft::spatial::knn::detail::utils::copy_selected((IdxT)list_sizes[l], - (IdxT)ps.dim, - database.data(), - list_indices[l], - (IdxT)ps.dim, - cluster_data.data(), - (IdxT)ps.dim, - stream_); - raft::stats::mean( - centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_); - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle() + ps.dim * l, - centroid.data(), - ps.dim, - raft::CompareApprox(0.001), - stream_)); - } - } else { - // The centers must be immutable - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle(), - idx.centers().data_handle(), - index_2.centers().size(), - raft::Compare(), - stream_)); - } - } - ASSERT_TRUE(eval_neighbours(indices_naive, - indices_ivfflat, - distances_naive, - distances_ivfflat, - ps.num_queries, - ps.k, - 0.001, - min_recall)); - } - } - - void testPacker() - { - ivf_flat::index_params index_params; - ivf_flat::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; - index_params.adaptive_centers = false; - search_params.n_probes = ps.nprobe; - - index_params.add_data_on_build = false; - index_params.kmeans_trainset_fraction = 1.0; - index_params.metric_arg = 0; - - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - - auto idx = ivf_flat::build(handle_, index_params, database_view); - - const std::optional> no_opt = std::nullopt; - index extend_index = ivf_flat::extend(handle_, database_view, no_opt, idx); - - auto list_sizes = raft::make_host_vector(idx.n_lists()); - update_host(list_sizes.data_handle(), - extend_index.list_sizes().data_handle(), - extend_index.n_lists(), - stream_); - resource::sync_stream(handle_); - - auto& lists = idx.lists(); - - // conservative memory allocation for codepacking - auto list_device_spec = list_spec{idx.dim(), false}; - - for (uint32_t label = 0; label < idx.n_lists(); label++) { - uint32_t list_size = list_sizes.data_handle()[label]; - - ivf::resize_list(handle_, lists[label], list_device_spec, list_size, 0); - } - - helpers::recompute_internal_state(handle_, &idx); - - using interleaved_group = Pow2; - - for (uint32_t label = 0; label < idx.n_lists(); label++) { - uint32_t list_size = list_sizes.data_handle()[label]; - - if (list_size > 0) { - uint32_t padded_list_size = interleaved_group::roundUp(list_size); - uint32_t n_elems = padded_list_size * idx.dim(); - auto list_data = lists[label]->data; - auto list_inds = extend_index.lists()[label]->indices; - - // fetch the flat codes - auto flat_codes = make_device_matrix(handle_, list_size, idx.dim()); - - matrix::gather( - handle_, - make_device_matrix_view( - (const DataT*)database.data(), static_cast(ps.num_db_vecs), idx.dim()), - make_device_vector_view((const IdxT*)list_inds.data_handle(), - list_size), - flat_codes.view()); - - helpers::codepacker::pack( - handle_, make_const_mdspan(flat_codes.view()), idx.veclen(), 0, list_data.view()); - - { - auto mask = make_device_vector(handle_, n_elems); - - linalg::map_offset(handle_, - mask.view(), - [dim = idx.dim(), - list_size, - padded_list_size, - chunk_size = util::FastIntDiv(idx.veclen())] __device__(auto i) { - uint32_t max_group_offset = interleaved_group::roundDown(list_size); - if (i < max_group_offset * dim) { return true; } - uint32_t surplus = (i - max_group_offset * dim); - uint32_t ingroup_id = interleaved_group::mod(surplus / chunk_size); - return ingroup_id < (list_size - max_group_offset); - }); - - // ensure that the correct number of indices are masked out - ASSERT_TRUE(thrust::reduce(resource::get_thrust_policy(handle_), - mask.data_handle(), - mask.data_handle() + n_elems, - 0) == list_size * ps.dim); - - auto packed_list_data = make_device_vector(handle_, n_elems); - - linalg::map_offset(handle_, - packed_list_data.view(), - [mask = mask.data_handle(), - list_data = list_data.data_handle()] __device__(uint32_t i) { - if (mask[i]) return list_data[i]; - return DataT{0}; - }); - - auto extend_data = extend_index.lists()[label]->data; - auto extend_data_filtered = make_device_vector(handle_, n_elems); - linalg::map_offset(handle_, - extend_data_filtered.view(), - [mask = mask.data_handle(), - extend_data = extend_data.data_handle()] __device__(uint32_t i) { - if (mask[i]) return extend_data[i]; - return DataT{0}; - }); - - ASSERT_TRUE(raft::devArrMatch(packed_list_data.data_handle(), - extend_data_filtered.data_handle(), - n_elems, - raft::Compare(), - stream_)); - } - - auto unpacked_flat_codes = - make_device_matrix(handle_, list_size, idx.dim()); - - helpers::codepacker::unpack( - handle_, list_data.view(), idx.veclen(), 0, unpacked_flat_codes.view()); - - ASSERT_TRUE(raft::devArrMatch(flat_codes.data_handle(), - unpacked_flat_codes.data_handle(), - list_size * ps.dim, - raft::Compare(), - stream_)); - } - } - } - - void testFilter() - { - size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivfflat(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_ivfflat(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - auto* database_filtered_ptr = database.data() + test_ivf_sample_filter::offset * ps.dim; - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database_filtered_ptr, - ps.num_queries, - ps.num_db_vecs - test_ivf_sample_filter::offset, - ps.dim, - ps.k, - ps.metric); - raft::linalg::addScalar(indices_naive_dev.data(), - indices_naive_dev.data(), - IdxT(test_ivf_sample_filter::offset), - queries_size, - stream_); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - { - // unless something is really wrong with clustering, this could serve as a lower bound on - // recall - double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); - - auto distances_ivfflat_dev = raft::make_device_matrix(handle_, ps.num_queries, ps.k); - auto indices_ivfflat_dev = - raft::make_device_matrix(handle_, ps.num_queries, ps.k); - - { - ivf_flat::index_params index_params; - ivf_flat::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; - index_params.adaptive_centers = ps.adaptive_centers; - search_params.n_probes = ps.nprobe; - - index_params.add_data_on_build = true; - index_params.kmeans_trainset_fraction = 0.5; - index_params.metric_arg = 0; - - // Create IVF Flat index - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_flat::build(handle_, index_params, database_view); - - // Create Bitset filter - auto removed_indices = - raft::make_device_vector(handle_, test_ivf_sample_filter::offset); - thrust::sequence(resource::get_thrust_policy(handle_), - thrust::device_pointer_cast(removed_indices.data_handle()), - thrust::device_pointer_cast(removed_indices.data_handle() + - test_ivf_sample_filter::offset)); - resource::sync_stream(handle_); - - raft::core::bitset removed_indices_bitset( - handle_, removed_indices.view(), ps.num_db_vecs); - - // Search with the filter - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.num_queries, ps.dim); - ivf_flat::search_with_filtering( - handle_, - search_params, - index, - search_queries_view, - indices_ivfflat_dev.view(), - distances_ivfflat_dev.view(), - raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); - - update_host( - distances_ivfflat.data(), distances_ivfflat_dev.data_handle(), queries_size, stream_); - update_host( - indices_ivfflat.data(), indices_ivfflat_dev.data_handle(), queries_size, stream_); - resource::sync_stream(handle_); - } - ASSERT_TRUE(eval_neighbours(indices_naive, - indices_ivfflat, - distances_naive, - distances_ivfflat, - ps.num_queries, - ps.k, - 0.001, - min_recall)); - } - } - - void SetUp() override - { - database.resize(ps.num_db_vecs * ps.dim, stream_); - search_queries.resize(ps.num_queries * ps.dim, stream_); - - raft::random::RngState r(1234ULL); - if constexpr (std::is_same{}) { - raft::random::uniform( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); - raft::random::uniform( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); - raft::random::uniformInt( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); - } - resource::sync_stream(handle_); - } - - void TearDown() override - { - resource::sync_stream(handle_); - database.resize(0, stream_); - search_queries.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - AnnIvfFlatInputs ps; - rmm::device_uvector database; - rmm::device_uvector search_queries; -}; - -const std::vector> inputs = { - // test various dims (aligned and not aligned to vector sizes) - {1000, 10000, 1, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 2, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 3, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 4, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 5, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 10000, 8, 16, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, - - // test dims that do not fit into kernel shared memory limits - {1000, 10000, 2048, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 2049, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 2050, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 2051, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 2052, 16, 40, 1024, raft::distance::DistanceType::InnerProduct, false}, - {1000, 10000, 2053, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, - {1000, 10000, 2056, 16, 40, 1024, raft::distance::DistanceType::L2Expanded, true}, - - // various random combinations - {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false}, - {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::L2Expanded, false}, - {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::L2Expanded, false}, - {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, - {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, - {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false}, - - // various combinations with k>raft::matrix::detail::select::warpsort::kMaxCapacity - {1000, 10000, 16, 1024, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, - {1000, 10000, 2053, 512, 50, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 10000, 2049, 2048, 70, 1024, raft::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false}, - {10, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false}, - {10, 10000, 16, 4000, 120, 2048, raft::distance::DistanceType::L2SqrtExpanded, true}, - {20, 100000, 16, 257, 20, 1024, raft::distance::DistanceType::L2SqrtExpanded, true}, - {1000, 100000, 16, 259, 20, 1024, raft::distance::DistanceType::L2Expanded, true, true}, - {10000, 131072, 8, 280, 20, 1024, raft::distance::DistanceType::InnerProduct, false}, - {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2Expanded, false}, - {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2SqrtExpanded, false}, - {100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::InnerProduct, false}, - {100000, 1024, 16, 300, 20, 60, raft::distance::DistanceType::L2Expanded, false}, - {100000, 1024, 16, 500, 20, 60, raft::distance::DistanceType::L2SqrtExpanded, false}, - {100000, 1024, 16, 700, 20, 60, raft::distance::DistanceType::InnerProduct, false}, - - // host input data - {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false, true}, - {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false, true}, - {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::L2Expanded, false, true}, - {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::L2Expanded, false, true}, - {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true}, - {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true}, - {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true}, - - {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::InnerProduct, true}, - {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::InnerProduct, false}, - {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::InnerProduct, true}, - {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct, true}, - {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::InnerProduct, false}, - {10000, 131072, 8, 10, 50, 1024, raft::distance::DistanceType::InnerProduct, true}, - - {1000, 10000, 4096, 20, 50, 1024, raft::distance::DistanceType::InnerProduct, false}, - - // test splitting the big query batches (> max gridDim.y) into smaller batches - {100000, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, false}, - {1000000, 1024, 32, 10, 256, 256, raft::distance::DistanceType::InnerProduct, false}, - {98306, 1024, 32, 10, 64, 64, raft::distance::DistanceType::InnerProduct, true}, - - // test radix_sort for getting the cluster selection - {1000, - 10000, - 16, - 10, - raft::matrix::detail::select::warpsort::kMaxCapacity * 2, - raft::matrix::detail::select::warpsort::kMaxCapacity * 4, - raft::distance::DistanceType::L2Expanded, - false}, - {1000, - 10000, - 16, - 10, - raft::matrix::detail::select::warpsort::kMaxCapacity * 4, - raft::matrix::detail::select::warpsort::kMaxCapacity * 4, - raft::distance::DistanceType::InnerProduct, - false}, - - // The following two test cases should show very similar recall. - // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers - {20000, 8712, 3, 10, 51, 66, raft::distance::DistanceType::L2Expanded, false}, - {100000, 8712, 3, 10, 51, 66, raft::distance::DistanceType::L2Expanded, false}}; - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu deleted file mode 100644 index 0e1036e566..0000000000 --- a/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter -#include "../ann_ivf_flat.cuh" - -namespace raft::neighbors::ivf_flat { - -typedef AnnIVFFlatTest AnnIVFFlatFilterTestF; -TEST_P(AnnIVFFlatFilterTestF, AnnIVFFlatFilter) { this->testFilter(); } - -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatFilterTestF, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu deleted file mode 100644 index 2ff17b8536..0000000000 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_flat.cuh" - -#include - -namespace raft::neighbors::ivf_flat { - -typedef AnnIVFFlatTest AnnIVFFlatTestF; -TEST_P(AnnIVFFlatTestF, AnnIVFFlat) -{ - this->testIVFFlat(); - this->testPacker(); -} - -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu deleted file mode 100644 index 6fe12506aa..0000000000 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_flat.cuh" - -#include - -namespace raft::neighbors::ivf_flat { - -typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; -TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) { this->testIVFFlat(); } - -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_int8, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu deleted file mode 100644 index ab6001c71b..0000000000 --- a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_flat.cuh" - -#include - -namespace raft::neighbors::ivf_flat { - -typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; -TEST_P(AnnIVFFlatTestF_uint8, AnnIVFFlat) { this->testIVFFlat(); } - -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_uint8, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh deleted file mode 100644 index 4ebe02027f..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ /dev/null @@ -1,1095 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../test_utils.cuh" -#include "ann_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include -#include -#include - -namespace raft::neighbors::ivf_pq { - -struct test_ivf_sample_filter { - static constexpr unsigned offset = 1500; -}; - -struct ivf_pq_inputs { - uint32_t num_db_vecs = 4096; - uint32_t num_queries = 1024; - uint32_t dim = 64; - uint32_t k = 32; - std::optional min_recall = std::nullopt; - - ivf_pq::index_params index_params; - ivf_pq::search_params search_params; - - // Set some default parameters for tests - ivf_pq_inputs() - { - index_params.n_lists = max(32u, min(1024u, num_db_vecs / 128u)); - index_params.kmeans_trainset_fraction = 1.0; - } -}; - -inline auto operator<<(std::ostream& os, const ivf_pq::codebook_gen& p) -> std::ostream& -{ - switch (p) { - case ivf_pq::codebook_gen::PER_CLUSTER: os << "codebook_gen::PER_CLUSTER"; break; - case ivf_pq::codebook_gen::PER_SUBSPACE: os << "codebook_gen::PER_SUBSPACE"; break; - default: RAFT_FAIL("unreachable code"); - } - return os; -} - -inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream& -{ - ivf_pq_inputs dflt; - bool need_comma = false; -#define PRINT_DIFF_V(spec, val) \ - do { \ - if (dflt spec != p spec) { \ - if (need_comma) { os << ", "; } \ - os << #spec << " = " << val; \ - need_comma = true; \ - } \ - } while (0) -#define PRINT_DIFF(spec) PRINT_DIFF_V(spec, p spec) - - os << "ivf_pq_inputs {"; - PRINT_DIFF(.num_db_vecs); - PRINT_DIFF(.num_queries); - PRINT_DIFF(.dim); - PRINT_DIFF(.k); - PRINT_DIFF_V(.min_recall, p.min_recall.value_or(0)); - PRINT_DIFF_V(.index_params.metric, print_metric{p.index_params.metric}); - PRINT_DIFF(.index_params.metric_arg); - PRINT_DIFF(.index_params.add_data_on_build); - PRINT_DIFF(.index_params.n_lists); - PRINT_DIFF(.index_params.kmeans_n_iters); - PRINT_DIFF(.index_params.kmeans_trainset_fraction); - PRINT_DIFF(.index_params.pq_bits); - PRINT_DIFF(.index_params.pq_dim); - PRINT_DIFF(.index_params.codebook_kind); - PRINT_DIFF(.index_params.force_random_rotation); - PRINT_DIFF(.search_params.n_probes); - PRINT_DIFF_V(.search_params.lut_dtype, print_dtype{p.search_params.lut_dtype}); - PRINT_DIFF_V(.search_params.internal_distance_dtype, - print_dtype{p.search_params.internal_distance_dtype}); - os << "}"; - return os; -} - -template -void compare_vectors_l2( - const raft::resources& res, T a, T b, uint32_t label, double compression_ratio, double eps) -{ - auto n_rows = a.extent(0); - auto dim = a.extent(1); - rmm::mr::managed_memory_resource managed_memory; - auto dist = make_device_mdarray(res, &managed_memory, make_extents(n_rows)); - linalg::map_offset(res, dist.view(), [a, b, dim] __device__(uint32_t i) { - spatial::knn::detail::utils::mapping f{}; - double d = 0.0f; - for (uint32_t j = 0; j < dim; j++) { - double t = f(a(i, j)) - f(b(i, j)); - d += t * t; - } - return sqrt(d / double(dim)); - }); - resource::sync_stream(res); - for (uint32_t i = 0; i < n_rows; i++) { - double d = dist(i); - // The theoretical estimate of the error is hard to come up with, - // the estimate below is based on experimentation + curse of dimensionality - ASSERT_LE(d, 1.2 * eps * std::pow(2.0, compression_ratio)) - << " (label = " << label << ", ix = " << i << ", eps = " << eps << ")"; - } -} - -template -auto min_output_size(const raft::resources& handle, - const ivf_pq::index& index, - uint32_t n_probes) -> IdxT -{ - auto acc_sizes = index.accum_sorted_sizes(); - uint32_t last_nonzero = index.n_lists(); - while (last_nonzero > 0 && acc_sizes(last_nonzero - 1) == acc_sizes(last_nonzero)) { - last_nonzero--; - } - return acc_sizes(last_nonzero) - acc_sizes(last_nonzero - std::min(last_nonzero, n_probes)); -} - -template -class ivf_pq_test : public ::testing::TestWithParam { - public: - ivf_pq_test() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_), - search_queries(0, stream_) - { - } - - void gen_data() - { - database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); - search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); - - raft::random::RngState r(1234ULL); - if constexpr (std::is_same{}) { - raft::random::uniform( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); - raft::random::uniform( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); - raft::random::uniformInt( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); - } - resource::sync_stream(handle_); - } - - void calc_ref() - { - size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data(), - ps.num_queries, - ps.num_db_vecs, - ps.dim, - ps.k, - ps.index_params.metric); - distances_ref.resize(queries_size); - update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); - indices_ref.resize(queries_size); - update_host(indices_ref.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - auto build_only() - { - auto ipams = ps.index_params; - ipams.add_data_on_build = true; - - auto index_view = - raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - return ivf_pq::build(handle_, ipams, index_view); - } - - auto build_2_extends() - { - auto db_indices = make_device_vector(handle_, ps.num_db_vecs); - linalg::map_offset(handle_, db_indices.view(), identity_op{}); - resource::sync_stream(handle_); - auto size_1 = IdxT(ps.num_db_vecs) / 2; - auto size_2 = IdxT(ps.num_db_vecs) - size_1; - auto vecs_1 = database.data(); - auto vecs_2 = database.data() + size_t(size_1) * size_t(ps.dim); - auto inds_1 = db_indices.data_handle(); - auto inds_2 = db_indices.data_handle() + size_t(size_1); - - auto ipams = ps.index_params; - ipams.add_data_on_build = false; - - auto database_view = - raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - auto idx = ivf_pq::build(handle_, ipams, database_view); - - auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); - auto inds_2_view = raft::make_device_vector_view(inds_2, size_2); - ivf_pq::extend(handle_, vecs_2_view, inds_2_view, &idx); - - auto vecs_1_view = - raft::make_device_matrix_view(vecs_1, size_1, ps.dim); - auto inds_1_view = raft::make_device_vector_view(inds_1, size_1); - ivf_pq::extend(handle_, vecs_1_view, inds_1_view, &idx); - return idx; - } - - auto build_serialize() - { - ivf_pq::serialize(handle_, "ivf_pq_index", build_only()); - return ivf_pq::deserialize(handle_, "ivf_pq_index"); - } - - void check_reconstruction(const index& index, - double compression_ratio, - uint32_t label, - uint32_t n_take, - uint32_t n_skip) - { - auto& rec_list = index.lists()[label]; - auto dim = index.dim(); - n_take = std::min(n_take, rec_list->size.load()); - n_skip = std::min(n_skip, rec_list->size.load() - n_take); - - if (n_take == 0) { return; } - - auto rec_data = make_device_matrix(handle_, n_take, dim); - auto orig_data = make_device_matrix(handle_, n_take, dim); - - ivf_pq::helpers::reconstruct_list_data(handle_, index, rec_data.view(), label, n_skip); - - matrix::gather(database.data(), - IdxT{dim}, - IdxT{n_take}, - rec_list->indices.data_handle() + n_skip, - IdxT{n_take}, - orig_data.data_handle(), - stream_); - - compare_vectors_l2(handle_, rec_data.view(), orig_data.view(), label, compression_ratio, 0.06); - } - - void check_reconstruct_extend(index* index, double compression_ratio, uint32_t label) - { - // NB: this is not reference, the list is retained; the index will have to create a new list on - // `erase_list` op. - auto old_list = index->lists()[label]; - auto n_rows = old_list->size.load(); - if (n_rows == 0) { return; } - - auto vectors_1 = make_device_matrix(handle_, n_rows, index->dim()); - auto indices = make_device_vector(handle_, n_rows); - copy(indices.data_handle(), old_list->indices.data_handle(), n_rows, stream_); - - ivf_pq::helpers::reconstruct_list_data(handle_, *index, vectors_1.view(), label, 0); - ivf_pq::helpers::erase_list(handle_, index, label); - // NB: passing the type parameter because const->non-const implicit conversion of the mdspans - // breaks type inference - ivf_pq::helpers::extend_list( - handle_, index, vectors_1.view(), indices.view(), label); - - auto& new_list = index->lists()[label]; - ASSERT_NE(old_list.get(), new_list.get()) - << "The old list should have been shared and retained after ivf_pq index has erased the " - "corresponding cluster."; - - auto vectors_2 = make_device_matrix(handle_, n_rows, index->dim()); - ivf_pq::helpers::reconstruct_list_data(handle_, *index, vectors_2.view(), label, 0); - // The code search is unstable, and there's high chance of repeating values of the lvl-2 codes. - // Hence, encoding-decoding chain often leads to altering both the PQ codes and the - // reconstructed data. - compare_vectors_l2( - handle_, vectors_1.view(), vectors_2.view(), label, compression_ratio, 0.04); // 0.025); - } - - void check_packing(index* index, uint32_t label) - { - auto old_list = index->lists()[label]; - auto n_rows = old_list->size.load(); - - if (n_rows == 0) { return; } - - auto codes = make_device_matrix(handle_, n_rows, index->pq_dim()); - auto indices = make_device_vector(handle_, n_rows); - copy(indices.data_handle(), old_list->indices.data_handle(), n_rows, stream_); - - ivf_pq::helpers::unpack_list_data(handle_, *index, codes.view(), label, 0); - ivf_pq::helpers::erase_list(handle_, index, label); - ivf_pq::helpers::extend_list_with_codes( - handle_, index, codes.view(), indices.view(), label); - - auto& new_list = index->lists()[label]; - ASSERT_NE(old_list.get(), new_list.get()) - << "The old list should have been shared and retained after ivf_pq index has erased the " - "corresponding cluster."; - auto list_data_size = (n_rows / ivf_pq::kIndexGroupSize) * new_list->data.extent(1) * - new_list->data.extent(2) * new_list->data.extent(3); - - ASSERT_TRUE(old_list->data.size() >= list_data_size); - ASSERT_TRUE(new_list->data.size() >= list_data_size); - ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), - new_list->data.data_handle(), - list_data_size, - Compare{})); - - // Pack a few vectors back to the list. - int row_offset = 9; - int n_vec = 3; - ASSERT_TRUE(row_offset + n_vec < n_rows); - size_t offset = row_offset * index->pq_dim(); - auto codes_to_pack = make_device_matrix_view( - codes.data_handle() + offset, n_vec, index->pq_dim()); - ivf_pq::helpers::pack_list_data(handle_, index, codes_to_pack, label, row_offset); - ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), - new_list->data.data_handle(), - list_data_size, - Compare{})); - - // Another test with the API that take list_data directly - auto list_data = index->lists()[label]->data.view(); - uint32_t n_take = 4; - ASSERT_TRUE(row_offset + n_take < n_rows); - auto codes2 = raft::make_device_matrix(handle_, n_take, index->pq_dim()); - ivf_pq::helpers::codepacker::unpack( - handle_, list_data, index->pq_bits(), row_offset, codes2.view()); - - // Write it back - ivf_pq::helpers::codepacker::pack( - handle_, make_const_mdspan(codes2.view()), index->pq_bits(), row_offset, list_data); - ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), - new_list->data.data_handle(), - list_data_size, - Compare{})); - } - void check_packing_contiguous(index* index, uint32_t label) - { - auto old_list = index->lists()[label]; - auto n_rows = old_list->size.load(); - - if (n_rows == 0) { return; } - - auto codes = make_device_matrix(handle_, n_rows, index->pq_dim()); - auto indices = make_device_vector(handle_, n_rows); - copy(indices.data_handle(), old_list->indices.data_handle(), n_rows, stream_); - - uint32_t code_size = ceildiv(index->pq_dim() * index->pq_bits(), 8); - - auto codes_compressed = make_device_matrix(handle_, n_rows, code_size); - - ivf_pq::helpers::unpack_contiguous_list_data( - handle_, *index, codes_compressed.data_handle(), n_rows, label, 0); - ivf_pq::helpers::erase_list(handle_, index, label); - ivf_pq::detail::extend_list_prepare(handle_, index, make_const_mdspan(indices.view()), label); - ivf_pq::helpers::pack_contiguous_list_data( - handle_, index, codes_compressed.data_handle(), n_rows, label, 0); - ivf_pq::helpers::recompute_internal_state(handle_, index); - - auto& new_list = index->lists()[label]; - ASSERT_NE(old_list.get(), new_list.get()) - << "The old list should have been shared and retained after ivf_pq index has erased the " - "corresponding cluster."; - auto list_data_size = (n_rows / ivf_pq::kIndexGroupSize) * new_list->data.extent(1) * - new_list->data.extent(2) * new_list->data.extent(3); - - ASSERT_TRUE(old_list->data.size() >= list_data_size); - ASSERT_TRUE(new_list->data.size() >= list_data_size); - ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), - new_list->data.data_handle(), - list_data_size, - Compare{})); - - // Pack a few vectors back to the list. - uint32_t row_offset = 9; - uint32_t n_vec = 3; - ASSERT_TRUE(row_offset + n_vec < n_rows); - size_t offset = row_offset * code_size; - auto codes_to_pack = make_device_matrix_view( - codes_compressed.data_handle() + offset, n_vec, index->pq_dim()); - ivf_pq::helpers::pack_contiguous_list_data( - handle_, index, codes_to_pack.data_handle(), n_vec, label, row_offset); - ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), - new_list->data.data_handle(), - list_data_size, - Compare{})); - - // // Another test with the API that take list_data directly - auto list_data = index->lists()[label]->data.view(); - uint32_t n_take = 4; - ASSERT_TRUE(row_offset + n_take < n_rows); - auto codes2 = raft::make_device_matrix(handle_, n_take, code_size); - ivf_pq::helpers::codepacker::unpack_contiguous(handle_, - list_data, - index->pq_bits(), - row_offset, - n_take, - index->pq_dim(), - codes2.data_handle()); - - // Write it back - ivf_pq::helpers::codepacker::pack_contiguous(handle_, - codes2.data_handle(), - n_vec, - index->pq_dim(), - index->pq_bits(), - row_offset, - list_data); - ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), - new_list->data.data_handle(), - list_data_size, - Compare{})); - } - - template - void run(BuildIndex build_index) - { - index index = build_index(); - - double compression_ratio = - static_cast(ps.dim * 8) / static_cast(index.pq_dim() * index.pq_bits()); - - for (uint32_t label = 0; label < index.n_lists(); label++) { - switch (label % 3) { - case 0: { - // Reconstruct and re-write vectors for one label - check_reconstruct_extend(&index, compression_ratio, label); - } break; - case 1: { - // Dump and re-write codes for one label - check_packing(&index, label); - check_packing_contiguous(&index, label); - } break; - default: { - // check a small subset of data in a randomly chosen cluster to see if the data - // reconstruction works well. - check_reconstruction(index, compression_ratio, label, 100, 7); - } - } - } - - size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivf_pq(queries_size); - std::vector distances_ivf_pq(queries_size); - - rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); - rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); - - auto query_view = - raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); - auto inds_view = raft::make_device_matrix_view( - indices_ivf_pq_dev.data(), ps.num_queries, ps.k); - auto dists_view = raft::make_device_matrix_view( - distances_ivf_pq_dev.data(), ps.num_queries, ps.k); - - ivf_pq::search( - handle_, ps.search_params, index, query_view, inds_view, dists_view); - - update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); - update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - - // A very conservative lower bound on recall - double min_recall = - static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists); - // Using a heuristic to lower the required recall due to code-packing errors - min_recall = - std::min(std::erfc(0.05 * compression_ratio / std::max(min_recall, 0.5)), min_recall); - // Use explicit per-test min recall value if provided. - min_recall = ps.min_recall.value_or(min_recall); - - ASSERT_TRUE(eval_neighbours(indices_ref, - indices_ivf_pq, - distances_ref, - distances_ivf_pq, - ps.num_queries, - ps.k, - 0.0001 * compression_ratio, - min_recall)) - << ps; - - // Test a few extra invariants - IdxT min_results = min_output_size(handle_, index, ps.search_params.n_probes); - IdxT max_oob = ps.k <= min_results ? 0 : ps.k - min_results; - IdxT found_oob = 0; - for (uint32_t query_ix = 0; query_ix < ps.num_queries; query_ix++) { - for (uint32_t k = 0; k < ps.k; k++) { - auto flat_i = query_ix * ps.k + k; - auto found_ix = indices_ivf_pq[flat_i]; - if (found_ix == ivf_pq::kOutOfBoundsRecord) { - found_oob++; - continue; - } - ASSERT_NE(found_ix, ivf::kInvalidRecord) - << "got an invalid record at query_ix = " << query_ix << ", k = " << k - << " (distance = " << distances_ivf_pq[flat_i] << ")"; - ASSERT_LT(found_ix, ps.num_db_vecs) - << "got an impossible index = " << found_ix << " at query_ix = " << query_ix - << ", k = " << k << " (distance = " << distances_ivf_pq[flat_i] << ")"; - } - } - ASSERT_LE(found_oob, max_oob) - << "got too many records out-of-bounds (see ivf_pq::kOutOfBoundsRecord)."; - if (found_oob > 0) { - RAFT_LOG_WARN( - "Got %zu results out-of-bounds because of large top-k (%zu) and small n_probes (%u) and " - "small DB size/n_lists ratio (%zu / %u)", - size_t(found_oob), - size_t(ps.k), - ps.search_params.n_probes, - size_t(ps.num_db_vecs), - ps.index_params.n_lists); - } - } - - void SetUp() override // NOLINT - { - gen_data(); - calc_ref(); - } - - void TearDown() override // NOLINT - { - cudaGetLastError(); - resource::sync_stream(handle_); - database.resize(0, stream_); - search_queries.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - ivf_pq_inputs ps; // NOLINT - rmm::device_uvector database; // NOLINT - rmm::device_uvector search_queries; // NOLINT - std::vector indices_ref; // NOLINT - std::vector distances_ref; // NOLINT -}; - -template -class ivf_pq_filter_test : public ::testing::TestWithParam { - public: - ivf_pq_filter_test() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_), - search_queries(0, stream_) - { - } - - void gen_data() - { - database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); - search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); - - raft::random::RngState r(1234ULL); - if constexpr (std::is_same{}) { - raft::random::uniform( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); - raft::random::uniform( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); - raft::random::uniformInt( - handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); - } - resource::sync_stream(handle_); - } - - void calc_ref() - { - size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database.data() + test_ivf_sample_filter::offset * ps.dim, - ps.num_queries, - ps.num_db_vecs - test_ivf_sample_filter::offset, - ps.dim, - ps.k, - ps.index_params.metric); - raft::linalg::addScalar(indices_naive_dev.data(), - indices_naive_dev.data(), - IdxT(test_ivf_sample_filter::offset), - queries_size, - stream_); - distances_ref.resize(queries_size); - update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); - indices_ref.resize(queries_size); - update_host(indices_ref.data(), indices_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - auto build_only() - { - auto ipams = ps.index_params; - ipams.add_data_on_build = true; - - auto index_view = - raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - return ivf_pq::build(handle_, ipams, index_view); - } - - template - void run(BuildIndex build_index) - { - index index = build_index(); - - double compression_ratio = - static_cast(ps.dim * 8) / static_cast(index.pq_dim() * index.pq_bits()); - size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivf_pq(queries_size); - std::vector distances_ivf_pq(queries_size); - - rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); - rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); - - auto query_view = - raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); - auto inds_view = raft::make_device_matrix_view( - indices_ivf_pq_dev.data(), ps.num_queries, ps.k); - auto dists_view = raft::make_device_matrix_view( - distances_ivf_pq_dev.data(), ps.num_queries, ps.k); - - // Create Bitset filter - auto removed_indices = - raft::make_device_vector(handle_, test_ivf_sample_filter::offset); - thrust::sequence( - resource::get_thrust_policy(handle_), - thrust::device_pointer_cast(removed_indices.data_handle()), - thrust::device_pointer_cast(removed_indices.data_handle() + test_ivf_sample_filter::offset)); - resource::sync_stream(handle_); - - raft::core::bitset removed_indices_bitset( - handle_, removed_indices.view(), ps.num_db_vecs); - ivf_pq::search_with_filtering( - handle_, - ps.search_params, - index, - query_view, - inds_view, - dists_view, - raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); - - update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); - update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - - // A very conservative lower bound on recall - double min_recall = - static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists); - // Using a heuristic to lower the required recall due to code-packing errors - min_recall = - std::min(std::erfc(0.05 * compression_ratio / std::max(min_recall, 0.5)), min_recall); - // Use explicit per-test min recall value if provided. - min_recall = ps.min_recall.value_or(min_recall); - - ASSERT_TRUE(eval_neighbours(indices_ref, - indices_ivf_pq, - distances_ref, - distances_ivf_pq, - ps.num_queries, - ps.k, - 0.0001 * compression_ratio, - min_recall)) - << ps; - } - - void SetUp() override // NOLINT - { - gen_data(); - calc_ref(); - } - - void TearDown() override // NOLINT - { - cudaGetLastError(); - resource::sync_stream(handle_); - database.resize(0, stream_); - search_queries.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - ivf_pq_inputs ps; // NOLINT - rmm::device_uvector database; // NOLINT - rmm::device_uvector search_queries; // NOLINT - std::vector indices_ref; // NOLINT - std::vector distances_ref; // NOLINT -}; - -/* Test cases */ -using test_cases_t = std::vector; - -// concatenate parameter sets for different type -template -auto operator+(const std::vector& a, const std::vector& b) -> std::vector -{ - std::vector res = a; - res.insert(res.end(), b.begin(), b.end()); - return res; -} - -inline auto defaults() -> test_cases_t { return {ivf_pq_inputs{}}; } - -template -auto map(const std::vector& xs, F f) -> std::vector -{ - std::vector ys(xs.size()); - std::transform(xs.begin(), xs.end(), ys.begin(), f); - return ys; -} - -inline auto with_dims(const std::vector& dims) -> test_cases_t -{ - return map(dims, [](uint32_t d) { - ivf_pq_inputs x; - x.dim = d; - return x; - }); -} - -/** These will surely trigger the fastest kernel available. */ -inline auto small_dims() -> test_cases_t { return with_dims({1, 2, 3, 4, 5, 8, 15, 16, 17}); } - -inline auto small_dims_per_cluster() -> test_cases_t -{ - return map(small_dims(), [](const ivf_pq_inputs& x) { - ivf_pq_inputs y(x); - y.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; - return y; - }); -} - -inline auto big_dims() -> test_cases_t -{ - // with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288, 16384}); - auto xs = with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144}); - return map(xs, [](const ivf_pq_inputs& x) { - ivf_pq_inputs y(x); - uint32_t pq_len = 2; - y.index_params.pq_dim = div_rounding_up_safe(x.dim, pq_len); - // This comes from pure experimentation, also the recall depens a lot on pq_len. - y.min_recall = 0.48 + 0.028 * std::log2(x.dim); - return y; - }); -} - -/** These will surely trigger no-smem-lut kernel. */ -inline auto big_dims_moderate_lut() -> test_cases_t -{ - return map(big_dims(), [](const ivf_pq_inputs& x) { - ivf_pq_inputs y(x); - uint32_t pq_len = 2; - y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u); - y.index_params.pq_bits = 6; - y.search_params.lut_dtype = CUDA_R_16F; - y.min_recall = 0.69; - return y; - }); -} - -/** Some of these should trigger no-basediff kernel. */ -inline auto big_dims_small_lut() -> test_cases_t -{ - return map(big_dims(), [](const ivf_pq_inputs& x) { - ivf_pq_inputs y(x); - uint32_t pq_len = 8; - y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u); - y.index_params.pq_bits = 6; - y.search_params.lut_dtype = CUDA_R_8U; - y.min_recall = 0.21; - return y; - }); -} - -/** - * A minimal set of tests to check various enum-like parameters. - */ -inline auto enum_variety() -> test_cases_t -{ - test_cases_t xs; -#define ADD_CASE(f) \ - do { \ - xs.push_back({}); \ - ([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \ - } while (0); - - ADD_CASE({ - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; - x.min_recall = 0.86; - }); - ADD_CASE({ - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; - x.min_recall = 0.86; - }); - ADD_CASE({ - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; - x.index_params.pq_bits = 4; - x.min_recall = 0.79; - }); - ADD_CASE({ - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; - x.index_params.pq_bits = 5; - x.min_recall = 0.83; - }); - - ADD_CASE({ - x.index_params.pq_bits = 6; - x.min_recall = 0.84; - }); - ADD_CASE({ - x.index_params.pq_bits = 7; - x.min_recall = 0.85; - }); - ADD_CASE({ - x.index_params.pq_bits = 8; - x.min_recall = 0.86; - }); - - ADD_CASE({ - x.index_params.force_random_rotation = true; - x.min_recall = 0.86; - }); - ADD_CASE({ - x.index_params.force_random_rotation = false; - x.min_recall = 0.86; - }); - - ADD_CASE({ - x.search_params.lut_dtype = CUDA_R_32F; - x.min_recall = 0.86; - }); - ADD_CASE({ - x.search_params.lut_dtype = CUDA_R_16F; - x.min_recall = 0.86; - }); - ADD_CASE({ - x.search_params.lut_dtype = CUDA_R_8U; - x.min_recall = 0.84; - }); - - ADD_CASE({ - x.search_params.internal_distance_dtype = CUDA_R_32F; - x.min_recall = 0.86; - }); - ADD_CASE({ - x.search_params.internal_distance_dtype = CUDA_R_16F; - x.search_params.lut_dtype = CUDA_R_16F; - x.min_recall = 0.86; - }); - - return xs; -} - -inline auto enum_variety_l2() -> test_cases_t -{ - return map(enum_variety(), [](const ivf_pq_inputs& x) { - ivf_pq_inputs y(x); - y.index_params.metric = distance::DistanceType::L2Expanded; - return y; - }); -} - -inline auto enum_variety_ip() -> test_cases_t -{ - return map(enum_variety(), [](const ivf_pq_inputs& x) { - ivf_pq_inputs y(x); - if (y.min_recall.has_value()) { - if (y.search_params.lut_dtype == CUDA_R_8U) { - // InnerProduct score is signed, - // thus we're forced to used signed 8-bit representation, - // thus we have one bit less precision - y.min_recall = y.min_recall.value() * 0.90; - } else { - // In other cases it seems to perform a little bit better, still worse than L2 - y.min_recall = y.min_recall.value() * 0.94; - } - } - y.index_params.metric = distance::DistanceType::InnerProduct; - return y; - }); -} - -inline auto enum_variety_l2sqrt() -> test_cases_t -{ - return map(enum_variety(), [](const ivf_pq_inputs& x) { - ivf_pq_inputs y(x); - y.index_params.metric = distance::DistanceType::L2SqrtExpanded; - return y; - }); -} - -/** - * Try different number of n_probes, some of which may trigger the non-fused version of the search - * kernel. - */ -inline auto var_n_probes() -> test_cases_t -{ - ivf_pq_inputs dflt; - std::vector xs; - for (auto x = dflt.index_params.n_lists; x >= 1; x /= 2) { - xs.push_back(x); - } - return map(xs, [](uint32_t n_probes) { - ivf_pq_inputs x; - x.search_params.n_probes = n_probes; - return x; - }); -} - -/** - * Try different number of nearest neighbours. - * Values smaller than 32 test if the code behaves well when Capacity (== 32) does not change, - * but `k <= Capacity` changes. - * - * Values between `32 and ivf_pq::detail::kMaxCapacity` test various instantiations of the - * main kernel (Capacity-templated) - * - * Values above ivf_pq::detail::kMaxCapacity should trigger the non-fused version of the kernel - * (manage_local_topk = false). - * - * Also we test here various values that are close-but-not-power-of-two to catch any problems - * related to rounding/alignment. - * - * Note, we cannot control explicitly which instance of the search kernel to choose, hence it's - * important to try a variety of different values of `k` to make sure all paths are triggered. - * - * Set the log level to DEBUG (5) or above to inspect the selected kernel instances. - */ -inline auto var_k() -> test_cases_t -{ - return map( - {1, 2, 3, 5, 8, 15, 16, 32, 63, 65, 127, 128, 256, 257, 1023, 2048, 2049}, [](uint32_t k) { - ivf_pq_inputs x; - x.k = k; - // when there's not enough data, try more cluster probes - x.search_params.n_probes = max(x.search_params.n_probes, min(x.index_params.n_lists, k)); - return x; - }); -} - -/** - * Cases brought up from downstream projects. - */ -inline auto special_cases() -> test_cases_t -{ - test_cases_t xs; - -#define ADD_CASE(f) \ - do { \ - xs.push_back({}); \ - ([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \ - } while (0); - - ADD_CASE({ - x.num_db_vecs = 1183514; - x.dim = 100; - x.num_queries = 10000; - x.k = 10; - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; - x.index_params.pq_dim = 10; - x.index_params.pq_bits = 8; - x.index_params.n_lists = 1024; - x.search_params.n_probes = 50; - }); - - ADD_CASE({ - x.num_db_vecs = 10000; - x.dim = 16; - x.num_queries = 500; - x.k = 128; - x.index_params.metric = distance::DistanceType::L2Expanded; - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; - x.index_params.pq_bits = 8; - x.index_params.n_lists = 100; - x.search_params.n_probes = 100; - }); - - ADD_CASE({ - x.num_db_vecs = 10000; - x.dim = 16; - x.num_queries = 500; - x.k = 129; - x.index_params.metric = distance::DistanceType::L2Expanded; - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; - x.index_params.pq_bits = 8; - x.index_params.n_lists = 100; - x.search_params.n_probes = 100; - }); - - ADD_CASE({ - x.num_db_vecs = 4335; - x.dim = 4; - x.num_queries = 100000; - x.k = 12; - x.index_params.metric = distance::DistanceType::L2Expanded; - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; - x.index_params.pq_dim = 2; - x.index_params.pq_bits = 8; - x.index_params.n_lists = 69; - x.search_params.n_probes = 69; - }); - - ADD_CASE({ - x.num_db_vecs = 4335; - x.dim = 4; - x.num_queries = 100000; - x.k = 12; - x.index_params.metric = distance::DistanceType::L2Expanded; - x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; - x.index_params.pq_dim = 2; - x.index_params.pq_bits = 8; - x.index_params.n_lists = 69; - x.search_params.n_probes = 69; - }); - - return xs; -} - -/* Test instantiations */ - -#define TEST_BUILD_SEARCH(type) \ - TEST_P(type, build_search) /* NOLINT */ \ - { \ - this->run([this]() { return this->build_only(); }); \ - } - -#define TEST_BUILD_EXTEND_SEARCH(type) \ - TEST_P(type, build_extend_search) /* NOLINT */ \ - { \ - this->run([this]() { return this->build_2_extends(); }); \ - } - -#define TEST_BUILD_SERIALIZE_SEARCH(type) \ - TEST_P(type, build_serialize_search) /* NOLINT */ \ - { \ - this->run([this]() { return this->build_serialize(); }); \ - } - -#define INSTANTIATE(type, vals) \ - INSTANTIATE_TEST_SUITE_P(IvfPq, type, ::testing::ValuesIn(vals)); /* NOLINT */ - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu deleted file mode 100644 index 5ba21c3c2f..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include // raft::neighbors::ivf_pq::index -#include - -#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ - template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset); \ - \ - template auto raft::neighbors::ivf_pq::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index; - -instantiate_raft_neighbors_ivf_pq_build(float, uint32_t); - -#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh deleted file mode 100644 index cd5435ab2e..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include // raft::neighbors::ivf_pq::index -#include - -#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ - extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset); \ - \ - extern template auto raft::neighbors::ivf_pq::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index; - -instantiate_raft_neighbors_ivf_pq_build(float, uint32_t); - -#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu deleted file mode 100644 index 00baa59f58..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include // raft::neighbors::ivf_pq::index -#include - -#include - -#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances) - -instantiate_raft_neighbors_ivf_pq_search(float, uint32_t); - -#undef instantiate_raft_neighbors_ivf_pq_search - -#define instantiate_raft_neighbors_ivf_pq_search_with_filtering(T, IdxT, FilterT) \ - template void raft::neighbors::ivf_pq::search_with_filtering( \ - raft::resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - FilterT sample_filter) - -#define COMMA , -instantiate_raft_neighbors_ivf_pq_search_with_filtering( - float, uint32_t, raft::neighbors::filtering::bitset_filter); - -instantiate_raft_neighbors_ivf_pq_search_with_filtering( - int8_t, int64_t, raft::neighbors::filtering::bitset_filter); - -instantiate_raft_neighbors_ivf_pq_search_with_filtering( - float, uint32_t, raft::neighbors::filtering::none_ivf_sample_filter); - -#undef COMMA -#undef instantiate_raft_neighbors_ivf_pq_search_with_filtering diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu deleted file mode 100644 index 70d5d8761f..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_pq.cuh" - -#include -#include - -namespace raft::neighbors::ivf_pq { - -using f32_f32_i64_filter = ivf_pq_filter_test; - -TEST_BUILD_SEARCH(f32_f32_i64_filter) -INSTANTIATE(f32_f32_i64_filter, defaults() + big_dims_moderate_lut()); -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu deleted file mode 100644 index ba96a8db0b..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_pq.cuh" - -#include -#include - -namespace raft::neighbors::ivf_pq { - -using f32_i08_i64_filter = ivf_pq_filter_test; - -TEST_BUILD_SEARCH(f32_i08_i64_filter) -INSTANTIATE(f32_i08_i64_filter, big_dims()); - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu deleted file mode 100644 index 9859061d70..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_pq.cuh" - -namespace raft::neighbors::ivf_pq { - -using f32_f32_i64 = ivf_pq_test; - -TEST_BUILD_EXTEND_SEARCH(f32_f32_i64) -TEST_BUILD_SERIALIZE_SEARCH(f32_f32_i64) -INSTANTIATE(f32_f32_i64, defaults() + small_dims() + big_dims_moderate_lut()); - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu deleted file mode 100644 index b8ada2249a..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_pq.cuh" -#include "ivf_pq_build_test-ext.cuh" - -#include -#include - -namespace raft::neighbors::ivf_pq { - -using f32_f32_u32 = ivf_pq_test; -using f32_f32_u32_filter = ivf_pq_filter_test; - -TEST_BUILD_SEARCH(f32_f32_u32) -TEST_BUILD_SERIALIZE_SEARCH(f32_f32_u32) -INSTANTIATE(f32_f32_u32, defaults() + var_n_probes() + var_k() + special_cases()); - -TEST_BUILD_SEARCH(f32_f32_u32_filter) -INSTANTIATE(f32_f32_u32_filter, defaults()); -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu deleted file mode 100644 index 970bdd6a12..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_pq.cuh" - -#include -namespace raft::neighbors::ivf_pq { - -using f32_i08_i64 = ivf_pq_test; - -TEST_BUILD_SEARCH(f32_i08_i64) -TEST_BUILD_SERIALIZE_SEARCH(f32_i08_i64) -INSTANTIATE(f32_i08_i64, defaults() + big_dims() + var_k()); - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu deleted file mode 100644 index e949c2f7ed..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_pq.cuh" - -namespace raft::neighbors::ivf_pq { - -using f32_u08_i64 = ivf_pq_test; - -TEST_BUILD_SEARCH(f32_u08_i64) -TEST_BUILD_EXTEND_SEARCH(f32_u08_i64) -INSTANTIATE(f32_u08_i64, small_dims_per_cluster() + enum_variety()); - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh deleted file mode 100644 index 5070d83b15..0000000000 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "ann_utils.cuh" - -#include -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include - -namespace raft::neighbors::experimental::nn_descent { - -struct AnnNNDescentInputs { - int n_rows; - int dim; - int graph_degree; - raft::distance::DistanceType metric; - bool host_dataset; - double min_recall; -}; - -struct AnnNNDescentBatchInputs { - std::pair recall_cluster; - int n_rows; - int dim; - int graph_degree; - raft::distance::DistanceType metric; - bool host_dataset; -}; - -inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& p) -{ - os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree - << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") - << std::endl; - return os; -} - -inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentBatchInputs& p) -{ - os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree - << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") - << ", clusters=" << p.recall_cluster.second << std::endl; - return os; -} - -template -class AnnNNDescentTest : public ::testing::TestWithParam { - public: - AnnNNDescentTest() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_) - { - } - - protected: - void testNNDescent() - { - size_t queries_size = ps.n_rows * ps.graph_degree; - std::vector indices_NNDescent(queries_size); - std::vector distances_NNDescent(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - database.data(), - database.data(), - ps.n_rows, - ps.n_rows, - ps.dim, - ps.graph_degree, - ps.metric); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - { - { - nn_descent::index_params index_params; - index_params.metric = ps.metric; - index_params.graph_degree = ps.graph_degree; - index_params.intermediate_graph_degree = 2 * ps.graph_degree; - index_params.max_iterations = 100; - index_params.return_distances = true; - - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.n_rows, ps.dim); - - { - if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); - auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - index index{handle_, ps.n_rows, static_cast(ps.graph_degree), true}; - nn_descent::build( - handle_, index_params, database_host_view, index, DistEpilogue()); - raft::copy( - indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - if (index.distances().has_value()) { - raft::copy(distances_NNDescent.data(), - index.distances().value().data_handle(), - queries_size, - stream_); - } - - } else { - index index{handle_, ps.n_rows, static_cast(ps.graph_degree), true}; - nn_descent::build( - handle_, index_params, database_view, index, DistEpilogue()); - raft::copy( - indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - if (index.distances().has_value()) { - raft::copy(distances_NNDescent.data(), - index.distances().value().data_handle(), - queries_size, - stream_); - } - }; - } - resource::sync_stream(handle_); - } - - double min_recall = ps.min_recall; - EXPECT_TRUE(eval_neighbours(indices_naive, - indices_NNDescent, - distances_naive, - distances_NNDescent, - ps.n_rows, - ps.graph_degree, - 0.001, - min_recall)); - } - } - - void SetUp() override - { - database.resize(((size_t)ps.n_rows) * ps.dim, stream_); - raft::random::RngState r(1234ULL); - if constexpr (std::is_same{}) { - raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); - } - resource::sync_stream(handle_); - } - - void TearDown() override - { - resource::sync_stream(handle_); - database.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - AnnNNDescentInputs ps; - rmm::device_uvector database; -}; - -template -class AnnNNDescentBatchTest : public ::testing::TestWithParam { - public: - AnnNNDescentBatchTest() - : stream_(resource::get_cuda_stream(handle_)), - ps(::testing::TestWithParam::GetParam()), - database(0, stream_) - { - } - - void testNNDescentBatch() - { - size_t queries_size = ps.n_rows * ps.graph_degree; - std::vector indices_NNDescent(queries_size); - std::vector distances_NNDescent(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - naive_knn(handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - database.data(), - database.data(), - ps.n_rows, - ps.n_rows, - ps.dim, - ps.graph_degree, - ps.metric); - update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - resource::sync_stream(handle_); - } - - { - { - nn_descent::index_params index_params; - index_params.metric = ps.metric; - index_params.graph_degree = ps.graph_degree; - index_params.intermediate_graph_degree = 2 * ps.graph_degree; - index_params.max_iterations = 10; - index_params.return_distances = true; - index_params.n_clusters = ps.recall_cluster.second; - - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.n_rows, ps.dim); - - { - if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); - auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - auto index = nn_descent::build( - handle_, index_params, database_host_view, DistEpilogue()); - raft::copy( - indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - if (index.distances().has_value()) { - raft::copy(distances_NNDescent.data(), - index.distances().value().data_handle(), - queries_size, - stream_); - } - - } else { - auto index = nn_descent::build( - handle_, index_params, database_view, DistEpilogue()); - raft::copy( - indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - if (index.distances().has_value()) { - raft::copy(distances_NNDescent.data(), - index.distances().value().data_handle(), - queries_size, - stream_); - } - }; - } - resource::sync_stream(handle_); - } - double min_recall = ps.recall_cluster.first; - EXPECT_TRUE(eval_neighbours(indices_naive, - indices_NNDescent, - distances_naive, - distances_NNDescent, - ps.n_rows, - ps.graph_degree, - 0.01, - min_recall, - true, - static_cast(ps.graph_degree * 0.1))); - } - } - - void SetUp() override - { - database.resize(((size_t)ps.n_rows) * ps.dim, stream_); - raft::random::RngState r(1234ULL); - if constexpr (std::is_same{}) { - raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); - } else { - raft::random::uniformInt( - handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); - } - resource::sync_stream(handle_); - } - - void TearDown() override - { - resource::sync_stream(handle_); - database.resize(0, stream_); - } - - private: - raft::resources handle_; - rmm::cuda_stream_view stream_; - AnnNNDescentBatchInputs ps; - rmm::device_uvector database; -}; - -const std::vector inputs = raft::util::itertools::product( - {1000, 2000}, // n_rows - {3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim - {32, 64}, // graph_degree - {raft::distance::DistanceType::L2Expanded}, - {false, true}, - {0.90}); - -// TODO: Investigate why this test is failing -// Reference issue https://github.com/rapidsai/raft/issues/2450 -// const std::vector inputsBatch = -// raft::util::itertools::product( -// {std::make_pair(0.9, 3lu), std::make_pair(0.9, 2lu)}, // min_recall, n_clusters -// {4000, 5000}, // n_rows -// {192, 512}, // dim -// {32, 64}, // graph_degree -// {raft::distance::DistanceType::L2Expanded}, -// {false, true}); - -} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu deleted file mode 100644 index c6f56e8c39..0000000000 --- a/cpp/test/neighbors/ann_nn_descent/test_batch_float_uint32_t.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_nn_descent.cuh" - -#include - -namespace raft::neighbors::experimental::nn_descent { - -typedef AnnNNDescentBatchTest AnnNNDescentBatchTestF_U32; -TEST_P(AnnNNDescentBatchTestF_U32, AnnNNDescentBatch) { this->testNNDescentBatch(); } - -INSTANTIATE_TEST_CASE_P(AnnNNDescentBatchTest, - AnnNNDescentBatchTestF_U32, - ::testing::ValuesIn(inputsBatch)); - -} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu deleted file mode 100644 index ec6d04ad12..0000000000 --- a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_nn_descent.cuh" - -#include - -namespace raft::neighbors::experimental::nn_descent { - -typedef AnnNNDescentTest AnnNNDescentTestF_U32; -TEST_P(AnnNNDescentTestF_U32, AnnNNDescent) { this->testNNDescent(); } - -INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu deleted file mode 100644 index 27fa42d636..0000000000 --- a/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_nn_descent.cuh" - -#include - -namespace raft::neighbors::experimental::nn_descent { - -typedef AnnNNDescentTest AnnNNDescentTestI8_U32; -TEST_P(AnnNNDescentTestI8_U32, AnnNNDescent) { this->testNNDescent(); } - -INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestI8_U32, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu deleted file mode 100644 index 3afe79dcc4..0000000000 --- a/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_nn_descent.cuh" - -#include - -namespace raft::neighbors::experimental::nn_descent { - -typedef AnnNNDescentTest AnnNNDescentTestUI8_U32; -TEST_P(AnnNNDescentTestUI8_U32, AnnNNDescent) { this->testNNDescent(); } - -INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestUI8_U32, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh deleted file mode 100644 index 82e3ace9da..0000000000 --- a/cpp/test/neighbors/ann_utils.cuh +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../test_utils.cuh" - -#include // raft::make_device_matrix -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -namespace raft::neighbors { - -struct print_dtype { - cudaDataType_t value; -}; - -inline auto operator<<(std::ostream& os, const print_dtype& p) -> std::ostream& -{ - switch (p.value) { - case CUDA_R_16F: os << "CUDA_R_16F"; break; - case CUDA_C_16F: os << "CUDA_C_16F"; break; - case CUDA_R_16BF: os << "CUDA_R_16BF"; break; - case CUDA_C_16BF: os << "CUDA_C_16BF"; break; - case CUDA_R_32F: os << "CUDA_R_32F"; break; - case CUDA_C_32F: os << "CUDA_C_32F"; break; - case CUDA_R_64F: os << "CUDA_R_64F"; break; - case CUDA_C_64F: os << "CUDA_C_64F"; break; - case CUDA_R_4I: os << "CUDA_R_4I"; break; - case CUDA_C_4I: os << "CUDA_C_4I"; break; - case CUDA_R_4U: os << "CUDA_R_4U"; break; - case CUDA_C_4U: os << "CUDA_C_4U"; break; - case CUDA_R_8I: os << "CUDA_R_8I"; break; - case CUDA_C_8I: os << "CUDA_C_8I"; break; - case CUDA_R_8U: os << "CUDA_R_8U"; break; - case CUDA_C_8U: os << "CUDA_C_8U"; break; - case CUDA_R_16I: os << "CUDA_R_16I"; break; - case CUDA_C_16I: os << "CUDA_C_16I"; break; - case CUDA_R_16U: os << "CUDA_R_16U"; break; - case CUDA_C_16U: os << "CUDA_C_16U"; break; - case CUDA_R_32I: os << "CUDA_R_32I"; break; - case CUDA_C_32I: os << "CUDA_C_32I"; break; - case CUDA_R_32U: os << "CUDA_R_32U"; break; - case CUDA_C_32U: os << "CUDA_C_32U"; break; - case CUDA_R_64I: os << "CUDA_R_64I"; break; - case CUDA_C_64I: os << "CUDA_C_64I"; break; - case CUDA_R_64U: os << "CUDA_R_64U"; break; - case CUDA_C_64U: os << "CUDA_C_64U"; break; - default: RAFT_FAIL("unreachable code"); - } - return os; -} - -struct print_metric { - raft::distance::DistanceType value; -}; - -inline auto operator<<(std::ostream& os, const print_metric& p) -> std::ostream& -{ - switch (p.value) { - case raft::distance::L2Expanded: os << "distance::L2Expanded"; break; - case raft::distance::L2SqrtExpanded: os << "distance::L2SqrtExpanded"; break; - case raft::distance::CosineExpanded: os << "distance::CosineExpanded"; break; - case raft::distance::L1: os << "distance::L1"; break; - case raft::distance::L2Unexpanded: os << "distance::L2Unexpanded"; break; - case raft::distance::L2SqrtUnexpanded: os << "distance::L2SqrtUnexpanded"; break; - case raft::distance::InnerProduct: os << "distance::InnerProduct"; break; - case raft::distance::Linf: os << "distance::Linf"; break; - case raft::distance::Canberra: os << "distance::Canberra"; break; - case raft::distance::LpUnexpanded: os << "distance::LpUnexpanded"; break; - case raft::distance::CorrelationExpanded: os << "distance::CorrelationExpanded"; break; - case raft::distance::JaccardExpanded: os << "distance::JaccardExpanded"; break; - case raft::distance::HellingerExpanded: os << "distance::HellingerExpanded"; break; - case raft::distance::Haversine: os << "distance::Haversine"; break; - case raft::distance::BrayCurtis: os << "distance::BrayCurtis"; break; - case raft::distance::JensenShannon: os << "distance::JensenShannon"; break; - case raft::distance::HammingUnexpanded: os << "distance::HammingUnexpanded"; break; - case raft::distance::KLDivergence: os << "distance::KLDivergence"; break; - case raft::distance::RusselRaoExpanded: os << "distance::RusselRaoExpanded"; break; - case raft::distance::DiceExpanded: os << "distance::DiceExpanded"; break; - case raft::distance::Precomputed: os << "distance::Precomputed"; break; - default: RAFT_FAIL("unreachable code"); - } - return os; -} - -template -struct idx_dist_pair { - IdxT idx; - DistT dist; - CompareDist eq_compare; - auto operator==(const idx_dist_pair& a) const -> bool - { - if (idx == a.idx) return true; - if (eq_compare(dist, a.dist)) return true; - return false; - } - idx_dist_pair(IdxT x, DistT y, CompareDist op) : idx(x), dist(y), eq_compare(op) {} -}; - -/** Calculate recall value using only neighbor indices - */ -template -auto calc_recall(const std::vector& expected_idx, - const std::vector& actual_idx, - size_t rows, - size_t cols) -{ - size_t match_count = 0; - size_t total_count = static_cast(rows) * static_cast(cols); - for (size_t i = 0; i < rows; ++i) { - for (size_t k = 0; k < cols; ++k) { - size_t idx_k = i * cols + k; // row major assumption! - auto act_idx = actual_idx[idx_k]; - for (size_t j = 0; j < cols; ++j) { - size_t idx = i * cols + j; // row major assumption! - auto exp_idx = expected_idx[idx]; - if (act_idx == exp_idx) { - match_count++; - break; - } - } - } - } - return std::make_tuple( - static_cast(match_count) / static_cast(total_count), match_count, total_count); -} - -/** check uniqueness of indices - */ -template -auto check_unique_indices(const std::vector& actual_idx, - size_t rows, - size_t cols, - size_t max_duplicates) -{ - size_t max_count; - size_t dup_count = 0lu; - std::set unique_indices; - for (size_t i = 0; i < rows; ++i) { - unique_indices.clear(); - max_count = 0; - for (size_t k = 0; k < cols; ++k) { - size_t idx_k = i * cols + k; // row major assumption! - auto act_idx = actual_idx[idx_k]; - if (act_idx == std::numeric_limits::max()) { - max_count++; - } else if (unique_indices.find(act_idx) == unique_indices.end()) { - unique_indices.insert(act_idx); - } else { - dup_count++; - if (dup_count > max_duplicates) { - return testing::AssertionFailure() - << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; - } - } - } - } - return testing::AssertionSuccess(); -} - -template -auto eval_recall(const std::vector& expected_idx, - const std::vector& actual_idx, - size_t rows, - size_t cols, - double eps, - double min_recall, - bool test_unique = true) -> testing::AssertionResult -{ - auto [actual_recall, match_count, total_count] = - calc_recall(expected_idx, actual_idx, rows, cols); - double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); - RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).", - actual_recall, - match_count, - total_count, - std::abs(error_margin * 100.0), - error_margin < 0 ? "above" : "below", - eps); - if (actual_recall < min_recall - eps) { - return testing::AssertionFailure() - << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" - << min_recall << "); eps = " << eps << ". "; - } - if (test_unique) - return check_unique_indices(actual_idx, rows, cols); - else - return testing::AssertionSuccess(); -} - -/** Overload of calc_recall to account for distances - */ -template -auto calc_recall(const std::vector& expected_idx, - const std::vector& actual_idx, - const std::vector& expected_dist, - const std::vector& actual_dist, - size_t rows, - size_t cols, - double eps) -{ - size_t match_count = 0; - size_t total_count = static_cast(rows) * static_cast(cols); - for (size_t i = 0; i < rows; ++i) { - for (size_t k = 0; k < cols; ++k) { - size_t idx_k = i * cols + k; // row major assumption! - auto act_idx = actual_idx[idx_k]; - auto act_dist = actual_dist[idx_k]; - for (size_t j = 0; j < cols; ++j) { - size_t idx = i * cols + j; // row major assumption! - auto exp_idx = expected_idx[idx]; - auto exp_dist = expected_dist[idx]; - idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox(eps)); - idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox(eps)); - if (exp_kvp == act_kvp) { - match_count++; - break; - } - } - } - } - return std::make_tuple( - static_cast(match_count) / static_cast(total_count), match_count, total_count); -} - -/** same as eval_recall, but in case indices do not match, - * then check distances as well, and accept match if actual dist is equal to expected_dist */ -template -auto eval_neighbours(const std::vector& expected_idx, - const std::vector& actual_idx, - const std::vector& expected_dist, - const std::vector& actual_dist, - size_t rows, - size_t cols, - double eps, - double min_recall, - bool test_unique = true, - size_t max_duplicates = 0) -> testing::AssertionResult -{ - auto [actual_recall, match_count, total_count] = - calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); - double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); - RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).", - actual_recall, - match_count, - total_count, - std::abs(error_margin * 100.0), - error_margin < 0 ? "above" : "below", - eps); - if (actual_recall < min_recall - eps) { - return testing::AssertionFailure() - << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" - << min_recall << "); eps = " << eps << ". "; - } - if (test_unique) - return check_unique_indices(actual_idx, rows, cols, max_duplicates); - else - return testing::AssertionSuccess(); -} - -template -auto eval_distances(raft::resources const& handle, - const T* x, // dataset, n_rows * n_cols - const T* queries, // n_queries * n_cols - const IdxT* neighbors, // n_queries * k - const DistT* distances, // n_queries *k - size_t n_rows, - size_t n_cols, - size_t n_queries, - uint32_t k, - raft::distance::DistanceType metric, - double eps) -> testing::AssertionResult -{ - // for each vector, we calculate the actual distance to the k neighbors - - for (size_t i = 0; i < n_queries; i++) { - auto y = raft::make_device_matrix(handle, k, n_cols); - auto naive_dist = raft::make_device_matrix(handle, 1, k); - - raft::matrix::copy_rows( - handle, - make_device_matrix_view(x, n_rows, n_cols), - y.view(), - make_device_vector_view(neighbors + i * k, k)); - - dim3 block_dim(16, 32, 1); - auto grid_y = - static_cast(std::min(raft::ceildiv(k, block_dim.y), 32768)); - dim3 grid_dim(raft::ceildiv(n_rows, block_dim.x), grid_y, 1); - - naive_distance_kernel - <<>>( - naive_dist.data_handle(), queries + i * n_cols, y.data_handle(), 1, k, n_cols, metric); - - if (!devArrMatch(distances + i * k, - naive_dist.data_handle(), - naive_dist.size(), - CompareApprox(eps))) { - std::cout << n_rows << "x" << n_cols << ", " << k << std::endl; - std::cout << "query " << i << std::endl; - print_vector(" indices", neighbors + i * k, k, std::cout); - print_vector("n dist", distances + i * k, k, std::cout); - print_vector("c dist", naive_dist.data_handle(), naive_dist.size(), std::cout); - - return testing::AssertionFailure(); - } - } - return testing::AssertionSuccess(); -} -} // namespace raft::neighbors diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu deleted file mode 100644 index e4d018aff0..0000000000 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "./knn_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -struct FusedL2KNNInputs { - int num_queries; - int num_db_vecs; - int dim; - int k; - raft::distance::DistanceType metric_; -}; - -template -class FusedL2KNNTest : public ::testing::TestWithParam { - public: - FusedL2KNNTest() - : stream_(resource::get_cuda_stream(handle_)), - params_(::testing::TestWithParam::GetParam()), - database(params_.num_db_vecs * params_.dim, stream_), - search_queries(params_.num_queries * params_.dim, stream_), - raft_indices_(params_.num_queries * params_.k, stream_), - raft_distances_(params_.num_queries * params_.k, stream_), - ref_indices_(params_.num_queries * params_.k, stream_), - ref_distances_(params_.num_queries * params_.k, stream_) - { - RAFT_CUDA_TRY(cudaMemsetAsync(database.data(), 0, database.size() * sizeof(T), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(search_queries.data(), 0, search_queries.size() * sizeof(T), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(raft_indices_.data(), 0, raft_indices_.size() * sizeof(int64_t), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(raft_distances_.data(), 0, raft_distances_.size() * sizeof(T), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(ref_indices_.data(), 0, ref_indices_.size() * sizeof(int64_t), stream_)); - RAFT_CUDA_TRY( - cudaMemsetAsync(ref_distances_.data(), 0, ref_distances_.size() * sizeof(T), stream_)); - } - - protected: - void testBruteForce() - { - // calculate the naive knn, by calculating the full pairwise distances and doing a k-select - rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); - distance::pairwise_distance( - handle_, - raft::make_device_matrix_view(search_queries.data(), num_queries, dim), - raft::make_device_matrix_view(database.data(), num_db_vecs, dim), - raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), - metric); - - matrix::select_k( - handle_, - make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), - std::nullopt, - make_device_matrix_view(ref_distances_.data(), num_queries, k_), - make_device_matrix_view(ref_indices_.data(), num_queries, k_), - true, - true); - - auto index_view = - raft::make_device_matrix_view(database.data(), num_db_vecs, dim); - auto query_view = - raft::make_device_matrix_view(search_queries.data(), num_queries, dim); - auto out_indices_view = - raft::make_device_matrix_view(raft_indices_.data(), num_queries, k_); - auto out_dists_view = - raft::make_device_matrix_view(raft_distances_.data(), num_queries, k_); - raft::neighbors::brute_force::fused_l2_knn( - handle_, index_view, query_view, out_indices_view, out_dists_view, metric); - - // verify. - ASSERT_TRUE(devArrMatchKnnPair(ref_indices_.data(), - raft_indices_.data(), - ref_distances_.data(), - raft_distances_.data(), - num_queries, - k_, - float(0.001), - stream_)); - } - - void SetUp() override - { - num_queries = params_.num_queries; - num_db_vecs = params_.num_db_vecs; - dim = params_.dim; - k_ = params_.k; - metric = params_.metric_; - - unsigned long long int seed = 1234ULL; - raft::random::RngState r(seed); - uniform(handle_, r, database.data(), num_db_vecs * dim, T(-1.0), T(1.0)); - uniform(handle_, r, search_queries.data(), num_queries * dim, T(-1.0), T(1.0)); - } - - private: - raft::resources handle_; - cudaStream_t stream_ = 0; - FusedL2KNNInputs params_; - int num_queries; - int num_db_vecs; - int dim; - rmm::device_uvector database; - rmm::device_uvector search_queries; - rmm::device_uvector raft_indices_; - rmm::device_uvector raft_distances_; - rmm::device_uvector ref_indices_; - rmm::device_uvector ref_distances_; - int k_; - raft::distance::DistanceType metric; -}; - -const std::vector inputs = { - {100, 1000, 16, 10, raft::distance::DistanceType::L2Expanded}, - {256, 256, 30, 10, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, - {100, 1000, 16, 50, raft::distance::DistanceType::L2Expanded}, - {20, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 16, 50, raft::distance::DistanceType::L2Expanded}, - {1000, 10000, 32, 50, raft::distance::DistanceType::L2Expanded}, - {10000, 40000, 32, 30, raft::distance::DistanceType::L2Expanded}, - // L2 unexpanded - {100, 1000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, - {1000, 10000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, - {100, 1000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, - {20, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, - {1000, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, - {1000, 10000, 32, 50, raft::distance::DistanceType::L2Unexpanded}, - {10000, 40000, 32, 30, raft::distance::DistanceType::L2Unexpanded}, -}; - -typedef FusedL2KNNTest FusedL2KNNTestF; -TEST_P(FusedL2KNNTestF, FusedBruteForce) { this->testBruteForce(); } - -INSTANTIATE_TEST_CASE_P(FusedL2KNNTest, FusedL2KNNTestF, ::testing::ValuesIn(inputs)); - -} // namespace knn -} // namespace spatial -} // namespace raft diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu deleted file mode 100644 index a98de85bda..0000000000 --- a/cpp/test/neighbors/knn.cu +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" - -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::brute_force { -struct KNNInputs { - std::vector> input; - int k; - std::vector labels; -}; - -template -RAFT_KERNEL build_actual_output( - int* output, int n_rows, int k, const int* idx_labels, const IdxT* indices) -{ - int element = threadIdx.x + blockDim.x * blockIdx.x; - if (element >= n_rows * k) return; - - output[element] = idx_labels[indices[element]]; -} - -RAFT_KERNEL build_expected_output(int* output, int n_rows, int k, const int* labels) -{ - int row = threadIdx.x + blockDim.x * blockIdx.x; - if (row >= n_rows) return; - - int cur_label = labels[row]; - for (int i = 0; i < k; i++) { - output[row * k + i] = cur_label; - } -} - -template -class KNNTest : public ::testing::TestWithParam { - public: - KNNTest() - : params_(::testing::TestWithParam::GetParam()), - stream(resource::get_cuda_stream(handle)), - actual_labels_(0, stream), - expected_labels_(0, stream), - input_(0, stream), - search_data_(0, stream), - indices_(0, stream), - distances_(0, stream), - search_labels_(0, stream) - { - } - - protected: - void testBruteForce() - { - // #if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) - raft::print_device_vector("Input array: ", input_.data(), rows_ * cols_, std::cout); - std::cout << "K: " << k_ << std::endl; - raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); - // #endif - - std::vector> index = { - make_device_matrix_view((const T*)(input_.data()), rows_, cols_)}; - auto search = raft::make_device_matrix_view( - (const T*)(search_data_.data()), rows_, cols_); - - auto indices = raft::make_device_matrix_view(indices_.data(), rows_, k_); - auto distances = - raft::make_device_matrix_view(distances_.data(), rows_, k_); - - auto metric = raft::distance::DistanceType::L2Unexpanded; - knn(handle, index, search, indices, distances, metric, std::make_optional(0)); - - build_actual_output<<>>( - actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); - - build_expected_output<<>>( - expected_labels_.data(), rows_, k_, search_labels_.data()); - - ASSERT_TRUE(devArrMatch( - expected_labels_.data(), actual_labels_.data(), rows_ * k_, raft::Compare(), stream)); - } - - void SetUp() override - { - rows_ = params_.input.size(); - cols_ = params_.input[0].size(); - k_ = params_.k; - - actual_labels_.resize(rows_ * k_, stream); - expected_labels_.resize(rows_ * k_, stream); - input_.resize(rows_ * cols_, stream); - search_data_.resize(rows_ * cols_, stream); - indices_.resize(rows_ * k_, stream); - distances_.resize(rows_ * k_, stream); - search_labels_.resize(rows_, stream); - - RAFT_CUDA_TRY( - cudaMemsetAsync(actual_labels_.data(), 0, actual_labels_.size() * sizeof(int), stream)); - RAFT_CUDA_TRY( - cudaMemsetAsync(expected_labels_.data(), 0, expected_labels_.size() * sizeof(int), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(input_.data(), 0, input_.size() * sizeof(float), stream)); - RAFT_CUDA_TRY( - cudaMemsetAsync(search_data_.data(), 0, search_data_.size() * sizeof(float), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(indices_.data(), 0, indices_.size() * sizeof(IdxT), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(distances_.data(), 0, distances_.size() * sizeof(float), stream)); - RAFT_CUDA_TRY( - cudaMemsetAsync(search_labels_.data(), 0, search_labels_.size() * sizeof(int), stream)); - - std::vector row_major_input; - for (std::size_t i = 0; i < params_.input.size(); ++i) { - for (std::size_t j = 0; j < params_.input[i].size(); ++j) { - row_major_input.push_back(params_.input[i][j]); - } - } - rmm::device_buffer input_d = - rmm::device_buffer(row_major_input.data(), row_major_input.size() * sizeof(float), stream); - float* input_ptr = static_cast(input_d.data()); - - rmm::device_buffer labels_d = - rmm::device_buffer(params_.labels.data(), params_.labels.size() * sizeof(int), stream); - int* labels_ptr = static_cast(labels_d.data()); - - raft::copy(input_.data(), input_ptr, rows_ * cols_, stream); - raft::copy(search_data_.data(), input_ptr, rows_ * cols_, stream); - raft::copy(search_labels_.data(), labels_ptr, rows_, stream); - resource::sync_stream(handle, stream); - } - - private: - raft::resources handle; - cudaStream_t stream; - - KNNInputs params_; - int rows_; - int cols_; - rmm::device_uvector input_; - rmm::device_uvector search_data_; - rmm::device_uvector indices_; - rmm::device_uvector distances_; - int k_; - - rmm::device_uvector search_labels_; - rmm::device_uvector actual_labels_; - rmm::device_uvector expected_labels_; -}; - -const std::vector inputs = { - // 2D - {{ - {2.7810836, 2.550537003}, - {1.465489372, 2.362125076}, - {3.396561688, 4.400293529}, - {1.38807019, 1.850220317}, - {3.06407232, 3.005305973}, - {7.627531214, 2.759262235}, - {5.332441248, 2.088626775}, - {6.922596716, 1.77106367}, - {8.675418651, -0.242068655}, - {7.673756466, 3.508563011}, - }, - 2, - {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}}}; - -typedef KNNTest KNNTestFint32_t; -TEST_P(KNNTestFint32_t, BruteForce) { this->testBruteForce(); } -typedef KNNTest KNNTestFuint32_t; -TEST_P(KNNTestFuint32_t, BruteForce) { this->testBruteForce(); } - -INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint32_t, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFuint32_t, ::testing::ValuesIn(inputs)); - -} // namespace raft::neighbors::brute_force diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu deleted file mode 100644 index 05e6048e56..0000000000 --- a/cpp/test/neighbors/refine.cu +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "ann_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -#include - -namespace raft::neighbors { - -template -class RefineTest : public ::testing::TestWithParam> { - public: - RefineTest() - : stream_(resource::get_cuda_stream(handle_)), - data(handle_, ::testing::TestWithParam>::GetParam()) - { - } - - protected: - public: // tamas remove - void testRefine() - { - std::vector indices(data.p.n_queries * data.p.k); - std::vector distances(data.p.n_queries * data.p.k); - - if (data.p.host_data) { - raft::neighbors::refine(handle_, - data.dataset_host.view(), - data.queries_host.view(), - data.candidates_host.view(), - data.refined_indices_host.view(), - data.refined_distances_host.view(), - data.p.metric); - raft::copy(indices.data(), - data.refined_indices_host.data_handle(), - data.refined_indices_host.size(), - stream_); - raft::copy(distances.data(), - data.refined_distances_host.data_handle(), - data.refined_distances_host.size(), - stream_); - - } else { - raft::neighbors::refine(handle_, - data.dataset.view(), - data.queries.view(), - data.candidates.view(), - data.refined_indices.view(), - data.refined_distances.view(), - data.p.metric); - update_host(distances.data(), - data.refined_distances.data_handle(), - data.refined_distances.size(), - stream_); - update_host( - indices.data(), data.refined_indices.data_handle(), data.refined_indices.size(), stream_); - } - resource::sync_stream(handle_); - - double min_recall = 1; - - ASSERT_TRUE(raft::neighbors::eval_neighbours(data.true_refined_indices_host, - indices, - data.true_refined_distances_host, - distances, - data.p.n_queries, - data.p.k, - 0.001, - min_recall)); - } - - public: - raft::resources handle_; - rmm::cuda_stream_view stream_; - RefineHelper data; -}; - -const std::vector> inputs = - raft::util::itertools::product>( - {static_cast(137)}, - {static_cast(1000)}, - {static_cast(16)}, - {static_cast(1), static_cast(10), static_cast(33)}, - {static_cast(33)}, - {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, - {false, true}); - -typedef RefineTest RefineTestF; -TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } - -INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); - -typedef RefineTest RefineTestF_uint8; -TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } -INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); - -typedef RefineTest RefineTestF_int8; -TEST_P(RefineTestF_int8, AnnRefine) { this->testRefine(); } -INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_int8, ::testing::ValuesIn(inputs)); -} // namespace raft::neighbors diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu deleted file mode 100644 index d378100711..0000000000 --- a/cpp/test/neighbors/tiled_knn.cu +++ /dev/null @@ -1,352 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" -#include "./ann_utils.cuh" -#include "./knn_utils.cuh" - -#include -#include -#include -#include // raft::distance::pairwise_distance -#include -#include -#include -#include -#include // raft::neighbors::detail::brute_force_knn_impl - -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::brute_force { - -struct TiledKNNInputs { - int num_queries; - int num_db_vecs; - int dim; - int k; - int row_tiles; - int col_tiles; - raft::distance::DistanceType metric; - bool row_major; -}; - -std::ostream& operator<<(std::ostream& os, const TiledKNNInputs& input) -{ - return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs - << " dim:" << input.dim << " k:" << input.k << " row_tiles:" << input.row_tiles - << " col_tiles:" << input.col_tiles << " metric:" << print_metric{input.metric} - << " row_major:" << input.row_major; -} - -template -class TiledKNNTest : public ::testing::TestWithParam { - public: - TiledKNNTest() - : stream_(resource::get_cuda_stream(handle_)), - params_(::testing::TestWithParam::GetParam()), - database(params_.num_db_vecs * params_.dim, stream_), - search_queries(params_.num_queries * params_.dim, stream_), - raft_indices_(params_.num_queries * params_.k, stream_), - raft_distances_(params_.num_queries * params_.k, stream_), - ref_indices_(params_.num_queries * params_.k, stream_), - ref_distances_(params_.num_queries * params_.k, stream_) - { - raft::matrix::fill( - handle_, - raft::make_device_matrix_view(database.data(), params_.num_db_vecs, params_.dim), - T{0.0}); - raft::matrix::fill( - handle_, - raft::make_device_matrix_view(search_queries.data(), params_.num_queries, params_.dim), - T{0.0}); - raft::matrix::fill( - handle_, - raft::make_device_matrix_view(raft_indices_.data(), params_.num_queries, params_.k), - 0); - raft::matrix::fill( - handle_, - raft::make_device_matrix_view(raft_distances_.data(), params_.num_queries, params_.k), - T{0.0}); - raft::matrix::fill( - handle_, - raft::make_device_matrix_view(ref_indices_.data(), params_.num_queries, params_.k), - 0); - raft::matrix::fill( - handle_, - raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), - T{0.0}); - } - - protected: - void testBruteForce() - { - float metric_arg = 3.0; - - // calculate the naive knn, by calculating the full pairwise distances and doing a k-select - rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); - rmm::device_uvector workspace(0, stream_); - distance::pairwise_distance(handle_, - search_queries.data(), - database.data(), - temp_distances.data(), - num_queries, - num_db_vecs, - dim, - workspace, - metric, - params_.row_major, - metric_arg); - - // setting the 'isRowMajor' flag in the pairwise distances api, not only sets - // the inputs as colmajor - but also the output. this means we have to transpose in this - // case - auto temp_dist = temp_distances.data(); - rmm::device_uvector temp_row_major_dist(num_db_vecs * num_queries, stream_); - if (!params_.row_major) { - raft::linalg::transpose( - handle_, temp_dist, temp_row_major_dist.data(), num_queries, num_db_vecs, stream_); - temp_dist = temp_row_major_dist.data(); - } - - matrix::select_k( - handle_, - raft::make_device_matrix_view(temp_dist, num_queries, num_db_vecs), - std::nullopt, - raft::make_device_matrix_view(ref_distances_.data(), params_.num_queries, params_.k), - raft::make_device_matrix_view(ref_indices_.data(), params_.num_queries, params_.k), - raft::distance::is_min_close(metric), - true); - - if ((params_.row_tiles == 0) && (params_.col_tiles == 0)) { - std::vector input{database.data()}; - std::vector sizes{static_cast(num_db_vecs)}; - neighbors::detail::brute_force_knn_impl(handle_, - input, - sizes, - dim, - const_cast(search_queries.data()), - num_queries, - raft_indices_.data(), - raft_distances_.data(), - k_, - params_.row_major, - params_.row_major, - nullptr, - metric, - metric_arg); - } else { - neighbors::detail::tiled_brute_force_knn(handle_, - search_queries.data(), - database.data(), - num_queries, - num_db_vecs, - dim, - k_, - raft_distances_.data(), - raft_indices_.data(), - metric, - metric_arg, - params_.row_tiles, - params_.col_tiles); - } - - // verify. - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(ref_indices_.data(), - raft_indices_.data(), - ref_distances_.data(), - raft_distances_.data(), - num_queries, - k_, - float(0.001), - stream_, - true)); - - // Also test out the 'index' api - where we can use precomputed norms - if (params_.row_major) { - auto idx = - raft::neighbors::brute_force::build(handle_, - raft::make_device_matrix_view( - database.data(), params_.num_db_vecs, params_.dim), - metric, - metric_arg); - - auto query_view = raft::make_device_matrix_view( - search_queries.data(), params_.num_queries, params_.dim); - - raft::neighbors::brute_force::search( - handle_, - idx, - query_view, - raft::make_device_matrix_view( - raft_indices_.data(), params_.num_queries, params_.k), - raft::make_device_matrix_view( - raft_distances_.data(), params_.num_queries, params_.k)); - - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(ref_indices_.data(), - raft_indices_.data(), - ref_distances_.data(), - raft_distances_.data(), - num_queries, - k_, - float(0.001), - stream_, - true)); - // also test out the batch api. First get new reference results (all k, up to a certain - // max size) - auto all_size = std::min(params_.num_db_vecs, 1024); - auto all_indices = raft::make_device_matrix(handle_, num_queries, all_size); - auto all_distances = raft::make_device_matrix(handle_, num_queries, all_size); - raft::neighbors::brute_force::search( - handle_, idx, query_view, all_indices.view(), all_distances.view()); - - int64_t offset = 0; - auto query = make_batch_k_query(handle_, idx, query_view, k_); - for (auto batch : *query) { - auto batch_size = batch.batch_size(); - auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); - auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); - - matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; - - matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); - matrix::slice( - handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); - - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), - batch.indices().data_handle(), - distances.data_handle(), - batch.distances().data_handle(), - num_queries, - batch_size, - float(0.001), - stream_, - true)); - - offset += batch_size; - if (offset + batch_size > all_size) break; - } - - // also test out with variable batch sizes - offset = 0; - int64_t batch_size = k_; - query = make_batch_k_query(handle_, idx, query_view, batch_size); - for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { - // batch_size could be less than requested (in the case of final batch). handle. - ASSERT_TRUE(it->indices().extent(1) <= batch_size); - batch_size = it->indices().extent(1); - - auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); - auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); - - matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; - matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); - matrix::slice( - handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); - - ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), - it->indices().data_handle(), - distances.data_handle(), - it->distances().data_handle(), - num_queries, - batch_size, - float(0.001), - stream_, - true)); - - offset += batch_size; - if (offset + batch_size > all_size) break; - - batch_size += 2; - } - } - } - - void SetUp() override - { - num_queries = params_.num_queries; - num_db_vecs = params_.num_db_vecs; - dim = params_.dim; - k_ = params_.k; - metric = params_.metric; - - unsigned long long int seed = 1234ULL; - raft::random::RngState r(seed); - - // JensenShannon distance requires positive values - T min_val = metric == raft::distance::DistanceType::JensenShannon ? T(0.0) : T(-1.0); - uniform(handle_, r, database.data(), num_db_vecs * dim, min_val, T(1.0)); - uniform(handle_, r, search_queries.data(), num_queries * dim, min_val, T(1.0)); - } - - private: - raft::resources handle_; - cudaStream_t stream_ = 0; - TiledKNNInputs params_; - int num_queries; - int num_db_vecs; - int dim; - rmm::device_uvector database; - rmm::device_uvector search_queries; - rmm::device_uvector raft_indices_; - rmm::device_uvector raft_distances_; - rmm::device_uvector ref_indices_; - rmm::device_uvector ref_distances_; - int k_; - raft::distance::DistanceType metric; -}; - -const std::vector random_inputs = { - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Expanded, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2Unexpanded, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtUnexpanded, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L1, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Linf, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::InnerProduct, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CorrelationExpanded, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::CosineExpanded, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::LpUnexpanded, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::JensenShannon, true}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::L2SqrtExpanded, true}, - // BrayCurtis isn't currently supported by pairwise_distance api - // {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::BrayCurtis}, - {256, 512, 16, 8, 16, 8, raft::distance::DistanceType::Canberra, true}, - {10000, 40000, 32, 30, 512, 1024, raft::distance::DistanceType::L2Expanded, true}, - {345, 1023, 16, 128, 512, 1024, raft::distance::DistanceType::CosineExpanded, true}, - {789, 20516, 64, 256, 512, 4096, raft::distance::DistanceType::L2SqrtExpanded, true}, - // Test where the final column tile has < K items: - {4, 12, 32, 6, 4, 8, raft::distance::DistanceType::L2Expanded, true}, - // Test where passing column_tiles < K - {1, 40, 32, 30, 1, 8, raft::distance::DistanceType::L2Expanded, true}, - // Passing tile sizes of 0 means to use brute_force_knn_impl (instead of the - // tiled_brute_force_knn api). - {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, true}, - {1000, 500000, 128, 128, 0, 0, raft::distance::DistanceType::L2Expanded, false}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::LpUnexpanded, true}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::L2SqrtExpanded, false}, - {1000, 5000, 128, 128, 0, 0, raft::distance::DistanceType::InnerProduct, false}}; - -typedef TiledKNNTest TiledKNNTestF; -TEST_P(TiledKNNTestF, BruteForce) { this->testBruteForce(); } - -INSTANTIATE_TEST_CASE_P(TiledKNNTest, TiledKNNTestF, ::testing::ValuesIn(random_inputs)); -} // namespace raft::neighbors::brute_force diff --git a/cpp/test/sparse/gram.cu b/cpp/test/sparse/gram.cu deleted file mode 100644 index 3505a3ddf5..0000000000 --- a/cpp/test/sparse/gram.cu +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#if defined RAFT_DISTANCE_COMPILED -#include -#include -#endif - -#include "../distance/gram_base.cuh" -#include "../test_utils.cuh" - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -namespace raft::distance::kernels { - -/** - * Structure to describe structure of the input matrices: - * - DENSE: dense, dense - * - MIX: CSR, dense - * - CSR: CSR, CSR - */ -enum SparseType { DENSE, MIX, CSR }; - -struct GramMatrixInputs { - int n1; // feature vectors in matrix 1 - int n2; // featuer vectors in matrix 2 - int n_cols; // number of elements in a feature vector - bool is_row_major; - SparseType sparse_input; - KernelParams kernel; - int ld1; - int ld2; - int ld_out; - // We will generate random input using the dimensions given here. - // The reference output is calculated by a custom kernel. -}; - -std::ostream& operator<<(std::ostream& os, const GramMatrixInputs& p) -{ - std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; - os << "/" << p.n1 << "x" << p.n2 << "x" << p.n_cols << "/" - << (p.is_row_major ? "RowMajor/" : "ColMajor/") - << (p.sparse_input == SparseType::DENSE - ? "DenseDense/" - : (p.sparse_input == SparseType::MIX ? "CsrDense/" : "CsrCsr/")) - << kernel_names[p.kernel.kernel] << "/ld_" << p.ld1 << "x" << p.ld2 << "x" << p.ld_out; - return os; -} - -/*struct KernelParams { - // Kernel function parameters - KernelType kernel; //!< Type of the kernel function - int degree; //!< Degree of polynomial kernel (ignored by others) - double gamma; //!< multiplier in the - double coef0; //!< additive constant in poly and tanh kernels -};*/ - -// const KernelParams linear_kernel_params{.kernel=KernelType::LINEAR}; - -// {KernelType::POLYNOMIAL, 2, 0.5, 2.4}, {KernelType::TANH, 0, 0.5, 2.4}, {KernelType::RBF, 0, 0.5} -const std::vector inputs = raft::util::itertools::product( - {42}, - {137}, - {2}, - {true, false}, - {SparseType::DENSE, SparseType::MIX, SparseType::CSR}, - {KernelParams{KernelType::LINEAR}, - KernelParams{KernelType::POLYNOMIAL, 2, 0.5, 2.4}, - KernelParams{KernelType::TANH, 0, 0.5, 2.4}, - KernelParams{KernelType::RBF, 0, 0.5}}); - -// (ld_1, ld_2, ld_out) not supported by RBF and CSR -const std::vector inputs_ld = raft::util::itertools::product( - {137}, - {42}, - {2}, - {true, false}, - {SparseType::DENSE, SparseType::MIX}, - {KernelParams{KernelType::LINEAR}, - KernelParams{KernelType::POLYNOMIAL, 2, 0.5, 2.4}, - KernelParams{KernelType::TANH, 0, 0.5, 2.4}}, - {159}, - {73}, - {144}); - -// (ld_1, ld_2) are supported by CSR -const std::vector inputs_ld_csr = - raft::util::itertools::product( - {42}, - {137}, - {2}, - {true, false}, - {SparseType::CSR, SparseType::MIX}, - {KernelParams{KernelType::LINEAR}, - KernelParams{KernelType::POLYNOMIAL, 2, 0.5, 2.4}, - KernelParams{KernelType::TANH, 0, 0.5, 2.4}}, - {64}, - {155}, - {0}); - -template -class GramMatrixTest : public ::testing::TestWithParam { - protected: - GramMatrixTest() - : params(GetParam()), - stream(resource::get_cuda_stream(handle)), - x1(0, stream), - x2(0, stream), - x1_csr_indptr(0, stream), - x1_csr_indices(0, stream), - x1_csr_data(0, stream), - x2_csr_indptr(0, stream), - x2_csr_indices(0, stream), - x2_csr_data(0, stream), - gram(0, stream), - gram_host(0) - { - if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; } - if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; } - if (params.ld_out == 0) { params.ld_out = params.is_row_major ? params.n2 : params.n1; } - // Derive the size of the output from the offset of the last element. - size_t size = get_offset(params.n1 - 1, params.n_cols - 1, params.ld1, params.is_row_major) + 1; - x1.resize(size, stream); - size = get_offset(params.n2 - 1, params.n_cols - 1, params.ld2, params.is_row_major) + 1; - x2.resize(size, stream); - size = get_offset(params.n1 - 1, params.n2 - 1, params.ld_out, params.is_row_major) + 1; - - gram.resize(size, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(gram.data(), 0, gram.size() * sizeof(math_t), stream)); - gram_host.resize(gram.size()); - std::fill(gram_host.begin(), gram_host.end(), 0); - - raft::random::RngState r(42137ULL); - raft::random::uniform(handle, r, x1.data(), x1.size(), math_t(0), math_t(1)); - raft::random::uniform(handle, r, x2.data(), x2.size(), math_t(0), math_t(1)); - - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - } - - ~GramMatrixTest() override {} - - int prepareCsr(math_t* dense, int n_rows, int ld, int* indptr, int* indices, math_t* data) - { - int nnz = 0; - double eps = 1e-6; - int n_cols = params.n_cols; - bool is_row_major = params.is_row_major; - size_t dense_size = get_offset(n_rows - 1, n_cols - 1, ld, is_row_major) + 1; - - std::vector dense_host(dense_size); - raft::update_host(dense_host.data(), dense, dense_size, stream); - resource::sync_stream(handle, stream); - - std::vector indptr_host(n_rows + 1); - std::vector indices_host(n_rows * n_cols); - std::vector data_host(n_rows * n_cols); - - // create csr matrix from dense (with threshold) - for (int i = 0; i < n_rows; ++i) { - indptr_host[i] = nnz; - for (int j = 0; j < n_cols; ++j) { - math_t value = dense_host[get_offset(i, j, ld, is_row_major)]; - if (value > eps) { - indices_host[nnz] = j; - data_host[nnz] = value; - nnz++; - } - } - } - indptr_host[n_rows] = nnz; - - // fill back dense matrix from CSR - std::fill(dense_host.data(), dense_host.data() + dense_size, 0); - for (int i = 0; i < n_rows; ++i) { - for (int idx = indptr_host[i]; idx < indptr_host[i + 1]; ++idx) { - dense_host[get_offset(i, indices_host[idx], ld, is_row_major)] = data_host[idx]; - } - } - - raft::update_device(dense, dense_host.data(), dense_size, stream); - raft::update_device(indptr, indptr_host.data(), n_rows + 1, stream); - raft::update_device(indices, indices_host.data(), nnz, stream); - raft::update_device(data, data_host.data(), nnz, stream); - resource::sync_stream(handle, stream); - return nnz; - } - - void runTest() - { - std::unique_ptr> kernel = - std::unique_ptr>(KernelFactory::create(params.kernel)); - - auto x1_span = - params.is_row_major - ? raft::make_device_strided_matrix_view( - x1.data(), params.n1, params.n_cols, params.ld1) - : raft::make_device_strided_matrix_view( - x1.data(), params.n1, params.n_cols, params.ld1); - auto x2_span = - params.is_row_major - ? raft::make_device_strided_matrix_view( - x2.data(), params.n2, params.n_cols, params.ld2) - : raft::make_device_strided_matrix_view( - x2.data(), params.n2, params.n_cols, params.ld2); - auto out_span = - params.is_row_major - ? raft::make_device_strided_matrix_view( - gram.data(), params.n1, params.n2, params.ld_out) - : raft::make_device_strided_matrix_view( - gram.data(), params.n1, params.n2, params.ld_out); - - if (params.sparse_input == SparseType::DENSE) { - (*kernel)(handle, x1_span, x2_span, out_span); - } else { - x1_csr_indptr.reserve(params.n1 + 1, stream); - x1_csr_indices.reserve(params.n1 * params.n_cols, stream); - x1_csr_data.reserve(params.n1 * params.n_cols, stream); - int x1_nnz = prepareCsr(x1.data(), - params.n1, - params.ld1, - x1_csr_indptr.data(), - x1_csr_indices.data(), - x1_csr_data.data()); - - auto x1_csr_structure = raft::make_device_compressed_structure_view( - x1_csr_indptr.data(), x1_csr_indices.data(), params.n1, params.n_cols, x1_nnz); - auto x1_csr = raft::device_csr_matrix_view( - raft::device_span(x1_csr_data.data(), x1_csr_structure.get_nnz()), - x1_csr_structure); - - if (params.sparse_input == SparseType::MIX) { - (*kernel)(handle, x1_csr, x2_span, out_span); - } else { - x2_csr_indptr.reserve(params.n2 + 1, stream); - x2_csr_indices.reserve(params.n2 * params.n_cols, stream); - x2_csr_data.reserve(params.n2 * params.n_cols, stream); - int x2_nnz = prepareCsr(x2.data(), - params.n2, - params.ld2, - x2_csr_indptr.data(), - x2_csr_indices.data(), - x2_csr_data.data()); - - auto x2_csr_structure = raft::make_device_compressed_structure_view( - x2_csr_indptr.data(), x2_csr_indices.data(), params.n2, params.n_cols, x2_nnz); - auto x2_csr = raft::device_csr_matrix_view( - raft::device_span(x2_csr_data.data(), x2_csr_structure.get_nnz()), - x2_csr_structure); - - (*kernel)(handle, x1_csr, x2_csr, out_span); - } - } - // Something in gram is executing not on the 'stream' and therefore - // a full device sync is required - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - naiveGramMatrixKernel(params.n1, - params.n2, - params.n_cols, - x1, - x2, - gram_host.data(), - params.ld1, - params.ld2, - params.ld_out, - params.is_row_major, - params.kernel, - stream, - handle); - resource::sync_stream(handle, stream); - - ASSERT_TRUE(raft::devArrMatchHost( - gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f), stream)); - } - - raft::resources handle; - cudaStream_t stream = 0; - GramMatrixInputs params; - - rmm::device_uvector x1; - rmm::device_uvector x2; - - rmm::device_uvector x1_csr_indptr; - rmm::device_uvector x1_csr_indices; - rmm::device_uvector x1_csr_data; - rmm::device_uvector x2_csr_indptr; - rmm::device_uvector x2_csr_indices; - rmm::device_uvector x2_csr_data; - - rmm::device_uvector gram; - std::vector gram_host; -}; - -typedef GramMatrixTest GramMatrixTestFloatStandard; -typedef GramMatrixTest GramMatrixTestFloatLd; -typedef GramMatrixTest GramMatrixTestFloatLdCsr; - -TEST_P(GramMatrixTestFloatStandard, Gram) { runTest(); } -TEST_P(GramMatrixTestFloatLd, Gram) { runTest(); } -TEST_P(GramMatrixTestFloatLdCsr, Gram) { runTest(); } - -INSTANTIATE_TEST_SUITE_P(GramMatrixTests, GramMatrixTestFloatStandard, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_SUITE_P(GramMatrixTests, GramMatrixTestFloatLd, ::testing::ValuesIn(inputs_ld)); -INSTANTIATE_TEST_SUITE_P(GramMatrixTests, - GramMatrixTestFloatLdCsr, - ::testing::ValuesIn(inputs_ld_csr)); -}; // end namespace raft::distance::kernels diff --git a/cpp/test/stats/neighborhood_recall.cu b/cpp/test/stats/neighborhood_recall.cu deleted file mode 100644 index 1a76154e2e..0000000000 --- a/cpp/test/stats/neighborhood_recall.cu +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../neighbors/ann_utils.cuh" -#include "../test_utils.h" - -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace raft::stats { - -struct NeighborhoodRecallInputs { - int n_rows; - int n_cols; - int k; -}; - -template -class NeighborhoodRecallTest : public ::testing::TestWithParam { - public: - NeighborhoodRecallTest() - : ps{::testing::TestWithParam::GetParam()}, - data_1{raft::make_device_matrix(res, ps.n_rows, ps.n_cols)}, - data_2{raft::make_device_matrix(res, ps.n_rows, ps.n_cols)} - { - } - - protected: - void test_recall() - { - size_t queries_size = ps.n_rows * ps.k; - - // calculate nn for dataset 1 - auto distances_1 = raft::make_device_matrix(res, ps.n_rows, ps.k); - auto indices_1 = raft::make_device_matrix(res, ps.n_rows, ps.k); - raft::neighbors::naive_knn( - res, - distances_1.data_handle(), - indices_1.data_handle(), - data_1.data_handle(), - data_1.data_handle(), - ps.n_rows, - ps.n_rows, - ps.n_cols, - ps.k, - raft::distance::DistanceType::L2Expanded); - std::vector distances_1_h(queries_size); - std::vector indices_1_h(queries_size); - raft::copy(distances_1_h.data(), - distances_1.data_handle(), - ps.n_rows * ps.k, - raft::resource::get_cuda_stream(res)); - raft::copy(indices_1_h.data(), - indices_1.data_handle(), - ps.n_rows * ps.k, - raft::resource::get_cuda_stream(res)); - - // calculate nn for dataset 2 - auto distances_2 = raft::make_device_matrix(res, ps.n_rows, ps.k); - auto indices_2 = raft::make_device_matrix(res, ps.n_rows, ps.k); - raft::neighbors::naive_knn( - res, - distances_2.data_handle(), - indices_2.data_handle(), - data_2.data_handle(), - data_2.data_handle(), - ps.n_rows, - ps.n_rows, - ps.n_cols, - ps.k, - raft::distance::DistanceType::L2Expanded); - std::vector distances_2_h(queries_size); - std::vector indices_2_h(queries_size); - raft::copy(distances_2_h.data(), - distances_2.data_handle(), - ps.n_rows * ps.k, - raft::resource::get_cuda_stream(res)); - raft::copy(indices_2_h.data(), - indices_2.data_handle(), - ps.n_rows * ps.k, - raft::resource::get_cuda_stream(res)); - - raft::resource::sync_stream(res); - - // find CPU recall scores - [[maybe_unused]] auto [indices_only_recall_h, mc1, tc1] = - raft::neighbors::calc_recall(indices_1_h, indices_2_h, ps.n_rows, ps.k); - [[maybe_unused]] auto [recall_h, mc2, tc2] = raft::neighbors::calc_recall( - indices_1_h, indices_2_h, distances_1_h, distances_2_h, ps.n_rows, ps.k, 0.001); - - // find GPU recall scores - auto s1 = 0; - auto indices_only_recall_scalar = raft::make_host_scalar(s1); - neighborhood_recall(res, - raft::make_const_mdspan(indices_1.view()), - raft::make_const_mdspan(indices_2.view()), - indices_only_recall_scalar.view()); - - auto s2 = 0; - auto recall_scalar = raft::make_host_scalar(s2); - DistanceT s3 = 0.001; - auto eps_mda = raft::make_host_scalar(s3); - - neighborhood_recall(res, - raft::make_const_mdspan(indices_1.view()), - raft::make_const_mdspan(indices_2.view()), - recall_scalar.view(), - raft::make_const_mdspan(distances_1.view()), - raft::make_const_mdspan(distances_2.view())); - - // assert correctness - ASSERT_TRUE(raft::match(indices_only_recall_h, - *indices_only_recall_scalar.data_handle(), - raft::CompareApprox(0.01))); - ASSERT_TRUE( - raft::match(recall_h, *recall_scalar.data_handle(), raft::CompareApprox(0.01))); - } - - void SetUp() override - { - // form two random datasets - raft::random::Rng r1(1234ULL); - r1.normal(data_1.data_handle(), - ps.n_rows * ps.n_cols, - DistanceT(0.1), - DistanceT(2.0), - raft::resource::get_cuda_stream(res)); - raft::random::Rng r2(21111ULL); - r2.normal(data_2.data_handle(), - ps.n_rows * ps.n_cols, - DistanceT(0.1), - DistanceT(2.0), - raft::resource::get_cuda_stream(res)); - resource::sync_stream(res); - } - - private: - raft::resources res; - NeighborhoodRecallInputs ps; - raft::device_matrix data_1; - raft::device_matrix data_2; -}; - -const std::vector inputs = - raft::util::itertools::product({10, 50, 100}, // n_rows - {80, 100}, // n_cols - {32, 64}); // k - -using NeighborhoodRecallTestF_U32 = NeighborhoodRecallTest; -TEST_P(NeighborhoodRecallTestF_U32, AnnCagra) { this->test_recall(); } - -INSTANTIATE_TEST_CASE_P(NeighborhoodRecallTest, - NeighborhoodRecallTestF_U32, - ::testing::ValuesIn(inputs)); - -} // end namespace raft::stats diff --git a/cpp/test/stats/silhouette_score.cu b/cpp/test/stats/silhouette_score.cu deleted file mode 100644 index ad080f5894..0000000000 --- a/cpp/test/stats/silhouette_score.cu +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "../test_utils.cuh" - -#include -#include -#include -#include - -#include - -#include - -#include -#include -#include - -namespace raft { -namespace stats { - -// parameter structure definition -struct silhouetteScoreParam { - int nRows; - int nCols; - int nLabels; - raft::distance::DistanceType metric; - int chunk; - double tolerance; -}; - -// test fixture class -template -class silhouetteScoreTest : public ::testing::TestWithParam { - protected: - silhouetteScoreTest() - : d_X(0, resource::get_cuda_stream(handle)), - sampleSilScore(0, resource::get_cuda_stream(handle)), - d_labels(0, resource::get_cuda_stream(handle)) - { - } - - void host_silhouette_score() - { - // generating random value test input - std::vector h_X(nElements, 0.0); - std::vector h_labels(nRows, 0); - std::random_device rd; - std::default_random_engine dre(nElements * nLabels); - std::uniform_int_distribution intGenerator(0, nLabels - 1); - std::uniform_real_distribution realGenerator(0, 100); - - std::generate(h_X.begin(), h_X.end(), [&]() { return realGenerator(dre); }); - std::generate(h_labels.begin(), h_labels.end(), [&]() { return intGenerator(dre); }); - - // allocating and initializing memory to the GPU - auto stream = resource::get_cuda_stream(handle); - d_X.resize(nElements, stream); - d_labels.resize(nElements, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(d_X.data(), 0, d_X.size() * sizeof(DataT), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_labels.data(), 0, d_labels.size() * sizeof(LabelT), stream)); - sampleSilScore.resize(nElements, stream); - - raft::update_device(d_X.data(), &h_X[0], (int)nElements, stream); - raft::update_device(d_labels.data(), &h_labels[0], (int)nElements, stream); - - // finding the distance matrix - - rmm::device_uvector d_distanceMatrix(nRows * nRows, stream); - double* h_distanceMatrix = (double*)malloc(nRows * nRows * sizeof(double*)); - - raft::distance::pairwise_distance( - handle, d_X.data(), d_X.data(), d_distanceMatrix.data(), nRows, nRows, nCols, params.metric); - - resource::sync_stream(handle, stream); - - raft::update_host(h_distanceMatrix, d_distanceMatrix.data(), nRows * nRows, stream); - - // finding the bincount array - - double* binCountArray = (double*)malloc(nLabels * sizeof(double*)); - memset(binCountArray, 0, nLabels * sizeof(double)); - - for (int i = 0; i < nRows; ++i) { - binCountArray[h_labels[i]] += 1; - } - - // finding the average intra cluster distance for every element - - double* a = (double*)malloc(nRows * sizeof(double*)); - - for (int i = 0; i < nRows; ++i) { - int myLabel = h_labels[i]; - double sumOfIntraClusterD = 0; - - for (int j = 0; j < nRows; ++j) { - if (h_labels[j] == myLabel) { sumOfIntraClusterD += h_distanceMatrix[i * nRows + j]; } - } - - if (binCountArray[myLabel] <= 1) - a[i] = -1; - else - a[i] = sumOfIntraClusterD / (binCountArray[myLabel] - 1); - } - - // finding the average inter cluster distance for every element - - double* b = (double*)malloc(nRows * sizeof(double*)); - - for (int i = 0; i < nRows; ++i) { - int myLabel = h_labels[i]; - double minAvgInterCD = ULLONG_MAX; - - for (int j = 0; j < nLabels; ++j) { - int curClLabel = j; - if (curClLabel == myLabel) continue; - double avgInterCD = 0; - - for (int k = 0; k < nRows; ++k) { - if (h_labels[k] == curClLabel) { avgInterCD += h_distanceMatrix[i * nRows + k]; } - } - - if (binCountArray[curClLabel]) - avgInterCD /= binCountArray[curClLabel]; - else - avgInterCD = ULLONG_MAX; - minAvgInterCD = min(minAvgInterCD, avgInterCD); - } - - b[i] = minAvgInterCD; - } - - // finding the silhouette score for every element - - double* truthSampleSilScore = (double*)malloc(nRows * sizeof(double*)); - for (int i = 0; i < nRows; ++i) { - if (a[i] == -1) - truthSampleSilScore[i] = 0; - else if (a[i] == 0 && b[i] == 0) - truthSampleSilScore[i] = 0; - else - truthSampleSilScore[i] = (b[i] - a[i]) / max(a[i], b[i]); - truthSilhouetteScore += truthSampleSilScore[i]; - } - - truthSilhouetteScore /= nRows; - } - - // the constructor - void SetUp() override - { - // getting the parameters - params = ::testing::TestWithParam::GetParam(); - - nRows = params.nRows; - nCols = params.nCols; - nLabels = params.nLabels; - chunk = params.chunk; - nElements = nRows * nCols; - - host_silhouette_score(); - - // calling the silhouette_score CUDA implementation - computedSilhouetteScore = raft::stats::silhouette_score( - handle, - raft::make_device_matrix_view(d_X.data(), nRows, nCols), - raft::make_device_vector_view(d_labels.data(), nRows), - std::make_optional(raft::make_device_vector_view(sampleSilScore.data(), nRows)), - nLabels, - params.metric); - - batchedSilhouetteScore = raft::stats::silhouette_score_batched( - handle, - raft::make_device_matrix_view(d_X.data(), nRows, nCols), - raft::make_device_vector_view(d_labels.data(), nRows), - std::make_optional(raft::make_device_vector_view(sampleSilScore.data(), nRows)), - nLabels, - chunk, - params.metric); - } - - // declaring the data values - raft::resources handle; - silhouetteScoreParam params; - int nLabels; - rmm::device_uvector d_X; - rmm::device_uvector sampleSilScore; - rmm::device_uvector d_labels; - int nRows; - int nCols; - int nElements; - double truthSilhouetteScore = 0; - double computedSilhouetteScore = 0; - double batchedSilhouetteScore = 0; - int chunk; -}; - -// setting test parameter values -const std::vector inputs = { - {4, 2, 3, raft::distance::DistanceType::L2Expanded, 4, 0.00001}, - {4, 2, 2, raft::distance::DistanceType::L2SqrtUnexpanded, 2, 0.00001}, - {8, 8, 3, raft::distance::DistanceType::L2Unexpanded, 4, 0.00001}, - {11, 2, 5, raft::distance::DistanceType::L2Expanded, 3, 0.00001}, - {40, 2, 8, raft::distance::DistanceType::L2Expanded, 10, 0.00001}, - {12, 7, 3, raft::distance::DistanceType::CosineExpanded, 8, 0.00001}, - {7, 5, 5, raft::distance::DistanceType::L1, 2, 0.00001}}; - -// writing the test suite -typedef silhouetteScoreTest silhouetteScoreTestClass; -TEST_P(silhouetteScoreTestClass, Result) -{ - ASSERT_NEAR(computedSilhouetteScore, truthSilhouetteScore, params.tolerance); - ASSERT_NEAR(batchedSilhouetteScore, truthSilhouetteScore, params.tolerance); -} -INSTANTIATE_TEST_CASE_P(silhouetteScore, silhouetteScoreTestClass, ::testing::ValuesIn(inputs)); - -} // end namespace stats -} // end namespace raft diff --git a/cpp/test/stats/trustworthiness.cu b/cpp/test/stats/trustworthiness.cu deleted file mode 100644 index 846c192022..0000000000 --- a/cpp/test/stats/trustworthiness.cu +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.cuh" - -#include -#include -#include -#include - -#include - -#include -#include - -namespace raft { -namespace stats { - -class TrustworthinessScoreTest : public ::testing::Test { - public: - TrustworthinessScoreTest() - : d_X(0, resource::get_cuda_stream(handle)), d_X_embedded(0, resource::get_cuda_stream(handle)) - { - } - - protected: - void basicTest() - { - std::vector X = { - 5.6142087, 8.59787, -4.382763, -3.6452143, -5.8816037, -0.6330313, 4.6920023, - -0.79210913, 0.6106314, 2.1210914, 5.919943, -8.43784, -6.4819884, 0.41001374, - -6.1052523, -4.0825715, -5.314755, -2.834671, 5.751696, -6.5012555, -0.4719201, - -7.53353, 7.6789393, -1.4959852, -5.5977287, -9.564147, 1.2902534, 3.559834, - -6.7659483, 8.265964, 4.595404, 9.133477, -6.1553917, -6.319754, -2.9039452, - 4.4150834, -3.094395, -4.426273, 9.584571, -5.64133, 6.6209483, 7.4044604, - 3.9620576, 5.639907, 10.33007, -0.8792053, 5.143776, -7.464049, 1.2448754, - -5.6300974, 5.4518576, 4.119535, 6.749645, 7.627064, -7.2298336, 1.9681473, - -6.9083176, 6.404673, 0.07186685, 9.0994835, 8.51037, -8.986389, 0.40534487, - 2.115397, 4.086756, 1.2284287, -2.6272132, 0.06527536, -9.587425, -7.206078, - 7.864875, 7.4397306, -6.9233336, -2.6643622, 3.3466153, 7.0408177, -3.6069896, - -9.971769, 4.4075623, 7.9063697, 2.559074, 4.323717, 1.6867131, -1.1576937, - -9.893141, -3.251416, -7.4889135, -4.0588717, -2.73338, -7.4852257, 3.4460473, - 9.759119, -5.4680476, -4.722435, -8.032619, -1.4598992, 4.227361, 3.135568, - 1.1950601, 1.1982028, 6.998856, -6.131138, -6.6921015, 0.5361224, -7.1213965, - -5.6104236, -7.2212887, -2.2710054, 8.544764, -6.0254574, 1.4582269, -5.5587835, - 8.031556, -0.26328218, -5.2591386, -9.262641, 2.8691363, 5.299787, -9.209455, - 8.523085, 5.180329, 10.655528, -5.7171874, -6.7739563, -3.6306462, 4.067106, - -1.5912259, -3.2345476, 8.042973, -3.6364832, 4.1242137, 9.886953, 5.4743724, - 6.3058076, 9.369645, -0.5175337, 4.9859877, -7.879498, 1.358422, -4.147944, - 3.8984218, 5.894656, 6.4903927, 8.702036, -8.023722, 2.802145, -7.748032, - 5.8461113, -0.34215945, 11.298865, 1.4107164, -9.949621, -1.6257563, -10.655836, - 2.4528909, 1.1570255, 5.170669, 2.8398793, 7.1838694, 9.088459, 2.631155, - 3.964414, 2.8769252, 0.04198391, -0.16993195, 3.6747139, -2.8377378, 6.1782537, - 10.759618, -4.5642614, -8.522967, 0.8614642, 6.623416, -1.029324, 5.5488334, - -7.804511, 2.128833, 7.9042315, 7.789576, -2.7944536, 0.72271067, -10.511495, - -0.78634536, -10.661714, 2.9376361, 1.9148129, 6.22859, 0.26264945, 8.028384, - 6.8743043, 0.9351067, 7.0690722, 4.2846055, 1.4134506, -0.18144785, 5.2778087, - -1.7140163, 9.217541, 8.602799, -2.6537218, -7.8377395, 1.1244944, 5.4540544, - -0.38506773, 3.9885726, -10.76455, 1.4440702, 9.136163, 6.664117, -5.7046547, - 8.038592, -9.229767, -0.2799413, 3.6064725, 4.187257, 1.0516582, -2.0707326, - -0.7615968, -8.561018, -3.7831352, 10.300297, 5.332594, -6.5880876, -4.2508664, - 1.7985519, 5.7226253, -4.1223383, -9.6697855, 1.4885283, 7.524974, 1.7206005, - 4.890457, 3.7264557, 0.4428284, -9.922455, -4.250455, -6.4410596, -2.107994, - -1.4109765, -6.1325397, 0.32883006, 6.0489736, 7.7257385, -8.281174, 1.0129383, - -10.792166, 8.378851, 10.802716, 9.848448, -9.188757, 1.3151443, 1.9971865, - -2.521849, 4.3268294, -7.775683, -2.2902298, 3.0824065, -7.17559, 9.6100855, - 7.3965735, -10.476525, 5.895973, -3.6974669, -7.6688933, 1.7354839, -7.4045196, - -1.7992063, -4.0394845, 5.2471714, -2.250571, 2.528036, -8.343515, -2.2374575, - -10.019771, 0.73371273, 3.1853926, 2.7994921, 2.6637669, 7.620401, 7.515571, - 0.68636256, 5.834537, 4.650282, -1.0362619, 0.4461701, 3.7870514, -4.1340904, - 7.202998, 9.736904, -3.005512, -8.920467, 1.1228397, 6.2598724, 1.2812365, - 4.5442104, -8.791537, 0.92113096, 8.464749, 8.359035, -4.3923397, 1.2252625, - -10.1986475, -1.4409319, -10.013967, 3.9071581, 1.683064, 4.877419, 1.6570637, - 9.559105, 7.3546534, 0.36635467, 5.220211, 4.6303267, 0.6601065, 0.16149978, - 3.8818731, -3.4438233, 8.42085, 8.659159, -3.0935583, -8.039611, 2.3060374, - 5.134666, 1.0458113, 6.0190983, -9.143728, 0.99048865, 9.210842, 6.670241, - -5.9614363, 0.8747396, 7.078824, 8.067469, -10.314754, 0.45977542, -9.28306, - 9.1838665, 9.318644, 7.189082, -11.092555, 1.0320464, 3.882163, 0.10953151, - 7.9029684, -6.9068265, -1.3526366, 5.3996363, -8.430931, 11.452577, 6.39663, - -11.090514, 4.6662245, -3.1268113, -8.357452, 2.2276728, -10.357126, -0.9291848, - -3.4193344, 3.1289792, -2.5030103, 6.772719, 11.457757, -4.2125936, -6.684548, - -4.7611327, 3.6960156, -2.3030636, -3.0591488, 10.452471, -4.1267314, 5.66614, - 7.501461, 5.072407, 6.636537, 8.990381, -0.2559256, 4.737867, -6.2149944, - 2.535682, -5.5484023, 5.7113924, 3.4742818, 7.9915137, 7.0052586, -7.156467, - 1.4354781, -8.286235, 5.7523417, -2.4175215, 9.678009, 0.05066403, -9.645226, - -2.2658763, -9.518178, 4.493372, 2.3232365, 2.1659086, 0.42507997, 8.360246, - 8.23535, 2.6878164, 5.236947, 3.4924245, -0.6089895, 0.8884741, 4.359464, - -4.6073823, 7.83441, 8.958755, -3.4690795, -9.182282, 1.2478025, 5.6311107, - -1.2408862, 3.6316886, -8.684654, 2.1078515, 7.2813864, 7.9265943, -3.6135032, - 0.4571511, 8.493568, 10.496853, -7.432897, 0.8625995, -9.607528, 7.2899456, - 8.83158, 8.908199, -10.300263, 1.1451302, 3.7871468, -0.97040755, 5.7664757, - -8.9688, -2.146672, 5.9641485, -6.2908535, 10.126465, 6.1553903, -12.066902, - 6.301596, -5.0419583, -8.228695, 2.4879954, -8.918582, -3.7434099, -4.1593685, - 3.7431836, -1.1704745, 0.5524103, 9.109399, 9.571567, -11.209955, 1.2462777, - -9.554555, 9.091726, 11.477966, 7.630937, -10.450911, 1.9205878, 5.358983, - -0.44546837, 6.7611346, -9.74753, -0.5939732, 3.8892255, -6.437991, 10.294727, - 5.6723895, -10.7883, 6.192348, -5.293862, -10.811491, 1.0194173, -7.074576, - -3.192368, -2.5231771, 4.2791643, -0.53309685, 0.501366, 9.636625, 7.710316, - -6.4219728, 1.0975566, -8.218886, 6.9011984, 9.873679, 8.903804, -9.316832, - 1.2404599, 4.9039655, 1.2272617, 4.541515, -5.2753224, -3.2196746, 3.1303136, - -7.285681, 9.041425, 5.6417427, -9.93667, 5.7548947, -5.113397, -8.544622, - 4.182665, -7.7709813, -3.2810235, -3.312072, 3.8900535, -2.0604856, 6.709082, - -8.461194, 1.2666026, 4.8770437, 2.6955879, 3.0340345, -1.1614609, -3.536341, - -7.090382, -5.36146, 9.072544, 6.4554095, -4.4728956, -1.88395, 3.1095037, - 8.782348, -3.316743, -8.65248, 1.6802986, 8.186188, 2.1783829, 4.931278, - 4.158475, 1.4033595, -11.320101, -3.7084908, -6.740436, -2.5555193, -1.0451177, - -6.5569925, 0.82810307, 8.505919, 8.332857, -9.488569, -0.21588463, -8.056692, - 8.493993, 7.6401625, 8.812983, -9.377281, 2.4369764, 3.1766508, 0.6300803, - 5.6666765, -7.913654, -0.42301777, 4.506412, -7.8954244, 10.904591, 5.042256, - -9.626183, 8.347351, -3.605006, -7.923387, 1.1024277, -8.705793, -2.5151258, - -2.5066147, 4.0515003, -2.060757, 6.2635093, 8.286584, -6.0509276, -6.76452, - -3.1158175, 1.6578803, -1.4608748, -1.24211, 8.151246, -4.2970877, 6.093071, - 7.4911637, 4.51018, 4.8425875, 9.211085, -2.4386222, 4.5830803, -5.6079445, - 2.3713675, -4.0707507, 3.1787417, 5.462342, 6.915912, 6.3928423, -7.2970796, - 5.0112796, -9.140893, 4.9990606, 0.38391754, 7.7088532, 1.9340848, 8.18833, - 8.16617, -9.42086, -0.3388326, -9.659727, 8.243045, 8.099073, 8.439428, - -7.038694, 2.1077902, 3.3866816, -1.9975324, 7.4972878, -7.2525196, -1.553731, - 4.08758, -6.6922374, 9.50525, 4.026735, -9.243538, 7.2740564, -3.9319072, - -6.3228955, 1.6693478, -7.923119, -3.7423058, -2.2813146, 5.3469067, -1.8285407, - 3.3118162, 8.826356, -4.4641976, -6.4751124, -9.200089, -2.519147, 4.225298, - 2.4105988, -0.4344186, 0.53441775, 5.2836394, -8.2816105, -4.996147, -1.6870759, - -7.8543897, -3.9788852, -7.0346904, -3.1289773, 7.4567637, -5.6227813, 1.0709786, - -8.866012, 8.427324, -1.1755563, -5.789216, -8.197835, 5.3342214, 6.0646234, - -6.8975716, 7.717031, 3.480355, 8.312151, -3.6645212, -3.0976524, -8.090359, - -1.9176173, 2.4257212, 1.9700835, 0.4098958, 2.1341088, 7.652741, -9.9595585, - -5.989757, 0.10119354, -7.935407, -5.792786, -5.22783, -4.318978, 5.414037, - -6.4621663, 1.670883, -6.9224787, 8.696932, -2.0214002, -6.6681314, -8.326418, - 4.9049683, 5.4442496, -6.403739, 7.5822453, 7.0972915, -9.072851, -0.23897195, - 1.7662339, 5.3096304, 1.983179, -2.222645, -0.34700772, -9.094717, -6.107907, - 9.525174, 8.1550665, -5.6940084, -4.1636486, 1.7360662, 8.528821, -3.7299833, - -9.341266, 2.608542, 9.108706, 0.7978509, 4.2488184, 2.454484, 0.9446999, - -10.106636, -3.8973773, -6.6566644, -4.5647273, -0.99837756, -6.568582, 9.324853, - -7.9020953, 2.0910501, 2.2896829, 1.6790711, 1.3159255, -3.5258796, 1.8898442, - -8.105812, -4.924962, 8.771129, 7.1202874, -5.991957, -3.4106019, 2.4450088, - 7.796387, -3.055946, -7.8971434, 1.9856719, 9.001636, 1.8511922, 3.019749, - 3.1227696, 0.4822102, -10.021213, -3.530504, -6.225959, -3.0029628, -1.7881511, - -7.3879776, 1.3925704, 9.499782, -3.7318087, -3.7074296, -7.7466836, -1.5284524, - 4.0535855, 3.112011, 0.10340207, -0.5429599, 6.67026, -9.155924, -4.924038, - 0.64248866, -10.0103655, -3.2742946, -4.850029, -3.6707063, 8.586258, -5.855605, - 4.906918, -6.7813993, 7.9938135, -2.5473144, -5.688948, -7.822478, 2.1421318, - 4.66659, -9.701272, 9.549149, 0.8998125, -8.651497, -0.56899565, -8.639817, - 2.3088377, 2.1264515, 3.2764478, 2.341989, 8.594338, 8.630639, 2.8440373, - 6.2043204, 4.433932, 0.6320018, -1.8179281, 5.09452, -1.5741565, 8.153934, - 8.744339, -3.6945698, -8.883078, 1.5329908, 5.2745943, 0.44716078, 4.8809066, - -7.9594903, 1.134374, 9.233994, 6.5528665, -4.520542, 9.477355, -8.622195, - -0.23191702, 2.0485356, 3.9379985, 1.5916302, -1.4516805, -0.0843819, -7.8554378, - -5.88308, 7.999766, 6.2572145, -5.585321, -4.0097756, 0.42382592, 6.160884, - -3.631315, -8.333449, 2.770595, 7.8495173, 3.3331623, 4.940415, 3.6207345, - -0.037517, -11.034698, -3.185103, -6.614664, -3.2177854, -2.0792234, -6.8879867, - 7.821685, -8.455084, 1.0784642, 4.0033927, 2.7343264, 2.6052725, -4.1224284, - -0.89305353, -6.8267674, -4.9715133, 8.880253, 5.6994023, -5.9695024, -4.9181266, - 1.3017995, 7.972617, -3.9452884, -10.424556, 2.4504194, 6.21529, 0.93840516, - 4.2070026, 6.159839, 0.91979957, -8.706724, -4.317946, -6.6823545, -3.0388, - -2.464262, -7.3716645, 1.3926703, 6.544412, -5.6251183, -5.122411, -8.622049, - -2.3905911, 3.9138813, 1.9779967, -0.05011125, 0.13310997, 7.229751, -9.742043, - -8.08724, 1.2426697, -7.9230795, -3.3162494, -7.129571, -3.5488048, 7.4701195, - -5.2357526, 0.5917681, -6.272206, 6.342328, -2.909731, -4.991607, -8.845513, - 3.3228495, 7.033246, -7.8180246, 8.214469, 6.3910093, 9.185153, -6.20472, - -7.713809, -3.8481297, 3.5579286, 0.7078448, -3.2893546, 7.384514, -4.448121, - 3.0104196, 9.492943, 8.024847, 4.9114385, 9.965594, -3.014036, 5.182494, - -5.8806014, 2.5312455, -5.9926524, 4.474469, 6.3717875, 6.993105, 6.493093, - -8.935534, 3.004074, -8.055647, 8.315765, -1.3026813, 8.250377, 0.02606229, - 6.8508425, 9.655665, -7.0116496, -0.41060972, -10.049198, 7.897801, 6.7791023, - 8.3362, -9.821014, 2.491157, 3.5160472, -1.6228812, 7.398063, -8.769123, - -3.1743705, 3.2827861, -6.497855, 10.831924, 5.2761307, -9.704417, 4.3817043, - -3.9841619, -8.111647, 1.1883026, -8.115312, -2.9240117, -5.8879666, 4.20928, - -0.3587938, 6.935672, -10.177582, 0.48819053, 3.1250648, 2.9306343, 3.082544, - -3.477687, -1.3768549, -7.4922366, -3.756631, 10.039836, 3.6670392, -5.9761434, - -4.4728765, 3.244255, 7.027899, -2.3806512, -10.4100685, 1.605716, 7.7953773, - 0.5408159, 1.7156523, 3.824097, -1.0604783, -10.142124, -5.246805, -6.5283823, - -4.579547, -2.42714, -6.709197, 2.7782338, 7.33353, -6.454507, -2.9929368, - -7.8362985, -2.695445, 2.4900775, 1.6682367, 0.4641757, -1.0495365, 6.9631333, - -9.291356, -8.23837, -0.34263706, -8.275113, -2.8454232, -5.0864096, -2.681942, - 7.5450225, -6.2517986, 0.06810654, -6.470652, 4.9042645, -1.8369255, -6.6937943, - -7.9625087, 2.8510258, 6.180508, -8.282598, 7.919079, 1.4897474, 6.7217417, - -4.2459426, -4.114431, -8.375707, -2.143264, 5.6972933, 1.5574739, 0.39375135, - 1.7930849, 5.1737595, -7.826241, -5.160268, -0.80433255, -7.839536, -5.2620406, - -5.4643164, -3.185536, 6.620315, -7.065227, 1.0524757, -6.125088, 5.7126627, - -1.6161644, -3.852159, -9.164279, 2.7005782, 5.946544, -8.468236, 8.2145405, - 1.1035942, 6.590157, -4.0461283, -4.8090615, -7.6702685, -2.1121511, 5.1147075, - 1.6128504, 2.0064135, 1.0544407, 6.0038295, -7.8282537, -4.801278, 0.32349443, - -8.0649805, -4.372714, -5.61336, -5.21394, 8.176595, -5.4753284, 1.7800134, - -8.267283, 7.2133374, -0.16594432, -6.317046, -9.490406, 4.1261597, 5.473317, - -7.7551675, 7.007468, 7.478628, -8.801905, 0.10975724, 3.5478222, 4.797803, - 1.3825226, -3.357369, 0.99262005, -6.94877, -5.4781394, 9.632604, 5.7492557, - -5.9014316, -3.1632116, 2.340859, 8.708098, -3.1255999, -8.848661, 4.5612836, - 8.455157, 0.73460823, 4.112301, 4.392744, -0.30759293, -6.8036823, -3.0331545, - -8.269506, -2.82415, -0.9411246, -5.993506, 2.1618164, -8.716055, -0.7432543, - -10.255819, 3.095418, 2.5131428, 4.752442, 0.9907621, 7.8279433, 7.85814, - 0.50430876, 5.2840405, 4.457291, 0.03330028, -0.40692952, 3.9244103, -2.117118, - 7.6977615, 8.759009, -4.2157164, -9.136053, 3.247858, 4.668686, 0.76162136, - 5.3833632, -9.231471, 0.44309422, 8.380872, 6.7211227, -3.091507, 2.173508, - -9.038242, -1.3666698, -9.819077, 0.37825826, 2.3898845, 4.2440815, 1.9161536, - 7.24787, 6.9124637, 1.6238527, 5.1140285, 3.1935842, 1.02845, -1.1273454, - 5.638998, -2.497932, 8.342559, 8.586319, -2.9069402, -7.6387944, 3.5975037, - 4.4115705, 0.41506064, 4.9078383, -9.68327, 1.8159529, 9.744613, 8.40622, - -4.495336, 9.244892, -8.789869, 1.3158468, 4.018167, 3.3922846, 2.652022, - -2.7495477, 0.2528986, -8.268324, -6.004913, 10.428784, 6.6580734, -5.537176, - -1.7177434, 2.7504628, 6.7735, -2.4454272, -9.998361, 2.9483433, 6.8266654, - 2.3787718, 4.472637, 2.5871701, 0.7355365, -7.7027745, -4.1879907, -7.172832, - -4.1843605, -0.03646783, -5.419406, 6.958486, 11.011111, -7.1821184, -7.956423, - -3.408451, 4.6850276, -2.348787, -4.398289, 6.9787564, -3.8324208, 5.967827, - 8.433518, 4.660108, 5.5657144, 9.964243, -1.3515275, 6.404833, -6.4805903, - 2.4379845, -6.0816774, 1.752272, 5.3771873, 6.9613523, 6.9788294, -6.3894596, - 3.7521114, -6.8034263, 6.4458385, -0.7233525, 10.512529, 4.362273, 9.231461, - -6.3382263, -7.659, -3.461823, 4.71463, 0.17817476, -3.685746, 7.2962036, - -4.6489477, 5.218017, 11.546999, 4.7218375, 6.8498397, 9.281103, -3.900459, - 6.844054, -7.0886965, -0.05019227, -8.233724, 5.5808983, 6.374517, 8.321048, - 7.969449, -7.3478637, 1.4917561, -8.003144, 4.780668, -1.1981848, 7.753739, - 2.0260844, -8.880096, -3.4258451, -7.141975, 1.9637157, 1.814725, 5.311151, - 1.4831505, 7.8483663, 7.257948, 1.395786, 6.417756, 5.376912, 0.59505713, - 0.00062552, 3.6634305, -4.159713, 7.3571978, 10.966816, -2.5419605, -8.466229, - 1.904205, 5.6338267, -0.52567476, 5.59736, -8.361799, 0.5009981, 8.460681, - 7.3891273, -3.5272243, 5.0552278, 9.921456, -7.69693, -7.286378, -1.9198836, - 3.1666567, -2.5832257, -2.2445817, 9.888111, -5.076563, 5.677401, 7.497946, - 5.662994, 5.414262, 8.566503, -2.5530663, 7.1032815, -6.0612082, 1.3419591, - -4.9595256, 4.3377542, 4.3790717, 6.793512, 8.383502, -7.1278043, 3.3240774, - -9.379446, 6.838661, -0.81241214, 8.694813, 0.79141915, 7.632467, 8.575382, - -8.533798, 0.28954387, -7.5675836, 5.8653326, 8.97235, 7.1649346, -10.575289, - 0.9359381, 5.02381, -0.5609511, 5.543464, -7.69131, -2.1792977, 2.4729247, - -6.1917787, 10.373678, 7.6549597, -8.809486, 5.5657206, -3.3169382, -8.042887, - 2.0874746, -7.079005, -3.33398, -3.6843317, 4.0172358, -2.0754814, 1.1726758, - 7.4618697, 6.9483604, -8.469206, 0.7401797, -10.318176, 8.384557, 10.5476265, - 9.146971, -9.250223, 0.6290606, 4.4941425, -0.7514017, 7.2271705, -8.309598, - -1.4761636, 4.0140634, -6.021102, 9.132852, 5.6610966, -11.249811, 8.359293, - -1.9445792, -7.7393436, -0.3931331, -8.824441, -2.5995944, -2.5714035, 4.140213, - -3.6863053, 5.517265, 9.020411, -4.9286127, -7.871219, -3.7446704, 2.5179656, - -1.4543481, -2.2703636, 7.010597, -3.6436229, 6.753862, 7.4129915, 7.1406755, - 5.653706, 9.5445175, 0.15698843, 4.761813, -7.698002, 1.6870106, -4.5410123, - 4.171763, 5.3747005, 6.341021, 7.456738, -8.231657, 2.763487, -9.208167, - 6.676799, -1.1957736, 10.062605, 4.0975976, 7.312957, -2.4981596, -2.9658387, - -8.150425, -2.1075552, 2.64375, 1.6636052, 1.1483809, 0.09276015, 5.8556347, - -7.8481026, -5.9913163, -0.02840613, -9.937289, -1.0486673, -5.2340155, -3.83912, - 7.7165728, -8.409944, 0.80863273, -6.9119215, 7.5712357, 0.36031485, -6.056131, - -8.470033, 1.8678337, 3.0121377, -7.3096333, 8.205484, 5.262654, 8.774514, - -4.7603083, -7.2096143, -4.437014, 3.6080024, -1.624254, -4.2787876, 8.880863, - -4.8984556, 5.1782074, 9.944454, 3.911282, 3.5396595, 8.867042, -1.2006199, - 5.393288, -5.6455317, 0.7829499, -4.0338907, 2.479272, 6.5080743, 8.582535, - 7.0097537, -6.9823785, 3.984318, -7.225381, 5.3135114, -1.0391048, 8.951443, - -0.70119005, -8.510742, -0.42949116, -10.9224825, 2.8176029, 1.6800792, 5.778404, - 1.7269998, 7.1975236, 7.7258267, 2.7632928, 5.3399253, 3.4650044, 0.01971426, - -1.6468811, 4.114996, -1.5110453, 6.8689218, 8.269899, -3.1568048, -7.0344677, - 1.2911975, 5.950357, 0.19028673, 4.657226, -8.199647, 2.246055, 8.989509, - 5.3101015, -4.2400866}; - - std::vector X_embedded = { - -0.41849962, -0.53906363, 0.46958843, -0.35832694, -0.23779503, -0.29751351, -0.01072748, - -0.21353109, -0.54769957, -0.55086273, 0.37093949, -0.12714292, -0.06639574, -0.36098689, - -0.13060696, -0.07362658, -1.01205945, -0.39285606, 0.2864089, -0.32031146, -0.19595343, - 0.08900568, -0.04813879, -0.06563424, -0.42655188, -0.69014251, 0.51459783, -0.1942696, - -0.07767916, -0.6119386, 0.04813685, -0.22557008, -0.56890118, -0.60293794, 0.43429622, - -0.09240723, -0.00624062, -0.25800395, -0.1886092, 0.01655941, -0.01961523, -0.14147359, - 0.41414487, -0.8512944, -0.61199242, -0.18586016, 0.14024924, -0.41635606, -0.02890144, - 0.1065347, 0.39700791, -1.14060664, -0.95313865, 0.14416681, 0.17306046, -0.53189689, - -0.98987544, -0.67918193, 0.41787854, -0.20878236, -0.06612862, 0.03502904, -0.03765266, - -0.0980606, -0.00971657, 0.29432917, 0.36575687, -1.1645509, -0.89094597, 0.03718805, - 0.2310573, -0.38345811, -0.10401925, -0.10653082, 0.38469055, -0.88302094, -0.80197543, - 0.03548668, 0.02775662, -0.54374295, 0.03379983, 0.00923623, 0.29320273, -1.05263519, - -0.93360096, 0.03778313, 0.12360487, -0.56437284, 0.0644429, 0.33432651, 0.36450726, - -1.22978747, -0.83822101, -0.18796451, 0.34888434, -0.3801491, -0.45327303, -0.59747899, - 0.39697698, -0.15616602, -0.06159166, -0.40301991, -0.11725303, -0.11913263, -0.12406619, - -0.11227967, 0.43083835, -0.90535849, -0.81646025, 0.10012121, -0.0141237, -0.63747931, - 0.04805023, 0.34190539, 0.50725192, -1.17861414, -0.74641538, -0.09333111, 0.27992678, - -0.56214809, 0.04970971, 0.36249384, 0.57705611, -1.16913795, -0.69849908, 0.10957897, - 0.27983218, -0.62088525, 0.0410459, 0.23973398, 0.40960434, -1.14183664, -0.83321381, - 0.02149482, 0.21720445, -0.49869928, -0.95655465, -0.51680422, 0.45761383, -0.08351214, - -0.12151554, 0.00819737, -0.20813803, -0.01055793, 0.25319234, 0.36154974, 0.1822421, - -1.15837133, -0.92209691, -0.0501582, 0.08535917, -0.54003763, -1.08675635, -1.04009593, - 0.09408128, 0.07009826, -0.01762833, -0.19180447, -0.18029785, -0.20342001, 0.04034991, - 0.1814747, 0.36906669, -1.13532007, -0.8852452, 0.0782818, 0.16825101, -0.50301319, - -0.29128098, -0.65341312, 0.51484352, -0.38758236, -0.22531103, -0.55021971, 0.10804344, - -0.3521522, -0.38849035, -0.74110794, 0.53761131, -0.25142813, -0.1118066, -0.47453368, - 0.06347904, -0.23796193, -1.02682328, -0.47594091, 0.39515916, -0.2782529, -0.16566519, - 0.08063579, 0.00810116, -0.06213913, -1.059654, -0.62496334, 0.53698546, -0.11806234, - 0.00356161, 0.11513405, -0.14213292, 0.04102662, -0.36622161, -0.73686272, 0.48323864, - -0.27338892, -0.14203401, -0.41736352, 0.03332564, -0.21907479, -0.06396769, 0.01831361, - 0.46263444, -1.01878166, -0.86486858, 0.17622118, -0.01249686, -0.74530888, -0.9354887, - -0.5027945, 0.38170099, -0.15547098, 0.00677824, -0.04677663, -0.13541745, 0.07253501, - -0.97933143, -0.58001202, 0.48235369, -0.18836913, -0.02430783, 0.07572441, -0.08101331, - 0.00630076, -0.16881248, -0.67989182, 0.46083611, -0.43910736, -0.29321918, -0.38735861, - 0.07669903, -0.29749861, -0.40047669, -0.56722462, 0.33168188, -0.13118173, -0.06672747, - -0.56856316, -0.26269144, -0.14236671, 0.10651901, 0.4962585, 0.38848072, -1.06653547, - -0.64079332, -0.47378591, 0.43195483, -0.04856951, -0.9840439, -0.70610428, 0.34028092, - -0.2089237, -0.05382041, 0.01625874, -0.02080803, -0.12535211, -0.04146428, -1.24533033, - 0.48944879, 0.0578458, 0.26708388, -0.90321028, 0.35377088, -0.36791429, -0.35382384, - -0.52748734, 0.42854419, -0.31744713, -0.19174226, -0.39073724, -0.03258846, -0.19978228, - -0.36185205, -0.57412046, 0.43681973, -0.25414538, -0.12904905, -0.46334973, -0.03123853, - -0.11303604, -0.87073672, -0.45441297, 0.41825858, -0.25303507, -0.21845073, 0.10248682, - -0.11045569, -0.10002795, -0.00572806, 0.16519061, 0.42651513, -1.11417019, -0.83789682, - 0.02995787, 0.16843079, -0.53874511, 0.03056994, 0.17877036, 0.49632853, -1.03276777, - -0.74778616, -0.03971953, 0.10907949, -0.67385727, -0.9523471, -0.56550741, 0.40409449, - -0.2703723, -0.10175014, 0.13605487, -0.06306008, -0.01768126, -0.4749442, -0.56964815, - 0.39389887, -0.19248079, -0.04161081, -0.38728487, -0.20341556, -0.12656988, -0.35949609, - -0.46137866, 0.28798422, -0.06603147, -0.04363992, -0.60343552, -0.23565227, -0.10242701, - -0.06792886, 0.09689897, 0.33259571, -0.98854214, -0.84444433, 0.00673901, 0.13457057, - -0.43145794, -0.51500046, -0.50821936, 0.38000089, 0.0132636, 0.0580942, -0.40157595, - -0.11967677, 0.02549113, -0.10350953, 0.22918226, 0.40411913, -1.05619383, -0.71218503, - -0.02197581, 0.26422262, -0.34765676, 0.06601537, 0.21712676, 0.34723559, -1.20982027, - -0.95646334, 0.00793948, 0.27620381, -0.43475035, -0.67326003, -0.6137197, 0.43724492, - -0.17666136, -0.06591748, -0.18937394, -0.07400128, -0.06881691, -0.5201112, -0.61088628, - 0.4225319, -0.18969463, -0.06921366, -0.33993208, -0.06990873, -0.10288513, -0.70659858, - -0.56003648, 0.46628812, -0.16090363, -0.0185108, -0.1431348, -0.1128775, -0.0078648, - -0.02323332, 0.04292452, 0.39291084, -0.94897962, -0.63863206, -0.16546988, 0.23698957, - -0.30633628}; - - auto stream = resource::get_cuda_stream(handle); - - d_X.resize(X.size(), stream); - d_X_embedded.resize(X_embedded.size(), stream); - raft::update_device(d_X.data(), X.data(), X.size(), stream); - raft::update_device(d_X_embedded.data(), X_embedded.data(), X_embedded.size(), stream); - auto n_sample = 50; - auto n_features_origin = 30; - auto n_features_embedded = 8; - - // euclidean test - score = trustworthiness_score( - handle, - raft::make_device_matrix_view(d_X.data(), n_sample, n_features_origin), - raft::make_device_matrix_view( - d_X_embedded.data(), n_sample, n_features_embedded), - 5); - } - - void SetUp() override { basicTest(); } - - void TearDown() override {} - - protected: - raft::resources handle; - - rmm::device_uvector d_X; - rmm::device_uvector d_X_embedded; - - double score; -}; - -typedef TrustworthinessScoreTest TrustworthinessScoreTestF; -TEST_F(TrustworthinessScoreTestF, Result) { ASSERT_TRUE(0.9375 < score && score < 0.9379); } -}; // namespace stats -}; // namespace raft